Source code for spateo.tdr.interpolations.interpolation_dl

from typing import Optional, Union

import numpy as np
import pandas as pd
from anndata import AnnData
from numpy import ndarray
from scipy.sparse import issparse

from ...logging import logger_manager as lm
from .interpolation_deeplearn import DataSampler, DeepInterpolation, interpolation_nn


[docs]def deep_intepretation( source_adata: AnnData, target_points: Optional[ndarray] = None, keys: Union[str, list] = None, spatial_key: str = "spatial", layer: str = "X", max_iter: int = 1000, data_batch_size: int = 2000, autoencoder_batch_size: int = 50, data_lr: float = 1e-4, autoencoder_lr: float = 1e-4, **kwargs, ) -> AnnData: """Learn a continuous mapping from space to gene expression pattern with the deep neural net model. Args: source_adata: AnnData object that contains spatial (numpy.ndarray) in the `obsm` attribute. target_points: The spatial coordinates of new data point. If target_coords is None, generate new points based on grid_num. keys: Gene list or info list in the `obs` attribute whose interpolate expression across space needs to learned. spatial_key: The key in ``.obsm`` that corresponds to the spatial coordinate of each bucket. layer: If ``'X'``, uses ``.X``, otherwise uses the representation given by ``.layers[layer]``. max_iter: The maximum iteration the network will be trained. data_batch_size: The size of the data sample batches to be generated in each iteration. autoencoder_batch_size: The size of the auto-encoder training batches to be generated in each iteration. Must be no greater than batch_size. . data_lr: The learning rate for network training. autoencoder_lr: The learning rate for network training the auto-encoder. Will have no effect if network_dim equal data_dim. **kwargs: Additional parameters that will be passed to the training step of the deep neural net. Returns: interp_adata: an anndata object that has interpolated expression. """ # Inference source_adata = source_adata.copy() source_adata.X = source_adata.X if layer == "X" else source_adata.layers[layer] source_spatial_data = source_adata.obsm[spatial_key] info_data = np.ones(shape=(source_spatial_data.shape[0], 1)) assert keys != None, "`keys` cannot be None." keys = [keys] if isinstance(keys, str) else keys obs_keys = [key for key in keys if key in source_adata.obs.keys()] if len(obs_keys) != 0: obs_data = np.asarray(source_adata.obs[obs_keys].values) info_data = np.c_[info_data, obs_data] var_keys = [key for key in keys if key in source_adata.var_names.tolist()] if len(var_keys) != 0: var_data = source_adata[:, var_keys].X if issparse(var_data): var_data = var_data.toarray() info_data = np.c_[info_data, var_data] info_data = info_data[:, 1:] data_dict = {"X": source_spatial_data, "Y": info_data} velocity_data_sampler = DataSampler(data=data_dict, normalize_data=False) NN_model = DeepInterpolation( model=interpolation_nn, data_sampler=velocity_data_sampler, enforce_positivity=False, ) NN_model.train( max_iter=max_iter, data_batch_size=data_batch_size, autoencoder_batch_size=autoencoder_batch_size, data_lr=data_lr, autoencoder_lr=autoencoder_lr, **kwargs, ) # Interpolation target_info_data = NN_model.predict(input_x=target_points) lm.main_info("Creating an adata object with the interpolated expression...") if len(obs_keys) != 0: obs_data = target_info_data[:, : len(obs_keys)] obs_data = pd.DataFrame(obs_data, columns=obs_keys) if len(var_keys) != 0: X = target_info_data[:, len(obs_keys) :] var_data = pd.DataFrame(index=var_keys) interp_adata = AnnData( X=X if len(var_keys) != 0 else None, obs=obs_data if len(obs_keys) != 0 else None, obsm={spatial_key: np.asarray(target_points)}, var=var_data if len(var_keys) != 0 else None, ) lm.main_finish_progress(progress_name="DeepLearnInterpolation") return interp_adata