spateo.external.STAGATE_pyG.Train_STAGATE¶
Functions¶
|
Training graph attention auto-encoder. |
Module Contents¶
- spateo.external.STAGATE_pyG.Train_STAGATE.train_STAGATE(adata, hidden_dims=[512, 30], n_epochs=1000, lr=0.001, key_added='STAGATE', gradient_clipping=5.0, weight_decay=0.0001, verbose=True, random_seed=0, save_loss=False, save_reconstrction=False, device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'))[source]¶
Training graph attention auto-encoder.
- Parameters:
- adata
AnnData object of scanpy package.
- hidden_dims
The dimension of the encoder.
- n_epochs
Number of total epochs in training.
- lr
Learning rate for AdamOptimizer.
- key_added
The latent embeddings are saved in adata.obsm[key_added].
- gradient_clipping
Gradient Clipping.
- weight_decay
Weight decay for AdamOptimizer.
- save_loss
If True, the training loss is saved in adata.uns[‘STAGATE_loss’].
- save_reconstrction
If True, the reconstructed expression profiles are saved in adata.layers[‘STAGATE_ReX’].
- device
See torch.device.
- Return type:
AnnData