from typing import List, Optional, Tuple, Union
import numpy as np
import ot
import torch
from anndata import AnnData
from sklearn.decomposition import NMF
from spateo.logging import logger_manager as lm
from .deprecated_utils import (
align_preprocess,
calc_exp_dissimilarity,
check_exp,
check_spatial_coords,
filter_common_genes,
intersect_lsts,
to_dense_matrix,
)
######################################
# Align spots across pairwise slices #
######################################
[docs]def paste_pairwise_align(
sampleA: AnnData,
sampleB: AnnData,
layer: str = "X",
genes: Optional[Union[list, np.ndarray]] = None,
spatial_key: str = "spatial",
alpha: float = 0.1,
dissimilarity: str = "kl",
G_init=None,
a_distribution=None,
b_distribution=None,
norm: bool = False,
numItermax: int = 200,
numItermaxEmd: int = 100000,
dtype: str = "float32",
device: str = "cpu",
verbose: bool = True,
) -> Tuple[np.ndarray, Optional[int]]:
"""
Calculates and returns optimal alignment of two slices.
Args:
sampleA: Sample A to align.
sampleB: Sample B to align.
layer: If `'X'`, uses ``sample.X`` to calculate dissimilarity between spots, otherwise uses the representation given by ``sample.layers[layer]``.
genes: Genes used for calculation. If None, use all common genes for calculation.
spatial_key: The key in `.obsm` that corresponds to the raw spatial coordinates.
alpha: Alignment tuning parameter. Note: 0 <= alpha <= 1.
When α = 0 only the gene expression data is taken into account,
while when α =1 only the spatial coordinates are taken into account.
dissimilarity: Expression dissimilarity measure: ``'kl'`` or ``'euclidean'``.
G_init: Initial mapping to be used in FGW-OT, otherwise default is uniform mapping.
a_distribution: Distribution of sampleA spots, otherwise default is uniform.
b_distribution: Distribution of sampleB spots, otherwise default is uniform.
norm: If ``True``, scales spatial distances such that neighboring spots are at distance 1. Otherwise, spatial distances remain unchanged.
numItermax: Max number of iterations for cg during FGW-OT.
numItermaxEmd: Max number of iterations for emd during FGW-OT.
dtype: The floating-point number type. Only float32 and float64.
device: Equipment used to run the program. You can also set the specified GPU for running. E.g.: '0'.
verbose: If ``True``, print progress updates.
Returns:
pi: Alignment of spots.
obj: Objective function output of FGW-OT.
"""
# Preprocessing
(nx, type_as, new_samples, exp_matrices, spatial_coords, normalize_scale, normalize_mean_list,) = align_preprocess(
samples=[sampleA, sampleB],
genes=genes,
spatial_key=spatial_key,
layer=layer,
normalize_c=False,
normalize_g=False,
select_high_exp_genes=False,
dtype=dtype,
device=device,
verbose=verbose,
)
# Calculate spatial distances
coordsA, coordsB = spatial_coords[0], spatial_coords[1]
D_A = ot.dist(coordsA, coordsA, metric="euclidean")
D_B = ot.dist(coordsB, coordsB, metric="euclidean")
# Calculate expression dissimilarity
X_A, X_B = exp_matrices[0], exp_matrices[1]
M = calc_exp_dissimilarity(X_A=X_A, X_B=X_B, dissimilarity=dissimilarity)
# init distributions
a = np.ones((sampleA.shape[0],)) / sampleA.shape[0] if a_distribution is None else np.asarray(a_distribution)
b = np.ones((sampleB.shape[0],)) / sampleB.shape[0] if b_distribution is None else np.asarray(b_distribution)
a = nx.from_numpy(a, type_as=type_as)
b = nx.from_numpy(b, type_as=type_as)
if norm:
D_A /= nx.min(D_A[D_A > 0])
D_B /= nx.min(D_B[D_B > 0])
# Run OT
constC, hC1, hC2 = ot.gromov.init_matrix(D_A, D_B, a, b, "square_loss")
if G_init is None:
G0 = a[:, None] * b[None, :]
else:
G_init = nx.from_numpy(G_init, type_as=type_as)
G0 = (1 / nx.sum(G_init)) * G_init
try:
from ot.optim import cg
except ImportError:
from ot.gromov import cg
pi, log = ot.gromov.cg(
a,
b,
(1 - alpha) * M,
alpha,
lambda G: ot.gromov.gwloss(constC, hC1, hC2, G),
lambda G: ot.gromov.gwggrad(constC, hC1, hC2, G),
G0,
armijo=False,
C1=D_A,
C2=D_B,
constC=constC,
numItermax=numItermax,
numItermaxEmd=numItermaxEmd,
log=True,
)
pi = nx.to_numpy(pi)
obj = nx.to_numpy(log["loss"][-1])
if device != "cpu":
torch.cuda.empty_cache()
return pi, obj
###################################################
# Integrate multiple slices into one center slice #
###################################################
[docs]def center_NMF(n_components, random_seed, dissimilarity="kl"):
if dissimilarity.lower() == "euclidean" or dissimilarity.lower() == "euc":
model = NMF(n_components=n_components, init="random", random_state=random_seed)
else:
model = NMF(
n_components=n_components,
solver="mu",
beta_loss="kullback-leibler",
init="random",
random_state=random_seed,
)
return model
[docs]def paste_center_align(
init_center_sample: AnnData,
samples: List[AnnData],
layer: str = "X",
genes: Optional[Union[list, np.ndarray]] = None,
spatial_key: str = "spatial",
lmbda: Optional[np.ndarray] = None,
alpha: float = 0.1,
n_components: int = 15,
threshold: float = 0.001,
max_iter: int = 10,
numItermax: int = 200,
numItermaxEmd: int = 100000,
dissimilarity: str = "kl",
norm: bool = False,
random_seed: Optional[int] = None,
pis_init: Optional[List[np.ndarray]] = None,
distributions: Optional[List[np.ndarray]] = None,
dtype: str = "float32",
device: str = "cpu",
verbose: bool = True,
) -> Tuple[AnnData, List[np.ndarray]]:
"""
Computes center alignment of slices.
Args:
init_center_sample: Sample to use as the initialization for center alignment; Make sure to include gene expression and spatial information.
samples: List of samples to use in the center alignment.
layer: If `'X'`, uses ``sample.X`` to calculate dissimilarity between spots, otherwise uses the representation given by ``sample.layers[layer]``.
genes: Genes used for calculation. If None, use all common genes for calculation.
spatial_key: The key in `.obsm` that corresponds to the raw spatial coordinates.
lmbda: List of probability weights assigned to each slice; If ``None``, use uniform weights.
alpha: Alignment tuning parameter. Note: 0 <= alpha <= 1.
When α = 0 only the gene expression data is taken into account,
while when α =1 only the spatial coordinates are taken into account.
n_components: Number of components in NMF decomposition.
threshold: Threshold for convergence of W and H during NMF decomposition.
max_iter: Maximum number of iterations for our center alignment algorithm.
numItermax: Max number of iterations for cg during FGW-OT.
numItermaxEmd: Max number of iterations for emd during FGW-OT.
dissimilarity: Expression dissimilarity measure: ``'kl'`` or ``'euclidean'``.
norm: If ``True``, scales spatial distances such that neighboring spots are at distance 1. Otherwise, spatial distances remain unchanged.
random_seed: Set random seed for reproducibility.
pis_init: Initial list of mappings between 'A' and 'slices' to solver. Otherwise, default will automatically calculate mappings.
distributions: Distributions of spots for each slice. Otherwise, default is uniform.
dtype: The floating-point number type. Only float32 and float64.
device: Equipment used to run the program. You can also set the specified GPU for running. E.g.: '0'.
verbose: If ``True``, print progress updates.
Returns:
- Inferred center sample with full and low dimensional representations (W, H) of the gene expression matrix.
- List of pairwise alignment mappings of the center sample (rows) to each input sample (columns).
"""
def _generate_center_sample(W, H, genes, coords, layer):
center_sample = AnnData(np.dot(W, H))
center_sample.var.index = genes
center_sample.obsm[spatial_key] = coords
if layer != "X":
center_sample.layers[layer] = center_sample.X
return center_sample
if lmbda is None:
lmbda = len(samples) * [1 / len(samples)]
if distributions is None:
distributions = len(samples) * [None]
# get common genes
all_samples_genes = [s[0].var.index for s in samples]
all_samples_genes.append(init_center_sample.var.index)
common_genes = filter_common_genes(*all_samples_genes)
common_genes = common_genes if genes is None else intersect_lsts(common_genes, genes)
# subset common genes
init_center_sample = init_center_sample[:, common_genes]
samples = [sample[:, common_genes] for sample in samples]
# Run initial NMF
if pis_init is None:
pis = [None for i in range(len(samples))]
B = check_exp(sample=init_center_sample, layer=layer)
else:
pis = pis_init
B = init_center_sample.shape[0] * sum(
[
lmbda[i] * np.dot(pis[i], to_dense_matrix(check_exp(samples[i], layer=layer)))
for i in range(len(samples))
]
)
init_NMF_model = center_NMF(n_components=n_components, random_seed=random_seed, dissimilarity=dissimilarity)
W = init_NMF_model.fit_transform(B)
H = init_NMF_model.components_
center_coords = check_spatial_coords(sample=init_center_sample, spatial_key=spatial_key)
# Minimize R
iteration_count = 0
R = 0
R_diff = 100
while R_diff > threshold and iteration_count < max_iter:
lm.main_info(message=f"{iteration_count} iteration of center alignment.", indent_level=1)
new_pis = []
r = []
for i in range(len(samples)):
p, r_q = paste_pairwise_align(
sampleA=_generate_center_sample(W=W, H=H, genes=common_genes, coords=center_coords, layer=layer),
sampleB=samples[i],
layer=layer,
spatial_key=spatial_key,
alpha=alpha,
dissimilarity=dissimilarity,
norm=norm,
G_init=pis[i],
b_distribution=distributions[i],
numItermax=numItermax,
numItermaxEmd=numItermaxEmd,
dtype=dtype,
device=device,
verbose=verbose,
)
new_pis.append(p)
r.append(r_q)
pis = new_pis.copy()
NMF_model = center_NMF(n_components, random_seed, dissimilarity=dissimilarity)
B = W.shape[0] * sum(
[
lmbda[i] * np.dot(pis[i], to_dense_matrix(check_exp(samples[i], layer=layer)))
for i in range(len(samples))
]
)
W = NMF_model.fit_transform(B)
H = NMF_model.components_
R_new = np.dot(r, lmbda)
iteration_count += 1
R_diff = abs(R - R_new)
R = R_new
lm.main_info(message=f"Objective: {R_new}", indent_level=2)
lm.main_info(message=f"Difference: {R_diff}", indent_level=2)
center_sample = init_center_sample.copy()
center_sample.X = np.dot(W, H)
center_sample.uns["paste_W"] = W
center_sample.uns["paste_H"] = H
center_sample.uns["full_rank"] = center_sample.shape[0] * sum(
[lmbda[i] * np.dot(pis[i], to_dense_matrix(samples[i].X)) for i in range(len(samples))]
)
center_sample.uns["obj"] = R
return center_sample, pis
########################################
# Generate aligned spatial coordinates #
########################################
[docs]def generalized_procrustes_analysis(X, Y, pi):
"""
Finds and applies optimal rotation between spatial coordinates of two layers (may also do a reflection).
Args:
X: np array of spatial coordinates.
Y: np array of spatial coordinates.
pi: mapping between the two layers output by PASTE.
Returns:
Aligned spatial coordinates of X, Y and the mapping relations.
"""
tX = pi.sum(axis=1).dot(X)
tY = pi.sum(axis=0).dot(Y)
X = X - tX
Y = Y - tY
H = Y.T.dot(pi.T.dot(X))
U, S, Vt = np.linalg.svd(H)
R = Vt.T.dot(U.T)
Y = R.dot(Y.T).T
mapping_dict = {"tX": tX, "tY": tY, "R": R}
return X, Y, mapping_dict