from typing import List, Optional, Tuple, Union
import numpy as np
import pandas as pd
from anndata import AnnData
from scipy.spatial import cKDTree
from spateo.logging import logger_manager as lm
####################
# Before Alignment #
####################
[docs]def _iteration(n: int, progress_name: str, verbose: bool = True, indent_level=1):
iteration = (
lm.progress_logger(range(n), progress_name=progress_name, indent_level=indent_level) if verbose else range(n)
)
return iteration
[docs]def downsampling(
models: Union[List[AnnData], AnnData],
n_sampling: Optional[int] = 2000,
sampling_method: str = "trn",
spatial_key: str = "spatial",
) -> Union[List[AnnData], AnnData]:
from dynamo.tools.sampling import sample
models = models if isinstance(models, list) else [models]
sampling_models = []
for m in models:
sampling_model = m.copy()
if n_sampling > sampling_model.shape[0]:
n_sampling = sampling_model.shape[0]
sampling = sample(
arr=np.asarray(sampling_model.obs_names),
n=n_sampling,
method=sampling_method,
X=sampling_model.obsm[spatial_key],
)
sampling_model = sampling_model[sampling, :]
sampling_models.append(sampling_model)
return sampling_models
## generate label transfer prior
[docs]def generate_label_transfer_prior(cat1, cat2, positive_pairs=None, negative_pairs=None):
label_transfer_prior = dict()
if positive_pairs is None:
positive_pairs = []
if negative_pairs is None:
negative_pairs = []
# let same annotation to have a high value
if (len(positive_pairs) == 0) and (len(negative_pairs) == 0):
for c in cat1:
if c in cat2:
positive_pairs.append({"left": [c], "right": [c], "value": 10})
for c2 in cat2:
cur_transfer_prior = dict()
for c1 in cat1:
cur_transfer_prior[c1] = 1
label_transfer_prior[c2] = cur_transfer_prior
for p in positive_pairs:
for l in p["left"]:
for r in p["right"]:
label_transfer_prior[r][l] = p["value"]
# label_transfer_prior[p[1]][p[0]] = p[2]
for p in negative_pairs:
for l in p["left"]:
for r in p["right"]:
label_transfer_prior[r][l] = p["value"]
# label_transfer_prior[p[1]][p[0]] = p[2]
norm_label_transfer_prior = dict()
for c2 in cat2:
norm_c = np.array([label_transfer_prior[c2][c1] for c1 in cat1]).sum()
cur_transfer_prior = dict()
for c1 in cat1:
cur_transfer_prior[c1] = label_transfer_prior[c2][c1] / norm_c
norm_label_transfer_prior[c2] = cur_transfer_prior
return norm_label_transfer_prior
###################
# After Alignment #
###################
[docs]def get_optimal_mapping_relationship(
X: np.ndarray,
Y: np.ndarray,
pi: np.ndarray,
keep_all: bool = False,
):
X_max_index = np.argwhere((pi.T == pi.T.max(axis=0)).T)
Y_max_index = np.argwhere(pi == pi.max(axis=0))
if not keep_all:
values, counts = np.unique(X_max_index[:, 0], return_counts=True)
x_index_unique, x_index_repeat = values[counts == 1], values[counts != 1]
X_max_index_unique = X_max_index[np.isin(X_max_index[:, 0], x_index_unique)]
for i in x_index_repeat:
i_max_index = X_max_index[X_max_index[:, 0] == i]
i_kdtree = cKDTree(Y[i_max_index[:, 1]])
_, ii = i_kdtree.query(X[i], k=1)
X_max_index_unique = np.concatenate([X_max_index_unique, i_max_index[ii].reshape(1, 2)], axis=0)
values, counts = np.unique(Y_max_index[:, 1], return_counts=True)
y_index_unique, y_index_repeat = values[counts == 1], values[counts != 1]
Y_max_index_unique = Y_max_index[np.isin(Y_max_index[:, 1], y_index_unique)]
for i in y_index_repeat:
i_max_index = Y_max_index[Y_max_index[:, 1] == i]
i_kdtree = cKDTree(X[i_max_index[:, 0]])
_, ii = i_kdtree.query(Y[i], k=1)
Y_max_index_unique = np.concatenate([Y_max_index_unique, i_max_index[ii].reshape(1, 2)], axis=0)
X_max_index = X_max_index_unique.copy()
Y_max_index = Y_max_index_unique.copy()
X_pi_value = pi[X_max_index[:, 0], X_max_index[:, 1]].reshape(-1, 1)
Y_pi_value = pi[Y_max_index[:, 0], Y_max_index[:, 1]].reshape(-1, 1)
return X_max_index, X_pi_value, Y_max_index, Y_pi_value
[docs]def mapping_aligned_coords(
X: np.ndarray,
Y: np.ndarray,
pi: np.ndarray,
keep_all: bool = False,
) -> Tuple[dict, dict]:
"""
Optimal mapping coordinates between X and Y.
Args:
X: Aligned spatial coordinates.
Y: Aligned spatial coordinates.
pi: Mapping between the two layers output by PASTE.
keep_all: Whether to retain all the optimal relationships obtained only based on the pi matrix, If ``keep_all``
is False, the optimal relationships obtained based on the pi matrix and the nearest coordinates.
Returns:
Two dicts of mapping_X, mapping_Y, pi_index, pi_value.
mapping_X is X coordinates aligned with Y coordinates.
mapping_Y is the Y coordinate aligned with X coordinates.
pi_index is index between optimal mapping points in the pi matrix.
pi_value is the value of optimal mapping points.
"""
X = X.copy()
Y = Y.copy()
pi = pi.copy()
# Obtain the optimal mapping between points
(
X_max_index,
X_pi_value,
Y_max_index,
Y_pi_value,
) = get_optimal_mapping_relationship(X=X, Y=Y, pi=pi, keep_all=keep_all)
mappings = []
for max_index, pi_value, subset in zip(
[X_max_index, Y_max_index], [X_pi_value, Y_pi_value], ["index_x", "index_y"]
):
mapping_data = pd.DataFrame(
np.concatenate([max_index, pi_value], axis=1),
columns=["index_x", "index_y", "pi_value"],
).astype(
dtype={
"index_x": np.int32,
"index_y": np.int32,
"pi_value": np.float64,
}
)
mapping_data.sort_values(by=[subset, "pi_value"], ascending=[True, False], inplace=True)
mapping_data.drop_duplicates(subset=[subset], keep="first", inplace=True)
mappings.append(
{
"mapping_X": X[mapping_data["index_x"].values],
"mapping_Y": Y[mapping_data["index_y"].values],
"pi_index": mapping_data[["index_x", "index_y"]].values,
"pi_value": mapping_data["pi_value"].values,
}
)
return mappings[0], mappings[1]
[docs]def mapping_center_coords(modelA: AnnData, modelB: AnnData, center_key: str) -> dict:
"""
Optimal mapping coordinates between X and Y based on intermediate coordinates.
Args:
modelA: modelA aligned with center model.
modelB: modelB aligned with center model.
center_key: The key in ``.uns`` that corresponds to the alignment info between modelA/modelB and center model.
Returns:
A dict of raw_X, raw_Y, mapping_X, mapping_Y, pi_value.
raw_X is the raw X coordinates.
raw_Y is the raw Y coordinates.
mapping_X is the Y coordinates aligned with X coordinates.
mapping_Y is the X coordinates aligned with Y coordinates.
pi_value is the value of optimal mapping points.
"""
modelA_dict = modelA.uns[center_key].copy()
modelB_dict = modelB.uns[center_key].copy()
mapping_X_cols = [f"mapping_X_{i}" for i in range(modelA_dict["mapping_Y"].shape[1])]
raw_X_cols = [f"raw_X_{i}" for i in range(modelA_dict["raw_Y"].shape[1])]
mapping_Y_cols = [f"mapping_Y_{i}" for i in range(modelB_dict["mapping_Y"].shape[1])]
raw_Y_cols = [f"raw_Y_{i}" for i in range(modelB_dict["raw_Y"].shape[1])]
X_cols = mapping_X_cols.copy() + raw_X_cols.copy() + ["mid"]
X_data = pd.DataFrame(
np.concatenate(
[
modelA_dict["raw_Y"],
modelA_dict["mapping_Y"],
modelA_dict["pi_index"][:, [0]],
],
axis=1,
),
columns=X_cols,
)
X_data["pi_value_X"] = modelA_dict["pi_value"].astype(np.float64)
Y_cols = mapping_Y_cols.copy() + raw_Y_cols.copy() + ["mid"]
Y_data = pd.DataFrame(
np.concatenate(
[
modelB_dict["raw_Y"],
modelB_dict["mapping_Y"],
modelB_dict["pi_index"][:, [0]],
],
axis=1,
),
columns=Y_cols,
)
Y_data["pi_value_Y"] = modelB_dict["pi_value"].astype(np.float64)
mapping_data = pd.merge(Y_data, X_data, on=["mid"], how="inner")
mapping_data["pi_value"] = mapping_data[["pi_value_X"]].values * mapping_data[["pi_value_Y"]].values
return {
"raw_X": mapping_data[raw_X_cols].values,
"raw_Y": mapping_data[raw_Y_cols].values,
"mapping_X": mapping_data[mapping_X_cols].values,
"mapping_Y": mapping_data[mapping_Y_cols].values,
"pi_value": mapping_data["pi_value"].astype(np.float64).values,
}
[docs]def get_labels_based_on_coords(
model: AnnData,
coords: np.ndarray,
labels_key: Union[str, List[str]],
spatial_key: str = "align_spatial",
) -> pd.DataFrame:
"""Obtain the label information in anndata.obs[key] corresponding to the coords."""
key = [labels_key] if isinstance(labels_key, str) else labels_key
cols = ["x", "y", "z"] if coords.shape[1] == 3 else ["x", "y"]
X_data = pd.DataFrame(model.obsm[spatial_key], columns=cols)
X_data[key] = model.obs[key].values
X_data.drop_duplicates(inplace=True, keep="first")
Y_data = pd.DataFrame(coords.copy(), columns=cols)
Y_data["map_index"] = Y_data.index
merge_data = pd.merge(Y_data, X_data, on=cols, how="inner")
return merge_data
#########################
# Some helper functions #
#########################
[docs]def solve_RT_by_correspondence(
X: np.ndarray, Y: np.ndarray, return_scale: bool = False
) -> Union[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, float]]:
"""
Solve for the rotation matrix R and translation vector t that best align the points in X to the points in Y.
Args:
X (np.ndarray): Source points, shape (N, D).
Y (np.ndarray): Target points, shape (N, D).
return_scale (bool, optional): Whether to return the scale factor. Defaults to False.
Returns:
Union[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, float]]:
If return_scale is False, returns the rotation matrix R and translation vector t.
If return_scale is True, also returns the scale factor s.
"""
D = X.shape[1]
N = X.shape[0]
# Calculate centroids of X and Y
tX = np.mean(X, axis=0)
tY = np.mean(Y, axis=0)
# Demean the points
X_demean = X - tX
Y_demean = Y - tY
# Compute the covariance matrix
H = np.dot(Y_demean.T, X_demean)
# Singular Value Decomposition
U, S, Vt = np.linalg.svd(H)
# Compute the rotation matrix
R = np.dot(Vt.T, U.T)
# Ensure the rotation matrix is proper
# if np.linalg.det(R) < 0:
# Vt[-1, :] *= -1
# R = np.dot(Vt.T, U.T)
# Compute the translation vector
t = tX - np.dot(tY, R.T)
if return_scale:
# Compute the scale factor
s = np.trace(np.dot(X_demean.T, X_demean) - np.dot(R.T, np.dot(Y_demean.T, X_demean))) / np.trace(
np.dot(Y_demean.T, Y_demean)
)
return R, t, s
else:
return R, t
##############
# Simulation #
##############
[docs]def split_slice(
adata,
spatial_key,
split_num=5,
axis=2,
):
spatial_points = adata.obsm[spatial_key]
N = spatial_points.shape[0]
sorted_points = np.argsort(spatial_points[:, axis])
points_per_segment = len(sorted_points) // split_num
split_adata = []
for slice_id, i in enumerate(range(0, N, points_per_segment)):
sorted_adata = adata[sorted_points[i : i + points_per_segment], :].copy()
sorted_adata.obs["slice"] = slice_id
split_adata.append(sorted_adata)
return split_adata[:split_num]