[docs]defdeep_intepretation(source_adata:AnnData,target_points:Optional[ndarray]=None,keys:Union[str,list]=None,spatial_key:str="spatial",layer:str="X",max_iter:int=1000,data_batch_size:int=2000,autoencoder_batch_size:int=50,data_lr:float=1e-4,autoencoder_lr:float=1e-4,**kwargs,)->AnnData:"""Learn a continuous mapping from space to gene expression pattern with the deep neural net model. Args: source_adata: AnnData object that contains spatial (numpy.ndarray) in the `obsm` attribute. target_points: The spatial coordinates of new data point. If target_coords is None, generate new points based on grid_num. keys: Gene list or info list in the `obs` attribute whose interpolate expression across space needs to learned. spatial_key: The key in ``.obsm`` that corresponds to the spatial coordinate of each bucket. layer: If ``'X'``, uses ``.X``, otherwise uses the representation given by ``.layers[layer]``. max_iter: The maximum iteration the network will be trained. data_batch_size: The size of the data sample batches to be generated in each iteration. autoencoder_batch_size: The size of the auto-encoder training batches to be generated in each iteration. Must be no greater than batch_size. . data_lr: The learning rate for network training. autoencoder_lr: The learning rate for network training the auto-encoder. Will have no effect if network_dim equal data_dim. **kwargs: Additional parameters that will be passed to the training step of the deep neural net. Returns: interp_adata: an anndata object that has interpolated expression. """# Inferencesource_adata=source_adata.copy()source_adata.X=source_adata.Xiflayer=="X"elsesource_adata.layers[layer]source_spatial_data=source_adata.obsm[spatial_key]info_data=np.ones(shape=(source_spatial_data.shape[0],1))assertkeys!=None,"`keys` cannot be None."keys=[keys]ifisinstance(keys,str)elsekeysobs_keys=[keyforkeyinkeysifkeyinsource_adata.obs.keys()]iflen(obs_keys)!=0:obs_data=np.asarray(source_adata.obs[obs_keys].values)info_data=np.c_[info_data,obs_data]var_keys=[keyforkeyinkeysifkeyinsource_adata.var_names.tolist()]iflen(var_keys)!=0:var_data=source_adata[:,var_keys].Xifissparse(var_data):var_data=var_data.toarray()info_data=np.c_[info_data,var_data]info_data=info_data[:,1:]data_dict={"X":source_spatial_data,"Y":info_data}velocity_data_sampler=DataSampler(data=data_dict,normalize_data=False)NN_model=DeepInterpolation(model=interpolation_nn,data_sampler=velocity_data_sampler,enforce_positivity=False,)NN_model.train(max_iter=max_iter,data_batch_size=data_batch_size,autoencoder_batch_size=autoencoder_batch_size,data_lr=data_lr,autoencoder_lr=autoencoder_lr,**kwargs,)# Interpolationtarget_info_data=NN_model.predict(input_x=target_points)lm.main_info("Creating an adata object with the interpolated expression...")iflen(obs_keys)!=0:obs_data=target_info_data[:,:len(obs_keys)]obs_data=pd.DataFrame(obs_data,columns=obs_keys)iflen(var_keys)!=0:X=target_info_data[:,len(obs_keys):]var_data=pd.DataFrame(index=var_keys)interp_adata=AnnData(X=Xiflen(var_keys)!=0elseNone,obs=obs_dataiflen(obs_keys)!=0elseNone,obsm={spatial_key:np.asarray(target_points)},var=var_dataiflen(var_keys)!=0elseNone,)lm.main_finish_progress(progress_name="DeepLearnInterpolation")returninterp_adata