Source code for spateo.tdr.models.models_individual.point_clouds

from typing import Optional, Tuple, Union

import numpy as np
import pyvista as pv
from anndata import AnnData
from pandas.core.frame import DataFrame
from pyvista import PolyData

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal

from ..utilities import add_model_labels

###############################
# Construct point cloud model #
###############################


[docs]def construct_pc( adata: AnnData, layer: str = "X", spatial_key: str = "spatial", groupby: Union[str, tuple] = None, key_added: str = "groups", mask: Union[str, int, float, list] = None, colormap: Union[str, list, dict] = "rainbow", alphamap: Union[float, list, dict] = 1.0, ) -> Tuple[PolyData, Optional[str]]: """ Construct a point cloud model based on 3D coordinate information. Args: adata: AnnData object. layer: If ``'X'``, uses ``.X``, otherwise uses the representation given by ``.layers[layer]``. spatial_key: The key in ``.obsm`` that corresponds to the spatial coordinate of each bucket. groupby: The key that stores clustering or annotation information in ``.obs``, a gene name or a list of gene names in ``.var``. key_added: The key under which to add the labels. mask: The part that you don't want to be displayed. colormap: Colors to use for plotting pc. The default colormap is ``'rainbow'``. alphamap: The opacity of the colors to use for plotting pc. The default alphamap is ``1.0``. Returns: pc: A point cloud, which contains the following properties: ``pc.point_data[key_added]``, the ``groupby`` information. ``pc.point_data[f'{key_added}_rgba']``, the rgba colors of the ``groupby`` information. ``pc.point_data['obs_index']``, the obs_index of each coordinate in the original adata. plot_cmap: Recommended colormap parameter values for plotting. """ # create an initial pc. adata = adata.copy() bucket_xyz = adata.obsm[spatial_key].astype(np.float64) if isinstance(bucket_xyz, DataFrame): bucket_xyz = bucket_xyz.values pc = pv.PolyData(bucket_xyz) # The`groupby` array in original adata.obs or adata.X. mask_list = mask if isinstance(mask, list) else [mask] obs_names = set(adata.obs_keys()) gene_names = set(adata.var_names.tolist()) if groupby is None: groups = np.asarray(["same"] * adata.obs.shape[0], dtype=str) elif groupby in obs_names: groups = np.asarray(adata.obs[groupby].map(lambda x: "mask" if x in mask_list else x).values) elif groupby in gene_names or set(groupby) <= gene_names: adata.X = adata.X if layer == "X" else adata.layers[layer] groups = np.asarray(adata[:, groupby].X.sum(axis=1).flatten()) else: raise ValueError( "`groupby` value is wrong." "\n`groupby` can be a string and one of adata.obs_names or adata.var_names." "\n`groupby` can also be a list and is a subset of adata.var_names." ) _, plot_cmap = add_model_labels( model=pc, labels=groups, key_added=key_added, where="point_data", colormap=colormap, alphamap=alphamap, inplace=True, ) # The obs_index of each coordinate in the original adata. pc.point_data["obs_index"] = np.array(adata.obs_names.tolist()) return pc, plot_cmap