Source code for spateo.tdr.models.models_migration.arrow_model

from typing import Optional, Tuple, Union

import numpy as np
import pyvista as pv
from pyvista import PolyData

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal

from ....logging import logger_manager as lm
from ..utilities import add_model_labels


[docs]def _construct_arrow( start_point: Union[list, tuple, np.ndarray] = (0.0, 0.0, 0.0), direction: Union[list, tuple, np.ndarray] = (1.0, 0.0, 0.0), tip_length: float = 0.25, tip_radius: float = 0.1, tip_resolution: int = 20, shaft_radius: float = 0.05, shaft_resolution: int = 20, scale: Optional[Union[str, float]] = "auto", ) -> PolyData: """ Create a 3D arrow model. Args: start_point: Start location in [x, y, z] of the arrow. direction: Direction the arrow points to in [x, y, z]. tip_length: Length of the tip. tip_radius: Radius of the tip. tip_resolution: Number of faces around the tip. shaft_radius: Radius of the shaft. shaft_resolution: Number of faces around the shaft. scale: Scale factor of the entire object. ``'auto'`` scales to length of direction array. Returns: Arrow model. """ return pv.Arrow( start=start_point, direction=direction, tip_length=tip_length, tip_radius=tip_radius, tip_resolution=tip_resolution, shaft_radius=shaft_radius, shaft_resolution=shaft_resolution, scale=scale, )
[docs]def construct_arrow( start_point: Union[list, tuple, np.ndarray], direction: Union[list, tuple, np.ndarray], arrow_scale: Optional[Union[int, float]] = None, key_added: Optional[str] = "arrow", label: str = "arrow", color: str = "gainsboro", alpha: float = 1.0, **kwargs, ) -> Tuple[PolyData, Optional[str]]: """ Create a 3D arrow model. Args: start_point: Start location in [x, y, z] of the arrow. direction: Direction the arrow points to in [x, y, z]. arrow_scale: Scale factor of the entire object. 'auto' scales to length of direction array. key_added: The key under which to add the labels. label: The label of arrow model. color: Color to use for plotting model. alpha: The opacity of the color to use for plotting model. **kwargs: Additional parameters that will be passed to ``_construct_arrow`` function. Returns: model: Arrow model. plot_cmap: Recommended colormap parameter values for plotting. """ model = _construct_arrow( start_point=start_point, direction=direction, scale="auto" if arrow_scale is None else arrow_scale, **kwargs ) plot_cmap = None if not (key_added is None): _, plot_cmap = add_model_labels( model=model, key_added=key_added, labels=np.asarray([label] * model.n_points), where="point_data", colormap=color, alphamap=alpha, inplace=True, ) return model, plot_cmap
[docs]def construct_arrows( start_points: np.ndarray, direction: np.ndarray = None, arrows_scale: Optional[np.ndarray] = None, n_sampling: Optional[int] = None, sampling_method: str = "trn", factor: float = 1.0, key_added: Optional[str] = "arrow", label: Union[str, list, np.ndarray] = "arrows", color: Union[str, list, dict, np.ndarray] = "gainsboro", alpha: Union[float, int, list, dict, np.ndarray] = 1.0, **kwargs, ) -> Tuple[PolyData, Optional[str]]: """ Create multiple 3D arrows model. Args: start_points: List of Start location in [x, y, z] of the arrows. direction: Direction the arrows points to in [x, y, z]. arrows_scale: Scale factor of the entire object. n_sampling: n_sampling is the number of coordinates to keep after sampling. If there are too many coordinates in start_points, the generated arrows model will be too complex and unsightly, so sampling is used to reduce the number of coordinates. sampling_method: The method to sample data points, can be one of ``['trn', 'kmeans', 'random']``. factor: Scale factor applied to scaling array. key_added: The key under which to add the labels. label: The label of arrows models. color: Color to use for plotting model. alpha: The opacity of the color to use for plotting model. **kwargs: Additional parameters that will be passed to ``_construct_arrow`` function. Returns: model: Arrows model. plot_cmap: Recommended colormap parameter values for plotting. """ from dynamo.tools.sampling import sample index_arr = np.arange(0, start_points.shape[0]) if not (n_sampling is None): index_arr = sample( arr=index_arr, n=n_sampling, method=sampling_method, X=start_points, ) else: if len(start_points) > 500: lm.main_warning( f"The number of start_points is more than 500. You may want to " f"lower the max number of arrows to draw." ) start_points = start_points[index_arr, :].copy() direction = direction[index_arr, :].copy() model = pv.PolyData(start_points) model.point_data["direction"] = direction model.point_data["scale"] = np.linalg.norm(direction, axis=1) if arrows_scale is None else arrows_scale[index_arr] labels = np.asarray([label] * len(start_points)) if isinstance(label, str) else np.asarray(label)[index_arr] assert len(labels) == len(start_points), "The number of labels is not equal to the number of start points." plot_cmap = None if not (key_added is None): _, plot_cmap = add_model_labels( model=model, key_added=key_added, labels=labels, where="point_data", colormap=color, alphamap=alpha, inplace=True, ) glyph = model.glyph(orient="direction", geom=_construct_arrow(**kwargs), scale="scale", factor=factor) return glyph, plot_cmap