import random
import numpy as np
import ot
import torch
from anndata import AnnData
try:
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
except ImportError:
from typing_extensions import Literal
from typing import List, Optional, Tuple, Union
from spateo.logging import logger_manager as lm
from .utils import (
_chunk,
_data,
_dot,
_identity,
_linalg,
_mul,
_pi,
_pinv,
_power,
_prod,
_psi,
_randperm,
_roll,
_unique,
_unsqueeze,
align_preprocess,
cal_dist,
calc_exp_dissimilarity,
coarse_rigid_alignment,
empty_cache,
get_optimal_R,
)
[docs]def con_K(
X: Union[np.ndarray, torch.Tensor],
Y: Union[np.ndarray, torch.Tensor],
beta: Union[int, float] = 0.01,
use_chunk: bool = False,
) -> Union[np.ndarray, torch.Tensor]:
"""con_K constructs the Squared Exponential (SE) kernel, where K(i,j)=k(X_i,Y_j)=exp(-beta*||X_i-Y_j||^2).
Args:
X: The first vector X\in\mathbb{R}^{N\times d}
Y: The second vector X\in\mathbb{R}^{M\times d}
beta: The length-scale of the SE kernel.
use_chunk (bool, optional): Whether to use chunk to reduce the GPU memory usage. Note that if set to ``True'' it will slow down the calculation. Defaults to False.
Returns:
K: The kernel K\in\mathbb{R}^{N\times M}
"""
assert X.shape[1] == Y.shape[1], "X and Y do not have the same number of features."
nx = ot.backend.get_backend(X, Y)
K = cal_dist(X, Y)
K = nx.exp(-beta * K)
return K
############
# BioAlign #
############
[docs]def get_P(
XnAHat: Union[np.ndarray, torch.Tensor],
XnB: Union[np.ndarray, torch.Tensor],
sigma2: Union[int, float, np.ndarray, torch.Tensor],
beta2: Union[int, float, np.ndarray, torch.Tensor],
alpha: Union[np.ndarray, torch.Tensor],
gamma: Union[float, np.ndarray, torch.Tensor],
Sigma: Union[np.ndarray, torch.Tensor],
GeneDistMat: Union[np.ndarray, torch.Tensor],
SpatialDistMat: Union[np.ndarray, torch.Tensor],
samples_s: Optional[List[float]] = None,
outlier_variance: float = None,
) -> Tuple[Any, Any, Any]:
"""Calculating the generating probability matrix P.
Args:
XAHat: Current spatial coordinate of sample A. Shape: N x D.
XnB : spatial coordinate of sample B (reference sample). Shape: M x D.
sigma2: The spatial coordinate noise.
beta2: The gene expression noise.
alpha: A vector that encoding each probability generated by the spots of sample A. Shape: N x 1.
gamma: Inlier proportion of sample A.
Sigma: The posterior covariance matrix of Gaussian process. Shape: N x N or N x 1.
GeneDistMat: The gene expression distance matrix between sample A and sample B. Shape: N x M.
SpatialDistMat: The spatial coordinate distance matrix between sample A and sample B. Shape: N x M.
samples_s: The space size of each sample. Area size for 2D samples and volume size for 3D samples.
Returns:
P: Generating probability matrix P. Shape: N x M.
"""
assert XnAHat.shape[1] == XnB.shape[1], "XnAHat and XnB do not have the same number of features."
assert XnAHat.shape[0] == alpha.shape[0], "XnAHat and alpha do not have the same length."
assert XnAHat.shape[0] == Sigma.shape[0], "XnAHat and Sigma do not have the same length."
nx = ot.backend.get_backend(XnAHat, XnB)
NA, NB, D = XnAHat.shape[0], XnB.shape[0], XnAHat.shape[1]
if samples_s is None:
samples_s = nx.maximum(
_prod(nx)(nx.max(XnAHat, axis=0) - nx.min(XnAHat, axis=0)),
_prod(nx)(nx.max(XnB, axis=0) - nx.min(XnB, axis=0)),
)
outlier_s = samples_s * NA
if outlier_variance is None:
exp_SpatialMat = nx.exp(-SpatialDistMat / (2 * sigma2))
else:
exp_SpatialMat = nx.exp(-SpatialDistMat / (2 * sigma2 / outlier_variance))
spatial_term1 = nx.einsum(
"ij,i->ij",
exp_SpatialMat,
(_mul(nx)(alpha, nx.exp(-Sigma / sigma2))),
)
spatial_outlier = _power(nx)((2 * _pi(nx) * sigma2), _data(nx, D / 2, XnAHat)) * (1 - gamma) / (gamma * outlier_s)
spatial_term2 = spatial_outlier + nx.einsum("ij->j", spatial_term1)
spatial_P = spatial_term1 / _unsqueeze(nx)(spatial_term2, 0)
spatial_inlier = 1 - spatial_outlier / (spatial_outlier + nx.einsum("ij->j", exp_SpatialMat))
term1 = nx.einsum(
"ij,i->ij",
_mul(nx)(nx.exp(-SpatialDistMat / (2 * sigma2)), nx.exp(-GeneDistMat / (2 * beta2))),
(_mul(nx)(alpha, nx.exp(-Sigma / sigma2))),
)
P = term1 / (_unsqueeze(nx)(nx.einsum("ij->j", term1), 0) + 1e-8)
P = nx.einsum("j,ij->ij", spatial_inlier, P)
term1 = nx.einsum(
"ij,i->ij",
nx.exp(-SpatialDistMat / (2 * sigma2)),
(_mul(nx)(alpha, nx.exp(-Sigma / sigma2))),
)
sigma2_P = term1 / (_unsqueeze(nx)(nx.einsum("ij->j", term1), 0) + 1e-8)
sigma2_P = nx.einsum("j,ij->ij", spatial_inlier, sigma2_P)
return P, spatial_P, sigma2_P
[docs]def get_P_chunk(
XnAHat: Union[np.ndarray, torch.Tensor],
XnB: Union[np.ndarray, torch.Tensor],
X_A: Union[np.ndarray, torch.Tensor],
X_B: Union[np.ndarray, torch.Tensor],
sigma2: Union[int, float, np.ndarray, torch.Tensor],
beta2: Union[int, float, np.ndarray, torch.Tensor],
alpha: Union[np.ndarray, torch.Tensor],
gamma: Union[float, np.ndarray, torch.Tensor],
Sigma: Union[np.ndarray, torch.Tensor],
samples_s: Optional[List[float]] = None,
outlier_variance: float = None,
chunk_size: int = 1000,
dissimilarity: str = "kl",
) -> Union[np.ndarray, torch.Tensor]:
"""Calculating the generating probability matrix P.
Args:
XAHat: Current spatial coordinate of sample A. Shape
"""
# Get the number of cells in each sample
NA, NB = XnAHat.shape[0], XnB.shape[0]
# Get the number of genes
G = X_A.shape[1]
# Get the number of spatial dimensions
D = XnAHat.shape[1]
chunk_num = int(np.ceil(NA / chunk_size))
assert XnAHat.shape[1] == XnB.shape[1], "XnAHat and XnB do not have the same number of features."
assert XnAHat.shape[0] == alpha.shape[0], "XnAHat and alpha do not have the same length."
assert XnAHat.shape[0] == Sigma.shape[0], "XnAHat and Sigma do not have the same length."
nx = ot.backend.get_backend(XnAHat, XnB)
if samples_s is None:
samples_s = nx.maximum(
_prod(nx)(nx.max(XnAHat, axis=0) - nx.min(XnAHat, axis=0)),
_prod(nx)(nx.max(XnB, axis=0) - nx.min(XnB, axis=0)),
)
outlier_s = samples_s * NA
# chunk
X_Bs = _chunk(nx, X_B, chunk_num, dim=0)
XnBs = _chunk(nx, XnB, chunk_num, dim=0)
Ps = []
for x_Bs, xnBs in zip(X_Bs, XnBs):
SpatialDistMat = cal_dist(XnAHat, xnBs)
GeneDistMat = calc_exp_dissimilarity(X_A=X_A, X_B=x_Bs, dissimilarity=dissimilarity)
if outlier_variance is None:
exp_SpatialMat = nx.exp(-SpatialDistMat / (2 * sigma2))
else:
exp_SpatialMat = nx.exp(-SpatialDistMat / (2 * sigma2 / outlier_variance))
spatial_term1 = nx.einsum(
"ij,i->ij",
exp_SpatialMat,
(_mul(nx)(alpha, nx.exp(-Sigma / sigma2))),
)
spatial_outlier = (
_power(nx)((2 * _pi(nx) * sigma2), _data(nx, D / 2, XnAHat)) * (1 - gamma) / (gamma * outlier_s)
)
spatial_inlier = 1 - spatial_outlier / (spatial_outlier + nx.einsum("ij->j", exp_SpatialMat))
term1 = nx.einsum(
"ij,i->ij",
_mul(nx)(nx.exp(-SpatialDistMat / (2 * sigma2)), nx.exp(-GeneDistMat / (2 * beta2))),
(_mul(nx)(alpha, nx.exp(-Sigma / sigma2))),
)
P = term1 / (_unsqueeze(nx)(nx.einsum("ij->j", term1), 0) + 1e-8)
P = nx.einsum("j,ij->ij", spatial_inlier, P)
Ps.append(P)
P = nx.concatenate(Ps, axis=1)
return P
[docs]def BA_align(
sampleA: AnnData,
sampleB: AnnData,
genes: Optional[Union[List, torch.Tensor]] = None,
spatial_key: str = "spatial",
key_added: str = "align_spatial",
iter_key_added: Optional[str] = "iter_spatial",
vecfld_key_added: Optional[str] = "VecFld_morpho",
layer: str = "X",
dissimilarity: str = "kl",
keep_size: bool = False,
max_iter: int = 200,
lambdaVF: Union[int, float] = 1e2,
beta: Union[int, float] = 0.01,
K: Union[int, float] = 15,
normalize_c: bool = True,
normalize_g: bool = True,
select_high_exp_genes: Union[bool, float, int] = False,
dtype: str = "float32",
device: str = "cpu",
inplace: bool = True,
verbose: bool = True,
nn_init: bool = True,
SVI_mode: bool = True,
batch_size: int = 1000,
partial_robust_level: float = 25,
) -> Tuple[Optional[Tuple[AnnData, AnnData]], np.ndarray, np.ndarray]:
"""_summary_
Args:
sampleA: Sample A that acts as reference.
sampleB: Sample B that performs alignment.
genes: Genes used for calculation. If None, use all common genes for calculation.
spatial_key: The key in ``.obsm`` that corresponds to the raw spatial coordinate.
key_added: ``.obsm`` key under which to add the aligned spatial coordinate.
iter_key_added: ``.uns`` key under which to add the result of each iteration of the iterative process. If ``iter_key_added`` is None, the results are not saved.
vecfld_key_added: The key that will be used for the vector field key in ``.uns``. If ``vecfld_key_added`` is None, the results are not saved.
layer: If ``'X'``, uses ``.X`` to calculate dissimilarity between spots, otherwise uses the representation given by ``.layers[layer]``.
dissimilarity: Expression dissimilarity measure: ``'kl'`` or ``'euclidean'``.
small_variance: When approximating the assignment matrix, if True, we use small sigma2 (0.001) rather than the infered sigma2
max_iter: Max number of iterations for morpho alignment.
lambdaVF : Hyperparameter that controls the non-rigid distortion degree. Smaller means more flexibility.
beta: The length-scale of the SE kernel. Higher means more flexibility.
K: The number of sparse inducing points used for Nystr ̈om approximation. Smaller means faster but less accurate.
normalize_c: Whether to normalize spatial coordinates.
normalize_g: Whether to normalize gene expression. If ``dissimilarity`` == ``'kl'``, ``normalize_g`` must be False.
select_high_exp_genes: Whether to select genes with high differences in gene expression.
samples_s: The space size of each sample. Area size for 2D samples and volume size for 3D samples.
dtype: The floating-point number type. Only ``float32`` and ``float64``.
device: Equipment used to run the program. You can also set the specified GPU for running. ``E.g.: '0'``.
inplace: Whether to copy adata or modify it inplace.
verbose: If ``True``, print progress updates.
nn_init: If ``True``, use nearest neighbor matching to initialize the alignment.
SVI_mode: Whether to use stochastic variational inferential (SVI) optimization strategy.
batch_size: The size of the mini-batch of SVI. If set smaller, the calculation will be faster, but it will affect the accuracy, and vice versa. If not set, it is automatically set to one-tenth of the data size.
partial_robust_level: The robust level of partial alignment. The larger the value, the more robust the alignment to partial cases is. Recommended setting from 1 to 50.
"""
empty_cache(device=device)
# Preprocessing
normalize_g = False if dissimilarity == "kl" else normalize_g
sampleA, sampleB = (sampleA, sampleB) if inplace else (sampleA.copy(), sampleB.copy())
(
nx,
type_as,
new_samples,
exp_matrices,
spatial_coords,
normalize_scale_list,
normalize_mean_list,
) = align_preprocess(
samples=[sampleA, sampleB],
layer=layer,
genes=genes,
spatial_key=spatial_key,
normalize_c=normalize_c,
normalize_g=normalize_g,
select_high_exp_genes=select_high_exp_genes,
dtype=dtype,
device=device,
verbose=verbose,
)
coordsA, coordsB = spatial_coords[1], spatial_coords[0]
X_A, X_B = exp_matrices[1], exp_matrices[0]
del spatial_coords, exp_matrices
NA, NB, D, G = coordsA.shape[0], coordsB.shape[0], coordsA.shape[1], X_A.shape[1]
sub_sample = False
sub_sample_num = 15000
if SVI_mode and (NA > sub_sample_num or NB > sub_sample_num):
if NA > sub_sample_num:
sub_idx_A = np.random.choice(NA, sub_sample_num, replace=False)
sub_coordsA = coordsA[sub_idx_A, :]
sub_X_A = X_A[sub_idx_A, :]
else:
sub_coordsA = coordsA
sub_X_A = X_A
if NB > sub_sample_num:
sub_idx_B = np.random.choice(NB, sub_sample_num, replace=False)
sub_coordsB = coordsB[sub_idx_B, :]
sub_X_B = X_B[sub_idx_B, :]
else:
sub_coordsB = coordsB
sub_X_B = X_B
GeneDistMat = calc_exp_dissimilarity(X_A=sub_X_A, X_B=sub_X_B, dissimilarity=dissimilarity)
sub_sample = True
else:
GeneDistMat = calc_exp_dissimilarity(X_A=X_A, X_B=X_B, dissimilarity=dissimilarity)
area = _prod(nx)(nx.max(coordsA, axis=0) - nx.min(coordsA, axis=0))
if nn_init:
# perform coarse rigid alignment
if sub_sample:
_cra_kwargs = dict(
coordsA=sub_coordsA,
coordsB=sub_coordsB,
X_A=sub_X_A,
X_B=sub_X_B,
transformed_points=coordsA,
)
else:
_cra_kwargs = dict(
coordsA=coordsA,
coordsB=coordsB,
X_A=X_A,
X_B=X_B,
transformed_points=None,
)
coordsA, inlier_A, inlier_B, inlier_P, init_R, init_t = coarse_rigid_alignment(
dissimilarity=dissimilarity, top_K=10, verbose=verbose, **_cra_kwargs
)
empty_cache(device=device)
coordsA = _data(nx, coordsA, type_as)
inlier_A = _data(nx, inlier_A, type_as)
inlier_B = _data(nx, inlier_B, type_as)
inlier_P = _data(nx, inlier_P, type_as)
coarse_alignment = coordsA
# Random select control points
Unique_coordsA = _unique(nx, coordsA, 0)
idx = random.sample(range(Unique_coordsA.shape[0]), min(K, Unique_coordsA.shape[0]))
ctrl_pts = Unique_coordsA[idx, :]
K = ctrl_pts.shape[0]
# construct the kernel
GammaSparse = con_K(ctrl_pts, ctrl_pts, beta)
U = con_K(coordsA, ctrl_pts, beta)
kappa = nx.ones((NA), type_as=type_as)
alpha = nx.ones((NA), type_as=type_as)
VnA = nx.zeros(coordsA.shape, type_as=type_as)
Coff = nx.zeros(ctrl_pts.shape, type_as=type_as)
gamma, gamma_a, gamma_b = (
_data(nx, 0.5, type_as),
_data(nx, 1.0, type_as),
_data(nx, 1.0, type_as),
)
minP, sigma2_terc, erc = (
_data(nx, 1e-5, type_as),
_data(nx, 1, type_as),
_data(nx, 1e-4, type_as),
)
SigmaDiag = nx.zeros((NA), type_as=type_as)
XAHat, RnA = coordsA, coordsA
if sub_sample:
SpatialDistMat = cal_dist(sub_coordsA, sub_coordsB)
del sub_coordsA, sub_coordsB
else:
SpatialDistMat = cal_dist(XAHat, coordsB)
sigma2 = 0.1 * nx.sum(SpatialDistMat) / (D * NA * NB) # 2 for 3D
s = _data(nx, 1, type_as)
R = _identity(nx, D, type_as)
minGeneDistMat = nx.min(GeneDistMat, 1)
# Automatically determine the value of beta2
beta2 = minGeneDistMat[nx.argsort(minGeneDistMat)[int(GeneDistMat.shape[0] * 0.05)]] / 5
beta2_end = nx.max(minGeneDistMat) / 5
del minGeneDistMat
if sub_sample:
del sub_X_A, sub_X_B, GeneDistMat
# The value of beta2 becomes progressively larger
beta2 = nx.maximum(beta2, _data(nx, 1e-2, type_as))
beta2_decrease = _power(nx)(beta2_end / beta2, 1 / (50))
# Use smaller spatial variance to reduce tails
outlier_variance = 1
max_outlier_variance = partial_robust_level # 20
outlier_variance_decrease = _power(nx)(_data(nx, max_outlier_variance, type_as), 1 / (max_iter / 2))
if SVI_mode:
SVI_deacy = _data(nx, 10.0, type_as)
# Select a random subset of data
batch_size = min(max(int(NB / 10), batch_size), NB)
randomidx = _randperm(nx)(NB)
randIdx = randomidx[:batch_size]
randomIdx = _roll(nx)(randomidx, batch_size)
randcoordsB = coordsB[randIdx, :] # batch_size x D
if sub_sample:
randGeneDistMat = calc_exp_dissimilarity(X_A=X_A, X_B=X_B[randIdx, :], dissimilarity=dissimilarity)
SpatialDistMat = cal_dist(coordsA, randcoordsB)
else:
randGeneDistMat = GeneDistMat[:, randIdx] # NA x batch_size
SpatialDistMat = SpatialDistMat[:, randIdx] # NA x batch_size
Sp, Sp_spatial, Sp_sigma2 = 0, 0, 0
SigmaInv = nx.zeros((K, K), type_as=type_as) # K x K
PXB_term = nx.zeros((NA, D), type_as=type_as) # NA x D
iteration = (
lm.progress_logger(range(max_iter), progress_name="Start morpho alignment") if verbose else range(max_iter)
)
if iter_key_added is not None:
sampleB.uns[iter_key_added] = dict()
sampleB.uns[iter_key_added][key_added] = {}
sampleB.uns[iter_key_added]["sigma2"] = {}
sampleB.uns[iter_key_added]["beta2"] = {}
sampleB.uns[iter_key_added]["scale"] = {}
for iter in iteration:
if iter_key_added is not None:
iter_XAHat = XAHat * normalize_scale_list[0] + normalize_mean_list[0] if normalize_c else XAHat
sampleB.uns[iter_key_added][key_added][iter] = nx.to_numpy(iter_XAHat)
sampleB.uns[iter_key_added]["sigma2"][iter] = nx.to_numpy(sigma2)
sampleB.uns[iter_key_added]["beta2"][iter] = nx.to_numpy(beta2)
sampleB.uns[iter_key_added]["scale"][iter] = nx.to_numpy(s)
if SVI_mode:
step_size = nx.minimum(_data(nx, 1.0, type_as), SVI_deacy / (iter + 1.0))
P, spatial_P, sigma2_P = get_P(
XnAHat=XAHat,
XnB=randcoordsB,
sigma2=sigma2,
beta2=beta2,
alpha=alpha,
gamma=gamma,
Sigma=SigmaDiag,
GeneDistMat=randGeneDistMat,
SpatialDistMat=SpatialDistMat,
outlier_variance=outlier_variance,
)
else:
P, spatial_P, sigma2_P = get_P(
XnAHat=XAHat,
XnB=coordsB,
sigma2=sigma2,
beta2=beta2,
alpha=alpha,
gamma=gamma,
Sigma=SigmaDiag,
GeneDistMat=GeneDistMat,
SpatialDistMat=SpatialDistMat,
outlier_variance=outlier_variance,
)
if iter > 5:
beta2 = (
nx.maximum(beta2 * beta2_decrease, beta2_end)
if beta2_decrease < 1
else nx.minimum(beta2 * beta2_decrease, beta2_end)
)
outlier_variance = nx.minimum(outlier_variance * outlier_variance_decrease, max_outlier_variance)
K_NA = nx.einsum("ij->i", P)
K_NB = nx.einsum("ij->j", P)
K_NA_spatial = nx.einsum("ij->i", spatial_P)
K_NB_spatial = nx.einsum("ij->j", spatial_P)
K_NA_sigma2 = nx.einsum("ij->i", sigma2_P)
K_NB_sigma2 = nx.einsum("ij->j", sigma2_P)
# Update gamma
if SVI_mode:
Sp = step_size * nx.einsum("ij->", P) + (1 - step_size) * Sp
Sp_spatial = step_size * nx.einsum("ij->", spatial_P) + (1 - step_size) * Sp_spatial
Sp_sigma2 = step_size * nx.einsum("ij->", sigma2_P) + (1 - step_size) * Sp_sigma2
gamma = nx.exp(_psi(nx)(gamma_a + Sp_spatial) - _psi(nx)(gamma_a + gamma_b + batch_size))
else:
Sp = nx.einsum("ij->", P)
Sp_spatial = nx.einsum("ij->", spatial_P)
Sp_sigma2 = nx.einsum("ij->", sigma2_P)
gamma = nx.exp(_psi(nx)(gamma_a + Sp_spatial) - _psi(nx)(gamma_a + gamma_b + NB))
gamma = _data(nx, 0.99, type_as) if gamma > 0.99 else gamma
gamma = _data(nx, 0.01, type_as) if gamma < 0.01 else gamma
# Update alpha
alpha = nx.exp(_psi(nx)(kappa + K_NA_spatial) - _psi(nx)(kappa * NA + Sp_spatial))
# Update VnA
if (sigma2 < 0.015 and s > 0.95) or (iter > 80):
if SVI_mode:
SigmaInv = (
step_size * (sigma2 * lambdaVF * GammaSparse + _dot(nx)(U.T, nx.einsum("ij,i->ij", U, K_NA)))
+ (1 - step_size) * SigmaInv
)
term1 = _dot(nx)(_pinv(nx)(SigmaInv), U.T)
PXB_term = (
step_size * (_dot(nx)(P, randcoordsB) - nx.einsum("ij,i->ij", RnA, K_NA))
+ (1 - step_size) * PXB_term
)
Coff = _dot(nx)(term1, PXB_term)
VnA = _dot(nx)(
U,
Coff,
)
SigmaDiag = sigma2 * nx.einsum("ij->i", nx.einsum("ij,ji->ij", U, term1))
else:
term1 = _dot(nx)(
_pinv(nx)(sigma2 * lambdaVF * GammaSparse + _dot(nx)(U.T, nx.einsum("ij,i->ij", U, K_NA))),
U.T,
)
SigmaDiag = sigma2 * nx.einsum("ij->i", nx.einsum("ij,ji->ij", U, term1))
Coff = _dot(nx)(term1, (_dot(nx)(P, coordsB) - nx.einsum("ij,i->ij", RnA, K_NA)))
VnA = _dot(nx)(
U,
Coff,
)
# Update R()
lambdaReg = 1e0 * Sp / nx.sum(inlier_P)
if SVI_mode:
PXA, PVA, PXB = (
_dot(nx)(K_NA, coordsA)[None, :],
_dot(nx)(K_NA, VnA)[None, :],
_dot(nx)(K_NB, randcoordsB)[None, :],
)
else:
PXA, PVA, PXB = (
_dot(nx)(K_NA, coordsA)[None, :],
_dot(nx)(K_NA, VnA)[None, :],
_dot(nx)(K_NB, coordsB)[None, :],
)
PCYC, PCXC = _dot(nx)(inlier_P.T, inlier_B), _dot(nx)(inlier_P.T, inlier_A)
if SVI_mode and iter > 1:
t = (
step_size
* (
((PXB - PVA - _dot(nx)(PXA, R.T)) + 2 * lambdaReg * sigma2 * (PCYC - _dot(nx)(PCXC, R.T)))
/ (Sp + 2 * lambdaReg * sigma2 * nx.sum(inlier_P))
)
+ (1 - step_size) * t
)
else:
t = ((PXB - PVA - _dot(nx)(PXA, R.T)) + 2 * lambdaReg * sigma2 * (PCYC - _dot(nx)(PCXC, R.T))) / (
Sp + 2 * lambdaReg * sigma2 * nx.sum(inlier_P)
)
if SVI_mode:
A = -(
_dot(nx)(PXA.T, t)
+ _dot(nx)(coordsA.T, nx.einsum("ij,i->ij", VnA, K_NA) - _dot(nx)(P, randcoordsB))
+ 2
* lambdaReg
* sigma2
* (_dot(nx)(PCXC.T, t) - _dot(nx)(nx.einsum("ij,i->ij", inlier_A, inlier_P[:, 0]).T, inlier_B))
).T
else:
A = -(
_dot(nx)(PXA.T, t)
+ _dot(nx)(coordsA.T, nx.einsum("ij,i->ij", VnA, K_NA) - _dot(nx)(P, coordsB))
+ 2
* lambdaReg
* sigma2
* (_dot(nx)(PCXC.T, t) - _dot(nx)(nx.einsum("ij,i->ij", inlier_A, inlier_P[:, 0]).T, inlier_B))
).T
svdU, svdS, svdV = _linalg(nx).svd(A)
C = _identity(nx, D, type_as)
C[-1, -1] = _linalg(nx).det(_dot(nx)(svdU, svdV))
if SVI_mode and iter > 1:
R = step_size * (_dot(nx)(_dot(nx)(svdU, C), svdV)) + (1 - step_size) * R
else:
R = _dot(nx)(_dot(nx)(svdU, C), svdV)
RnA = s * _dot(nx)(coordsA, R.T) + t
XAHat = RnA + VnA
# Update sigma2 and beta2
if SVI_mode:
SpatialDistMat = cal_dist(XAHat, randcoordsB)
else:
SpatialDistMat = cal_dist(XAHat, coordsB)
sigma2_old = sigma2
sigma2 = nx.maximum(
(
nx.einsum("ij,ij", sigma2_P, SpatialDistMat) / (D * Sp_sigma2)
+ nx.einsum("i,i", K_NA_sigma2, SigmaDiag) / Sp_sigma2
),
_data(nx, 1e-3, type_as),
)
sigma2_terc = nx.abs((sigma2 - sigma2_old) / sigma2)
# Next batch
if SVI_mode and iter < max_iter - 1:
randIdx = randomidx[:batch_size]
randomidx = _roll(nx)(randomidx, batch_size)
randcoordsB = coordsB[randIdx, :]
if sub_sample:
randGeneDistMat = calc_exp_dissimilarity(X_A=X_A, X_B=X_B[randIdx, :], dissimilarity=dissimilarity)
else:
randGeneDistMat = GeneDistMat[:, randIdx] # NA x batch_size
SpatialDistMat = cal_dist(XAHat, randcoordsB)
empty_cache(device=device)
# full data
if SVI_mode:
P = get_P_chunk(
XnAHat=XAHat,
XnB=coordsB,
X_A=X_A,
X_B=X_B,
sigma2=sigma2,
beta2=beta2,
alpha=alpha,
gamma=gamma,
Sigma=SigmaDiag,
outlier_variance=outlier_variance,
)
# Get optimal Rigid transformation
optimal_RnA, optimal_R, optimal_t = get_optimal_R(
coordsA=coordsA,
coordsB=coordsB,
P=P,
R_init=R,
)
if verbose:
lm.main_info(f"Key Parameters: gamma: {gamma}; beta2: {beta2}; sigma2: {sigma2}")
if keep_size:
area_after = _prod(nx)(nx.max(XAHat, axis=0) - nx.min(XAHat, axis=0))
XAHat = XAHat * (area / area_after)
if normalize_c:
XAHat = XAHat * normalize_scale_list[0] + normalize_mean_list[0]
RnA = RnA * normalize_scale_list[0] + normalize_mean_list[0]
optimal_RnA = optimal_RnA * normalize_scale_list[0] + normalize_mean_list[0]
coarse_alignment = coarse_alignment * normalize_scale_list[0] + normalize_mean_list[0]
# Save aligned coordinates
sampleB.obsm["Nonrigid_align_spatial"] = nx.to_numpy(XAHat).copy()
sampleB.obsm["Rigid_align_spatial"] = nx.to_numpy(optimal_RnA).copy()
# save vector field and other parameters
if not (vecfld_key_added is None):
sampleB.uns[vecfld_key_added] = {
"s": nx.to_numpy(s),
"R": nx.to_numpy(R),
"t": nx.to_numpy(t),
"optimal_R": nx.to_numpy(optimal_R),
"optimal_t": nx.to_numpy(optimal_t),
"init_R": init_R,
"init_t": init_t,
"beta": beta,
"Coff": nx.to_numpy(Coff),
"ctrl_pts": nx.to_numpy(ctrl_pts),
"normalize_scale": nx.to_numpy(normalize_scale_list[0]) if normalize_c else None,
"normalize_mean_list": [nx.to_numpy(normalize_mean) for normalize_mean in normalize_mean_list]
if normalize_c
else None,
"normalize_c": normalize_c,
"dissimilarity": dissimilarity,
"beta2": nx.to_numpy(sigma2),
"sigma2": nx.to_numpy(sigma2),
"gamma": nx.to_numpy(gamma),
"NA": NA,
"outlier_variance": nx.to_numpy(outlier_variance),
}
empty_cache(device=device)
return (
None if inplace else (sampleA, sampleB),
nx.to_numpy(P.T),
nx.to_numpy(sigma2),
)