from typing import Optional, Tuple, Union
import matplotlib as mpl
import numpy as np
from pyvista import PolyData, UnstructuredGrid
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
[docs]def add_model_labels(
model: Union[PolyData, UnstructuredGrid],
labels: np.ndarray,
key_added: str = "groups",
where: Literal["point_data", "cell_data"] = "cell_data",
colormap: Union[str, list, dict, np.ndarray] = "rainbow",
alphamap: Union[float, list, dict, np.ndarray] = 1.0,
mask_color: Optional[str] = "gainsboro",
mask_alpha: Optional[float] = 0.0,
inplace: bool = False,
) -> Tuple[Optional[PolyData or UnstructuredGrid], Optional[Union[str]]]:
"""
Add rgba color to each point of model based on labels.
Args:
model: A reconstructed model.
labels: An array of labels of interest.
key_added: The key under which to add the labels.
where: The location where the label information is recorded in the model.
colormap: Colors to use for plotting data.
alphamap: The opacity of the color to use for plotting data.
mask_color: Color to use for plotting mask information.
mask_alpha: The opacity of the color to use for plotting mask information.
inplace: Updates model in-place.
Returns:
A model, which contains the following properties:
``model.cell_data[key_added]`` or ``model.point_data[key_added]``, the labels array;
``model.cell_data[f'{key_added}_rgba']`` or ``model.point_data[f'{key_added}_rgba']``, the rgba colors of the labels.
plot_cmap: Recommended colormap parameter values for plotting.
"""
model = model.copy() if not inplace else model
labels = np.asarray(labels).flatten()
if not np.issubdtype(labels.dtype, np.number):
cu_arr = np.sort(np.unique(labels), axis=0).astype(object)
raw_labels_hex = labels.copy().astype(object)
raw_labels_alpha = labels.copy().astype(object)
raw_labels_hex[raw_labels_hex == "mask"] = mpl.colors.to_hex(mask_color)
raw_labels_alpha[raw_labels_alpha == "mask"] = mask_alpha
# Set raw hex.
if isinstance(colormap, str):
if colormap in list(mpl.colormaps()):
lscmap = mpl.colormaps[colormap]
raw_hex_list = [mpl.colors.to_hex(lscmap(i)) for i in np.linspace(0, 1, len(cu_arr))]
for label, color in zip(cu_arr, raw_hex_list):
raw_labels_hex[raw_labels_hex == label] = color
else:
raw_labels_hex[raw_labels_hex != "mask"] = mpl.colors.to_hex(colormap)
elif isinstance(colormap, dict):
for label, color in colormap.items():
raw_labels_hex[raw_labels_hex == label] = mpl.colors.to_hex(color)
elif isinstance(colormap, list) or isinstance(colormap, np.ndarray):
raw_hex_list = np.array([mpl.colors.to_hex(color) for color in colormap]).astype(object)
for label, color in zip(cu_arr, raw_hex_list):
raw_labels_hex[raw_labels_hex == label] = color
else:
raise ValueError("`colormap` value is wrong." "\nAvailable `colormap` types are: `str`, `list` and `dict`.")
# Set raw alpha.
if isinstance(alphamap, float) or isinstance(alphamap, int):
raw_labels_alpha[raw_labels_alpha != "mask"] = alphamap
elif isinstance(alphamap, dict):
for label, alpha in alphamap.items():
raw_labels_alpha[raw_labels_alpha == label] = alpha
elif isinstance(alphamap, list) or isinstance(alphamap, np.ndarray):
raw_labels_alpha = np.asarray(alphamap).astype(object)
else:
raise ValueError(
"`alphamap` value is wrong." "\nAvailable `alphamap` types are: `float`, `list` and `dict`."
)
# Set rgba.
labels_rgba = [mpl.colors.to_rgba(c, alpha=a) for c, a in zip(raw_labels_hex, raw_labels_alpha)]
labels_rgba = np.array(labels_rgba).astype(np.float32)
# Added rgba of the labels.
if where == "point_data":
model.point_data[f"{key_added}_rgba"] = labels_rgba
else:
model.cell_data[f"{key_added}_rgba"] = labels_rgba
plot_cmap = None
else:
plot_cmap = colormap
# Added labels.
if where == "point_data":
model.point_data[key_added] = labels
else:
model.cell_data[key_added] = labels
return model if not inplace else None, plot_cmap