Source code for spateo.alignment.utils

from typing import List, Optional, Tuple, Union

import anndata as ad
import numpy as np
import pandas as pd
from anndata import AnnData
from scipy.spatial import cKDTree

from spateo.logging import logger_manager as lm

####################
# Before Alignment #
####################


[docs]def _iteration(n: int, progress_name: str, verbose: bool = True, start_n: int = 0, indent_level=1): iteration = ( lm.progress_logger(range(start_n, n), progress_name=progress_name, indent_level=indent_level) if verbose else range(n) ) return iteration
[docs]def downsampling( models: Union[List[AnnData], AnnData], n_sampling: Optional[int] = 2000, sampling_method: str = "trn", spatial_key: str = "spatial", ) -> Union[List[AnnData], AnnData]: from dynamo.tools.sampling import sample models = models if isinstance(models, list) else [models] sampling_models = [] for m in models: sampling_model = m.copy() if n_sampling > sampling_model.shape[0]: n_sampling = sampling_model.shape[0] sampling = sample( arr=np.asarray(sampling_model.obs_names), n=n_sampling, method=sampling_method, X=sampling_model.obsm[spatial_key], ) sampling_model = sampling_model[sampling, :] sampling_models.append(sampling_model) return sampling_models
## generate label transfer prior
[docs]def generate_label_transfer_prior(cat1, cat2, positive_pairs=None, negative_pairs=None): label_transfer_prior = dict() if positive_pairs is None: positive_pairs = [] if negative_pairs is None: negative_pairs = [] # let same annotation to have a high value if (len(positive_pairs) == 0) and (len(negative_pairs) == 0): for c in cat1: if c in cat2: positive_pairs.append({"left": [c], "right": [c], "value": 10}) for c2 in cat2: cur_transfer_prior = dict() for c1 in cat1: cur_transfer_prior[c1] = 1 label_transfer_prior[c2] = cur_transfer_prior for p in positive_pairs: for l in p["left"]: for r in p["right"]: label_transfer_prior[r][l] = p["value"] # label_transfer_prior[p[1]][p[0]] = p[2] for p in negative_pairs: for l in p["left"]: for r in p["right"]: label_transfer_prior[r][l] = p["value"] # label_transfer_prior[p[1]][p[0]] = p[2] norm_label_transfer_prior = dict() for c2 in cat2: norm_c = np.array([label_transfer_prior[c2][c1] for c1 in cat1]).sum() cur_transfer_prior = dict() for c1 in cat1: cur_transfer_prior[c1] = label_transfer_prior[c2][c1] / norm_c norm_label_transfer_prior[c2] = cur_transfer_prior return norm_label_transfer_prior
# group pca
[docs]def group_pca( adatas: List[ad.AnnData], batch_key: str = "batch", pca_key: str = "X_pca", use_hvg: bool = True, hvg_key: str = "highly_variable", **args, ) -> None: """ Perform PCA on a concatenated set of AnnData objects and store the results back in each individual AnnData. Parameters: ---------- adatas : List[AnnData] A list of AnnData objects to be concatenated and processed. batch_key : str, optional The key to distinguish different batches in the concatenated AnnData object (default is 'batch'). pca_key : str, optional The key under which to store the PCA results in each AnnData's `.obsm` attribute (default is 'X_pca'). use_hvg : bool, optional Whether to perform PCA using only highly variable genes (default is True). hvg_key : str, optional The key under which highly variable genes are marked in `.var` (default is 'highly_variable'). **args Additional arguments to pass to `sc.tl.pca`. Raises: ------ ValueError: If the specified batch_key already exists in any of the input AnnData objects. If no highly variable genes are found when use_hvg is True. """ import scanpy as sc # Check if batch_key already exists in any of the adatas for i, adata in enumerate(adatas): if batch_key in adata.obs.columns: raise ValueError( f"batch_key '{batch_key}' already exists in adata.obs for dataset {i}. Please choose a different key." ) # Concatenate all AnnData objects, using batch_key to differentiate them adata_pca = ad.concat(adatas, label=batch_key) # Identify and use highly variable genes for PCA if requested if use_hvg: sc.pp.highly_variable_genes(adata_pca, batch_key=batch_key) if not adata_pca.var[hvg_key].any(): raise ValueError( "No highly variable genes were found. Please check your data or parameters for highly variable gene selection." ) # Perform PCA using only highly variable genes sc.tl.pca(adata_pca, **args, use_highly_variable=True) else: # Perform PCA without restricting to highly variable genes sc.tl.pca(adata_pca, **args) # Split the PCA results back into the original AnnData objects for i in range(len(adatas)): adatas[i].obsm[pca_key] = adata_pca[adata_pca.obs[batch_key] == str(i)].obsm["X_pca"].copy()
################### # After Alignment # ###################
[docs]def get_optimal_mapping_relationship( X: np.ndarray, Y: np.ndarray, pi: np.ndarray, keep_all: bool = False, ): X_max_index = np.argwhere((pi.T == pi.T.max(axis=0)).T) Y_max_index = np.argwhere(pi == pi.max(axis=0)) if not keep_all: values, counts = np.unique(X_max_index[:, 0], return_counts=True) x_index_unique, x_index_repeat = values[counts == 1], values[counts != 1] X_max_index_unique = X_max_index[np.isin(X_max_index[:, 0], x_index_unique)] for i in x_index_repeat: i_max_index = X_max_index[X_max_index[:, 0] == i] i_kdtree = cKDTree(Y[i_max_index[:, 1]]) _, ii = i_kdtree.query(X[i], k=1) X_max_index_unique = np.concatenate([X_max_index_unique, i_max_index[ii].reshape(1, 2)], axis=0) values, counts = np.unique(Y_max_index[:, 1], return_counts=True) y_index_unique, y_index_repeat = values[counts == 1], values[counts != 1] Y_max_index_unique = Y_max_index[np.isin(Y_max_index[:, 1], y_index_unique)] for i in y_index_repeat: i_max_index = Y_max_index[Y_max_index[:, 1] == i] i_kdtree = cKDTree(X[i_max_index[:, 0]]) _, ii = i_kdtree.query(Y[i], k=1) Y_max_index_unique = np.concatenate([Y_max_index_unique, i_max_index[ii].reshape(1, 2)], axis=0) X_max_index = X_max_index_unique.copy() Y_max_index = Y_max_index_unique.copy() X_pi_value = pi[X_max_index[:, 0], X_max_index[:, 1]].reshape(-1, 1) Y_pi_value = pi[Y_max_index[:, 0], Y_max_index[:, 1]].reshape(-1, 1) return X_max_index, X_pi_value, Y_max_index, Y_pi_value
[docs]def mapping_aligned_coords( X: np.ndarray, Y: np.ndarray, pi: np.ndarray, keep_all: bool = False, ) -> Tuple[dict, dict]: """ Optimal mapping coordinates between X and Y. Args: X: Aligned spatial coordinates. Y: Aligned spatial coordinates. pi: Mapping between the two layers output by PASTE. keep_all: Whether to retain all the optimal relationships obtained only based on the pi matrix, If ``keep_all`` is False, the optimal relationships obtained based on the pi matrix and the nearest coordinates. Returns: Two dicts of mapping_X, mapping_Y, pi_index, pi_value. mapping_X is X coordinates aligned with Y coordinates. mapping_Y is the Y coordinate aligned with X coordinates. pi_index is index between optimal mapping points in the pi matrix. pi_value is the value of optimal mapping points. """ X = X.copy() Y = Y.copy() pi = pi.copy() # Obtain the optimal mapping between points ( X_max_index, X_pi_value, Y_max_index, Y_pi_value, ) = get_optimal_mapping_relationship(X=X, Y=Y, pi=pi, keep_all=keep_all) mappings = [] for max_index, pi_value, subset in zip( [X_max_index, Y_max_index], [X_pi_value, Y_pi_value], ["index_x", "index_y"] ): mapping_data = pd.DataFrame( np.concatenate([max_index, pi_value], axis=1), columns=["index_x", "index_y", "pi_value"], ).astype( dtype={ "index_x": np.int32, "index_y": np.int32, "pi_value": np.float64, } ) mapping_data.sort_values(by=[subset, "pi_value"], ascending=[True, False], inplace=True) mapping_data.drop_duplicates(subset=[subset], keep="first", inplace=True) mappings.append( { "mapping_X": X[mapping_data["index_x"].values], "mapping_Y": Y[mapping_data["index_y"].values], "pi_index": mapping_data[["index_x", "index_y"]].values, "pi_value": mapping_data["pi_value"].values, } ) return mappings[0], mappings[1]
[docs]def mapping_center_coords(modelA: AnnData, modelB: AnnData, center_key: str) -> dict: """ Optimal mapping coordinates between X and Y based on intermediate coordinates. Args: modelA: modelA aligned with center model. modelB: modelB aligned with center model. center_key: The key in ``.uns`` that corresponds to the alignment info between modelA/modelB and center model. Returns: A dict of raw_X, raw_Y, mapping_X, mapping_Y, pi_value. raw_X is the raw X coordinates. raw_Y is the raw Y coordinates. mapping_X is the Y coordinates aligned with X coordinates. mapping_Y is the X coordinates aligned with Y coordinates. pi_value is the value of optimal mapping points. """ modelA_dict = modelA.uns[center_key].copy() modelB_dict = modelB.uns[center_key].copy() mapping_X_cols = [f"mapping_X_{i}" for i in range(modelA_dict["mapping_Y"].shape[1])] raw_X_cols = [f"raw_X_{i}" for i in range(modelA_dict["raw_Y"].shape[1])] mapping_Y_cols = [f"mapping_Y_{i}" for i in range(modelB_dict["mapping_Y"].shape[1])] raw_Y_cols = [f"raw_Y_{i}" for i in range(modelB_dict["raw_Y"].shape[1])] X_cols = mapping_X_cols.copy() + raw_X_cols.copy() + ["mid"] X_data = pd.DataFrame( np.concatenate( [ modelA_dict["raw_Y"], modelA_dict["mapping_Y"], modelA_dict["pi_index"][:, [0]], ], axis=1, ), columns=X_cols, ) X_data["pi_value_X"] = modelA_dict["pi_value"].astype(np.float64) Y_cols = mapping_Y_cols.copy() + raw_Y_cols.copy() + ["mid"] Y_data = pd.DataFrame( np.concatenate( [ modelB_dict["raw_Y"], modelB_dict["mapping_Y"], modelB_dict["pi_index"][:, [0]], ], axis=1, ), columns=Y_cols, ) Y_data["pi_value_Y"] = modelB_dict["pi_value"].astype(np.float64) mapping_data = pd.merge(Y_data, X_data, on=["mid"], how="inner") mapping_data["pi_value"] = mapping_data[["pi_value_X"]].values * mapping_data[["pi_value_Y"]].values return { "raw_X": mapping_data[raw_X_cols].values, "raw_Y": mapping_data[raw_Y_cols].values, "mapping_X": mapping_data[mapping_X_cols].values, "mapping_Y": mapping_data[mapping_Y_cols].values, "pi_value": mapping_data["pi_value"].astype(np.float64).values, }
[docs]def get_labels_based_on_coords( model: AnnData, coords: np.ndarray, labels_key: Union[str, List[str]], spatial_key: str = "align_spatial", ) -> pd.DataFrame: """Obtain the label information in anndata.obs[key] corresponding to the coords.""" key = [labels_key] if isinstance(labels_key, str) else labels_key cols = ["x", "y", "z"] if coords.shape[1] == 3 else ["x", "y"] X_data = pd.DataFrame(model.obsm[spatial_key], columns=cols) X_data[key] = model.obs[key].values X_data.drop_duplicates(inplace=True, keep="first") Y_data = pd.DataFrame(coords.copy(), columns=cols) Y_data["map_index"] = Y_data.index merge_data = pd.merge(Y_data, X_data, on=cols, how="inner") return merge_data
######################### # Some helper functions # #########################
[docs]def solve_RT_by_correspondence( X: np.ndarray, Y: np.ndarray, return_scale: bool = False ) -> Union[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, float]]: """ Solve for the rotation matrix R and translation vector t that best align the points in X to the points in Y. Args: X (np.ndarray): Source points, shape (N, D). Y (np.ndarray): Target points, shape (N, D). return_scale (bool, optional): Whether to return the scale factor. Defaults to False. Returns: Union[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, float]]: If return_scale is False, returns the rotation matrix R and translation vector t. If return_scale is True, also returns the scale factor s. """ D = X.shape[1] N = X.shape[0] # Calculate centroids of X and Y tX = np.mean(X, axis=0) tY = np.mean(Y, axis=0) # Demean the points X_demean = X - tX Y_demean = Y - tY # Compute the covariance matrix H = np.dot(Y_demean.T, X_demean) # Singular Value Decomposition U, S, Vt = np.linalg.svd(H) # Compute the rotation matrix R = np.dot(Vt.T, U.T) # Ensure the rotation matrix is proper # if np.linalg.det(R) < 0: # Vt[-1, :] *= -1 # R = np.dot(Vt.T, U.T) # Compute the translation vector t = tX - np.dot(tY, R.T) if return_scale: # Compute the scale factor s = np.trace(np.dot(X_demean.T, X_demean) - np.dot(R.T, np.dot(Y_demean.T, X_demean))) / np.trace( np.dot(Y_demean.T, Y_demean) ) return R, t, s else: return R, t
[docs]def rigid_transformation( adata, spatial_key, key_added, theta=None, translation=None, inplace=True, ): if inplace == False: adata = adata.copy() spatial = adata.obsm[spatial_key] mean = np.mean(spatial, axis=0) spatial = spatial - mean if theta is None: # random rotation theta = np.random.rand() * 2 * np.pi rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) spatial = np.matmul(spatial, rotation_matrix) spatial = spatial + mean if translation is not None: spatial = spatial + translation adata.obsm[key_added] = spatial if inplace: pass else: return adata
############## # Simulation # ##############
[docs]def split_slice( adata, spatial_key, split_num=5, axis=2, ): spatial_points = adata.obsm[spatial_key] N = spatial_points.shape[0] sorted_points = np.argsort(spatial_points[:, axis]) points_per_segment = len(sorted_points) // split_num split_adata = [] for slice_id, i in enumerate(range(0, N, points_per_segment)): sorted_adata = adata[sorted_points[i : i + points_per_segment], :].copy() sorted_adata.obs["slice"] = slice_id split_adata.append(sorted_adata) return split_adata[:split_num]
# def tps_deformation( # adata, # spatial_key, # key_added, # grid_num=2, # tps_noise_scale=25, # add_corner_points=True, # alpha=0.1, # inplace=True, # ): # from tps import ThinPlateSpline # spatial = adata.obsm[spatial_key] # # get min max # x_min, x_max = np.min(spatial[:, 0]), np.max(spatial[:, 0]) # y_min, y_max = np.min(spatial[:, 1]), np.max(spatial[:, 1]) # # define the length of grid # grid_size_x = (x_max - x_min) / grid_num # grid_size_y = (y_max - y_min) / grid_num # # generate the grid # x_grid = np.linspace(x_min, x_max, grid_num + 1)[:-1] + grid_size_x / 2 # y_grid = np.linspace(y_min, y_max, grid_num + 1)[:-1] + grid_size_y / 2 # xx, yy = np.meshgrid(x_grid, y_grid) # # generate control points # src_points = [] # dst_points = [] # for i in range(xx.shape[0]): # for j in range(xx.shape[1]): # x_center, y_center = xx[i, j], yy[i, j] # x = x_center # y = y_center # src_points.append(np.column_stack([x, y])) # dst_points.append( # src_points[-1] + np.random.normal(scale=(grid_size_x + grid_size_y) * tps_noise_scale / 2, size=(1, 2)) # ) # src_points = np.concatenate(src_points, axis=0) # dst_points = np.concatenate(dst_points, axis=0) # if add_corner_points: # src_points = np.concatenate( # [np.array([[x_min, y_min], [x_min, y_max], [x_max, y_max], [x_max, y_min]]), src_points], 0 # ) # dst_points = np.concatenate( # [np.array([[x_min, y_min], [x_min, y_max], [x_max, y_max], [x_max, y_min]]), dst_points], 0 # ) # # calculate the TPS deformation # tps = ThinPlateSpline(alpha=alpha) # Regularization # tps.fit(src_points, dst_points) # # perform tps deformation # tps_spatial = tps.transform(spatial) # if inplace: # adata.obsm[key_added] = tps_spatial # return lambda x: tps.transform(x) # else: # adata_tps = adata.copy() # adata_tps.obsm[key_added] = tps_spatial # return adata_tps, lambda x: tps.transform(x)
[docs]def tps_deformation( adata, spatial_key, key_added, grid_num=2, tps_noise_scale=25, add_corner_points=True, alpha=0.1, inplace=True, ): from tps import ThinPlateSpline spatial = adata.obsm[spatial_key] # get min max x_min, x_max = np.min(spatial[:, 0]), np.max(spatial[:, 0]) y_min, y_max = np.min(spatial[:, 1]), np.max(spatial[:, 1]) # define the length of grid grid_size_x = (x_max - x_min) / grid_num grid_size_y = (y_max - y_min) / grid_num # generate the grid x_grid = np.linspace(x_min, x_max, grid_num + 1)[:-1] + grid_size_x / 2 y_grid = np.linspace(y_min, y_max, grid_num + 1)[:-1] + grid_size_y / 2 xx, yy = np.meshgrid(x_grid, y_grid) # generate control points src_points = [] dst_points = [] for i in range(xx.shape[0]): for j in range(xx.shape[1]): x_center, y_center = xx[i, j], yy[i, j] x = x_center y = y_center src_points.append(np.column_stack([x, y])) dst_points.append( src_points[-1] + np.random.normal(scale=(grid_size_x + grid_size_y) * tps_noise_scale / 2, size=(1, 2)) ) src_points = np.concatenate(src_points, axis=0) dst_points = np.concatenate(dst_points, axis=0) if add_corner_points: src_points = np.concatenate( [np.array([[x_min, y_min], [x_min, y_max], [x_max, y_max], [x_max, y_min]]), src_points], 0 ) dst_points = np.concatenate( [np.array([[x_min, y_min], [x_min, y_max], [x_max, y_max], [x_max, y_min]]), dst_points], 0 ) # calculate the TPS deformation tps = ThinPlateSpline(alpha=alpha) # Regularization tps.fit(src_points, dst_points) # perform tps deformation tps_spatial = tps.transform(spatial) if inplace: adata.obsm[key_added] = tps_spatial return lambda x: tps.transform(x) else: adata_tps = adata.copy() adata_tps.obsm[key_added] = tps_spatial return adata_tps, lambda x: tps.transform(x)