Source code for spateo.plotting.static.three_d_plot.three_dims_plotter

from typing import Optional, Tuple, Union

import matplotlib as mpl
import numpy as np
import pyvista as pv
from matplotlib.colors import LinearSegmentedColormap
from pyvista import MultiBlock, Plotter, PolyData, UnstructuredGrid

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal


[docs]def _get_default_cmap(): if "default_cmap" not in mpl.colormaps(): colors = ["#4B0082", "#800080", "#F97306", "#FFA500", "#FFD700", "#FFFFCB"] nodes = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] mpl.colormaps.register(LinearSegmentedColormap.from_list("default_cmap", list(zip(nodes, colors)))) return "default_cmap"
[docs]def create_plotter( jupyter: bool = False, off_screen: bool = False, window_size: tuple = (512, 512), background: str = "white", shape: Union[str, list, tuple] = (1, 1), show_axes: bool = True, ) -> Plotter: """ Create a plotting object to display pyvista/vtk model. Args: jupyter: Whether to plot in jupyter notebook. off_screen: Renders off-screen when True. Useful for automated screenshots. window_size: Window size in pixels. The default window_size is ``[1024, 768]``. background: The background color of the window. shape: Number of sub-render windows inside the main window. By default, there is only one render window. * Specify two across with ``shape``=(2, 1) and a two by two grid with ``shape``=(2, 2). * ``shape`` Can also accept a string descriptor as shape. ``E.g.: shape="3|1" means 3 plots on the left and 1 on the right,`` ``E.g.: shape="4/2" means 4 plots on top and 2 at the bottom.`` show_axes: Whether to add a camera orientation widget to the active renderer. Returns: plotter: The plotting object to display pyvista/vtk model. """ # Create an initial plotting object. _get_default_cmap() plotter = pv.Plotter( off_screen=off_screen, window_size=window_size, notebook=False if jupyter is False else True, lighting="light_kit", shape=shape, ) # Set the background color of the active render window. plotter.background_color = background # Add a camera orientation widget to the active renderer (This Widget cannot be used in jupyter notebook). if jupyter != "trame": if show_axes: if shape == (1, 1): plotter.add_camera_orientation_widget() else: plotter.add_axes() return plotter
[docs]def _set_jupyter( jupyter: Union[bool, Literal["trame", "none", "static"]] = False, off_screen: bool = False, ): if jupyter is False: off_screen1, off_screen2 = off_screen, True jupyter_backend = "none" elif jupyter is True: off_screen1, off_screen2 = True, off_screen jupyter_backend = "static" elif jupyter in ["trame", "none", "static"]: off_screen1, off_screen2 = True, off_screen jupyter_backend = jupyter else: raise ValueError( "`jupyter` value is wrong." "\nAvailable `jupyter` value are: `True`, `False`, `'trame'`, `'none'`, `'static'`." ) return off_screen1, off_screen2, jupyter_backend
[docs]def add_model( plotter: Plotter, model: Union[PolyData, UnstructuredGrid, MultiBlock, list], key: Union[str, list] = None, colormap: Optional[Union[str, list]] = None, ambient: Union[float, list] = 0.2, opacity: Union[float, list, np.ndarray] = 1.0, model_style: Union[Literal["points", "surface", "wireframe"], list] = "surface", model_size: Union[float, list] = 3.0, ): """ Add model(s) to the plotter. Args: plotter: The plotting object to display pyvista/vtk model. model: A reconstructed model. key: The key under which are the labels. colormap: Name of the Matplotlib colormap to use when mapping the scalars. When the colormap is None, use {key}_rgba to map the scalars, otherwise use the colormap to map scalars. ambient: When lighting is enabled, this is the amount of light in the range of 0 to 1 (default 0.0) that reaches the actor when not directed at the light source emitted from the viewer. opacity: Opacity of the model. If a single float value is given, it will be the global opacity of the model and uniformly applied everywhere, elif a numpy.ndarray with single float values is given, it will be the opacity of each point. - should be between 0 and 1. A string can also be specified to map the scalars range to a predefined opacity transfer function (options include: 'linear', 'linear_r', 'geom', 'geom_r'). model_style: Visualization style of the model. One of the following: * ``model_style = 'surface'``, * ``model_style = 'wireframe'``, * ``model_style = 'points'``. model_size: If ``model_style = 'points'``, point size of any nodes in the dataset plotted. If ``model_style = 'wireframe'``, thickness of lines. """ def _add_model(_p, _model, _key, _colormap, _style, _ambient, _opacity, _model_size): """Add any PyVista/VTK model to the scene.""" if _style == "points": render_spheres, render_tubes, smooth_shading = True, False, True elif _style == "wireframe": render_spheres, render_tubes, smooth_shading = False, True, False else: render_spheres, render_tubes, smooth_shading = False, False, True mesh_kwargs = dict( style=_style, render_points_as_spheres=render_spheres, render_lines_as_tubes=render_tubes, point_size=_model_size, line_width=_model_size, ambient=_ambient, opacity=_opacity, smooth_shading=smooth_shading, show_scalar_bar=False, ) if _colormap is None: added_kwargs = dict( scalars=f"{_key}_rgba" if _key in _model.array_names else _model.active_scalars_name, rgba=True, ) else: added_kwargs = ( dict( scalars=_key if _key in _model.array_names else _model.active_scalars_name, cmap=_colormap, ) if _colormap in list(mpl.colormaps()) else dict(color=_colormap) ) mesh_kwargs.update(added_kwargs) _p.add_mesh(_model, **mesh_kwargs) # Add model(s) to the plotter. if isinstance(model, (MultiBlock, list)): n_model = len(model) keys = key if isinstance(key, list) else [key] keys = keys * n_model if len(keys) == 1 else keys cmaps = colormap if isinstance(colormap, list) else [colormap] cmaps = cmaps * n_model if len(cmaps) == 1 else cmaps mts = model_style if isinstance(model_style, list) else [model_style] mts = mts * n_model if len(mts) == 1 else mts mss = model_size if isinstance(model_size, list) else [model_size] mss = mss * n_model if len(mss) == 1 else mss ams = ambient if isinstance(ambient, list) else [ambient] ams = ams * n_model if len(ams) == 1 else ams ops = opacity if isinstance(opacity, list) else [opacity] ops = ops * n_model if len(ops) == 1 else ops for sub_model, sub_key, sub_cmap, sub_mt, sub_ms, sub_am, sub_op in zip(model, keys, cmaps, mts, mss, ams, ops): _add_model( _p=plotter, _model=sub_model, _key=sub_key, _colormap=sub_cmap, _style=sub_mt, _model_size=sub_ms, _ambient=sub_am, _opacity=sub_op, ) else: _add_model( _p=plotter, _model=model, _key=key, _colormap=colormap, _style=model_style, _model_size=model_size, _ambient=ambient, _opacity=opacity, )
[docs]def add_str_legend( plotter: Plotter, labels: np.ndarray, colors: np.ndarray, legend_size: Optional[Tuple] = None, legend_loc: Literal[ "upper right", "upper left", "lower left", "lower right", "center left", "center right", "lower center", "upper center", "center", ] = "lower right", ): """ Add a string legend to the plotter. Args: plotter: The plotting object to display pyvista/vtk model. labels: Legend labels. colors: The colors corresponding to the legend labels. legend_size: The percentage (0 to 1) width and height of the window for the legend. For example (0.1, 0.1) would make the legend 10% the size of the entire figure window. If legend_size is None, legend_size will be adjusted adaptively. legend_loc: The location of the legend in the window. Available ``legend_loc`` are: * ``legend_loc = 'upper right'`` * ``legend_loc = 'upper left'`` * ``legend_loc = 'lower left'`` * ``legend_loc = 'lower right'`` * ``legend_loc = 'center left'`` * ``legend_loc = 'center right'`` * ``legend_loc = 'lower center'`` * ``legend_loc = 'upper center'`` * ``legend_loc = 'center'`` """ legend_data = np.concatenate( [labels.reshape(-1, 1).astype(object), colors.reshape(-1, 1).astype(object)], axis=1, ) legend_data = legend_data[legend_data[:, 0] != "mask", :] assert len(legend_data) != 0, "No legend can be added, please set `show_legend=False`." legend_entries = legend_data[np.lexsort(legend_data[:, ::-1].T)] if legend_size is None: legend_num = 10 if len(legend_entries) >= 10 else len(legend_entries) legend_size = (0.1 + 0.01 * legend_num, 0.1 + 0.012 * legend_num) plotter.add_legend( legend_entries.tolist(), face="circle", bcolor=None, loc=legend_loc, size=legend_size, )
[docs]def add_num_legend( plotter: Plotter, title: Optional[str] = "", title_font_size: Optional[Union[int, float]] = None, label_font_size: Optional[Union[int, float]] = None, font_color: str = "black", font_family: Literal["times", "courier", "arial"] = "arial", fmt="%.2e", n_labels: int = 5, legend_size: Union[tuple, list] = (0.1, 0.4), legend_loc: Union[tuple, list] = (0.85, 0.3), vertical: bool = True, ): """ Add a numeric legend (colorbar) to the plotter. Args: plotter: The plotting object to display pyvista/vtk model. title: Title of the legend. Default '' which is rendered as an empty title. title_font_size: Sets the size of the title font. label_font_size: Sets the size of the label font. font_color: The color of the font. font_family: The font of the text. Available `font_family` are: * ``font_family = times`` * ``font_family = courier`` * ``font_family = arial`` fmt: printf format for labels. n_labels: Number of labels to use for the legend. legend_size: The percentage (0 to 1) width and height of the window for the legend. legend_loc: The location of the legend in the window. legend_loc[0]: The percentage (0 to 1) along the windows’s horizontal direction to place the bottom left corner of the legend. legend_loc[1]: The percentage (0 to 1) along the windows’s vertical direction to place the bottom left corner of the legend. vertical: Use vertical or horizontal legend. """ plotter.add_scalar_bar( title=title, n_labels=n_labels, title_font_size=title_font_size, label_font_size=label_font_size, color=font_color, font_family=font_family, use_opacity=True, width=legend_size[0], height=legend_size[1], position_x=legend_loc[0], position_y=legend_loc[1], vertical=vertical, fmt=fmt, )
[docs]def add_legend( plotter: Plotter, model: Union[PolyData, UnstructuredGrid, MultiBlock], key: Optional[Union[str, list]] = None, colormap: Optional[Union[str, list]] = None, legend_size: Optional[Tuple] = None, legend_loc: Union[str, tuple, list] = None, title: Optional[str] = "", title_font_size: Optional[Union[int, float]] = None, label_font_size: Optional[Union[int, float]] = None, font_color: str = "black", font_family: Literal["times", "courier", "arial"] = "arial", fmt="%.2e", n_labels: int = 5, vertical: bool = True, ): """ Add a legend to the plotter. Args: plotter: The plotting object to display pyvista/vtk model. model: A reconstructed model. key: The key under which are the legend labels. colormap: Name of the Matplotlib colormap to use when mapping the scalars. When the colormap is None, use {key}_rgba to map the scalars, otherwise use the colormap to map scalars. legend_size: The percentage (0 to 1) width and height of the window for the legend. legend_loc: The location of the legend in the window. title: Title of the legend. Default '' which is rendered as an empty title. title_font_size: Sets the size of the title font. Only available when colormap is not None. label_font_size: Sets the size of the label font. Only available when colormap is not None. font_color: The color of the font. Only available when colormap is not None. font_family: The font of the text. Only available when colormap is not None. Available `font_family` are: * ``font_family = times`` * ``font_family = courier`` * ``font_family = arial`` fmt: printf format for labels. Only available when colormap is not None. n_labels: Number of labels to use for the legend. Only available when colormap is not None. vertical: Use vertical or horizontal legend. Only available when colormap is not None. """ if colormap is None: assert key is not None, "When colormap is None, key cannot be None at the same time." if isinstance(model, MultiBlock): keys = key if isinstance(key, list) else [key] * len(model) legend_label_data, legend_color_data = [], [] for m, k in zip(model, keys): legend_label_data.append(np.asarray(m[k]).flatten()) legend_color_data.append(np.asarray([mpl.colors.to_hex(i) for i in m[f"{k}_rgba"]]).flatten()) legend_label_data = np.concatenate(legend_label_data, axis=0) legend_color_data = np.concatenate(legend_color_data, axis=0) else: legend_label_data = np.asarray(model[key]).flatten() legend_color_data = np.asarray([mpl.colors.to_hex(i) for i in model[f"{key}_rgba"]]).flatten() legend_data = np.concatenate([legend_label_data.reshape(-1, 1), legend_color_data.reshape(-1, 1)], axis=1) unique_legend_data = np.unique(legend_data, axis=0) add_str_legend( plotter=plotter, labels=unique_legend_data[:, 0], colors=unique_legend_data[:, 1], legend_size=legend_size, legend_loc="lower right" if legend_loc is None else legend_loc, ) else: add_num_legend( plotter=plotter, legend_size=(0.1, 0.4) if legend_size is None else legend_size, legend_loc=(0.85, 0.3) if legend_loc is None else legend_loc, title=title, n_labels=n_labels, title_font_size=title_font_size, label_font_size=label_font_size, font_color=font_color, font_family=font_family, fmt=fmt, vertical=vertical, )
[docs]def add_outline( plotter: Plotter, model: Union[PolyData, UnstructuredGrid, MultiBlock], outline_width: float = 5.0, outline_color: Union[str, tuple] = "black", show_labels: bool = True, font_size: int = 16, font_color: Union[str, tuple] = "white", font_family: Literal["times", "courier", "arial"] = "arial", ): """ Produce an outline of the full extent for the model. If labels is True, add the length, width and height information of the model to the outline. Args: plotter: The plotting object to display pyvista/vtk model. model: A reconstructed model. outline_width: The width of the outline. outline_color: The color of the outline. show_labels: Whether to add the length, width and height information of the model to the outline. font_family: The font family of the text. Available `font_family` are: * ``font_family = times`` * ``font_family = courier`` * ``font_family = arial`` font_size: The size of the font. font_color: The color of the font. """ model_outline = model.outline() plotter.add_mesh(model_outline, color=outline_color, line_width=outline_width) if show_labels is True: mo_points = np.asarray(model_outline.points) model_x = mo_points[:, 0].max() - mo_points[:, 0].min() model_y = mo_points[:, 1].max() - mo_points[:, 1].min() model_z = mo_points[:, 2].max() - mo_points[:, 2].min() model_x, model_y, model_z = ( round(model_x.astype(float), 5), round(model_y.astype(float), 5), round(model_z.astype(float), 5), ) momid_points = [ mo_points[1, :] - [model_x / 2, 0, 0], mo_points[1, :] + [0, model_y / 2, 0], mo_points[1, :] + [0, 0, model_z / 2], ] momid_labels = [model_x, model_y, model_z] plotter.add_point_labels( points=momid_points, labels=momid_labels, bold=True, font_size=font_size, font_family=font_family, shape="rounded_rect", shape_color=outline_color, show_points=False, text_color=font_color, )
[docs]def add_text( plotter: Plotter, text: str, font_family: Literal["times", "courier", "arial"] = "arial", font_size: Union[int, float] = 15, font_color: Union[str, tuple, list] = "black", text_loc: Literal[ "lower_left", "lower_right", "upper_left", "upper_right", "lower_edge", "upper_edge", "right_edge", "left_edge", ] = "upper_left", ): """ Add text to the plotter. Args: plotter: The plotting object to display pyvista/vtk model. text: The text to add the rendering. font_family: The font family of the text. Available `font_family` are: * ``font_family = times`` * ``font_family = courier`` * ``font_family = arial`` font_size: The size of the font. font_color: The color of the font. text_loc: The location of the text in the window. Available `text_loc` are: * ``text_loc = 'lower_left'`` * ``text_loc = 'lower_right'`` * ``text_loc = 'upper_left'`` * ``text_loc = 'upper_right'`` * ``text_loc = 'lower_edge'`` * ``text_loc = 'upper_edge'`` * ``text_loc = 'right_edge'`` * ``text_loc = 'left_edge'`` """ plotter.add_text( text=text, font=font_family, color=font_color, font_size=font_size, position=text_loc, )
[docs]def output_plotter( plotter: Plotter, filename: Optional[str] = None, view_up: tuple = (0.5, 0.5, 1), framerate: int = 15, jupyter: Union[bool, Literal["trame", "none", "static"]] = False, ): """ Output plotter as image, gif file or mp4 file. Args: plotter: The plotting object to display pyvista/vtk model. filename: Filename of output file. Writer type is inferred from the extension of the filename. * Output an image file,please enter a filename ending with ``'.png', '.tif', '.tiff', '.bmp', '.jpeg', '.jpg', '.svg', '.eps', '.ps', '.pdf', '.tex'``. When ``jupyter=False``, if you want to save '.png' file, please ensure ``off_screen=True``. * Output a gif file, please enter a filename ending with ``.gif``. * Output a mp4 file, please enter a filename ending with ``.mp4``. view_up: The normal to the orbital plane. Only available when filename ending with ``.mp4`` or ``.gif``. framerate: Frames per second. Only available when filename ending with ``.mp4`` or ``.gif``. jupyter: Whether to plot in jupyter notebook. Available ``jupyter`` are: * ``'none'`` - Do not display in the notebook. * ``'trame'`` - Show a trame widget * ``'static'`` - Display a static figure. Returns: cpo: List of camera position, focal point, and view up. Returned only if filename is None or filename ending with ``'.png', '.tif', '.tiff', '.bmp', '.jpeg', '.jpg', '.svg', '.eps', '.ps', '.pdf', '.tex'``. img: Numpy array of the last image. Returned only if filename is None or filename ending with ``'.png', '.tif', '.tiff', '.bmp', '.jpeg', '.jpg', '.svg', '.eps', '.ps', '.pdf', '.tex'``. """ def _to_graph(_screenshot, _jupyter_backend): if jupyter is False or jupyter == "none": if plotter.shape == (1, 1): cpo, img = plotter.show( screenshot=_screenshot, return_img=True, return_cpos=True, jupyter_backend=_jupyter_backend, ) return cpo, img else: plotter.show(screenshot=_screenshot, jupyter_backend=_jupyter_backend) else: plotter.show(screenshot=_screenshot, jupyter_backend=_jupyter_backend) def _to_gif(_filename, _view_up): """Output plotter to gif file.""" path = plotter.generate_orbital_path(factor=2.0, shift=0, viewup=_view_up, n_points=20) plotter.open_gif(_filename) plotter.orbit_on_path(path, write_frames=True, viewup=(0, 0, 1), step=0.1) plotter.close() def _to_mp4(_filename, _view_up, _framerate): """Output plotter to mp4 file.""" path = plotter.generate_orbital_path(factor=2.0, shift=0, viewup=_view_up, n_points=20) plotter.open_movie(_filename, framerate=_framerate, quality=5) plotter.orbit_on_path(path, write_frames=True, viewup=(0, 0, 1), step=0.1) plotter.close() _, _, jupyter_backend = _set_jupyter(jupyter=jupyter) # The format of the output file. if filename is None: # p.show(jupyter_backend=jupyter_backend) if jupyter is False or jupyter == "none": if plotter.shape == (1, 1): cpo, img = plotter.show(return_img=True, return_cpos=True, jupyter_backend=jupyter_backend) return cpo, img else: plotter.show(jupyter_backend=jupyter_backend) else: plotter.show(jupyter_backend=jupyter_backend) else: filename_format = filename.split(".")[-1] # Output the plotter in the format of the output file. if filename_format in ["png", "tif", "tiff", "bmp", "jpeg", "jpg"]: _to_graph(_screenshot=filename, _jupyter_backend=jupyter_backend) elif filename_format in ["svg", "eps", "ps", "pdf", "tex"]: plotter.save_graphic(filename, title="PyVista Export", raster=True, painter=True) _to_graph(_screenshot=None, _jupyter_backend=jupyter_backend) elif filename_format == "gif": _to_gif(_filename=filename, _view_up=view_up) return None elif filename_format == "mp4": _to_mp4(_filename=filename, _view_up=view_up, _framerate=framerate) return None else: raise ValueError( "\nFilename is wrong." "\nIf outputting an image file, " "please enter a filename ending with " "`.png`, `.tif`, `.tiff`, `.bmp`, `.jpeg`, `.jpg`, `.svg`, `.eps`, `.ps`, `.pdf`, `.tex`." "\nIf outputting a gif file, please enter a filename ending with `.gif`." "\nIf outputting a mp4 file, please enter a filename ending with `.mp4`." )
[docs]def save_plotter( plotter: Plotter, filename: str, ): """Save plotter as gltf file, html file, obj file or vtkjs file. Args: plotter: The plotting object to display pyvista/vtk model. filename: The filename of the file where the plotter is saved. Writer type is inferred from the extension of the filename. * Output a gltf file, please enter a filename ending with ``.gltf``. * Output a html file, please enter a filename ending with ``.html``. * Output an obj file, please enter a filename ending with ``.obj``. * Output a vtkjs file, please enter a filename without format. """ # The format of the save file. filename_format = filename.split(".")[-1] # Save the plotter in the format of the output file. if filename_format == "gltf": plotter.export_gltf(filename) elif filename_format == "html": plotter.export_html(filename) elif filename_format == "obj": plotter.export_obj(filename) else: plotter.export_vtkjs(filename)