spateo.external.STAGATE_pyG.Train_STAGATE

Functions

train_STAGATE(adata[, hidden_dims, n_epochs, lr, ...])

Training graph attention auto-encoder.

Module Contents

spateo.external.STAGATE_pyG.Train_STAGATE.train_STAGATE(adata, hidden_dims=[512, 30], n_epochs=1000, lr=0.001, key_added='STAGATE', gradient_clipping=5.0, weight_decay=0.0001, verbose=True, random_seed=0, save_loss=False, save_reconstrction=False, device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'))[source]

Training graph attention auto-encoder.

Parameters:
adata

AnnData object of scanpy package.

hidden_dims

The dimension of the encoder.

n_epochs

Number of total epochs in training.

lr

Learning rate for AdamOptimizer.

key_added

The latent embeddings are saved in adata.obsm[key_added].

gradient_clipping

Gradient Clipping.

weight_decay

Weight decay for AdamOptimizer.

save_loss

If True, the training loss is saved in adata.uns[‘STAGATE_loss’].

save_reconstrction

If True, the reconstructed expression profiles are saved in adata.layers[‘STAGATE_ReX’].

device

See torch.device.

Return type:

AnnData