spateo.external.STAGATE_pyG.Train_STAGATE ========================================= .. py:module:: spateo.external.STAGATE_pyG.Train_STAGATE Functions --------- .. autoapisummary:: spateo.external.STAGATE_pyG.Train_STAGATE.train_STAGATE Module Contents --------------- .. py:function:: train_STAGATE(adata, hidden_dims=[512, 30], n_epochs=1000, lr=0.001, key_added='STAGATE', gradient_clipping=5.0, weight_decay=0.0001, verbose=True, random_seed=0, save_loss=False, save_reconstrction=False, device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')) Training graph attention auto-encoder. :param adata: AnnData object of scanpy package. :param hidden_dims: The dimension of the encoder. :param n_epochs: Number of total epochs in training. :param lr: Learning rate for AdamOptimizer. :param key_added: The latent embeddings are saved in adata.obsm[key_added]. :param gradient_clipping: Gradient Clipping. :param weight_decay: Weight decay for AdamOptimizer. :param save_loss: If True, the training loss is saved in adata.uns['STAGATE_loss']. :param save_reconstrction: If True, the reconstructed expression profiles are saved in adata.layers['STAGATE_ReX']. :param device: See torch.device. :rtype: AnnData