Source code for spateo.tdr.morphometrics.morphofield.gaussian_process

from typing import List, Optional, Tuple, Union

import numpy as np
from anndata import AnnData
from scipy.spatial.distance import cdist

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal

from spateo.logging import logger_manager as lm
from spateo.tdr.interpolations import get_X_Y_grid


[docs]def _con_K( x: np.ndarray, y: np.ndarray, beta: float = 0.1, method: str = "cdist", return_d: bool = False ) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]: if len(x.shape) == 1: x = x[None, :] if method == "cdist" and not return_d: K = cdist(x, y, "sqeuclidean") if len(K) == 1: K = K.flatten() else: n = x.shape[0] m = y.shape[0] D = np.matlib.tile(x[:, :, None], [1, 1, m]) - np.transpose(np.matlib.tile(y[:, :, None], [1, 1, n]), [2, 1, 0]) K = np.squeeze(np.sum(D**2, 1)) K = -beta * K K = np.exp(K) if return_d: return K, D else: return K
[docs]def _con_K_geodist( x: np.ndarray, kernel_dict: dict, beta: float = 0.1, return_d: bool = False, ) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]: # find the nearest neighbor if len(x.shape) == 1: x = x[None, :] d = cdist(x, kernel_dict["X"], "euclidean") nearest_idx = np.argmin(d, axis=1) # calculate the geodesic distance # get the first node in the path to inducing points nearest_inducing_nodes = kernel_dict["first_node_idx"][nearest_idx] # mask that indicates whether the inducing points are in the same connected component K_mask = nearest_inducing_nodes < 0 nearest_inducing_nodes[nearest_inducing_nodes < 0] = 0 # calculate the distance to that first nodes gather_inducing_nodes = kernel_dict["X"][nearest_inducing_nodes] to_first_node_dist_D = np.tile(x[:, None, :], [1, gather_inducing_nodes.shape[1], 1]) - gather_inducing_nodes to_first_node_dist = np.sqrt(np.sum(to_first_node_dist_D**2, axis=2)) origin_to_first_node_dist = ( np.tile(kernel_dict["X"][nearest_idx][:, None, :], [1, gather_inducing_nodes.shape[1], 1]) - gather_inducing_nodes ) origin_to_first_node_dist = np.sqrt(np.sum(origin_to_first_node_dist**2, axis=2)) D = kernel_dict["kernel_graph_distance"][nearest_idx] + to_first_node_dist - origin_to_first_node_dist # apply the mask D[K_mask] = 10000 # calculate the kernel K = D**2 K = -beta * K K = np.squeeze(np.exp(K)) if return_d: to_first_node_dist_D[K_mask, :] = 0 D = D[:, :, None] * to_first_node_dist_D / to_first_node_dist[:, :, None] D = D.transpose([0, 2, 1]) return K, D else: return K
# def _gp_velocity(X: np.ndarray, vf_dict: dict) -> np.ndarray: # # pre_scale = vf_dict["pre_norm_scale"] # norm_x = (X - vf_dict["norm_dict"]["mean_transformed"]) / vf_dict["norm_dict"]["scale_transformed"] # if vf_dict["kernel_type"] == "euc": # quary_kernel = _con_K(norm_x, vf_dict["inducing_variables"], vf_dict["beta"]) # elif vf_dict["kernel_type"] == "geodist": # pass # # not implemented yet # # quary_kernel = _con_K_geodist(norm_x, vf_dict["kernel_dict"], vf_dict["beta"]) # else: # raise ValueError(f"current only support cdist and geodist") # quary_velocities = np.dot(quary_kernel, vf_dict["Coff"]) # quary_rigid = np.dot(norm_x, vf_dict["R"].T) + vf_dict["t"] # quary_norm_x = quary_velocities + quary_rigid # quary_x = quary_norm_x * vf_dict["norm_dict"]["scale_fixed"] + vf_dict["norm_dict"]["mean_fixed"] # _velocities = quary_x - X # return _velocities / 10000
[docs]def _gp_velocity( X: np.ndarray, vf_dict: dict, nonrigid_only: bool = False, ) -> np.ndarray: norm_x = (X - vf_dict["norm_dict"]["mean_transformed"]) / vf_dict["norm_dict"]["scale_transformed"] if vf_dict["kernel_type"] == "euc": quary_kernel = _con_K(norm_x, vf_dict["inducing_variables"], vf_dict["beta"]) elif vf_dict["kernel_type"] == "geodist": raise NotImplementedError("geodist is not implemented yet") # not implemented yet # quary_kernel = _con_K_geodist(norm_x, vf_dict["kernel_dict"], vf_dict["beta"]) else: raise ValueError(f"current only support cdist and geodist") quary_velocities = np.dot(quary_kernel, vf_dict["Coff"]) if nonrigid_only: _velocities = ( quary_velocities * vf_dict["norm_dict"]["scale_fixed"] + (vf_dict["norm_dict"]["scale_fixed"] - vf_dict["norm_dict"]["scale_transformed"]) * norm_x ) else: quary_rigid = np.dot(norm_x, vf_dict["R"].T) + vf_dict["t"] quary_norm_x = quary_velocities + quary_rigid quary_x = quary_norm_x * vf_dict["norm_dict"]["scale_fixed"] + vf_dict["norm_dict"]["mean_fixed"] _velocities = quary_x - X return _velocities / 10000
# def _gp_velocity(X: np.ndarray, vf_dict: dict) -> np.ndarray: # # pre_scale = vf_dict["pre_norm_scale"] # norm_x = (X - vf_dict["norm_dict"]["mean_transformed"]) / vf_dict["norm_dict"]["scale_transformed"] # if vf_dict["kernel_type"] == "euc": # quary_kernel = _con_K(norm_x, vf_dict["inducing_variables"], vf_dict["beta"]) # elif vf_dict["kernel_type"] == "geodist": # pass # # not implemented yet # # quary_kernel = _con_K_geodist(norm_x, vf_dict["kernel_dict"], vf_dict["beta"]) # else: # raise ValueError(f"current only support cdist and geodist") # quary_velocities = np.dot(quary_kernel, vf_dict["Coff"]) # quary_rigid = np.dot(norm_x, vf_dict["R"].T) + vf_dict["t"] # quary_norm_x = quary_velocities + quary_rigid # _velocities = ( # quary_norm_x * vf_dict["norm_dict"]["scale_fixed"] - norm_x * vf_dict["norm_dict"]["scale_transformed"] # ) # _velocities = quary_velocities * vf_dict["norm_dict"]["scale_fixed"] # return _velocities / 10000 # def _gp_velocity(X: np.ndarray, vf_dict: dict) -> np.ndarray: # # pre_scale = vf_dict["pre_norm_scale"] # # pre_scale = vf_dict["norm_dict"]["scale_fixed"] / vf_dict["norm_dict"]["scale_transformed"] # norm_x = (X - vf_dict["norm_dict"]["mean_transformed"]) / vf_dict["norm_dict"]["scale_transformed"] # if vf_dict["kernel_type"] == "euc": # quary_kernel = _con_K(norm_x, vf_dict["inducing_variables"], vf_dict["beta"]) # elif vf_dict["kernel_type"] == "geodist": # pass # # not implemented yet # # quary_kernel = _con_K_geodist(norm_x, vf_dict["kernel_dict"], vf_dict["beta"]) # else: # raise ValueError(f"current only support cdist and geodist") # quary_velocities = np.dot(quary_kernel, vf_dict["Coff"]) # quary_rigid = np.dot(norm_x, vf_dict["R"].T) + vf_dict["t"] # quary_norm_x = quary_velocities + quary_rigid # _velocities = ( # quary_norm_x * vf_dict["norm_dict"]["scale_fixed"] - norm_x * vf_dict["norm_dict"]["scale_transformed"] # ) # _velocities = quary_velocities * vf_dict["norm_dict"]["scale_fixed"] # # _velocities = quary_velocities # return _velocities / 10000
[docs]def morphofield_gp( adata: AnnData, spatial_key: str = "align_spatial", vf_key: str = "VecFld_morpho", NX: Optional[np.ndarray] = None, grid_num: Optional[List[int]] = None, nonrigid_only: bool = False, inplace: bool = True, ) -> Optional[AnnData]: """ Calculating and predicting the vector field during development by the Gaussian Process method. Args: adata: AnnData object that contains the cell coordinates of the two states after alignment. spatial_key: The key from the ``.obsm`` that corresponds to the spatial coordinates of each cell. vf_key: The key in ``.uns`` that corresponds to the reconstructed vector field. key_added: The key that will be used for the vector field key in ``.uns``. NX: The spatial coordinates of new data point. If NX is None, generate new points based on grid_num. grid_num: The number of grids in each dimension for generating the grid velocity. Default is ``[50, 50, 50]``. nonrigid_only: If True, only the nonrigid part of the vector field will be calculated. inplace: Whether to copy adata or modify it inplace. Returns: An ``AnnData`` object is updated/copied with the ``key_added`` dictionary in the ``.uns`` attribute. The ``key_added`` dictionary which contains: X: Cell coordinates of the current state. V: Developmental direction of the X. grid: Grid coordinates of current state. grid_V: Prediction of developmental direction of the grid. method: The method of learning vector field. Here method == 'gaussian_process'. """ adata = adata if inplace else adata.copy() if vf_key in adata.uns.keys(): vf_dict = adata.uns[vf_key] vf_dict["X"] = np.asarray(adata.obsm[spatial_key], dtype=float) vf_dict["V"] = _gp_velocity(vf_dict["X"], vf_dict=vf_dict, nonrigid_only=nonrigid_only) if not (NX is None): predict_X = NX else: if grid_num is None: grid_num = [50, 50, 50] lm.main_warning(f"grid_num and NX are both None, using `grid_num = [50,50,50]`.", indent_level=1) _, _, Grid, grid_in_hull = get_X_Y_grid(X=vf_dict["X"].copy(), Y=vf_dict["V"].copy(), grid_num=grid_num) predict_X = Grid vf_dict["grid"] = predict_X vf_dict["grid_V"] = _gp_velocity(predict_X, vf_dict=vf_dict, nonrigid_only=nonrigid_only) vf_dict["method"] = "gaussian_process" lm.main_finish_progress(progress_name="morphofield") else: raise Exception( f"The {vf_key} that corresponds to the reconstructed vector field is not in ``anndata.uns``." f"Please run ``st.align.morpho_align(adata, vecfld_key_added='{vf_key}')`` before running this function." ) return None if inplace else adata