import random
import numpy as np
import ot
import torch
from anndata import AnnData
try:
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
except ImportError:
from typing_extensions import Literal
from typing import List, Optional, Tuple, Union
import pandas as pd
from spateo.logging import logger_manager as lm
from .morpho_sparse_utils import (
_init_guess_beta2,
_init_guess_sigma2,
calc_distance,
calc_P_related,
get_optimal_R_sparse,
)
from .utils import (
_data,
_dot,
_identity,
_linalg,
_mul,
_pi,
_pinv,
_power,
_prod,
_psi,
_randperm,
_roll,
_unique,
align_preprocess,
cal_dist,
coarse_rigid_alignment,
empty_cache,
guidance_pair_preprocess,
)
# construct kernel
[docs]def con_K(
X: Union[np.ndarray, torch.Tensor],
Y: Union[np.ndarray, torch.Tensor],
beta: Union[int, float] = 0.01,
use_chunk: bool = False,
) -> Union[np.ndarray, torch.Tensor]:
"""con_K constructs the Squared Exponential (SE) kernel, where K(i,j)=k(X_i,Y_j)=exp(-beta*||X_i-Y_j||^2).
Args:
X: The first vector X\in\mathbb{R}^{N\times d}
Y: The second vector X\in\mathbb{R}^{M\times d}
beta: The length-scale of the SE kernel.
use_chunk (bool, optional): Whether to use chunk to reduce the GPU memory usage. Note that if set to ``True'' it will slow down the calculation. Defaults to False.
Returns:
K: The kernel K\in\mathbb{R}^{N\times M}
"""
assert X.shape[1] == Y.shape[1], "X and Y do not have the same number of features."
nx = ot.backend.get_backend(X, Y)
K = cal_dist(X, Y)
K = nx.exp(-beta * K)
return K
# get the assignment matrix P
[docs]def get_P_sparse(
XnAHat: Union[np.ndarray, torch.Tensor],
XnB: Union[np.ndarray, torch.Tensor],
X_A: Union[np.ndarray, torch.Tensor],
X_B: Union[np.ndarray, torch.Tensor],
sigma2: Union[int, float, np.ndarray, torch.Tensor],
beta2: Union[int, float, np.ndarray, torch.Tensor],
alpha: Union[np.ndarray, torch.Tensor],
gamma: Union[float, np.ndarray, torch.Tensor],
Sigma: Union[np.ndarray, torch.Tensor],
samples_s: Optional[List[float]] = None,
outlier_variance: float = None,
label_mask: Optional[np.ndarray] = None,
batch_capacity: int = 1,
labelA: Optional[pd.Series] = None,
labelB: Optional[pd.Series] = None,
label_transfer_prior: Optional[dict] = None,
top_k: int = 1024,
dissimilarity: str = "kl",
):
assert XnAHat.shape[1] == XnB.shape[1], "XnAHat and XnB do not have the same number of features."
assert XnAHat.shape[0] == alpha.shape[0], "XnAHat and alpha do not have the same length."
assert XnAHat.shape[0] == Sigma.shape[0], "XnAHat and Sigma do not have the same length."
nx = ot.backend.get_backend(XnAHat, XnB)
NA, NB, D = XnAHat.shape[0], XnB.shape[0], XnAHat.shape[1]
if samples_s is None:
samples_s = nx.maximum(
_prod(nx)(nx.max(XnAHat, axis=0) - nx.min(XnAHat, axis=0)),
_prod(nx)(nx.max(XnB, axis=0) - nx.min(XnB, axis=0)),
)
outlier_s = samples_s * nx.sum(label_mask, axis=0) if label_mask is not None else samples_s * NA
spatial_outlier = _power(nx)((2 * _pi(nx) * sigma2), _data(nx, D / 2, XnAHat)) * (1 - gamma) / (gamma * outlier_s)
K_NA_spatial, K_NA_sigma2, P, sigma2_temp = calc_P_related(
XnAHat=XnAHat,
XnB=XnB,
X_A=X_A,
X_B=X_B,
labelA=labelA,
labelB=labelB,
label_transfer_prior=label_transfer_prior,
sigma2=sigma2,
sigma2_robust=sigma2 / outlier_variance,
beta2=beta2,
spatial_outlier=spatial_outlier,
col_mul=(_mul(nx)(alpha, nx.exp(-Sigma / sigma2))),
batch_capacity=batch_capacity,
top_k=top_k,
dissimilarity=dissimilarity,
)
K_NA = P.sum(1).to_dense()
K_NB = P.sum(0).to_dense()
Sp = P.sum()
Sp_spatial = K_NA_spatial.sum()
Sp_sigma2 = K_NA_sigma2.sum()
assignment_results = {
"K_NA": K_NA,
"K_NB": K_NB,
"K_NA_spatial": K_NA_spatial,
"K_NA_sigma2": K_NA_sigma2,
"Sp": Sp,
"Sp_spatial": Sp_spatial,
"Sp_sigma2": Sp_sigma2,
"sigma2_temp": sigma2_temp,
}
return P, assignment_results
# morpho pairwise alignment
# TO-DO: Calculate the gene dist mat and save it in the cpu. When using the mat, we can use cuda to convert it to gpu
[docs]def BA_align_sparse(
sampleA: AnnData,
sampleB: AnnData,
genes: Optional[Union[List, torch.Tensor]] = None,
spatial_key: str = "spatial",
key_added: str = "align_spatial",
iter_key_added: Optional[str] = None,
vecfld_key_added: Optional[str] = "VecFld_morpho",
layer: str = "X",
use_rep: Optional[str] = None,
dissimilarity: str = "kl",
max_iter: int = 200,
lambdaVF: Union[int, float] = 1e2,
beta: Union[int, float] = 0.01,
K: Union[int, float] = 15,
beta2: Optional[Union[int, float]] = None,
beta2_end: Optional[Union[int, float]] = None,
sigma2_init: float = 0.1,
normalize_c: bool = True,
normalize_g: bool = True,
dtype: str = "float32",
device: str = "cpu",
inplace: bool = True,
verbose: bool = True,
nn_init: bool = False,
partial_robust_level: float = 25,
use_label_prior: bool = False,
label_key: Optional[str] = "cluster",
label_transfer_prior: Optional[dict] = None,
SVI_mode: bool = True,
batch_size: int = 1024,
use_sparse: bool = True,
pre_compute_dist: bool = False,
batch_capacity: int = 1,
guidance_pair: Optional[list] = None,
guidance_effect: Optional[Union[bool, str]] = False,
guidance_epsilon: float = 1,
) -> Tuple[Optional[Tuple[AnnData, AnnData]], np.ndarray, np.ndarray]:
empty_cache(device=device)
# Preprocessing and extract the spatial and expression information
normalize_g = False if dissimilarity == "kl" else normalize_g
sampleA, sampleB = (sampleA, sampleB) if inplace else (sampleA.copy(), sampleB.copy())
(
nx,
type_as,
new_samples,
exp_matrices,
spatial_coords,
normalize_scale_list,
normalize_mean_list,
) = align_preprocess(
samples=[sampleA, sampleB],
layer=layer,
genes=genes,
spatial_key=spatial_key,
normalize_c=normalize_c,
normalize_g=normalize_g,
dtype=dtype,
device=device,
verbose=verbose,
use_rep=use_rep,
)
# normalize guidance pair and convert to correct data types
if guidance_effect is not False:
guidance_pair = guidance_pair_preprocess(guidance_pair, normalize_scale_list, normalize_mean_list, nx, type_as)
X_AI = guidance_pair[0]
X_BI = guidance_pair[1]
coordsA, coordsB = spatial_coords[1], spatial_coords[0]
X_A, X_B = exp_matrices[1], exp_matrices[0]
del spatial_coords, exp_matrices
NA, NB, D, G = coordsA.shape[0], coordsB.shape[0], coordsA.shape[1], X_A.shape[1]
# generate label mask by label consistency prior
if use_label_prior:
# check the label key
if label_key not in sampleA.obs.keys():
raise ValueError(f"adataA does not have label key {label_key}.")
if label_key not in sampleB.obs.keys():
raise ValueError(f"adataB does not have label key {label_key}.")
labelA = pd.Series(sampleB.obs[label_key].values)
labelB = pd.Series(sampleA.obs[label_key].values)
else:
labelA, labelB = None, None
# perform coarse rigid alignment
if nn_init:
_cra_kwargs = dict(
coordsA=coordsA,
coordsB=coordsB,
X_A=X_A,
X_B=X_B,
transformed_points=None,
)
coordsA, inlier_A, inlier_B, inlier_P, init_R, init_t = coarse_rigid_alignment(
dissimilarity=dissimilarity, top_K=10, verbose=verbose, **_cra_kwargs
)
empty_cache(device=device)
init_R = _data(nx, init_R, type_as)
init_t = _data(nx, init_t, type_as)
coordsA = _data(nx, coordsA, type_as)
inlier_A = _data(nx, inlier_A, type_as)
inlier_B = _data(nx, inlier_B, type_as)
inlier_P = _data(nx, inlier_P, type_as)
inlier_P = inlier_P * batch_size / nx.sum(inlier_P)
inlier_R = inlier_A
inlier_V = nx.zeros(inlier_A.shape, type_as=type_as)
# TO-DO: integrate into one function
else:
init_R = nx.eye(D, type_as=type_as)
init_t = nx.zeros((D,), type_as=type_as)
inlier_A = []
inlier_B = []
inlier_P = []
if (guidance_effect is not False) and (guidance_pair is not None):
X_AI = X_AI @ init_R.T + init_t
if len(inlier_A) == 0:
inlier_A = X_AI
inlier_B = X_BI
inlier_P = nx.ones((X_AI.shape[0], 1), type_as=type_as)
else:
inlier_A = nx.concatenate([inlier_A, X_AI], axis=0)
inlier_B = nx.concatenate([inlier_B, X_BI], axis=0)
inlier_P = nx.concatenate([inlier_P, nx.ones((X_AI.shape[0], 1), type_as=type_as)], axis=0)
inlier_R = inlier_A
inlier_V = nx.zeros(inlier_A.shape, type_as=type_as)
inlier_AHat = inlier_A
coarse_alignment = coordsA
# construct kernel for inducing variables
Unique_coordsA = _unique(nx, coordsA, 0)
idx = random.sample(range(Unique_coordsA.shape[0]), min(K, Unique_coordsA.shape[0]))
ctrl_pts = Unique_coordsA[idx, :]
K = ctrl_pts.shape[0]
GammaSparse = con_K(ctrl_pts, ctrl_pts, beta)
U = con_K(coordsA, ctrl_pts, beta)
if guidance_effect == "nonrigid":
inlier_U = con_K(inlier_A, ctrl_pts, beta)
kernel_dict = {
"dist": "cdist",
"X": nx.to_numpy(coordsA),
"idx": idx,
"U": nx.to_numpy(U),
"GammaSparse": nx.to_numpy(GammaSparse),
"ctrl_pts": nx.to_numpy(ctrl_pts),
}
# Initialize optimization parameters
kappa = nx.ones((NA), type_as=type_as)
alpha = nx.ones((NA), type_as=type_as)
VnA = nx.zeros(coordsA.shape, type_as=type_as)
Coff = nx.zeros(ctrl_pts.shape, type_as=type_as)
gamma, gamma_a, gamma_b = (
_data(nx, 0.5, type_as),
_data(nx, 1.0, type_as),
_data(nx, 1.0, type_as),
)
minP, sigma2_terc, erc = (
_data(nx, 1e-5, type_as),
_data(nx, 1, type_as),
_data(nx, 1e-4, type_as),
)
SigmaDiag = nx.zeros((NA), type_as=type_as)
XAHat, RnA = coordsA, coordsA
s = _data(nx, 1, type_as)
R = _identity(nx, D, type_as)
# calculate the initial values of sigma2 and beta2
sigma2 = sigma2_init * _init_guess_sigma2(XAHat, coordsB)
beta2, beta2_end = _init_guess_beta2(nx, X_A, X_B, dissimilarity, partial_robust_level, beta2, beta2_end)
empty_cache(device=device)
# initial the sigma2 and beta2 temperature for better performance
outlier_variance = 1
max_outlier_variance = partial_robust_level # 20
outlier_variance_decrease = _power(nx)(_data(nx, max_outlier_variance, type_as), 1 / (max_iter / 2))
beta2_decrease = _power(nx)(beta2_end / beta2, 1 / (50))
# Initial calculation of the gene and spatial similarity (distance) matrix
spatial_threshold = 6 * sigma2
# if pre_compute_dist, we compute the full similarity of the expression (NA x NB) and store it, else we will compute this in each iteration.
if pre_compute_dist:
GeneDistMat = calc_distance(
X_A=X_A,
X_B=X_B,
metric=dissimilarity,
use_sparse=use_sparse,
sparse_method="topk",
threshold=1000,
)
if SVI_mode:
SVI_deacy = _data(nx, 10.0, type_as)
# Select a random subset of data
batch_size = min(max(int(NB / 10), batch_size), NB)
randomidx = _randperm(nx)(NB)
randIdx = randomidx[:batch_size]
randomIdx = _roll(nx)(randomidx, batch_size)
randcoordsB = coordsB[randIdx, :] # batch_size x D
randX_B = X_B[randIdx, :] # batch_size x G
randlabelB = labelB.iloc[np.array(randIdx)] if labelB is not None else None
Sp, Sp_spatial, Sp_sigma2 = 0, 0, 0
SigmaInv = nx.zeros((K, K), type_as=type_as) # K x K
PXB_term = nx.zeros((NA, D), type_as=type_as) # NA x D
iteration = (
lm.progress_logger(range(max_iter), progress_name="Start morpho alignment") if verbose else range(max_iter)
)
# intermediate results
if iter_key_added is not None:
sampleB.uns[iter_key_added] = dict()
sampleB.uns[iter_key_added][key_added] = {}
sampleB.uns[iter_key_added]["sigma2"] = {}
sampleB.uns[iter_key_added]["beta2"] = {}
sampleB.uns[iter_key_added]["scale"] = {}
# main iteration begin
for iter in iteration:
# save intermediate results
if iter_key_added is not None:
iter_XAHat = XAHat * normalize_scale_list[0] + normalize_mean_list[0] if normalize_c else XAHat
sampleB.uns[iter_key_added][key_added][iter] = nx.to_numpy(iter_XAHat)
sampleB.uns[iter_key_added]["sigma2"][iter] = nx.to_numpy(sigma2)
sampleB.uns[iter_key_added]["beta2"][iter] = nx.to_numpy(beta2)
sampleB.uns[iter_key_added]["scale"][iter] = nx.to_numpy(s)
# update the assignment matrix
if SVI_mode:
step_size = nx.minimum(_data(nx, 1.0, type_as), SVI_deacy / (iter + 1.0))
P, assignment_results = get_P_sparse(
XnAHat=XAHat,
XnB=randcoordsB,
X_A=X_A,
X_B=randX_B,
labelA=labelA,
labelB=randlabelB,
sigma2=sigma2,
beta2=beta2,
alpha=alpha,
gamma=gamma,
Sigma=SigmaDiag,
outlier_variance=outlier_variance,
label_transfer_prior=label_transfer_prior,
batch_capacity=batch_capacity,
dissimilarity=dissimilarity,
)
else:
P, assignment_results = get_P_sparse(
XnAHat=XAHat,
XnB=coordsB,
X_A=X_A,
X_B=randX_B,
labelA=labelA,
labelB=labelB,
sigma2=sigma2,
beta2=beta2,
alpha=alpha,
gamma=gamma,
Sigma=SigmaDiag,
outlier_variance=outlier_variance,
label_transfer_prior=label_transfer_prior,
batch_capacity=batch_capacity,
dissimilarity=dissimilarity,
)
# print(sigma2)
# update temperature
if iter > 5:
beta2 = (
nx.maximum(beta2 * beta2_decrease, beta2_end)
if beta2_decrease < 1
else nx.minimum(beta2 * beta2_decrease, beta2_end)
)
outlier_variance = nx.minimum(outlier_variance * outlier_variance_decrease, max_outlier_variance)
K_NA, K_NB = assignment_results["K_NA"], assignment_results["K_NB"]
K_NA_spatial = assignment_results["K_NA_spatial"]
K_NA_sigma2 = assignment_results["K_NA_sigma2"]
# Update gamma
if SVI_mode:
Sp = step_size * assignment_results["Sp"] + (1 - step_size) * Sp
Sp_spatial = step_size * assignment_results["Sp_spatial"] + (1 - step_size) * Sp_spatial
Sp_sigma2 = step_size * assignment_results["Sp_sigma2"] + (1 - step_size) * Sp_sigma2
gamma = nx.exp(_psi(nx)(gamma_a + Sp_spatial) - _psi(nx)(gamma_a + gamma_b + batch_size))
else:
Sp = assignment_results["Sp"]
Sp_spatial = assignment_results["Sp_spatial"]
Sp_sigma2 = assignment_results["Sp_sigma2"]
gamma = nx.exp(_psi(nx)(gamma_a + Sp_spatial) - _psi(nx)(gamma_a + gamma_b + NB))
gamma = _data(nx, 0.99, type_as) if gamma > 0.99 else gamma
gamma = _data(nx, 0.01, type_as) if gamma < 0.01 else gamma
# Update alpha
if SVI_mode:
alpha = (
step_size * nx.exp(_psi(nx)(kappa + K_NA_spatial) - _psi(nx)(kappa * NA + Sp_spatial))
+ (1 - step_size) * alpha
)
else:
alpha = nx.exp(_psi(nx)(kappa + K_NA_spatial) - _psi(nx)(kappa * NA + Sp_spatial))
# Update VnA
# if (sigma2 < 0.015) or (iter > 80):
if True:
if SVI_mode:
if (guidance_effect == "nonrigid") or (guidance_effect == "both"):
SigmaInv = (
step_size
* (
sigma2 * lambdaVF * GammaSparse
+ _dot(nx)(U.T, nx.einsum("ij,i->ij", U, K_NA))
+ (sigma2 / guidance_epsilon) * _dot(nx)(inlier_U.T, inlier_U * inlier_P)
)
+ (1 - step_size) * SigmaInv
)
else:
SigmaInv = (
step_size * (sigma2 * lambdaVF * GammaSparse + _dot(nx)(U.T, nx.einsum("ij,i->ij", U, K_NA)))
+ (1 - step_size) * SigmaInv
)
Sigma = _pinv(nx)(SigmaInv)
term1 = _dot(nx)(Sigma, U.T)
PXB_term = (
step_size * (_dot(nx)(P, randcoordsB) - nx.einsum("ij,i->ij", RnA, K_NA))
+ (1 - step_size) * PXB_term
)
if (guidance_effect == "nonrigid") or (guidance_effect == "both"):
term1_guide = _dot(nx)(Sigma, inlier_U.T)
XBRA_guide_term = (inlier_B - inlier_R) * inlier_P
Coff = _dot(nx)(term1, PXB_term) + (sigma2 / guidance_epsilon) * _dot(nx)(
term1_guide, XBRA_guide_term
)
inlier_V = _dot(nx)(inlier_U, Coff)
else:
Coff = _dot(nx)(term1, PXB_term)
VnA = _dot(nx)(
U,
Coff,
)
SigmaDiag = sigma2 * nx.einsum("ij->i", nx.einsum("ij,ji->ij", U, term1))
else:
if (guidance_effect == "nonrigid") or (guidance_effect == "both"):
SigmaInv = (
sigma2 * lambdaVF * GammaSparse
+ _dot(nx)(U.T, nx.einsum("ij,i->ij", U, K_NA))
+ (sigma2 / guidance_epsilon) * _dot(nx)(inlier_U.T, inlier_U * inlier_P)
)
else:
SigmaInv = sigma2 * lambdaVF * GammaSparse + _dot(nx)(U.T, nx.einsum("ij,i->ij", U, K_NA))
Sigma = _pinv(nx)(SigmaInv)
term1 = _dot(nx)(Sigma, U.T)
PXB_term = _dot(nx)(P, coordsB) - nx.einsum("ij,i->ij", RnA, K_NA)
if (guidance_effect == "nonrigid") or (guidance_effect == "both"):
term1_guide = _dot(nx)(Sigma, inlier_U.T)
XBRA_guide_term = (inlier_B - inlier_R) * inlier_P
Coff = _dot(nx)(term1, PXB_term) + (sigma2 / guidance_epsilon) * _dot(nx)(
term1_guide, XBRA_guide_term
)
inlier_V = _dot(nx)(inlier_U, Coff)
else:
Coff = _dot(nx)(term1, PXB_term)
VnA = _dot(nx)(U, Coff)
SigmaDiag = sigma2 * nx.einsum("ij->i", nx.einsum("ij,ji->ij", U, term1))
# Update rigid transformation R()
# Solve for the translation t
if SVI_mode:
PXA, PVA, PXB = (
_dot(nx)(K_NA, coordsA)[None, :],
_dot(nx)(K_NA, VnA)[None, :],
_dot(nx)(K_NB, randcoordsB)[None, :],
)
else:
PXA, PVA, PXB = (
_dot(nx)(K_NA, coordsA)[None, :],
_dot(nx)(K_NA, VnA)[None, :],
_dot(nx)(K_NB, coordsB)[None, :],
)
if SVI_mode and iter > 1:
if (guidance_effect == "rigid") or (guidance_effect == "both") or (nn_init == True):
t = (
step_size
* (
(
(PXB - PVA - _dot(nx)(PXA, R.T))
+ (sigma2 / guidance_epsilon)
* _dot(nx)(inlier_P.T, inlier_B - inlier_V - _dot(nx)(inlier_A, R.T))
)
/ (Sp + (sigma2 / guidance_epsilon) * nx.sum(inlier_P))
)
+ (1 - step_size) * t
)
else:
t = step_size * ((PXB - PVA - _dot(nx)(PXA, R.T)) / Sp) + (1 - step_size) * t
else:
if (guidance_effect == "rigid") or (guidance_effect == "both") or (nn_init == True):
t = (
(PXB - PVA - _dot(nx)(PXA, R.T))
+ (sigma2 / guidance_epsilon) * _dot(nx)(inlier_P.T, inlier_B - inlier_V - _dot(nx)(inlier_A, R.T))
) / (Sp + (sigma2 / guidance_epsilon) * nx.sum(inlier_P))
else:
t = (PXB - PVA - _dot(nx)(PXA, R.T)) / Sp
# Solve for the rotation
if (guidance_effect == "rigid") or (guidance_effect == "both") or (nn_init == True):
mu_XB = (PXB + (sigma2 / guidance_epsilon) * _dot(nx)(inlier_P.T, inlier_B)) / (
Sp + (sigma2 / guidance_epsilon) * nx.sum(inlier_P)
)
mu_XA = (PXA + (sigma2 / guidance_epsilon) * _dot(nx)(inlier_P.T, inlier_A)) / (
Sp + (sigma2 / guidance_epsilon) * nx.sum(inlier_P)
)
mu_Vn = (PVA + (sigma2 / guidance_epsilon) * _dot(nx)(inlier_P.T, inlier_V)) / (
Sp + (sigma2 / guidance_epsilon) * nx.sum(inlier_P)
)
XAI_hat = inlier_A - mu_XA
XBI_hat = inlier_B - mu_XB
fI_hat = inlier_V - mu_Vn
else:
mu_XB = PXB / Sp
mu_XA = PXA / Sp
mu_Vn = PVA / Sp
XA_hat = coordsA - mu_XA
f_hat = VnA - mu_Vn
# print(K_NA)
if SVI_mode:
XB_hat = randcoordsB - mu_XB
else:
XB_hat = coordsB - mu_XB
# print(inlier_P.shape)
# if SVI_mode:
if (guidance_effect == "rigid") or (guidance_effect == "both") or (nn_init == True):
A_guide = _dot(nx)((XAI_hat * inlier_P).T, (fI_hat - XBI_hat))
A = -(
_dot(nx)(XA_hat.T, nx.einsum("ij,i->ij", f_hat, K_NA))
- _dot(nx)(_dot(nx)(XA_hat.T, P), XB_hat)
+ (sigma2 / guidance_epsilon) * A_guide
).T
else:
A = -(_dot(nx)(XA_hat.T, nx.einsum("ij,i->ij", f_hat, K_NA)) - _dot(nx)(_dot(nx)(XA_hat.T, P), XB_hat)).T
# print(A)
svdU, svdS, svdV = _linalg(nx).svd(A)
C = _identity(nx, D, type_as)
C[-1, -1] = _linalg(nx).det(_dot(nx)(svdU, svdV))
if SVI_mode and iter > 1:
R = step_size * (_dot(nx)(_dot(nx)(svdU, C), svdV)) + (1 - step_size) * R
else:
R = _dot(nx)(_dot(nx)(svdU, C), svdV)
RnA = _dot(nx)(coordsA, R.T) + t
XAHat = RnA + VnA
if guidance_effect != False:
inlier_R = _dot(nx)(inlier_A, R.T) + t
inlier_AHat = inlier_R + inlier_V
# print(R)
# Update sigma2 and beta2 (optional)
sigma2_old = sigma2
sigma2 = nx.maximum(
(assignment_results["sigma2_temp"] + nx.einsum("i,i", K_NA_sigma2, SigmaDiag) / Sp_sigma2),
_data(nx, 1e-3, type_as),
)
if iter < 100:
sigma2 = nx.maximum(sigma2, _data(nx, 1e-2, type_as))
sigma2_terc = nx.abs((sigma2 - sigma2_old) / sigma2)
# SVI next batch
spatial_threshold = 6 * sigma2
# if SVI_mode and iter < max_iter - 1:
if SVI_mode:
randIdx = randomidx[:batch_size]
randomidx = _roll(nx)(randomidx, batch_size)
randcoordsB = coordsB[randIdx, :]
randX_B = X_B[randIdx, :] # batch_size x G
randlabelB = labelB.iloc[np.array(randIdx)] if labelB is not None else None
empty_cache(device=device)
# end of the iteration
# get the full data assignment
if SVI_mode:
if not pre_compute_dist:
P, assignment_results = get_P_sparse(
XnAHat=XAHat,
XnB=coordsB,
X_A=X_A,
X_B=X_B,
labelA=labelA,
labelB=labelB,
sigma2=sigma2,
beta2=beta2,
alpha=alpha,
gamma=gamma,
Sigma=SigmaDiag,
outlier_variance=outlier_variance,
label_transfer_prior=label_transfer_prior,
top_k=32,
dissimilarity=dissimilarity,
batch_capacity=batch_capacity,
)
# Get optimal Rigid transformation
optimal_RnA, optimal_R, optimal_t = get_optimal_R_sparse(
coordsA=coordsA,
coordsB=coordsB,
P=P,
R_init=R,
)
# combine the initial rigid transformation and final rigid transformation
t = _dot(nx)(init_t, R.T) + t
R = _dot(nx)(R, init_R)
optimal_t = _dot(nx)(init_t, optimal_R.T) + optimal_t
optimal_R = _dot(nx)(optimal_R, init_R)
# output optimization parameters
if verbose:
lm.main_info(f"Key Parameters: gamma: {gamma}; beta2: {beta2}; sigma2: {sigma2}")
# de-normalize
if normalize_c:
XAHat = XAHat * normalize_scale_list[0] + normalize_mean_list[0]
RnA = RnA * normalize_scale_list[0] + normalize_mean_list[0]
optimal_RnA = optimal_RnA * normalize_scale_list[0] + normalize_mean_list[0]
coarse_alignment = coarse_alignment * normalize_scale_list[0] + normalize_mean_list[0]
output_R = optimal_R
output_t = (
optimal_t * normalize_scale_list[0] + normalize_mean_list[0] - _dot(nx)(normalize_mean_list[1], optimal_R.T)
)
# Save aligned coordinates to adata
sampleB.obsm[key_added + "_nonrigid"] = nx.to_numpy(XAHat).copy()
sampleB.obsm[key_added + "_rigid"] = nx.to_numpy(optimal_RnA).copy()
# save vector field and other parameters
if not (vecfld_key_added is None):
norm_dict = {
"mean_transformed": nx.to_numpy(normalize_mean_list[1]),
"mean_fixed": nx.to_numpy(normalize_mean_list[0]),
"scale": nx.to_numpy(normalize_scale_list[1]),
"scale_transformed": nx.to_numpy(normalize_scale_list[1]),
"scale_fixed": nx.to_numpy(normalize_scale_list[0]),
}
sampleB.uns[vecfld_key_added] = {
"R": nx.to_numpy(R),
"t": nx.to_numpy(t),
"optimal_R": nx.to_numpy(optimal_R),
"optimal_t": nx.to_numpy(optimal_t),
"output_R": nx.to_numpy(output_R),
"output_t": nx.to_numpy(output_t),
"beta": beta,
"C": nx.to_numpy(Coff),
"X_ctrl": nx.to_numpy(ctrl_pts),
"norm_dict": norm_dict,
"kernel_dict": kernel_dict,
"dissimilarity": dissimilarity,
"beta2": nx.to_numpy(sigma2),
"sigma2": nx.to_numpy(sigma2),
"gamma": nx.to_numpy(gamma),
"NA": NA,
"outlier_variance": nx.to_numpy(outlier_variance),
"method": "morpho",
"pre_norm_scale": 1,
}
empty_cache(device=device)
return (
None if inplace else (sampleA, sampleB),
P.to("cpu").to_dense().numpy(),
nx.to_numpy(sigma2),
)