Source code for spateo.alignment.paste_alignment

from typing import List, Optional, Tuple, Union

import numpy as np
from anndata import AnnData

from spateo.configuration import SKM

from .methods import generalized_procrustes_analysis, paste_pairwise_align
from .transform import paste_transform
from .utils import _iteration, downsampling


@SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE, "models")
[docs]def paste_align( models: List[AnnData], layer: str = "X", genes: Optional[Union[list, np.ndarray]] = None, spatial_key: str = "spatial", key_added: str = "align_spatial", mapping_key_added: str = "models_align", alpha: float = 0.1, numItermax: int = 200, numItermaxEmd: int = 100000, dtype: str = "float64", device: str = "cpu", verbose: bool = True, **kwargs, ) -> Tuple[List[AnnData], List[Union[np.ndarray, np.ndarray]]]: """ Align spatial coordinates of models. Args: models: List of models (AnnData Object). layer: If ``'X'``, uses ``.X`` to calculate dissimilarity between spots, otherwise uses the representation given by ``.layers[layer]``. genes: Genes used for calculation. If None, use all common genes for calculation. spatial_key: The key in ``.obsm`` that corresponds to the raw spatial coordinate. key_added: ``.obsm`` key under which to add the aligned spatial coordinates. mapping_key_added: `.uns` key under which to add the alignment info. alpha: Alignment tuning parameter. Note: 0 <= alpha <= 1. When ``alpha = 0`` only the gene expression data is taken into account, while when ``alpha =1`` only the spatial coordinates are taken into account. numItermax: Max number of iterations for cg during FGW-OT. numItermaxEmd: Max number of iterations for emd during FGW-OT. dtype: The floating-point number type. Only ``float32`` and ``float64``. device: Equipment used to run the program. You can also set the specified GPU for running. ``E.g.: '0'``. verbose: If ``True``, print progress updates. **kwargs: Additional parameters that will be passed to ``pairwise_align`` function. Returns: align_models: List of models (AnnData Object) after alignment. pis: List of pi matrices. """ for m in models: m.obsm[key_added] = m.obsm[spatial_key] pis = [] align_models = [model.copy() for model in models] for i in _iteration(n=len(align_models) - 1, progress_name="Models alignment", verbose=verbose): modelA = align_models[i] modelB = align_models[i + 1] # Calculate and returns optimal alignment of two models. pi, _ = paste_pairwise_align( sampleA=modelA.copy(), sampleB=modelB.copy(), layer=layer, genes=genes, spatial_key=key_added, alpha=alpha, numItermax=numItermax, numItermaxEmd=numItermaxEmd, dtype=dtype, device=device, verbose=verbose, **kwargs, ) pis.append(pi) # Calculate new coordinates of two models modelA_coords, modelB_coords, mapping_dict = generalized_procrustes_analysis( X=modelA.obsm[key_added], Y=modelB.obsm[key_added], pi=pi ) if i == 0: modelA.obsm[key_added] = modelA_coords modelA.uns[mapping_key_added] = mapping_dict modelB.obsm[key_added] = modelB_coords modelB.uns[mapping_key_added] = mapping_dict return align_models, pis
@SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE, "models") @SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE, "models_ref", optional=True)
[docs]def paste_align_ref( models: List[AnnData], models_ref: Optional[List[AnnData]] = None, n_sampling: Optional[int] = 2000, sampling_method: str = "trn", layer: str = "X", genes: Optional[Union[list, np.ndarray]] = None, spatial_key: str = "spatial", key_added: str = "align_spatial", mapping_key_added: str = "models_align", alpha: float = 0.1, numItermax: int = 200, numItermaxEmd: int = 100000, dtype: str = "float64", device: str = "cpu", verbose: bool = True, **kwargs, ) -> Tuple[List[AnnData], List[AnnData], List[Union[np.ndarray, np.ndarray]]]: """ Align the spatial coordinates of one model list through the affine transformation matrix obtained from another model list. Args: models: List of models (AnnData Object). models_ref: Another list of models (AnnData Object). n_sampling: When ``models_ref`` is None, new data containing n_sampling coordinate points will be automatically generated for alignment. sampling_method: The method to sample data points, can be one of ``["trn", "kmeans", "random"]``. layer: If ``'X'``, uses ``.X`` to calculate dissimilarity between spots, otherwise uses the representation given by ``.layers[layer]``. genes: Genes used for calculation. If None, use all common genes for calculation. spatial_key: The key in ``.obsm`` that corresponds to the raw spatial coordinates. key_added: ``.obsm`` key under which to add the aligned spatial coordinates. mapping_key_added: `.uns` key under which to add the alignment info. alpha: Alignment tuning parameter. Note: 0 <= alpha <= 1. When ``alpha = 0`` only the gene expression data is taken into account, while when ``alpha =1`` only the spatial coordinates are taken into account. numItermax: Max number of iterations for cg during FGW-OT. numItermaxEmd: Max number of iterations for emd during FGW-OT. dtype: The floating-point number type. Only ``float32`` and ``float64``. device: Equipment used to run the program. You can also set the specified GPU for running. ``E.g.: '0'``. verbose: If ``True``, print progress updates. **kwargs: Additional parameters that will be passed to ``models_align`` function. Returns: align_models: List of models (AnnData Object) after alignment. align_models_ref: List of models_ref (AnnData Object) after alignment. pis: The list of pi matrices from align_models_ref. """ for m in models: m.obsm[key_added] = m.obsm[spatial_key] # Downsampling if models_ref is None: models_sampling = [model.copy() for model in models] models_ref = downsampling( models=models_sampling, n_sampling=n_sampling, sampling_method=sampling_method, spatial_key=spatial_key ) # Align spatial coordinates of slices with a small number of coordinates. align_models_ref, pis = paste_align( models=models_ref, layer=layer, genes=genes, spatial_key=spatial_key, key_added=key_added, mapping_key_added=mapping_key_added, alpha=alpha, numItermax=numItermax, numItermaxEmd=numItermaxEmd, dtype=dtype, device=device, verbose=verbose, **kwargs, ) align_models = [] for i, (align_model_ref, model) in enumerate(zip(align_models_ref, models)): align_model = model.copy() if i == 0: tX = align_model_ref.uns[mapping_key_added]["tX"] align_model.obsm[key_added] = align_model.obsm[spatial_key] - tX else: align_model = paste_transform( adata=align_model, adata_ref=align_model_ref, spatial_key=spatial_key, key_added=key_added, mapping_key=mapping_key_added, ) align_model.uns[mapping_key_added] = align_model_ref.uns[mapping_key_added] align_models.append(align_model) return align_models, align_models_ref, pis