Source code for spateo.tools.cluster._stagate

"""Module providing a encapsulation of pySTAGATE."""

import torch
import torch.nn.functional as F
from anndata import AnnData
from scipy.sparse import issparse
from tqdm import tqdm


[docs]class pySTAGATE: """Class representing the object of pySTAGATE.""" def __init__( self, adata: AnnData, num_batch_x, num_batch_y, basis="spatial", spatial_key: list = ["X", "Y"], batch_size: int = 1, rad_cutoff: int = 200, num_epoch: int = 1000, lr: float = 0.001, weight_decay: float = 1e-4, hidden_dims: list = [512, 30], device: str = "cuda:0", ) -> None: """ Initialize the pySTAGATE object. Args: adata: an Anndata object, after normalization. num_batch_x: Number of batches in the x direction. num_batch_y: Number of batches in the y direction. basis: The basis for stored in adata.obsm. Default is 'spatial'. spatial_key: The spatial key for stored in adata.obs. Default is ['X','Y']. batch_size: The batch size for training. Default is 1. rad_cutoff: The radius cutoff for the spatial graph. Default is 200. num_epoch: The number of epochs for training. Default is 1000. lr: The learning rate for training. Default is 0.001. weight_decay: The weight decay for training. Default is 1e-4. hidden_dims: The hidden dimensions for the STAGATE model. Default is [512, 30]. device: The device for training. Default is 'cuda:0'. """ # Initialize device from ...external.STAGATE_pyG import ( STAGATE, Batch_Data, Cal_Spatial_Net, Stats_Spatial_Net, Transfer_pytorch_Data, ) if issparse(adata.obsm[basis]): adata.obsm[basis] = adata.obsm[basis].toarray() adata.obs["X"] = adata.obsm[basis][:, 0] adata.obs["Y"] = adata.obsm[basis][:, 1] device = torch.device(device if torch.cuda.is_available() else "cpu")
[docs] self.device = device
# Create batches batch_list = Batch_Data( adata, num_batch_x=num_batch_x, num_batch_y=num_batch_y, spatial_key=spatial_key, plot_Stats=True ) for temp_adata in batch_list: Cal_Spatial_Net(temp_adata, rad_cutoff=rad_cutoff) # Transfer to PyTorch data format # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') data_list = [Transfer_pytorch_Data(adata) for adata in batch_list] for temp in data_list: temp.to(device) Cal_Spatial_Net(adata, rad_cutoff=rad_cutoff) data = Transfer_pytorch_Data(adata) Stats_Spatial_Net(adata) # batch_size=1 or 2 from torch_geometric.loader import DataLoader
[docs] self.loader = DataLoader(data_list, batch_size=batch_size, shuffle=True)
# hyper-parameters
[docs] self.num_epoch = num_epoch
[docs] self.lr = lr
[docs] self.weight_decay = weight_decay
[docs] self.hidden_dims = hidden_dims
[docs] self.adata = adata
[docs] self.data = data
# Model and optimizer
[docs] self.model = STAGATE(hidden_dims=[data_list[0].x.shape[1]] + self.hidden_dims).to(device)
[docs] self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
[docs] def train(self): """Train the STAGATE model.""" for epoch in tqdm(range(1, self.num_epoch + 1)): for batch in self.loader: self.model.train() self.optimizer.zero_grad() z, out = self.model(batch.x, batch.edge_index) loss = F.mse_loss(batch.x, out) # F.nll_loss(out[data.train_mask], data.y[data.train_mask]) loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0) self.optimizer.step() # The total network self.data.to(self.device)
[docs] def predicted(self): """ Predict the STAGATE representation and ReX values for all cells. """ self.model.eval() z, out = self.model(self.data.x, self.data.edge_index) stagate_rep = z.to("cpu").detach().numpy() self.adata.obsm["STAGATE"] = stagate_rep rex = out.to("cpu").detach().numpy() rex[rex < 0] = 0 self.adata.layers["STAGATE_ReX"] = rex print('The STAGATE representation values are stored in adata.obsm["STAGATE"].') print('The rex values are stored in adata.layers["STAGATE_ReX"].')
[docs] def cal_pSM( self, n_neighbors: int = 20, resolution: int = 1, max_cell_for_subsampling: int = 5000, psm_key="pSM_STAGATE" ): """ Calculate the pseudo-spatial map using diffusion pseudotime (DPT) algorithm. Parameters ---------- n_neighbors: int Number of neighbors for constructing the kNN graph. resolution: float Resolution for clustering. max_cell_for_subsampling: int Maximum number of cells for subsampling. If the number of cells is larger than this value, the subsampling will be performed. Returns ------- pSM_values: numpy.ndarray The pseudo-spatial map values. """ import numpy as np import scanpy as sc from scipy.spatial import distance_matrix sc.pp.neighbors(self.adata, n_neighbors=n_neighbors, use_rep="STAGATE") sc.tl.umap(self.adata) sc.tl.leiden(self.adata, resolution=resolution) sc.tl.paga(self.adata) # max_cell_for_subsampling = max_cell_for_subsampling if self.adata.shape[0] < max_cell_for_subsampling: sub_adata_x = self.adata.obsm["STAGATE"] else: indices = np.arange(self.adata.shape[0]) selected_ind = np.random.choice(indices, max_cell_for_subsampling, False) sub_adata_x = self.adata[selected_ind, :].obsm["STAGATE"] sum_dists = distance_matrix(sub_adata_x, sub_adata_x).sum(axis=1) self.adata.uns["iroot"] = np.argmax(sum_dists) sc.tl.diffmap(self.adata) sc.tl.dpt(self.adata) self.adata.obs.rename({"dpt_pseudotime": psm_key}, axis=1, inplace=True) print(f'The pseudo-spatial map values are stored in adata.obs["{psm_key}"].') psm_values = self.adata.obs[psm_key].to_numpy() return psm_values
# End-of-file (EOF)