Source code for spateo.plotting.static.align

import math
from typing import List, Optional, Union

import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from anndata import AnnData

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal

from ...tools.cluster.utils import integrate, to_dense_matrix
from ...tools.utils import compute_smallest_distance
from .utils import save_return_show_fig_utils


[docs]def multi_slices( slices: Union[AnnData, List[AnnData]], slices_key: Optional[str] = None, label: Optional[str] = None, spatial_key: str = "align_spatial", layer: str = "X", point_size: Optional[float] = None, font_size: Optional[float] = 20, color: Optional[str] = "skyblue", palette: Optional[str] = None, alpha: float = 1.0, ncols: int = 4, ax_height: float = 1, dpi: int = 100, show_legend: bool = True, save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", save_kwargs: Optional[dict] = None, **kwargs, ): # Check slices object. if isinstance(slices, list): adatas = [s.copy() for s in slices] for i, s in enumerate(adatas): s.X = s.layers[layer].copy() if layer != "X" else s.X.copy() s.uns = {"__type": "UMI"} if slices_key is None: slices_key = "slices" s.obs[slices_key] = f"slice_{i}" adata = integrate(adatas=adatas, batch_key=slices_key) else: assert slices_key != None, "When `slices` is an anndata object, `slices_key` cannot be None." adata = slices.copy() adata.X = adata.layers[layer].copy() if layer != "X" else adata.X.copy() # Check label data and generate plotting data. slices_data = pd.DataFrame(adata.obsm[spatial_key][:, :2], columns=["x", "y"], dtype=float) slices_data[slices_key] = adata.obs[slices_key].values if label is None: label = "spatial coordinates" slices_data[label] = label elif label in adata.obs_keys(): slices_data[label] = adata.obs[label].values elif label in adata.var_names: adata.X = to_dense_matrix(adata.X) slices_data[label] = adata[:, label].X else: raise ValueError("`label` is not a valid column names or gene name.") # Set the arrangement of subgraphs slices_id = slices_data[slices_key].unique().tolist() nrows = math.ceil(len(slices_id) / ncols) # Set the aspect ratio of each subplot spatial_coords = slices_data[["x", "y"]].values.copy() ptp_vec = spatial_coords.ptp(0) aspect_ratio = ptp_vec[0] / ptp_vec[1] ax_height = 2 if nrows == 1 and ax_height == 1 else ax_height axsize = (ax_height * aspect_ratio, ax_height * 2) # Set multi-plot grid for plotting. sns.set_theme( context="paper", style="white", font="Arial", font_scale=1, rc={ "font.size": font_size, "font.family": ["sans-serif"], "font.sans-serif": ["Arial", "sans-serif", "Helvetica", "DejaVu Sans", "Bitstream Vera Sans"], }, ) g = sns.FacetGrid( slices_data.copy(), col=slices_key, hue=label, palette=palette, sharex=True, sharey=True, height=axsize[1] * nrows, aspect=aspect_ratio, col_wrap=ncols, despine=False, ) # Calculate the most suitable size of the point. if point_size is None: group_slices_data = slices_data.groupby(by=slices_key) min_dist_list = [] for key, data in group_slices_data: sample_num = 1000 if len(data) > 1000 else len(data) min_dist_list.append(compute_smallest_distance(coords=data[["x", "y"]].values, sample_num=sample_num)) point_size = min(min_dist_list) * axsize[0] / ptp_vec[0] * dpi point_size = point_size**2 * ncols * nrows # Draw scatter plots. g.map_dataframe(sns.scatterplot, x="x", y="y", alpha=alpha, color=color, s=point_size, legend="brief", **kwargs) # Set legend. label_values = slices_data[label].values if label_values.dtype in ["float16", "float32", "float64", "int16", "int32", "int64"]: from mpl_toolkits.axes_grid1.inset_locator import inset_axes ax = g.facet_axis(row_i=0, col_j=ncols - 1) norm = mpl.colors.Normalize(vmin=None, vmax=None) mappable = mpl.cm.ScalarMappable(norm=norm, cmap=palette) mappable.set_array(label_values) plt.colorbar( mappable, cax=inset_axes( ax, width="12%", height="100%", loc="center left", bbox_to_anchor=(1.02, 0.0, 0.5, 1.0), bbox_transform=ax.transAxes, borderpad=1.85, ), ax=ax, orientation="vertical", alpha=alpha, label=label, ) else: g.add_legend() plt.tight_layout() return save_return_show_fig_utils( save_show_or_return=save_show_or_return, show_legend=show_legend, background="white", prefix="multi_slices", save_kwargs=save_kwargs, total_panels=len(slices_id), fig=g, axes=g, return_all=False, return_all_list=None, )