Source code for spateo.tdr.models.models_backbone.backbone_utils

from typing import Optional, Union

import numpy as np
import pandas as pd
from pyvista import PolyData, UnstructuredGrid
from scipy.spatial import KDTree

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal


[docs]def map_points_to_backbone( model: Union[PolyData, UnstructuredGrid], backbone_model: PolyData, nodes_key: str = "nodes", key_added: Optional[str] = "nodes", inplace: bool = False, **kwargs, ): """ Find the closest principal tree node to any point in the model through KDTree. Args: model: The reconstructed model. backbone_model: The constructed backbone model. nodes_key: The key that corresponds to the coordinates of the nodes in the backbone. key_added: The key under which to add the nodes labels. inplace: Updates model in-place. **kwargs: Additional parameters that will be passed to ``scipy.spatial.KDTree.`` function. Returns: A model, which contains the following properties: `model.point_data[key_added]`, the nodes labels array. """ model = model.copy() if not inplace else model nodes_data = pd.DataFrame(np.asarray(backbone_model.points), columns=["x", "y", "z"], dtype=float) nodes_data[nodes_key] = backbone_model.point_data[nodes_key].astype(int) nodes_data = nodes_data.sort_values(by=nodes_key) backbone_nodes = nodes_data.loc[:, ["x", "y", "z"]].values nodes_kdtree = KDTree(np.asarray(backbone_nodes), **kwargs) _, ii = nodes_kdtree.query(np.asarray(model.points), k=1) model.point_data[key_added] = ii return model if not inplace else None
[docs]def map_gene_to_backbone( model: Union[PolyData, UnstructuredGrid], tree: PolyData, key: Union[str, list], nodes_key: Optional[str] = "nodes", inplace: bool = False, ): """ Find the closest principal tree node to any point in the model through KDTree. Args: model: A reconstructed model contains the gene expression label. tree: A three-dims principal tree model contains the nodes label. key: The key that corresponds to the gene expression. nodes_key: The key that corresponds to the coordinates of the nodes in the tree. inplace: Updates tree model in-place. Returns: A tree, which contains the following properties: `tree.point_data[key]`, the gene expression array. """ model = model.copy() model_data = pd.DataFrame(model[nodes_key], columns=["nodes_id"]) key = [key] if isinstance(key, str) else key for sub_key in key: model_data[sub_key] = np.asarray(model[sub_key]) model_data = model_data.groupby(by="nodes_id").sum() model_data["nodes_id"] = model_data.index model_data.index = range(len(model_data.index)) tree = tree.copy() if not inplace else tree tree_data = pd.DataFrame(tree[nodes_key], columns=["nodes_id"]) tree_data = pd.merge(tree_data, model_data, how="outer", on="nodes_id") tree_data.fillna(value=0, inplace=True) for sub_key in key: tree.point_data[sub_key] = tree_data[sub_key].values return tree if not inplace else None
[docs]def _euclidean_distance(N1, N2): temp = np.asarray(N1) - np.asarray(N2) euclid_dist = np.sqrt(np.dot(temp.T, temp)) return euclid_dist
[docs]def sort_nodes_of_curve(nodes, started_node): current_node = tuple(started_node) remaining_nodes = [tuple(node) for node in nodes] sorted_nodes = [] while remaining_nodes: closest_node = min(remaining_nodes, key=lambda x: _euclidean_distance(current_node, x)) sorted_nodes.append(closest_node) remaining_nodes.remove(closest_node) current_node = closest_node sorted_nodes = np.asarray([list(sn) for sn in sorted_nodes]) return sorted_nodes