Source code for spateo.tdr.models.utilities.label_utils

from typing import Optional, Tuple, Union

import matplotlib as mpl
import numpy as np
from pyvista import PolyData, UnstructuredGrid

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal


[docs]def add_model_labels( model: Union[PolyData, UnstructuredGrid], labels: np.ndarray, key_added: str = "groups", where: Literal["point_data", "cell_data"] = "cell_data", colormap: Union[str, list, dict, np.ndarray] = "rainbow", alphamap: Union[float, list, dict, np.ndarray] = 1.0, mask_color: Optional[str] = "gainsboro", mask_alpha: Optional[float] = 0.0, inplace: bool = False, ) -> Tuple[Optional[PolyData or UnstructuredGrid], Optional[Union[str]]]: """ Add rgba color to each point of model based on labels. Args: model: A reconstructed model. labels: An array of labels of interest. key_added: The key under which to add the labels. where: The location where the label information is recorded in the model. colormap: Colors to use for plotting data. alphamap: The opacity of the color to use for plotting data. mask_color: Color to use for plotting mask information. mask_alpha: The opacity of the color to use for plotting mask information. inplace: Updates model in-place. Returns: A model, which contains the following properties: ``model.cell_data[key_added]`` or ``model.point_data[key_added]``, the labels array; ``model.cell_data[f'{key_added}_rgba']`` or ``model.point_data[f'{key_added}_rgba']``, the rgba colors of the labels. plot_cmap: Recommended colormap parameter values for plotting. """ model = model.copy() if not inplace else model labels = np.asarray(labels).flatten() if not np.issubdtype(labels.dtype, np.number): cu_arr = np.sort(np.unique(labels), axis=0).astype(object) raw_labels_hex = labels.copy().astype(object) raw_labels_alpha = labels.copy().astype(object) raw_labels_hex[raw_labels_hex == "mask"] = mpl.colors.to_hex(mask_color) raw_labels_alpha[raw_labels_alpha == "mask"] = mask_alpha # Set raw hex. if isinstance(colormap, str): if colormap in list(mpl.colormaps()): lscmap = mpl.cm.get_cmap(colormap) raw_hex_list = [mpl.colors.to_hex(lscmap(i)) for i in np.linspace(0, 1, len(cu_arr))] for label, color in zip(cu_arr, raw_hex_list): raw_labels_hex[raw_labels_hex == label] = color else: raw_labels_hex[raw_labels_hex != "mask"] = mpl.colors.to_hex(colormap) elif isinstance(colormap, dict): for label, color in colormap.items(): raw_labels_hex[raw_labels_hex == label] = mpl.colors.to_hex(color) elif isinstance(colormap, list) or isinstance(colormap, np.ndarray): raw_hex_list = np.array([mpl.colors.to_hex(color) for color in colormap]).astype(object) for label, color in zip(cu_arr, raw_hex_list): raw_labels_hex[raw_labels_hex == label] = color else: raise ValueError("`colormap` value is wrong." "\nAvailable `colormap` types are: `str`, `list` and `dict`.") # Set raw alpha. if isinstance(alphamap, float) or isinstance(alphamap, int): raw_labels_alpha[raw_labels_alpha != "mask"] = alphamap elif isinstance(alphamap, dict): for label, alpha in alphamap.items(): raw_labels_alpha[raw_labels_alpha == label] = alpha elif isinstance(alphamap, list) or isinstance(alphamap, np.ndarray): raw_labels_alpha = np.asarray(alphamap).astype(object) else: raise ValueError( "`alphamap` value is wrong." "\nAvailable `alphamap` types are: `float`, `list` and `dict`." ) # Set rgba. labels_rgba = [mpl.colors.to_rgba(c, alpha=a) for c, a in zip(raw_labels_hex, raw_labels_alpha)] labels_rgba = np.array(labels_rgba).astype(np.float32) # Added rgba of the labels. if where == "point_data": model.point_data[f"{key_added}_rgba"] = labels_rgba else: model.cell_data[f"{key_added}_rgba"] = labels_rgba plot_cmap = None else: plot_cmap = colormap # Added labels. if where == "point_data": model.point_data[key_added] = labels else: model.cell_data[key_added] = labels return model if not inplace else None, plot_cmap