[docs]deftrain_STAGATE(adata,hidden_dims=[512,30],n_epochs=1000,lr=0.001,key_added="STAGATE",gradient_clipping=5.0,weight_decay=0.0001,verbose=True,random_seed=0,save_loss=False,save_reconstrction=False,device=torch.device("cuda:0"iftorch.cuda.is_available()else"cpu"),):"""\ Training graph attention auto-encoder. Parameters ---------- adata AnnData object of scanpy package. hidden_dims The dimension of the encoder. n_epochs Number of total epochs in training. lr Learning rate for AdamOptimizer. key_added The latent embeddings are saved in adata.obsm[key_added]. gradient_clipping Gradient Clipping. weight_decay Weight decay for AdamOptimizer. save_loss If True, the training loss is saved in adata.uns['STAGATE_loss']. save_reconstrction If True, the reconstructed expression profiles are saved in adata.layers['STAGATE_ReX']. device See torch.device. Returns ------- AnnData """# seed_everything()seed=random_seedimportrandomrandom.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)np.random.seed(seed)adata.X=sp.csr_matrix(adata.X)if"highly_variable"inadata.var.columns:adata_Vars=adata[:,adata.var["highly_variable"]]else:adata_Vars=adataifverbose:print("Size of Input: ",adata_Vars.shape)if"Spatial_Net"notinadata.uns.keys():raiseValueError("Spatial_Net is not existed! Run Cal_Spatial_Net first!")data=Transfer_pytorch_Data(adata_Vars)model=STAGATE(hidden_dims=[data.x.shape[1]]+hidden_dims).to(device)data=data.to(device)optimizer=torch.optim.Adam(model.parameters(),lr=lr,weight_decay=weight_decay)# loss_list = []forepochintqdm(range(1,n_epochs+1)):model.train()optimizer.zero_grad()z,out=model(data.x,data.edge_index)loss=F.mse_loss(data.x,out)# F.nll_loss(out[data.train_mask], data.y[data.train_mask])# loss_list.append(loss)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(),gradient_clipping)optimizer.step()model.eval()z,out=model(data.x,data.edge_index)STAGATE_rep=z.to("cpu").detach().numpy()adata.obsm[key_added]=STAGATE_repifsave_loss:adata.uns["STAGATE_loss"]=lossifsave_reconstrction:ReX=out.to("cpu").detach().numpy()ReX[ReX<0]=0adata.layers["STAGATE_ReX"]=ReXreturnadata