Source code for spateo.tools.coarse_align

"""
Todo:
    * @Xiaojieqiu: update with Google style documentation, function typings, tests
"""
import math
from typing import List, Optional, Tuple, Union

import anndata
import numpy as np
import pandas as pd
import shapely
import shapely.geometry as geometry
from sklearn.decomposition import PCA

from ..configuration import SKM
from ..io.bbs import alpha_shape
from ..logging import logger_manager as lm


[docs]def procrustes( X: np.ndarray, Y: np.ndarray, scaling: bool = True, reflection: str = "best", ) -> Tuple[float, np.ndarray, dict]: """A port of MATLAB's `procrustes` function to Numpy. This function will need to be rewritten just with scipy.spatial.procrustes and scipy.linalg.orthogonal_procrustes later. Procrustes analysis determines a linear transformation (translation, reflection, orthogonal rotation and scaling) of the points in Y to best conform them to the points in matrix X, using the sum of squared errors as the goodness of fit criterion. d, Z, [tform] = procrustes(X, Y) Args: X, Y: matrices of target and input coordinates. they must have equal numbers of points (rows), but Y may have fewer dimensions (columns) than X. scaling: if False, the scaling component of the transformation is forced to 1 reflection: if 'best' (default), the transformation solution may or may not include a reflection component, depending on which fits the data best. setting reflection to True or False forces a solution with reflection or no reflection respectively. Returns: d: the residual sum of squared errors, normalized according to a measure of the scale of X, ((X - X.mean(0))**2).sum() Z: the matrix of transformed Y-values tform: a dict specifying the rotation, translation and scaling that maps X --> Y """ n, m = X.shape ny, my = Y.shape muX = X.mean(0) muY = Y.mean(0) X0 = X - muX Y0 = Y - muY ssX = np.linalg.norm(X0, "fro") ** 2 # (X0**2.).sum() ssY = np.linalg.norm(Y0, "fro") ** 2 # (Y0**2.).sum() # centred Frobenius norm normX = np.sqrt(ssX) normY = np.sqrt(ssY) # scale to equal (unit) norm X0 /= normX Y0 /= normY if my < m: Y0 = np.concatenate((Y0, np.zeros(n, m - my)), 0) # optimum rotation matrix of Y A = np.dot(X0.T, Y0) U, s, Vt = np.linalg.svd(A, full_matrices=False) V = Vt.T T = np.dot(V, U.T) if reflection != "best": # does the current solution use a reflection? have_reflection = np.linalg.det(T) < 0 # if that's not what was specified, force another reflection if reflection != have_reflection: V[:, -1] *= -1 s[-1] *= -1 T = np.dot(V, U.T) traceTA = s.sum() if scaling: # optimum scaling of Y b = traceTA * normX / normY # standarised distance between X and b*Y*T + c d = 1 - traceTA**2 # transformed coords Z = normX * traceTA * np.dot(Y0, T) + muX else: b = 1 d = 1 + ssY / ssX - 2 * traceTA * normY / normX Z = normY * np.dot(Y0, T) + muX # transformation matrix if my < m: T = T[:my, :] c = muX - b * np.dot(muY, T) tform = {"rotation": T, "scale": b, "translation": c} return d, Z, tform
[docs]def AffineTrans( x: np.ndarray, y: np.ndarray, centroid_x: float, centroid_y: float, theta: Tuple[None, float], R: Tuple[None, np.ndarray], ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Translate the x/y coordinates of data points by the translating the centroid to the origin. Then data will be rotated with angle theta. Args: x: x coordinates for the data points (bins). 1D np.array. y: y coordinates for the data points (bins). 1D np.array. centroid_x: x coordinates for the centroid of data points (bins). centroid_y: y coordinates for the centroid of data points (bins). theta: the angle of rotation. Unit is is in `np.pi` (so 90 degree is `np.pi / 2` and value is defined in the clockwise direction. R: the rotation matrix. If `R` is provided, `theta` will be ignored. Returns: T_t: The translation matrix used in affine transformation. T_r: The rotation matrix used in affine transformation. trans_xy_coord: The matrix that stores the translated and rotated coordinates. """ if theta is None and R is None: lm.EXCEPTION(f"`theta` and `R` cannot be both None!") trans_xy_coord = np.zeros((len(x), 2)) T_t, T_r = np.zeros((3, 3)), np.zeros((3, 3)) np.fill_diagonal(T_t, 1) np.fill_diagonal(T_r, 1) T_t[0, 2], T_t[1, 2] = -centroid_x, -centroid_y if R is None: sin_theta, cos_theta = np.sin(theta), np.cos(theta) T_r[0, 0], T_r[0, 1] = cos_theta, -sin_theta T_r[1, 0], T_r[1, 1] = sin_theta, cos_theta else: T_r[:2, :2] = R for cur_x, cur_y, cur_ind in zip(x, y, np.arange(len(x))): data = np.array([cur_x, cur_y, 1]) res = T_t @ data res = T_r @ res trans_xy_coord[cur_ind, :] = res[:2] return T_t, T_r, trans_xy_coord
[docs]def pca_align(X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """Use pca to rotate a coordinate matrix to reveal the largest variance on each dimension. This can be used to `correct`, for example, embryo slices to the right orientation. Args: X: The input coordinate matrix. Returns: Y: The rotated coordinate matrix that has the major variances on each dimension. R: The rotation matrix that was used to convert the input X matrix to output Y matrix. """ pca = PCA(n_components=X.shape[1]) pca.fit(X) R = pca.components_ Y = (R @ X.T).T return Y, R
@SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE)
[docs]def align_slices_pca( adata: anndata.AnnData, spatial_key: str = "spatial", inplace: bool = False, result_key: Tuple[None, str] = None, ) -> None: """Coarsely align the slices based on the major axis, identified via PCA Args: adata: the input adata object that contains the spatial key in .obsm. spatial_key: the key in .obsm that points to the spatial information. inplace: whether the spatial coordinates will be inplace updated or a new key `spatial_. result_key: when inplace is False, this points to the key in .obsm that stores the corrected spatial coordinates. Returns: Nothing but updates the spatial coordinates either inplace or with the `result_key` key based on the major axis identified via PCA. """ coords = adata.obsm[spatial_key].copy() x, y = coords[:, 0], coords[:, 1] try: adata_concave_hull, _ = alpha_shape(x, y, alpha=1) if type(adata_concave_hull) == shapely.geometry.multipolygon.MultiPolygon: alpha_shape_x, alpha_shape_y = adata_concave_hull[0].exterior.xy else: alpha_shape_x, alpha_shape_y = adata_concave_hull.exterior.xy centroid_x, centroid_y = adata_concave_hull.centroid.coords.xy centroid_x, centroid_y = centroid_x[0], centroid_y[0] adata.uns["bbs"] = {"x": alpha_shape_x, "y": alpha_shape_y, "centroid_x": centroid_x, "centroid_y": centroid_y} except: centroid_x, centroid_y = np.nanmedian(coords, 0) adata.uns["bbs"] = {"x": None, "y": None, "centroid_x": centroid_x, "centroid_y": centroid_y} coords_correct, R = pca_align(coords) _, _, spatial_corrected = AffineTrans( coords[:, 0], coords[:, 1], centroid_x, centroid_y, None, R, ) # rotate 90 degree _, _, coords_correct_processed = AffineTrans( spatial_corrected[:, 0], spatial_corrected[:, 1], 0, 0, np.pi / 2, None, ) # reflect vertically coords_correct_processed[:, 1] = -coords_correct_processed[:, 1] # reflect vertically again: coords_correct_processed[:, 1] = -coords_correct_processed[:, 1] # account for the mirror effect when plotting an image if inplace: adata.obsm["spatial"] = coords_correct_processed else: key = "spatial_corrected" if result_key is None else result_key adata.obsm[key] = coords_correct_processed