Source code for spateo.alignment.methods.deprecated_morpho

import random

import numpy as np
import ot
import torch
from anndata import AnnData

try:
    from typing import Any, Dict, List, Literal, Optional, Tuple, Union
except ImportError:
    from typing_extensions import Literal

from typing import List, Optional, Tuple, Union

from spateo.logging import logger_manager as lm

from .utils import (
    _chunk,
    _data,
    _dot,
    _identity,
    _init_guess_beta2,
    _init_guess_sigma2,
    _linalg,
    _mul,
    _pi,
    _pinv,
    _power,
    _prod,
    _psi,
    _randperm,
    _roll,
    _unique,
    _unsqueeze,
    align_preprocess,
    cal_dist,
    calc_exp_dissimilarity,
    coarse_rigid_alignment,
    empty_cache,
    get_optimal_R,
    guidance_pair_preprocess,
)


[docs]def con_K( X: Union[np.ndarray, torch.Tensor], Y: Union[np.ndarray, torch.Tensor], beta: Union[int, float] = 0.01, use_chunk: bool = False, ) -> Union[np.ndarray, torch.Tensor]: """con_K constructs the Squared Exponential (SE) kernel, where K(i,j)=k(X_i,Y_j)=exp(-beta*||X_i-Y_j||^2). Args: X: The first vector X\in\mathbb{R}^{N\times d} Y: The second vector X\in\mathbb{R}^{M\times d} beta: The length-scale of the SE kernel. use_chunk (bool, optional): Whether to use chunk to reduce the GPU memory usage. Note that if set to ``True'' it will slow down the calculation. Defaults to False. Returns: K: The kernel K\in\mathbb{R}^{N\times M} """ assert X.shape[1] == Y.shape[1], "X and Y do not have the same number of features." nx = ot.backend.get_backend(X, Y) K = cal_dist(X, Y) K = nx.exp(-beta * K) return K
############ # BioAlign # ############
[docs]def get_P( XnAHat: Union[np.ndarray, torch.Tensor], XnB: Union[np.ndarray, torch.Tensor], sigma2: Union[int, float, np.ndarray, torch.Tensor], beta2: Union[int, float, np.ndarray, torch.Tensor], alpha: Union[np.ndarray, torch.Tensor], gamma: Union[float, np.ndarray, torch.Tensor], Sigma: Union[np.ndarray, torch.Tensor], GeneDistMat: Union[np.ndarray, torch.Tensor], SpatialDistMat: Union[np.ndarray, torch.Tensor], samples_s: Optional[List[float]] = None, outlier_variance: float = None, ) -> Tuple[Any, Any, Any]: """Calculating the generating probability matrix P. Args: XAHat: Current spatial coordinate of sample A. Shape: N x D. XnB : spatial coordinate of sample B (reference sample). Shape: M x D. sigma2: The spatial coordinate noise. beta2: The gene expression noise. alpha: A vector that encoding each probability generated by the spots of sample A. Shape: N x 1. gamma: Inlier proportion of sample A. Sigma: The posterior covariance matrix of Gaussian process. Shape: N x N or N x 1. GeneDistMat: The gene expression distance matrix between sample A and sample B. Shape: N x M. SpatialDistMat: The spatial coordinate distance matrix between sample A and sample B. Shape: N x M. samples_s: The space size of each sample. Area size for 2D samples and volume size for 3D samples. Returns: P: Generating probability matrix P. Shape: N x M. """ assert XnAHat.shape[1] == XnB.shape[1], "XnAHat and XnB do not have the same number of features." assert XnAHat.shape[0] == alpha.shape[0], "XnAHat and alpha do not have the same length." assert XnAHat.shape[0] == Sigma.shape[0], "XnAHat and Sigma do not have the same length." nx = ot.backend.get_backend(XnAHat, XnB) NA, NB, D = XnAHat.shape[0], XnB.shape[0], XnAHat.shape[1] if samples_s is None: samples_s = nx.maximum( _prod(nx)(nx.max(XnAHat, axis=0) - nx.min(XnAHat, axis=0)), _prod(nx)(nx.max(XnB, axis=0) - nx.min(XnB, axis=0)), ) outlier_s = samples_s * NA if outlier_variance is None: exp_SpatialMat = nx.exp(-SpatialDistMat / (2 * sigma2)) else: exp_SpatialMat = nx.exp(-SpatialDistMat / (2 * sigma2 / outlier_variance)) spatial_term1 = nx.einsum( "ij,i->ij", exp_SpatialMat, (_mul(nx)(alpha, nx.exp(-Sigma / sigma2))), ) spatial_outlier = _power(nx)((2 * _pi(nx) * sigma2), _data(nx, D / 2, XnAHat)) * (1 - gamma) / (gamma * outlier_s) spatial_term2 = spatial_outlier + nx.einsum("ij->j", spatial_term1) spatial_P = spatial_term1 / _unsqueeze(nx)(spatial_term2, 0) spatial_inlier = 1 - spatial_outlier / (spatial_outlier + nx.einsum("ij->j", exp_SpatialMat)) term1 = nx.einsum( "ij,i->ij", _mul(nx)(nx.exp(-SpatialDistMat / (2 * sigma2)), nx.exp(-GeneDistMat / (2 * beta2))), (_mul(nx)(alpha, nx.exp(-Sigma / sigma2))), ) P = term1 / (_unsqueeze(nx)(nx.einsum("ij->j", term1), 0) + 1e-8) P = nx.einsum("j,ij->ij", spatial_inlier, P) term1 = nx.einsum( "ij,i->ij", nx.exp(-SpatialDistMat / (2 * sigma2)), (_mul(nx)(alpha, nx.exp(-Sigma / sigma2))), ) sigma2_P = term1 / (_unsqueeze(nx)(nx.einsum("ij->j", term1), 0) + 1e-8) sigma2_P = nx.einsum("j,ij->ij", spatial_inlier, sigma2_P) return P, spatial_P, sigma2_P
[docs]def get_P_chunk( XnAHat: Union[np.ndarray, torch.Tensor], XnB: Union[np.ndarray, torch.Tensor], X_A: Union[np.ndarray, torch.Tensor], X_B: Union[np.ndarray, torch.Tensor], sigma2: Union[int, float, np.ndarray, torch.Tensor], beta2: Union[int, float, np.ndarray, torch.Tensor], alpha: Union[np.ndarray, torch.Tensor], gamma: Union[float, np.ndarray, torch.Tensor], Sigma: Union[np.ndarray, torch.Tensor], samples_s: Optional[List[float]] = None, outlier_variance: float = None, chunk_size: int = 1000, dissimilarity: str = "kl", ) -> Union[np.ndarray, torch.Tensor]: """Calculating the generating probability matrix P. Args: XAHat: Current spatial coordinate of sample A. Shape """ # Get the number of cells in each sample NA, NB = XnAHat.shape[0], XnB.shape[0] # Get the number of genes G = X_A.shape[1] # Get the number of spatial dimensions D = XnAHat.shape[1] chunk_num = int(np.ceil(NA / chunk_size)) assert XnAHat.shape[1] == XnB.shape[1], "XnAHat and XnB do not have the same number of features." assert XnAHat.shape[0] == alpha.shape[0], "XnAHat and alpha do not have the same length." assert XnAHat.shape[0] == Sigma.shape[0], "XnAHat and Sigma do not have the same length." nx = ot.backend.get_backend(XnAHat, XnB) if samples_s is None: samples_s = nx.maximum( _prod(nx)(nx.max(XnAHat, axis=0) - nx.min(XnAHat, axis=0)), _prod(nx)(nx.max(XnB, axis=0) - nx.min(XnB, axis=0)), ) outlier_s = samples_s * NA # chunk X_Bs = _chunk(nx, X_B, chunk_num, dim=0) XnBs = _chunk(nx, XnB, chunk_num, dim=0) Ps = [] for x_Bs, xnBs in zip(X_Bs, XnBs): SpatialDistMat = cal_dist(XnAHat, xnBs) GeneDistMat = calc_exp_dissimilarity(X_A=X_A, X_B=x_Bs, dissimilarity=dissimilarity) if outlier_variance is None: exp_SpatialMat = nx.exp(-SpatialDistMat / (2 * sigma2)) else: exp_SpatialMat = nx.exp(-SpatialDistMat / (2 * sigma2 / outlier_variance)) spatial_term1 = nx.einsum( "ij,i->ij", exp_SpatialMat, (_mul(nx)(alpha, nx.exp(-Sigma / sigma2))), ) spatial_outlier = ( _power(nx)((2 * _pi(nx) * sigma2), _data(nx, D / 2, XnAHat)) * (1 - gamma) / (gamma * outlier_s) ) spatial_inlier = 1 - spatial_outlier / (spatial_outlier + nx.einsum("ij->j", exp_SpatialMat)) term1 = nx.einsum( "ij,i->ij", _mul(nx)(nx.exp(-SpatialDistMat / (2 * sigma2)), nx.exp(-GeneDistMat / (2 * beta2))), (_mul(nx)(alpha, nx.exp(-Sigma / sigma2))), ) P = term1 / (_unsqueeze(nx)(nx.einsum("ij->j", term1), 0) + 1e-8) P = nx.einsum("j,ij->ij", spatial_inlier, P) Ps.append(P) P = nx.concatenate(Ps, axis=1) return P
# TO-DO: keep size # TO-DO: genes can incorporate SVG in .var
[docs]def BA_align( sampleA: AnnData, sampleB: AnnData, rep_layer: Union[str, List[str]] = "X", rep_field: Union[str, List[str]] = "layer", genes: Optional[Union[List[str], torch.Tensor]] = None, spatial_key: str = "spatial", key_added: str = "align_spatial", iter_key_added: Optional[str] = None, save_concrete_iter: bool = False, vecfld_key_added: Optional[str] = None, dissimilarity: Union[str, List[str]] = "kl", probability_type: Union[str, List[str]] = "gauss", probability_parameters: Optional[Union[float, List[float]]] = None, label_transfer_dict: Optional[Union[dict, List[dict]]] = None, nn_init: bool = True, allow_flip: bool = False, init_layer: str = "X", init_field: str = "layer", max_iter: int = 200, SVI_mode: bool = True, batch_size: int = 1000, pre_compute_dist: bool = True, sparse_calculation_mode: bool = False, lambdaVF: Union[int, float] = 1e2, beta: Union[int, float] = 0.01, K: Union[int, float] = 15, sigma2_init_scale: Optional[Union[int, float]] = 0.1, partial_robust_level: float = 25, normalize_c: bool = True, normalize_g: bool = True, dtype: str = "float32", device: str = "cpu", # inplace: bool = True, verbose: bool = True, guidance_pair: Optional[Union[List[np.ndarray], np.ndarray]] = None, guidance_effect: Optional[Union[bool, str]] = False, guidance_epsilon: float = 1, ) -> Tuple[Tuple[AnnData, AnnData], np.ndarray]: """ Align two spatial transcriptomics AnnData objects using the Spateo alignment algorithm. Args: sampleA (AnnData): The first AnnData object that acts as the reference. sampleB (AnnData): The second AnnData object to be aligned. rep_layer (Union[str, List[str]], optional): Representation layer(s) in AnnData to be used for alignment. Defaults to "X". rep_field (Union[str, List[str]], optional): Representation layer field(s) in AnnData to be used for alignment. "layer" means gene expression, "obsm" means embdedding like pca or VAE, "obs" means discrete label annotation. Note that Spateo only accept one label annotation. Defaults to "layer". genes (Optional[Union[List[str], torch.Tensor]], optional): List or tensor of genes to be used for alignment. For example, you can input the genes you are interested or spatially variabe genes here. Defaults to None. spatial_key (str, optional): Key in `.obsm` of AnnData corresponding to the spatial coordinates. Defaults to "spatial". key_added (str, optional): Key under which the aligned spatial coordinates are added in `.obsm`. Defaults to "align_spatial". iter_key_added (Optional[str], optional): Key under which to store intermediate iteration results in `.uns`. Defaults to None. vecfld_key_added (Optional[str], optional): Key under which to store vector field results in `.uns`. Defaults to None. dissimilarity (Union[str, List[str]], optional): Measure(s) of pairwise dissimilarity of each observation to be used. Defaults to "kl". probability_type (Union[str, List[str]], optional): Type(s) of probability distribution used. Defaults to "gauss". probability_parameters (Optional[Union[float, List[float]]], optional): Parameters for the probability distribution. Defaults to None. label_transfer_dict (Optional[Union[dict, List[dict]]], optional): Dictionary that stores the label transfer probability. Defaults to None. nn_init (bool, optional): Whether to use nearest neighbor initialization. Defaults to True. allow_flip (bool, optional): Whether to allow flipping of coordinates. Defaults to False. init_layer (str, optional): Layer for init alignment. Defaults to "X". init_field (str, optional): Layer field for init alignment. Defaults to 'layer'. max_iter (int, optional): Maximum number of iterations. Defaults to 200. SVI_mode (bool, optional): Whether to use Stochastic Variational Inference mode. Defaults to True. batch_size (int, optional): Size of the mini-batch for SVI. Defaults to 1000. pre_compute_dist (bool, optional): Whether to pre-compute the gene similarity matrix. Defaults to True. sparse_calculation_mode (bool, optional): Whether to use sparse matrix calculations. Defaults to False. lambdaVF (Union[int, float], optional): Regularization parameter for vector field. Defaults to 1e2. beta (Union[int, float], optional): Length-scale of the SE kernel. Defaults to 0.01. K (Union[int, float], optional): Number of sparse inducing points for Nyström approximation. Defaults to 15. sigma2_init_scale (Optional[Union[int, float]], optional): Initial spatial dispersion scale. Defaults to 0.1. partial_robust_level (float, optional): Robust level for partial alignment. Defaults to 25. normalize_c (bool, optional): Whether to normalize spatial coordinates. Defaults to True. normalize_g (bool, optional): Whether to normalize gene expression. Defaults to True. dtype (str, optional): Data type for computations. Defaults to "float32". device (str, optional): Device for computation, e.g., "cpu" or "0" for GPU. Defaults to "cpu". # inplace (bool, optional): Whether to modify `adata` inplace. Defaults to True. verbose (bool, optional): Whether to print verbose messages. Defaults to True. guidance_pair (Optional[Union[List[np.ndarray], np.ndarray]], optional): Guidance pairs for alignment. Defaults to None. guidance_effect (Optional[Union[bool, str]], optional): Effect of guidance. Defaults to False. guidance_epsilon (float, optional): Epsilon value for guidance. Defaults to 1. Returns: Tuple[Tuple[AnnData, AnnData], np.ndarray]: A tuple containing the aligned AnnData objects and assignment matrix. """ # TODO: remove the type checking out # assert dissimilarity in [ # "kl", # "euclidean", # "euc", # "cos", # "cosine", # ], "``dissimilarity`` value is not valid. Available ``dissimilarity`` are: ``'kl'``, ``'euclidean'``, ``'euc'``, ``'cos'``, and ``'cosine'``." # normalize_g = False if dissimilarity == "kl" else normalize_g # if using GPU, empty the GPU memory empty_cache(device=device) # prerocessing ( nx, type_as, exp_layers, spatial_coords, label_transfer, normalize_scales, normalize_means, genes, ) = align_preprocess( samples=[sampleA, sampleB], rep_layer=rep_layer, rep_field=rep_field, genes=genes, spatial_key=spatial_key, label_transfer_dict=label_transfer_dict, normalize_c=normalize_c, normalize_g=normalize_g, dtype=dtype, device=device, verbose=verbose, ) coordsA, coordsB = spatial_coords[1], spatial_coords[0] exp_layer_A, exp_layer_B = exp_layers[1], exp_layers[0] NA, NB, D = coordsA.shape[0], coordsB.shape[0], coordsA.shape[1] # normalize guidance pair and convert to correct data types if isinstance(guidance_pair, list) and (guidance_effect is not False): guidance_pair = guidance_pair_preprocess( nx=nx, type_as=type_as, guidance_pair=guidance_pair, normalize_scales=normalize_scales, normalize_means=normalize_means, ) X_AI, X_BI = guidance_pair[1], guidance_pair[0] V_AI = nx.zeros(X_AI.shape, type_as=type_as) else: X_AI, X_BI = None, None # perform coarse rigid alignment # TODO: add downsampling in the coarse_rigid_alignment # TODO: coordsA should not be transformed here, because the inducing variable is in the same space if nn_init: inlier_A, inlier_B, inlier_P, init_R, init_t = coarse_rigid_alignment( nx=nx, type_as=type_as, coordsA=coordsA, coordsB=coordsB, init_layer=init_layer, init_field=init_field, genes=genes, samples=[sampleA, sampleB], top_K=10, allow_flip=allow_flip, verbose=verbose, ) else: init_R = nx.eye(D, type_as=type_as) init_t = nx.zeros((D,), type_as=type_as) # # apply coarse alignment to guidance pair # if X_AI is not None: # X_AI = X_AI @ init_R.T + init_t # construct the kernel for Gaussian processes ( inducing_variables, # K x D inducing_variables_index, # K GammaSparse, # K x K U, # NA x K U_I, # NI x K / None ) = get_kernel( spatial_coords=coordsA, inducing_variables="random", kernel_bandwidth=beta, kernel_type="euc", add_evaluation_points=X_AI if (guidance_effect == "nonrigid") else None, ) K = inducing_variables.shape[0] # initial guess for sigma2, beta2, anneling factor for sigma2 and beta2 sigma2 = sigma2_init_scale * _init_guess_sigma2(coordsA, coordsB) probability_parameters = _init_probability_parameters( exp_layer_A=exp_layer_A, exp_layer_B=exp_layer_B, dissimilarity=dissimilarity, probability_type=probability_type, probability_parameters=probability_parameters, ) sigma2_variance = 1 sigma2_variance_end = partial_robust_level sigma2_variance_decress = _get_anneling_factor( start=sigma2_variance, end=sigma2_variance_end, iter=(max_iter / 2), nx=nx ) # initialize the variational variables kappa = nx.ones((NA), type_as=type_as) alpha = nx.ones((NA), type_as=type_as) gamma, gamma_a, gamma_b = ( _data(nx, 0.5, type_as), _data(nx, 1.0, type_as), _data(nx, 1.0, type_as), ) VnA = nx.zeros(coordsA.shape, type_as=type_as) # nonrigid vector velocity XAHat, RnA = coordsA, coordsA # initial transformed / rigid position Coff = nx.zeros(K, type_as=type_as) # inducing variables coefficient SigmaDiag = nx.zeros((NA), type_as=type_as) # Gaussian processes variance R = _identity(nx, D, type_as) # rotation in rigid transformation nonrigid_flag = False # indicate if to start nonrigid # initialize the SVI if SVI_mode: SVI_deacy = _data(nx, 10.0, type_as) # Select a random subset of data batch_size = min(max(int(NB / 10), batch_size), NB) batch_perm = _randperm(nx)(NB) batch_idx = batch_perm[:batch_size] batch_perm = _roll(nx)(batch_perm, batch_size) batch_coordsB = coordsB[batch_idx, :] # batch_size x D Sp, Sp_spatial, Sp_sigma2 = 0, 0, 0 SigmaInv = nx.zeros((K, K), type_as=type_as) # K x K PXB_term = nx.zeros((NA, D), type_as=type_as) # NA x D # calculate the representation(s) pairwise distance matrix if pre_compute_dist is True or not in SVI mode if (not SVI_mode) or (pre_compute_dist): exp_layer_dist = calc_distance( X=exp_layer_A, Y=exp_layer_B, metric=dissimilarity, label_transfer=label_transfer ) # get the current batch representation(s) pairwise distance matrix # TODO: we can insert spatial_dist calculation this into get_P # get the current batch spatial proximity matrix # if sparse_calculation_mode is False: # spatial_dist = calc_distance( # X=XAHat, # Y=randcoordsB if SVI_mode else coordsB, # metric="euc", # ) # NA x batch_size (SVI_mode) / NA x NB (not SVI_mode) # initialize the intermediate results if iter_key_added is not None: sampleB.uns[iter_key_added] = dict() sampleB.uns[iter_key_added][key_added] = {} sampleB.uns[iter_key_added]["sigma2"] = {} sampleB.uns[iter_key_added]["beta2"] = {} if save_concrete_iter: sampleB.uns[iter_key_added]["matches"] = {} sampleB.uns[iter_key_added]["alpha"] = {} # start iteration iteration = ( lm.progress_logger(range(max_iter), progress_name="Start Spateo pairwise alignment") if verbose else range(max_iter) ) for iter in iteration: # update the step size for SVI step_size = nx.minimum(_data(nx, 1.0, type_as), SVI_deacy / (iter + 1.0)) # calculate the assignment matrix P, assignment_results = update_assignment_P( nx=nx, type_as=type_as, spatial_A=XAHat, spatial_B=coordsB, exp_layer_A=exp_layer_A, exp_layer_B=exp_layer_B, batch_idx=batch_idx if SVI_mode else None, exp_layer_dist=exp_layer_dist if pre_compute_dist else None, sigma2=sigma2, alpha=alpha, gamma=gamma, Sigma=SigmaDiag, sigma2_variance=sigma2_variance, probability_type=probability_type, probability_parameters=probability_parameters, sparse_calculation_mode=sparse_calculation_mode, ) # update variational variables gamma and alpha K_NA, K_NB = assignment_results["K_NA"], assignment_results["K_NB"] K_NA_spatial = assignment_results["K_NA_spatial"] K_NA_sigma2 = assignment_results["K_NA_sigma2"] (Sp, Sp_spatial, Sp_sigma2) = update_Sp( step_size=step_size, SVI_mode=SVI_mode, Sp=Sp, Sp_spatial=Sp_spatial, Sp_sigma2=Sp_sigma2, assignment_results=assignment_results, ) # update gamma gamma = update_gamma( nx=nx, type_as=type_as, gamma=gamma, step_size=step_size, batch_size=batch_size, gamma_a=gamma_a, gamma_b=gamma_b, Sp_spatial=Sp_spatial, SVI_mode=SVI_mode, ) # update alpha alpha = update_alpha(alpha, step_size, kappa, assignment_results, SVI_mode) # update nonrigid vector field if (sigma2 < 0.015) or (iter > 80) or nonrigid_flag: nonrigid_flag = True (VnA, V_AI, SigmaDiag, SigmaInv, PXB_term, Coff) = update_nonrigid( nx=nx, type_as=type_as, SVI_mode=SVI_mode, guidance_effect=guidance_effect, SigmaInv=SigmaInv, step_size=step_size, sigma2=sigma2, lambdaVF=lambdaVF, GammaSparse=GammaSparse, U=U, K_NA=K_NA, PXB_term=PXB_term, P=P, coordsB=randcoordsB if SVI_mode else coordsB, RnA=RnA, guidance_epsilon=guidance_epsilon, U_I=U_I, R_AI=R_AI, X_BI=X_BI, ) # update rigid transformation rigid_variable_field = update_rigid(rigid_variable_field) # update sigma2 and beta2 # SpatialDistMat = cal_dist(XAHat, randcoordsB) if SVI_mode else cal_dist(XAHat, coordsB) sigma2, sigma2_variance = update_sigma2(sigma2_variable_field, assignment_results) # beta2 = update_beta() # iterate to next batch if SVI_mode and iter < max_iter - 1: batch_idx = batch_perm[:batch_size] batch_perm = _roll(nx)(batch_perm, batch_size) randcoordsB = coordsB[randIdx, :] # get the full cell-cell assignment if SVI_mode: P, assignment_results = update_assignment_P( nx=nx, type_as=type_as, spatial_A=XAHat, spatial_B=coordsB, exp_layer_A=exp_layer_A, exp_layer_B=exp_layer_B, exp_layer_dist=exp_layer_dist if pre_compute_dist else None, sigma2=sigma2, alpha=alpha, gamma=gamma, Sigma=SigmaDiag, sigma2_variance=sigma2_variance, probability_type=probability_type, probability_parameters=probability_parameters, sparse_calculation_mode=sparse_calculation_mode, ) # Get optimal rigid transformation based on final mapping # TODO: make sure the R_init means optimal_RnA, optimal_R, optimal_t = get_optimal_R( coordsA=coordsA, coordsB=coordsB, P=P, R_init=R, ) if verbose: lm.main_info( f"Key Parameters: gamma: {gamma}; sigma2: {sigma2}; probability_parameters: {probability_parameters}" ) # denormalize if normalize_c: (XAHat, RnA, optimal_RnA, coarse_alignment) = denormalize( XAHat, RnA, optimal_RnA, coarse_alignment, normalize_scale=normalize_scales[0], normalize_mean=normalize_means[0], ) # XAHat = XAHat * normalize_scale_list[0] + normalize_mean_list[0] # RnA = RnA * normalize_scale_list[0] + normalize_mean_list[0] # optimal_RnA = optimal_RnA * normalize_scale_list[0] + normalize_mean_list[0] # coarse_alignment = coarse_alignment * normalize_scale_list[0] + normalize_mean_list[0] # Save aligned coordinates sampleB.obsm[f"{key_added}_nonrigid"] = nx.to_numpy(XAHat).copy() sampleB.obsm[f"{key_added}_rigid"] = nx.to_numpy(optimal_RnA).copy() # save vector field and other parameters if not (vecfld_key_added is None): sampleB.uns[vecfld_key_added] = { "R": nx.to_numpy(R), "t": nx.to_numpy(t), "optimal_R": nx.to_numpy(optimal_R), "optimal_t": nx.to_numpy(optimal_t), "init_R": init_R, "init_t": init_t, "beta": beta, "Coff": nx.to_numpy(Coff), "inducing_variables": nx.to_numpy(inducing_variables), "normalize_scales": nx.to_numpy(normalize_scales) if normalize_c else None, "normalize_means": nx.to_numpy(normalize_means) if normalize_c else None, "normalize_c": normalize_c, "dissimilarity": dissimilarity, "beta2": nx.to_numpy(sigma2), "sigma2": nx.to_numpy(sigma2), "gamma": nx.to_numpy(gamma), "NA": NA, "sigma2_variance": nx.to_numpy(sigma2_variance), "method": "Spateo", } empty_cache(device=device) return ( (sampleA, sampleB), nx.to_numpy(P.T), )
# def BA_align( # sampleA: AnnData, # sampleB: AnnData, # genes: Optional[Union[List, torch.Tensor]] = None, # spatial_key: str = "spatial", # key_added: str = "align_spatial", # iter_key_added: Optional[str] = None, # vecfld_key_added: Optional[str] = "VecFld_morpho", # layer: str = "X", # dissimilarity: str = "kl", # use_rep: Optional[str] = None, # keep_size: bool = False, # max_iter: int = 200, # lambdaVF: Union[int, float] = 1e2, # beta: Union[int, float] = 0.01, # K: Union[int, float] = 15, # beta2: Optional[Union[int, float]] = None, # beta2_end: Optional[Union[int, float]] = None, # normalize_c: bool = True, # normalize_g: bool = True, # dtype: str = "float32", # device: str = "cpu", # inplace: bool = True, # verbose: bool = True, # nn_init: bool = True, # allow_flip: bool = False, # SVI_mode: bool = True, # batch_size: int = 1000, # partial_robust_level: float = 25, # pre_compute_dist: bool = True, # guidance_pair: Optional[list] = None, # guidance_effect: Optional[Union[bool, str]] = False, # guidance_epsilon: float = 1, # ) -> Tuple[Optional[Tuple[AnnData, AnnData]], np.ndarray, np.ndarray]: # """The core function of Spateo alignment # Args: # sampleA: Sample A that acts as reference. # sampleB: Sample B that performs alignment. # genes: Genes used for calculation. If None, use all common genes for calculation. # spatial_key: The key in ``.obsm`` that corresponds to the raw spatial coordinate. # key_added: ``.obsm`` key under which to add the aligned spatial coordinate. # iter_key_added: ``.uns`` key under which to add the result of each iteration of the iterative process. If ``iter_key_added`` is None, the results are not saved. # vecfld_key_added: The key that will be used for the vector field key in ``.uns``. If ``vecfld_key_added`` is None, the results are not saved. # layer: If ``'X'``, uses ``.X`` to calculate dissimilarity between spots, otherwise uses the representation given by ``.layers[layer]``. # dissimilarity: Expression dissimilarity measure: ``'kl'``, ``'euclidean'``, or ``'cos'``. # use_rep: Use the indicated representation. If use_rep is None, then use the given "layer", else use the key stored in .obsm. E.g., "X_pca". # max_iter: Max number of iterations for morpho alignment. # lambdaVF : Hyperparameter that controls the non-rigid distortion degree. Smaller means more flexibility. # beta: The length-scale of the SE kernel. Higher means more flexibility. # K: The number of sparse inducing points used for Nystr ̈om approximation. Smaller means faster but less accurate. # beta2: Manually assigned significance gene expression similarity. Smaller indicating greater significance. # beta2_end: Manually assigned significance gene expression similarity. Smaller indicating greater significance. # normalize_c: Whether to normalize spatial coordinates. # normalize_g: Whether to normalize gene expression. If ``dissimilarity`` == ``'kl'``, ``normalize_g`` must be False. # samples_s: The space size of each sample. Area size for 2D samples and volume size for 3D samples. # dtype: The floating-point number type. Only ``float32`` and ``float64``. # device: Equipment used to run the program. You can also set the specified GPU for running. ``E.g.: '0'``. # inplace: Whether to copy adata or modify it inplace. # verbose: If ``True``, print progress updates. # nn_init: If ``True``, use nearest neighbor matching to initialize the alignment. # SVI_mode: Whether to use stochastic variational inferential (SVI) optimization strategy. # batch_size: The size of the mini-batch of SVI. If set smaller, the calculation will be faster, but it will affect the accuracy, and vice versa. If not set, it is automatically set to one-tenth of the data size. # partial_robust_level: The robust level of partial alignment. The larger the value, the more robust the alignment to partial cases is. Recommended setting from 1 to 50. # pre_compute_dist: If ``True``, the gene similarity matrix is computed before the mini batch is performed. Otherwise, it is computed during the mini batch. This can be significantly faster, but can also require more GPU memory if using GPU. # """ # empty_cache(device=device) # # Preprocessing # normalize_g = False if dissimilarity == "kl" else normalize_g # sampleA, sampleB = (sampleA, sampleB) if inplace else (sampleA.copy(), sampleB.copy()) # ( # nx, # type_as, # new_samples, # exp_matrices, # spatial_coords, # normalize_scale_list, # normalize_mean_list, # ) = align_preprocess( # samples=[sampleA, sampleB], # layer=layer, # genes=genes, # spatial_key=spatial_key, # normalize_c=normalize_c, # normalize_g=normalize_g, # dtype=dtype, # device=device, # verbose=verbose, # use_rep=use_rep, # ) # # normalize guidance pair and convert to correct data types # if isinstance(guidance_pair, list) and (guidance_effect is not False): # guidance_pair = guidance_pair_preprocess(guidance_pair, normalize_scale_list, normalize_mean_list, nx, type_as) # X_AI = guidance_pair[0] # X_BI = guidance_pair[1] # coordsA, coordsB = spatial_coords[1], spatial_coords[0] # X_A, X_B = exp_matrices[1], exp_matrices[0] # del spatial_coords, exp_matrices # NA, NB, D, G = coordsA.shape[0], coordsB.shape[0], coordsA.shape[1], X_A.shape[1] # sub_sample = False # sub_sample_num = 20000 # if SVI_mode and (NA > sub_sample_num or NB > sub_sample_num) and (pre_compute_dist is False): # if NA > sub_sample_num: # sub_idx_A = np.random.choice(NA, sub_sample_num, replace=False) # sub_coordsA = coordsA[sub_idx_A, :] # sub_X_A = X_A[sub_idx_A, :] # else: # sub_coordsA = coordsA # sub_X_A = X_A # if NB > sub_sample_num: # sub_idx_B = np.random.choice(NB, sub_sample_num, replace=False) # sub_coordsB = coordsB[sub_idx_B, :] # sub_X_B = X_B[sub_idx_B, :] # else: # sub_coordsB = coordsB # sub_X_B = X_B # GeneDistMat = calc_exp_dissimilarity(X_A=sub_X_A, X_B=sub_X_B, dissimilarity=dissimilarity) # sub_sample = True # else: # GeneDistMat = calc_exp_dissimilarity(X_A=X_A, X_B=X_B, dissimilarity=dissimilarity) # area = _prod(nx)(nx.max(coordsA, axis=0) - nx.min(coordsA, axis=0)) # if nn_init: # # perform coarse rigid alignment # if sub_sample: # _cra_kwargs = dict( # coordsA=sub_coordsA, # coordsB=sub_coordsB, # X_A=sub_X_A, # X_B=sub_X_B, # transformed_points=coordsA, # allow_flip=allow_flip, # ) # else: # _cra_kwargs = dict( # coordsA=coordsA, # coordsB=coordsB, # X_A=X_A, # X_B=X_B, # transformed_points=None, # allow_flip=allow_flip, # ) # coordsA, inlier_A, inlier_B, inlier_P, init_R, init_t = coarse_rigid_alignment( # dissimilarity=dissimilarity, top_K=10, verbose=verbose, **_cra_kwargs # ) # empty_cache(device=device) # coordsA = _data(nx, coordsA, type_as) # inlier_A = _data(nx, inlier_A, type_as) # inlier_B = _data(nx, inlier_B, type_as) # inlier_P = _data(nx, inlier_P, type_as) # # inlier_P = inlier_P * M / nx.sum(inlier_P) # inlier_R = inlier_A # inlier_V = nx.zeros(inlier_A.shape, type_as=type_as) # else: # init_R = nx.eye(D, type_as=type_as) # init_t = nx.zeros((D,), type_as=type_as) # inlier_A = [] # inlier_B = [] # inlier_P = [] # if (guidance_effect is not False) and (guidance_pair is not None): # X_AI = X_AI @ init_R.T + init_t # if len(inlier_A) == 0: # inlier_A = X_AI # inlier_B = X_BI # inlier_P = nx.ones((X_AI.shape[0], 1), type_as=type_as) # else: # inlier_A = nx.concatenate([inlier_A, X_AI], axis=0) # inlier_B = nx.concatenate([inlier_B, X_BI], axis=0) # inlier_P = nx.concatenate([inlier_P, nx.ones((X_AI.shape[0], 1), type_as=type_as)], axis=0) # inlier_R = inlier_A # inlier_V = nx.zeros(inlier_A.shape, type_as=type_as) # inlier_AHat = inlier_A # coarse_alignment = coordsA # # Random select control points # Unique_coordsA = _unique(nx, coordsA, 0) # idx = random.sample(range(Unique_coordsA.shape[0]), min(K, Unique_coordsA.shape[0])) # ctrl_pts = Unique_coordsA[idx, :] # K = ctrl_pts.shape[0] # # construct the kernel # GammaSparse = con_K(ctrl_pts, ctrl_pts, beta) # U = con_K(coordsA, ctrl_pts, beta) # if guidance_effect == "nonrigid": # inlier_U = con_K(inlier_A, ctrl_pts, beta) # kappa = nx.ones((NA), type_as=type_as) # alpha = nx.ones((NA), type_as=type_as) # VnA = nx.zeros(coordsA.shape, type_as=type_as) # Coff = nx.zeros(ctrl_pts.shape, type_as=type_as) # gamma, gamma_a, gamma_b = ( # _data(nx, 0.5, type_as), # _data(nx, 1.0, type_as), # _data(nx, 1.0, type_as), # ) # minP, sigma2_terc, erc = ( # _data(nx, 1e-5, type_as), # _data(nx, 1, type_as), # _data(nx, 1e-4, type_as), # ) # SigmaDiag = nx.zeros((NA), type_as=type_as) # XAHat, RnA = coordsA, coordsA # if sub_sample: # SpatialDistMat = cal_dist(sub_coordsA, sub_coordsB) # del sub_coordsA, sub_coordsB # else: # SpatialDistMat = cal_dist(XAHat, coordsB) # # initial guess for sigma2 and beta2 # sigma2 = _init_guess_sigma2(XAHat, coordsB) # beta2, beta2_end = _init_guess_beta2( # nx, X_A, X_B, dissimilarity, partial_robust_level, beta2, beta2_end, verbose=verbose # ) # beta2_decrease = _power(nx)(beta2_end / beta2, 1 / (50)) # R = _identity(nx, D, type_as) # nonrigid_flag = False # # Use smaller spatial variance to reduce tails # outlier_variance = 1 # max_outlier_variance = partial_robust_level # outlier_variance_decrease = _power(nx)(_data(nx, max_outlier_variance, type_as), 1 / (max_iter / 2)) # if SVI_mode: # SVI_deacy = _data(nx, 10.0, type_as) # # Select a random subset of data # batch_size = min(max(int(NB / 10), batch_size), NB) # randomidx = _randperm(nx)(NB) # randIdx = randomidx[:batch_size] # randomIdx = _roll(nx)(randomidx, batch_size) # randcoordsB = coordsB[randIdx, :] # batch_size x D # if sub_sample: # randGeneDistMat = calc_exp_dissimilarity(X_A=X_A, X_B=X_B[randIdx, :], dissimilarity=dissimilarity) # SpatialDistMat = cal_dist(coordsA, randcoordsB) # else: # randGeneDistMat = GeneDistMat[:, randIdx] # NA x batch_size # SpatialDistMat = SpatialDistMat[:, randIdx] # NA x batch_size # Sp, Sp_spatial, Sp_sigma2 = 0, 0, 0 # SigmaInv = nx.zeros((K, K), type_as=type_as) # K x K # PXB_term = nx.zeros((NA, D), type_as=type_as) # NA x D # iteration = ( # lm.progress_logger(range(max_iter), progress_name="Start morpho alignment") if verbose else range(max_iter) # ) # if iter_key_added is not None: # sampleB.uns[iter_key_added] = dict() # sampleB.uns[iter_key_added][key_added] = {} # sampleB.uns[iter_key_added]["sigma2"] = {} # sampleB.uns[iter_key_added]["beta2"] = {} # # sampleB.uns[iter_key_added]["inlier_AHat"] = {} # # sampleB.uns[iter_key_added]["inlier_B"] = nx.to_numpy(X_BI * normalize_scale_list[0] + normalize_mean_list[0] if normalize_c else X_BI) # for iter in iteration: # if iter_key_added is not None: # iter_XAHat = XAHat * normalize_scale_list[0] + normalize_mean_list[0] if normalize_c else XAHat # # iter_inlier_AHat = inlier_AHat * normalize_scale_list[0] + normalize_mean_list[0] if normalize_c else inlier_AHat # sampleB.uns[iter_key_added][key_added][iter] = nx.to_numpy(iter_XAHat) # sampleB.uns[iter_key_added]["sigma2"][iter] = nx.to_numpy(sigma2) # sampleB.uns[iter_key_added]["beta2"][iter] = nx.to_numpy(beta2) # # sampleB.uns[iter_key_added]["inlier_AHat"][iter] = nx.to_numpy(iter_inlier_AHat) # if SVI_mode: # step_size = nx.minimum(_data(nx, 1.0, type_as), SVI_deacy / (iter + 1.0)) # P, spatial_P, sigma2_P = get_P( # XnAHat=XAHat, # XnB=randcoordsB, # sigma2=sigma2, # beta2=beta2, # alpha=alpha, # gamma=gamma, # Sigma=SigmaDiag, # GeneDistMat=randGeneDistMat, # SpatialDistMat=SpatialDistMat, # outlier_variance=outlier_variance, # ) # else: # P, spatial_P, sigma2_P = get_P( # XnAHat=XAHat, # XnB=coordsB, # sigma2=sigma2, # beta2=beta2, # alpha=alpha, # gamma=gamma, # Sigma=SigmaDiag, # GeneDistMat=GeneDistMat, # SpatialDistMat=SpatialDistMat, # outlier_variance=outlier_variance, # ) # if iter > 5: # beta2 = ( # nx.maximum(beta2 * beta2_decrease, beta2_end) # if beta2_decrease < 1 # else nx.minimum(beta2 * beta2_decrease, beta2_end) # ) # outlier_variance = nx.minimum(outlier_variance * outlier_variance_decrease, max_outlier_variance) # K_NA = nx.einsum("ij->i", P) # K_NB = nx.einsum("ij->j", P) # K_NA_spatial = nx.einsum("ij->i", spatial_P) # K_NB_spatial = nx.einsum("ij->j", spatial_P) # K_NA_sigma2 = nx.einsum("ij->i", sigma2_P) # K_NB_sigma2 = nx.einsum("ij->j", sigma2_P) # # Update gamma # if SVI_mode: # Sp = step_size * nx.einsum("ij->", P) + (1 - step_size) * Sp # Sp_spatial = step_size * nx.einsum("ij->", spatial_P) + (1 - step_size) * Sp_spatial # Sp_sigma2 = step_size * nx.einsum("ij->", sigma2_P) + (1 - step_size) * Sp_sigma2 # gamma = nx.exp(_psi(nx)(gamma_a + Sp_spatial) - _psi(nx)(gamma_a + gamma_b + batch_size)) # else: # Sp = nx.einsum("ij->", P) # Sp_spatial = nx.einsum("ij->", spatial_P) # Sp_sigma2 = nx.einsum("ij->", sigma2_P) # gamma = nx.exp(_psi(nx)(gamma_a + Sp_spatial) - _psi(nx)(gamma_a + gamma_b + NB)) # gamma = _data(nx, 0.99, type_as) if gamma > 0.99 else gamma # gamma = _data(nx, 0.01, type_as) if gamma < 0.01 else gamma # # Update alpha # if SVI_mode: # alpha = ( # step_size * nx.exp(_psi(nx)(kappa + K_NA_spatial) - _psi(nx)(kappa * NA + Sp_spatial)) # + (1 - step_size) * alpha # ) # else: # alpha = nx.exp(_psi(nx)(kappa + K_NA_spatial) - _psi(nx)(kappa * NA + Sp_spatial)) # # Update VnA # if (sigma2 < 0.015) or (iter > 80) or nonrigid_flag: # nonrigid_flag = True # if SVI_mode: # if (guidance_effect == "nonrigid") or (guidance_effect == "both"): # SigmaInv = ( # step_size # * ( # sigma2 * lambdaVF * GammaSparse # + _dot(nx)(U.T, nx.einsum("ij,i->ij", U, K_NA)) # + (sigma2 / guidance_epsilon) * _dot(nx)(inlier_U.T, inlier_U * inlier_P) # ) # + (1 - step_size) * SigmaInv # ) # else: # SigmaInv = ( # step_size * (sigma2 * lambdaVF * GammaSparse + _dot(nx)(U.T, nx.einsum("ij,i->ij", U, K_NA))) # + (1 - step_size) * SigmaInv # ) # Sigma = _pinv(nx)(SigmaInv) # term1 = _dot(nx)(Sigma, U.T) # PXB_term = ( # step_size * (_dot(nx)(P, randcoordsB) - nx.einsum("ij,i->ij", RnA, K_NA)) # + (1 - step_size) * PXB_term # ) # if (guidance_effect == "nonrigid") or (guidance_effect == "both"): # term1_guide = _dot(nx)(Sigma, inlier_U.T) # XBRA_guide_term = (inlier_B - inlier_R) * inlier_P # Coff = _dot(nx)(term1, PXB_term) + (sigma2 / guidance_epsilon) * _dot(nx)( # term1_guide, XBRA_guide_term # ) # inlier_V = _dot(nx)(inlier_U, Coff) # else: # Coff = _dot(nx)(term1, PXB_term) # VnA = _dot(nx)( # U, # Coff, # ) # SigmaDiag = sigma2 * nx.einsum("ij->i", nx.einsum("ij,ji->ij", U, term1)) # else: # if (guidance_effect == "nonrigid") or (guidance_effect == "both"): # SigmaInv = ( # sigma2 * lambdaVF * GammaSparse # + _dot(nx)(U.T, nx.einsum("ij,i->ij", U, K_NA)) # + (sigma2 / guidance_epsilon) * _dot(nx)(inlier_U.T, inlier_U * inlier_P) # ) # else: # SigmaInv = sigma2 * lambdaVF * GammaSparse + _dot(nx)(U.T, nx.einsum("ij,i->ij", U, K_NA)) # Sigma = _pinv(nx)(SigmaInv) # term1 = _dot(nx)(Sigma, U.T) # PXB_term = _dot(nx)(P, coordsB) - nx.einsum("ij,i->ij", RnA, K_NA) # if (guidance_effect == "nonrigid") or (guidance_effect == "both"): # term1_guide = _dot(nx)(Sigma, inlier_U.T) # XBRA_guide_term = (inlier_B - inlier_R) * inlier_P # Coff = _dot(nx)(term1, PXB_term) + (sigma2 / guidance_epsilon) * _dot(nx)( # term1_guide, XBRA_guide_term # ) # inlier_V = _dot(nx)(inlier_U, Coff) # else: # Coff = _dot(nx)(term1, PXB_term) # VnA = _dot(nx)(U, Coff) # SigmaDiag = sigma2 * nx.einsum("ij->i", nx.einsum("ij,ji->ij", U, term1)) # # Update rigid transformation R() # if SVI_mode: # PXA, PVA, PXB = ( # _dot(nx)(K_NA, coordsA)[None, :], # _dot(nx)(K_NA, VnA)[None, :], # _dot(nx)(K_NB, randcoordsB)[None, :], # ) # else: # PXA, PVA, PXB = ( # _dot(nx)(K_NA, coordsA)[None, :], # _dot(nx)(K_NA, VnA)[None, :], # _dot(nx)(K_NB, coordsB)[None, :], # ) # if SVI_mode and iter > 1: # if (guidance_effect == "rigid") or (guidance_effect == "both") or (nn_init == True): # t = ( # step_size # * ( # ( # (PXB - PVA - _dot(nx)(PXA, R.T)) # + (sigma2 / guidance_epsilon) # * _dot(nx)(inlier_P.T, inlier_B - inlier_V - _dot(nx)(inlier_A, R.T)) # ) # / (Sp + (sigma2 / guidance_epsilon) * nx.sum(inlier_P)) # ) # + (1 - step_size) * t # ) # else: # t = step_size * ((PXB - PVA - _dot(nx)(PXA, R.T)) / Sp) + (1 - step_size) * t # else: # if (guidance_effect == "rigid") or (guidance_effect == "both") or (nn_init == True): # t = ( # (PXB - PVA - _dot(nx)(PXA, R.T)) # + (sigma2 / guidance_epsilon) * _dot(nx)(inlier_P.T, inlier_B - inlier_V - _dot(nx)(inlier_A, R.T)) # ) / (Sp + (sigma2 / guidance_epsilon) * nx.sum(inlier_P)) # else: # t = (PXB - PVA - _dot(nx)(PXA, R.T)) / Sp # # Solve for the rotation # if (guidance_effect == "rigid") or (guidance_effect == "both") or (nn_init == True): # mu_XB = (PXB + (sigma2 / guidance_epsilon) * _dot(nx)(inlier_P.T, inlier_B)) / ( # Sp + (sigma2 / guidance_epsilon) * nx.sum(inlier_P) # ) # mu_XA = (PXA + (sigma2 / guidance_epsilon) * _dot(nx)(inlier_P.T, inlier_A)) / ( # Sp + (sigma2 / guidance_epsilon) * nx.sum(inlier_P) # ) # mu_Vn = (PVA + (sigma2 / guidance_epsilon) * _dot(nx)(inlier_P.T, inlier_V)) / ( # Sp + (sigma2 / guidance_epsilon) * nx.sum(inlier_P) # ) # XAI_hat = inlier_A - mu_XA # XBI_hat = inlier_B - mu_XB # fI_hat = inlier_V - mu_Vn # else: # mu_XB = PXB / Sp # mu_XA = PXA / Sp # mu_Vn = PVA / Sp # XA_hat = coordsA - mu_XA # f_hat = VnA - mu_Vn # if SVI_mode: # XB_hat = randcoordsB - mu_XB # else: # XB_hat = coordsB - mu_XB # if (guidance_effect == "rigid") or (guidance_effect == "both") or (nn_init == True): # A_guide = _dot(nx)((XAI_hat * inlier_P).T, (fI_hat - XBI_hat)) # A = -( # _dot(nx)(XA_hat.T, nx.einsum("ij,i->ij", f_hat, K_NA)) # - _dot(nx)(_dot(nx)(XA_hat.T, P), XB_hat) # + (sigma2 / guidance_epsilon) * A_guide # ).T # else: # A = -(_dot(nx)(XA_hat.T, nx.einsum("ij,i->ij", f_hat, K_NA)) - _dot(nx)(_dot(nx)(XA_hat.T, P), XB_hat)).T # svdU, svdS, svdV = _linalg(nx).svd(A) # C = _identity(nx, D, type_as) # C[-1, -1] = _linalg(nx).det(_dot(nx)(svdU, svdV)) # if SVI_mode and iter > 1: # R = step_size * (_dot(nx)(_dot(nx)(svdU, C), svdV)) + (1 - step_size) * R # else: # R = _dot(nx)(_dot(nx)(svdU, C), svdV) # RnA = _dot(nx)(coordsA, R.T) + t # inlier_R = _dot(nx)(inlier_A, R.T) + t # XAHat = RnA + VnA # inlier_AHat = inlier_R + inlier_V # # Update sigma2 and beta2 # if SVI_mode: # SpatialDistMat = cal_dist(XAHat, randcoordsB) # else: # SpatialDistMat = cal_dist(XAHat, coordsB) # sigma2_old = sigma2 # sigma2 = nx.maximum( # ( # nx.einsum("ij,ij", sigma2_P, SpatialDistMat) / (D * Sp_sigma2) # + nx.einsum("i,i", K_NA_sigma2, SigmaDiag) / Sp_sigma2 # ), # _data(nx, 1e-3, type_as), # ) # sigma2_terc = nx.abs((sigma2 - sigma2_old) / sigma2) # # Next batch # if SVI_mode and iter < max_iter - 1: # randIdx = randomidx[:batch_size] # randomidx = _roll(nx)(randomidx, batch_size) # randcoordsB = coordsB[randIdx, :] # if sub_sample: # randGeneDistMat = calc_exp_dissimilarity(X_A=X_A, X_B=X_B[randIdx, :], dissimilarity=dissimilarity) # else: # randGeneDistMat = GeneDistMat[:, randIdx] # NA x batch_size # SpatialDistMat = cal_dist(XAHat, randcoordsB) # empty_cache(device=device) # # full data # if SVI_mode: # P = get_P_chunk( # XnAHat=XAHat, # XnB=coordsB, # X_A=X_A, # X_B=X_B, # sigma2=sigma2, # beta2=beta2, # alpha=alpha, # gamma=gamma, # Sigma=SigmaDiag, # outlier_variance=outlier_variance, # ) # # Get optimal Rigid transformation # optimal_RnA, optimal_R, optimal_t = get_optimal_R( # coordsA=coordsA, # coordsB=coordsB, # P=P, # R_init=R, # ) # if verbose: # lm.main_info(f"Key Parameters: gamma: {gamma}; beta2: {beta2}; sigma2: {sigma2}") # if keep_size: # area_after = _prod(nx)(nx.max(XAHat, axis=0) - nx.min(XAHat, axis=0)) # XAHat = XAHat * (area / area_after) # if normalize_c: # XAHat = XAHat * normalize_scale_list[0] + normalize_mean_list[0] # RnA = RnA * normalize_scale_list[0] + normalize_mean_list[0] # optimal_RnA = optimal_RnA * normalize_scale_list[0] + normalize_mean_list[0] # coarse_alignment = coarse_alignment * normalize_scale_list[0] + normalize_mean_list[0] # # Save aligned coordinates # sampleB.obsm["Nonrigid_align_spatial"] = nx.to_numpy(XAHat).copy() # sampleB.obsm["Rigid_align_spatial"] = nx.to_numpy(optimal_RnA).copy() # # save vector field and other parameters # if not (vecfld_key_added is None): # sampleB.uns[vecfld_key_added] = { # "R": nx.to_numpy(R), # "t": nx.to_numpy(t), # "optimal_R": nx.to_numpy(optimal_R), # "optimal_t": nx.to_numpy(optimal_t), # "init_R": init_R, # "init_t": init_t, # "beta": beta, # "Coff": nx.to_numpy(Coff), # "ctrl_pts": nx.to_numpy(ctrl_pts), # "normalize_scale": nx.to_numpy(normalize_scale_list[0]) if normalize_c else None, # "normalize_mean_list": [nx.to_numpy(normalize_mean) for normalize_mean in normalize_mean_list] # if normalize_c # else None, # "normalize_c": normalize_c, # "dissimilarity": dissimilarity, # "beta2": nx.to_numpy(sigma2), # "sigma2": nx.to_numpy(sigma2), # "gamma": nx.to_numpy(gamma), # "NA": NA, # "outlier_variance": nx.to_numpy(outlier_variance), # } # empty_cache(device=device) # return ( # None if inplace else (sampleA, sampleB), # nx.to_numpy(P.T), # nx.to_numpy(sigma2), # )