try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
from typing import List, Optional, Union
import numpy as np
from anndata import AnnData
from pyvista import PolyData
from spateo.tdr import (
add_model_labels,
center_to_zero,
collect_models,
construct_pc,
merge_models,
translate_model,
)
from .three_dims_plots import three_d_multi_plot
[docs]def _check_cpos_in_multi_plot(
models: List,
window_size: Optional[tuple] = None,
cpo: Union[str, list] = "xy",
):
cpos, cpos_size = [], []
for model in models:
_window_size = (512, 512) if window_size is None else window_size
_cpo = model.plot(
cpos=cpo,
jupyter_backend="none",
return_cpos=True,
off_screen=True,
window_size=_window_size,
)
cpos.append(_cpo)
cpos_size.append(_cpo[0][2])
cpo_index = np.argmax(np.asarray(cpos_size))
cpo = cpos[cpo_index]
return cpo
[docs]def multi_models(
*adata: AnnData,
layer: str = "X",
group_key: Union[str, list] = None,
spatial_key: str = "align_spatial",
id_key: str = "slices",
mode: Literal["single", "overlap", "both"] = "single",
center_zero: bool = False,
filename: Optional[str] = None,
jupyter: Union[bool, Literal["none", "static", "trame"]] = False,
off_screen: bool = False,
cpo: Union[str, list] = "xy",
shape: Union[str, list, tuple] = None,
window_size: Optional[tuple] = None,
background: str = "white",
colormap: Union[str, list, dict] = "red",
overlap_cmap: Union[str, list, dict] = "dodgerblue",
alphamap: Union[float, list, dict] = 1.0,
overlap_amap: Union[float, list, dict] = 0.5,
ambient: Union[float, list] = 0.2,
opacity: Union[float, np.ndarray, list] = 1.0,
model_size: Union[float, list] = 3.0,
show_axes: bool = True,
show_legend: bool = True,
legend_kwargs: Optional[dict] = None,
text: Union[bool, str] = True,
text_kwargs: Optional[dict] = None,
**kwargs,
):
"""
Visualize multiple models separately in one figure.
Args:
*adata: A list of models[Anndata object].
layer: If ``'X'``, uses ``.X``, otherwise uses the representation given by ``.layers[layer]``.
group_key: The key that stores clustering or annotation information in ``.obs``, a gene name or a list of gene names in ``.var``.
spatial_key: The key in ``.obsm`` that corresponds to the spatial coordinate of each bucket.
id_key: The key in ``.obs`` that corresponds to the model id of each bucket.
mode: Three modes of visualization. Available ``mode`` are:
* ``'single'`` - Visualize each model individually.
* ``'overlap'`` - Simultaneously visualize two models aligned front to back in one subplot.
* ``'both'`` - Simultaneously visualize both types above.
center_zero: Whether to move the center point of the model to the (0, 0, 0).
filename: Filename of output file. Writer type is inferred from the extension of the filename.
* Output an image file,please enter a filename ending with
``'.png', '.tif', '.tiff', '.bmp', '.jpeg', '.jpg', '.svg', '.eps', '.ps', '.pdf', '.tex'``.
When ``jupyter=False``, if you want to save '.png' file, please ensure ``off_screen=True``.
jupyter: Whether to plot in jupyter notebook. Available ``jupyter`` are:
* ``'none'`` - Do not display in the notebook.
* ``'trame'`` - Show a trame widget
* ``'static'`` - Display a static figure.
off_screen: Renders off-screen when True. Useful for automated screenshots.
cpo: Camera position of the active render window. Available ``cpo`` are:
* Iterable containing position, focal_point, and view up.
``E.g.: [(2.0, 5.0, 13.0), (0.0, 0.0, 0.0), (-0.7, -0.5, 0.3)].``
* Iterable containing a view vector.
``E.g.: [-1.0, 2.0, -5.0].``
* A string containing the plane orthogonal to the view direction.
``E.g.: 'xy', 'xz', 'yz', 'yx', 'zx', 'zy', 'iso'.``
shape: Number of sub-render windows inside the main window. By default, there is only one render window.
* Specify two across with ``shape``=(2, 1) and a two by two grid with ``shape``=(2, 2).
* ``shape`` Can also accept a string descriptor as shape.
``E.g.: shape="3|1" means 3 plots on the left and 1 on the right,``
``E.g.: shape="4/2" means 4 plots on top and 2 at the bottom.``
window_size: Window size in pixels. The default window_size is ``[512, 512]``.
background: The background color of the window.
colormap: Colors to use for plotting pc. The default colormap is ``'dodgerblue'``.
overlap_cmap: Colors to use for plotting overlapped pc. The default colormap is ``'red'``.
alphamap: The opacity of the colors to use for plotting pc. The default alphamap is ``1.0``.
overlap_amap: The opacity of the colors to use for plotting overlapped pc. The default alphamap is ``.5``.
ambient: When lighting is enabled, this is the amount of light in the range of 0 to 1 (default 0.0) that reaches
the actor when not directed at the light source emitted from the viewer.
opacity: Opacity of the model.
If a single float value is given, it will be the global opacity of the model and uniformly applied
everywhere, elif a numpy.ndarray with single float values is given, it
will be the opacity of each point. - should be between 0 and 1.
A string can also be specified to map the scalars range to a predefined opacity transfer function
(options include: 'linear', 'linear_r', 'geom', 'geom_r').
model_size: The point size of any nodes in the dataset plotted.
show_axes: Whether to add a camera orientation widget to the active renderer.
show_legend: whether to add a legend to the plotter.
legend_kwargs: A dictionary that will be pass to the ``add_legend`` function.
By default, it is an empty dictionary and the ``add_legend`` function will use the
``{"legend_size": None, "legend_loc": None, "legend_size": None, "legend_loc": None,
"title_font_size": None, "label_font_size": None, "font_family": "arial", "fmt": "%.2e",
"n_labels": 5, "vertical": True}`` as its parameters. Otherwise, you can provide a dictionary
that properly modify those keys according to your needs.
text: The text to add the rendering.
text_kwargs: A dictionary that will be pass to the ``add_text`` function.
By default, it is an empty dictionary and the ``add_legend`` function will use the
``{ "font_family": "arial", "font_size": 12, "font_color": "black", "text_loc": "upper_left"}``
as its parameters. Otherwise, you can provide a dictionary that properly modify those keys
according to your needs.
**kwargs: Additional parameters that will be passed to ``three_d_multi_plot`` function.
"""
adata_list = adata[0]
adata_list = adata_list if isinstance(adata_list, list) else [adata_list]
# Construct a point cloud model
pcs, ids, keys, cmaps = [], [], [], []
for i, adata in enumerate(adata_list):
adata = adata.copy()
adata_id = str(adata.obs[id_key].unique().tolist()[0])
group_key = id_key if group_key is None else group_key
if adata.obsm[spatial_key].shape[1] == 2:
z = np.zeros(shape=(adata.obsm[spatial_key].shape[0], 1))
adata.obsm[spatial_key] = np.c_[adata.obsm[spatial_key], z]
pc, plot_cmap = construct_pc(
adata=adata.copy(),
layer=layer,
spatial_key=spatial_key,
groupby=group_key,
key_added=f"{adata_id}-{group_key}",
colormap=colormap,
alphamap=alphamap,
)
if center_zero is True:
center_to_zero(model=pc, inplace=True)
ids.append(adata_id)
pcs.append(pc)
keys.append(f"{adata_id}-{group_key}")
cmaps.append(plot_cmap)
# Check the shared cpo again
cpo = _check_cpos_in_multi_plot(models=pcs, window_size=window_size, cpo=cpo)
# Visualization.
if mode == "single":
three_d_multi_plot(
model=collect_models(pcs),
key=keys,
filename=filename,
jupyter=jupyter,
off_screen=off_screen,
shape=shape,
window_size=window_size,
background=background,
ambient=ambient,
opacity=opacity,
colormap=cmaps,
cpo=[cpo],
model_style=["points"],
model_size=[model_size],
show_axes=show_axes,
show_legend=show_legend,
legend_kwargs=legend_kwargs,
text=[f"\nModel id: {id}" for id in ids] if text is True else text,
text_kwargs=text_kwargs,
**kwargs,
)
elif mode in ["overlap", "both"]:
overlap_pcs, overlap_ids, overlap_keys, overlap_cmaps = [], [], [], []
for i in range(len(pcs) - 1):
pc1, pc2 = pcs[i].copy(), pcs[i + 1].copy()
npc1, opc1_1 = add_model_labels(
model=pc1.copy(),
labels=np.asarray(pc1.point_data[keys[i]]),
key_added=f"overlap-{keys[i]}-{keys[i + 1]}",
where="point_data",
colormap=overlap_cmap,
alphamap=overlap_amap,
inplace=False,
)
_, opc1_2 = add_model_labels(
model=npc1,
labels=np.asarray(pc1.point_data[keys[i]]),
key_added=f"overlap-{keys[i]}",
where="point_data",
colormap=overlap_cmap,
alphamap=overlap_amap,
inplace=True,
)
npc2, opc2_1 = add_model_labels(
model=pc2.copy(),
labels=np.asarray(pc2.point_data[keys[i + 1]]),
key_added=f"overlap-{keys[i]}-{keys[i + 1]}",
where="point_data",
colormap=colormap,
alphamap=alphamap,
inplace=False,
)
_, opc2_2 = add_model_labels(
model=npc2,
labels=np.asarray(npc2.point_data[keys[i + 1]]),
key_added=f"overlap-{keys[i + 1]}",
where="point_data",
colormap=colormap,
alphamap=alphamap,
inplace=True,
)
overlap_pc = merge_models([npc1, npc2])
if not (cmaps[i] is None):
overlap_pc = merge_models([npc1, npc2])
_, plot_cmap = add_model_labels(
model=overlap_pc,
labels=np.asarray(overlap_pc.point_data[f"overlap-{keys[i]}-{keys[i + 1]}"]),
key_added=f"overlap-{keys[i]}-{keys[i + 1]}",
where="point_data",
colormap=colormap,
inplace=True,
)
else:
plot_cmap = None
if mode == "overlap":
overlap_pcs.append(overlap_pc)
overlap_ids.append(f"{ids[i]} & {ids[i + 1]}")
overlap_keys.append(f"overlap-{keys[i]}-{keys[i + 1]}")
overlap_cmaps.append(plot_cmap)
elif mode == "both":
overlap_pcs.extend([npc1, npc2, overlap_pc])
overlap_ids.extend([ids[i], ids[i + 1], f"{ids[i]} & {ids[i + 1]}"])
overlap_keys.extend(
[
f"overlap-{keys[i]}",
f"overlap-{keys[i + 1]}",
f"overlap-{keys[i]}-{keys[i + 1]}",
]
)
overlap_cmaps.extend([opc1_2, opc2_2, plot_cmap])
three_d_multi_plot(
model=collect_models(overlap_pcs),
key=overlap_keys,
filename=filename,
jupyter=jupyter,
off_screen=off_screen,
shape=shape,
window_size=window_size,
background=background,
ambient=ambient,
opacity=opacity,
colormap=overlap_cmaps,
cpo=[cpo],
model_style=["points"],
model_size=[model_size],
show_axes=show_axes,
show_legend=show_legend,
legend_kwargs=legend_kwargs,
text=[f"\nModel id: {id}" for id in overlap_ids] if text is True else text,
text_kwargs=text_kwargs,
**kwargs,
)