import math
from typing import List, Optional, Union

import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from anndata import AnnData

    from typing import Literal
except ImportError:
    from typing_extensions import Literal

import numpy as np

from import integrate, to_dense_matrix
from import compute_smallest_distance
from .utils import save_return_show_fig_utils

[docs]def slices_2d( slices: Union[AnnData, List[AnnData]], slices_key: Optional[Union[bool, str]] = None, label_key: Optional[str] = None, label_type: Optional[str] = None, spatial_key: str = "spatial", point_size: Optional[float] = None, n_sampling: int = -1, palette: Optional[dict] = None, ncols: int = 4, title: str = "", title_kwargs: Optional[dict] = None, show_legend: bool = True, legend_kwargs: Optional[dict] = None, axis_off: bool = False, axis_kwargs: Optional[dict] = None, ticks_off: bool = True, x_min=None, x_max=None, y_min=None, y_max=None, height: float = 2, alpha: float = 1.0, # TODO: alpha to be a key in adata cmap="tab20", center_coordinate: bool = False, gridspec_kws: Optional[dict] = None, return_palette: bool = False, save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", save_kwargs: Optional[dict] = None, **kwargs, ): # Check slices object. if isinstance(slices, AnnData): slices = [slices] # get spatial coords and labels spatial_coords = [] labels = [] slice_ids = [] for i, s in enumerate(slices): if spatial_key in s.obsm.keys(): spatial_coords.append(s.obsm[spatial_key].copy()) else: raise ValueError(f"adata.obsm['{spatial_key}'] does not exist.") if label_key in s.obs.keys(): labels.append(s.obs[label_key].copy()) # label_type = "cluster" elif label_key in s.var_names: labels.append(s[:, label_key].X.A.copy().squeeze()) # label_type = "scalar" else: raise ValueError(f"adata.obs['{label_key}'] or adata.var['{label_key}'] does not exist.") if (slices_key is not None) and (slices_key in s.obs.keys()): unique_id = np.unique(s.obs[slices_key]) if len(unique_id) == 1: slice_ids.append(unique_id[0]) else: raise ValueError(f"adata.obs['{slices_key}'] must have only one unique value.") else: slice_ids.append(str(i)) assert ( spatial_coords[-1].shape[0] == labels[-1].shape[0] ), "The number of spatial coordinates and labels must be the same. Please check the data." # infer the label_type if not specified if label_type is None: if labels[0].values.dtype in ["float16", "float32", "float64", "int16", "int32", "int64"]: label_type = "scalar" else: label_type = "cluster" # downsampling if n_sampling is set for i in range(len(slices)): sampling_idx = ( np.random.choice(spatial_coords[i].shape[0], n_sampling, replace=False) if n_sampling > 0 and n_sampling < spatial_coords[i].shape[0] else np.arange(spatial_coords[i].shape[0]) ) spatial_coords[i] = spatial_coords[i][sampling_idx] labels[i] = labels[i][sampling_idx] # center the coordinates if center_coordinate: for i in range(len(slices)): spatial_coords[i] = spatial_coords[i] - np.mean(spatial_coords[i], axis=0) # Set the arrangement of subgraphs nrows = math.ceil(len(slices) / ncols) # create dataframe for ploting slices_spatial_data = pd.DataFrame(columns=["x", "y", "label", "slice_id", "col", "row"]) for i in range(len(slices)): slices_spatial_data = pd.concat( [ slices_spatial_data, pd.DataFrame( { "x": spatial_coords[i][:, 0], "y": spatial_coords[i][:, 1], "label": labels[i], "slice_id": slice_ids[i], "col": i % ncols, "row": i // ncols, } ), ], axis=0, ) # set the aspect ratio of each subplot ptp_vec = slices_spatial_data[["x", "y"]].values.ptp(0) aspect_ratio = ptp_vec[0] / ptp_vec[1] # Set multi-plot grid for plotting. sns.set_theme( context="paper", style="white", font="Arial", font_scale=1, rc={ # "font.size": font_size, "": ["sans-serif"], "font.sans-serif": ["Arial", "sans-serif", "Helvetica", "DejaVu Sans", "Bitstream Vera Sans"], }, ) # generate palette if (palette is None) and (label_type == "cluster"): palette = _agenerate_palette(*labels, cmap=cmap) elif label_type == "scalar": palette = cmap # adjust the gridspec _gridspec_kws = {"wspace": 0.1, "hspace": 0.2} if gridspec_kws is not None: _gridspec_kws.update(gridspec_kws) if slices_key is False: _gridspec_kws["hspace"] = _gridspec_kws["wspace"] * aspect_ratio # determine the pointsize if not specified if point_size is None: point_size = 500 * height**2 * aspect_ratio / (slices_spatial_data.shape[0] / len(slices)) # plotting g = sns.FacetGrid( slices_spatial_data, col="col", row="row", hue="label", palette=palette, sharex=True, sharey=True, height=height, aspect=aspect_ratio, despine=False, gridspec_kws=_gridspec_kws, ) scatterplot_kwargs = {"x": "x", "y": "y", "alpha": alpha, "s": point_size, "legend": False, "edgecolor": None} scatterplot_kwargs.update(kwargs) g.map_dataframe(sns.scatterplot, **scatterplot_kwargs) for i, (col_val, ax) in enumerate(g.axes_dict.items()): if i < len(slices): if slices_key is False: ax.set_title("") else: ax.set_title(f"Slice {slice_ids[i]}", title_kwargs) else: ax.set_title("") ax.set_xticks([]) ax.set_yticks([]) ax.axis("off") ax.set_aspect("equal") if axis_off: ax.axis("off") if ticks_off: ax.set_xticks([]) ax.set_yticks([]) ax.set_xlabel("") ax.set_ylabel("") # create legend if show_legend: if label_type == "cluster": _legend_kwargs = { "loc": "center left", "bbox_to_anchor": (1, 0.5), "prop": {"family": "Arial", "size": 10}, "fancybox": False, "edgecolor": "black", "framealpha": 1, "columnspacing": 0.8, "handletextpad": 0.5, "frameon": True, } if legend_kwargs: _legend_kwargs.update(legend_kwargs) # if legend_kwargs.get('loc', None) == 'upper center': # _legend_kwargs['bbox_to_anchor'] = (0.5, 0) legend_elements = [ mpl.lines.Line2D( [0], [0], marker="o", color="w", label=k, markerfacecolor=v, markersize=6, markeredgecolor="k" ) for k, v in palette.items() ] g.figure.legend(handles=legend_elements, **_legend_kwargs) else: _legend_kwargs = { "loc": "center left", # 'bbox_to_anchor': (1, 0.5, 0.5, 1.0), # 'prop': {'family': 'Arial', 'size': 10}, # 'fancybox': False, # 'edgecolor': 'black', # 'framealpha': 1, # 'columnspacing': 0.5, # 'handletextpad': 0.1, # 'frameon': True, } if legend_kwargs: _legend_kwargs.update(legend_kwargs) # if legend_kwargs.get('loc', None) == 'upper center': # _legend_kwargs['bbox_to_anchor'] = (0.5, 0, 0.5, 1.0) # TODO: add colorbar for scalar value input label_values = slices_spatial_data["label"].values norm = mpl.colors.Normalize(vmin=None, vmax=None) mappable =, cmap=palette) mappable.set_array(label_values) from mpl_toolkits.axes_grid1.inset_locator import inset_axes g.figure.colorbar( mappable, use_gridspec=False, shrink=0.5, cax=inset_axes( ax, width="15%", height="75%", loc="center left", # **_legend_kwargs, bbox_to_anchor=(1.02, 0.0, 0.5, 1.0), bbox_transform=ax.transAxes, ), ) # TODO: add save_return_show_fig_utils # plt.tight_layout() if return_palette: return ( save_return_show_fig_utils( save_show_or_return=save_show_or_return, show_legend=show_legend, background="white", prefix="multi_slices", save_kwargs=save_kwargs, total_panels=len(slice_ids), fig=g, axes=g, return_all=False, return_all_list=None, ), palette, ) else: return save_return_show_fig_utils( save_show_or_return=save_show_or_return, show_legend=show_legend, background="white", prefix="multi_slices", save_kwargs=save_kwargs, total_panels=len(slice_ids), fig=g, axes=g, return_all=False, return_all_list=None, )
# return g, palette
[docs]def overlay_slices_2d( slices: Union[AnnData, List[AnnData]], slices_key: Optional[Union[bool, str]] = None, label_key: Optional[str] = None, overlay_type: Literal["forward", "backward", "both"] = "both", spatial_key: str = "spatial", point_size: Optional[float] = None, n_sampling: int = -1, palette: Optional[dict] = None, ncols: int = 4, title: str = "", title_kwargs: Optional[dict] = None, show_legend: bool = True, legend_kwargs: Optional[dict] = None, axis_off: bool = False, axis_kwargs: Optional[dict] = None, ticks_off: bool = True, x_min=None, x_max=None, y_min=None, y_max=None, height: float = 2, alpha: float = 1.0, # TODO: alpha to be a key in adata cmap="tab20", center_coordinate: bool = False, # different from slices_2d gridspec_kws: Optional[dict] = None, save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", save_kwargs: Optional[dict] = None, **kwargs, ): # Check slices object. if isinstance(slices, AnnData): slices = [slices] # get spatial coords and labels spatial_coords = [] labels = [] slice_ids = [] for i, s in enumerate(slices): if spatial_key in s.obsm.keys(): spatial_coords.append(s.obsm[spatial_key].copy()) else: raise ValueError(f"adata.obsm['{spatial_key}'] does not exist.") if label_key is not None: if label_key in s.obs.keys(): labels.append(s.obs[label_key].copy()) label_type = "cluster" elif label_key in s.var_names: labels.append(s[:, label_key].X.A.copy().squeeze()) label_type = "scalar" else: raise ValueError(f"adata.obs['{label_key}'] or adata.var['{label_key}'] does not exist.") assert ( spatial_coords[-1].shape[0] == labels[-1].shape[0] ), "The number of spatial coordinates and labels must be the same. Please check the data." else: label_type = "cluster" if (slices_key is not None) and (slices_key in s.obs.keys()): unique_id = np.unique(s.obs[slices_key]) if len(unique_id) == 1: slice_ids.append(unique_id[0]) else: raise ValueError(f"adata.obs['{slices_key}'] must have only one unique value.") else: slice_ids.append(str(i)) # downsampling if n_sampling is set for i in range(len(slices)): sampling_idx = ( np.random.choice(spatial_coords[i].shape[0], n_sampling, replace=False) if n_sampling > 0 and n_sampling < spatial_coords[i].shape[0] else np.arange(spatial_coords[i].shape[0]) ) spatial_coords[i] = spatial_coords[i][sampling_idx] if label_key is not None: labels[i] = labels[i][sampling_idx] # center the coordinates if center_coordinate: for i in range(len(slices)): spatial_coords[i] = spatial_coords[i] - np.mean(spatial_coords[i], axis=0) # Set the arrangement of subgraphs nrows = math.ceil(len(slices) / ncols) # create dataframe for ploting slices_spatial_data = pd.DataFrame(columns=["x", "y", "label", "overlay_id", "slice_id", "col", "row"]) for i in range(len(slices)): if ( (overlay_type == "both") or ((overlay_type == "backward") and (i < len(slices) - 1)) or ((overlay_type == "forward") and (i > 0)) ): slices_spatial_data = pd.concat( [ slices_spatial_data, pd.DataFrame( { "x": spatial_coords[i][:, 0], "y": spatial_coords[i][:, 1], "label": labels[i] if label_key is not None else "unknow", "overlay_id": "current", "slice_id": slice_ids[i], "col": i % ncols, "row": i // ncols, } ), ], axis=0, ) if (i > 0) and ((overlay_type == "forward") or (overlay_type == "both")): slices_spatial_data = pd.concat( [ slices_spatial_data, pd.DataFrame( { "x": spatial_coords[i][:, 0], "y": spatial_coords[i][:, 1], "label": labels[i] if label_key is not None else "unknow", "overlay_id": "current", "slice_id": slice_ids[i], "col": i % ncols, "row": i // ncols, } ), ], axis=0, ) slices_spatial_data = pd.concat( [ slices_spatial_data, pd.DataFrame( { "x": spatial_coords[i - 1][:, 0], "y": spatial_coords[i - 1][:, 1], "label": labels[i - 1] if label_key is not None else "unknow", "overlay_id": "forward", "slice_id": slice_ids[i - 1], "col": i % ncols, "row": i // ncols, } ), ], axis=0, ) if (i < len(slices) - 1) and ((overlay_type == "backward") or (overlay_type == "both")): slices_spatial_data = pd.concat( [ slices_spatial_data, pd.DataFrame( { "x": spatial_coords[i][:, 0], "y": spatial_coords[i][:, 1], "label": labels[i] if label_key is not None else "unknow", "overlay_id": "current", "slice_id": slice_ids[i], "col": i % ncols, "row": i // ncols, } ), ], axis=0, ) slices_spatial_data = pd.concat( [ slices_spatial_data, pd.DataFrame( { "x": spatial_coords[i + 1][:, 0], "y": spatial_coords[i + 1][:, 1], "label": labels[i + 1] if label_key is not None else "unknow", "overlay_id": "backward", "slice_id": slice_ids[i + 1], "col": i % ncols, "row": i // ncols, } ), ], axis=0, ) # set the aspect ratio of each subplot ptp_vec = slices_spatial_data[["x", "y"]].values.ptp(0) aspect_ratio = ptp_vec[0] / ptp_vec[1] # Set multi-plot grid for plotting. sns.set_theme( context="paper", style="white", font="Arial", font_scale=1, rc={ # "font.size": font_size, "": ["sans-serif"], "font.sans-serif": ["Arial", "sans-serif", "Helvetica", "DejaVu Sans", "Bitstream Vera Sans"], }, ) # generate palette if label_key is not None: if (palette is None) and (label_type == "cluster"): palette = _agenerate_palette(*labels, cmap=cmap) else: palette = cmap else: if overlay_type == "both": palette = { "current": "red", "forward": "green", "backward": "blue", } elif overlay_type == "forward": palette = { "current": "red", "forward": "green", } elif overlay_type == "backward": palette = { "current": "red", "backward": "blue", } # adjust the gridspec _gridspec_kws = {"wspace": 0.1, "hspace": 0.2} if gridspec_kws is not None: _gridspec_kws.update(gridspec_kws) if slices_key is False: _gridspec_kws["hspace"] = _gridspec_kws["wspace"] * aspect_ratio # determine the pointsize if not specified if point_size is None: point_size = 500 * height**2 * aspect_ratio / (slices_spatial_data.shape[0] / len(slices)) # plotting g = sns.FacetGrid( slices_spatial_data, col="col", row="row", hue="label" if label_key is not None else "overlay_id", palette=palette, sharex=True, sharey=True, height=height, aspect=aspect_ratio, despine=False, gridspec_kws=_gridspec_kws, xlim=(x_min, x_max) if x_min is not None and x_max is not None else None, ylim=(y_min, y_max) if y_min is not None and y_max is not None else None, ) scatterplot_kwargs = {"x": "x", "y": "y", "alpha": alpha, "s": point_size, "legend": False, "edgecolor": None} scatterplot_kwargs.update(kwargs) g.map_dataframe(sns.scatterplot, **scatterplot_kwargs) for i, (col_val, ax) in enumerate(g.axes_dict.items()): if i < len(slices): if slices_key is False: ax.set_title("") else: ax.set_title(f"Slice {slice_ids[i]}", title_kwargs) else: ax.set_title("") ax.set_xticks([]) ax.set_yticks([]) ax.axis("off") ax.set_aspect("equal") if axis_off: ax.axis("off") if ticks_off: ax.set_xticks([]) ax.set_yticks([]) # if x_max is not None and x_min is not None: # ax.set_xlim(x_min, x_max) # if y_max is not None and y_min is not None: # ax.set_ylim(y_min, y_max) ax.set_xlabel("") ax.set_ylabel("") # create legend if show_legend: if label_type == "cluster": _legend_kwargs = { "loc": "upper center", "bbox_to_anchor": (0.5, 0), "prop": {"family": "Arial", "size": 10}, "fancybox": False, "edgecolor": "black", "framealpha": 1, "columnspacing": 0.8, "handletextpad": 0.5, "frameon": True, "ncol": 8, "borderaxespad": -4, "frameon": False, } if legend_kwargs: _legend_kwargs.update(legend_kwargs) # if legend_kwargs.get('loc', None) == 'upper center': # _legend_kwargs['bbox_to_anchor'] = (0.5, 0) legend_elements = [ mpl.lines.Line2D( [0], [0], marker="o", color="w", label=k, markerfacecolor=v, markersize=6, markeredgecolor="k" ) for k, v in palette.items() ] g.figure.legend(handles=legend_elements, **_legend_kwargs) else: _legend_kwargs = { "loc": "center left", # 'bbox_to_anchor': (1, 0.5, 0.5, 1.0), # 'prop': {'family': 'Arial', 'size': 10}, # 'fancybox': False, # 'edgecolor': 'black', # 'framealpha': 1, # 'columnspacing': 0.5, # 'handletextpad': 0.1, # 'frameon': True, } if legend_kwargs: _legend_kwargs.update(legend_kwargs) # if legend_kwargs.get('loc', None) == 'upper center': # _legend_kwargs['bbox_to_anchor'] = (0.5, 0, 0.5, 1.0) # TODO: add colorbar for scalar value input label_values = slices_spatial_data["label"].values norm = mpl.colors.Normalize(vmin=None, vmax=None) mappable =, cmap=palette) mappable.set_array(label_values) from mpl_toolkits.axes_grid1.inset_locator import inset_axes g.figure.colorbar( mappable, use_gridspec=False, shrink=0.5, cax=inset_axes( ax, width="15%", height="75%", loc="center left", # **_legend_kwargs, bbox_to_anchor=(1.02, 0.0, 0.5, 1.0), bbox_transform=ax.transAxes, ), ) # TODO: add save_return_show_fig_utils # plt.tight_layout() return save_return_show_fig_utils( save_show_or_return=save_show_or_return, show_legend=show_legend, background="white", prefix="multi_slices", save_kwargs=save_kwargs, total_panels=len(slice_ids), fig=g, axes=g, return_all=False, return_all_list=None, )
# def plot_align_correspondence_2d( # slices: List[AnnData], # mapping: List[np.ndarray], # label_key: Optional[str] = None, # spatial_key: str = "spatial", # point_size: Optional[float] = None, # linewidth: Optional[float] = None, # n_sampling: int = -1, # mapping_sampling: int = -1, # robust_threshold: Optional[float] = None, # palette: Optional[dict] = None, # show_legend: bool = True, # legend_kwargs: Optional[dict] = None, # axis_off: bool = False, # axis_kwargs: Optional[dict] = None, # ticks_off: bool = True, # x_min=None, # x_max=None, # y_min=None, # y_max=None, # height: float = 2, # alpha: float = 1.0, # TODO: alpha to be a key in adata # cmap="tab20", # center_coordinate: bool = False, # different from slices_2d # ): # assert len(mapping) == len(slices) - 1, "The length of mapping should be len(slices) - 1." # # get spatial coords and labels # spatial_coords = [] # labels = [] # for i, s in enumerate(slices): # if spatial_key in s.obsm.keys(): # spatial_coords.append(s.obsm[spatial_key].copy()) # else: # raise ValueError(f"adata.obsm['{spatial_key}'] does not exist.") # if label_key in s.obs.keys(): # labels.append(s.obs[label_key].copy()) # # label_type = "cluster" # elif label_key in s.var_names: # labels.append(s[:, label_key].X.A.copy().squeeze()) # # label_type = "scalar" # else: # raise ValueError(f"adata.obs['{label_key}'] or adata.var['{label_key}'] does not exist.") # assert ( # spatial_coords[-1].shape[0] == labels[-1].shape[0] # ), "The number of spatial coordinates and labels must be the same. Please check the data." # # add mapping # correspondences = [] # for i in range(len(mapping)): # if mapping[i].shape[1] == 2: # correspondences.append(mapping[i]) # elif (mapping[i].shape[1] == slices[i+1].shape[0]) and (mapping[i].shape[0] == slices[i].shape[0]): # sampling_idx = ( # np.random.choice(mapping[i].shape[0], mapping_sampling, replace=False) # if mapping_sampling > 0 and mapping_sampling < mapping[i].shape[0] # else np.arange(mapping[i].shape[0]) # ) # mapping_argmax = np.argmax(mapping[i][sampling_idx], axis=1) # mapping_valmax = mapping[i][sampling_idx][np.arange(mapping_sampling), mapping_argmax] # mask = np.arange(sampling_idx.shape[0]) if robust_threshold is None else mapping_valmax > robust_threshold # correspondence = np.array([sampling_idx[mask], mapping_argmax[mask]]).T # correspondences.append(correspondence) # else: # raise ValueError("The shape of mapping is not correct.") # # infer the label_type if not specified # if label_type is None: # if labels[0].values.dtype in ["float16", "float32", "float64", "int16", "int32", "int64"]: # label_type = "scalar" # else: # label_type = "cluster" # # downsampling if n_sampling is set # # TODO: implement downsampling # # center the coordinates # if center_coordinate: # for i in range(len(slices)): # spatial_coords[i] = spatial_coords[i] - np.mean(spatial_coords[i], axis=0) # # determine the interval of slices # slices_interval = [] # for i in range(len(slices) - 1): # slices_interval.append( # _compute_smallest_distance(spatial_coords[i], spatial_coords[i+1]) # ) # # Update the spatial coordinates # cur_pos = 0 # for i in range(len(slices) - 1): # cur_pos += slices_interval[i] # spatial_coords[i+1] += cur_pos # # determine the mapping line position and label # mapping_lines = [] # mapping_labels = [] # for i in range(len(correspondences)): # mapping_lines.append( # np.concatenate([spatial_coords[i][correspondences[i][:, 0]], spatial_coords[i+1][correspondences[i][:, 1]]], axis=1) # ) # mapping_labels.append( # labels[i][correspondences[i][:, 0]] # ) # # Set plot theme # sns.set_theme( # context="paper", # style="white", # font="Arial", # font_scale=1, # rc={ # # "font.size": font_size, # "": ["sans-serif"], # "font.sans-serif": ["Arial", "sans-serif", "Helvetica", "DejaVu Sans", "Bitstream Vera Sans"], # }, # ) # # generate palette # if (palette is None) and (label_type == "cluster"): # palette = _agenerate_palette(*labels, cmap=cmap) # else: # palette = cmap # # generate figure # fig, ax = plt.subplots(1, 1, figsize=(height * aspect_ratio, height)) # # determine the pointsize if not specified # if point_size is None: # point_size = 500 * height**2 * aspect_ratio / (slices_spatial_data.shape[0] / len(slices)) # # plotting # sns.scatterplot(x=x, y=y, hue=label, legend=legend, palette = palette, ax=ax, s=point_size, edgecolor=edgecolor) # TODO: finish this # def plot_align_correspondence_3d( # slices: List[AnnData], # label_key: Optional[str] = None, # spatial_key: str = "spatial", # point_size: Optional[float] = None, # n_sampling: int = -1, # palette: Optional[dict] = None, # show_legend: bool = True, # legend_kwargs: Optional[dict] = None, # axis_off: bool = False, # axis_kwargs: Optional[dict] = None, # ticks_off: bool = True, # x_min=None, # x_max=None, # y_min=None, # y_max=None, # height: float = 2, # alpha: float = 1.0, # TODO: alpha to be a key in adata # cmap="tab20", # center_coordinate: bool = False, # different from slices_2d # ): # # get spatial coords and labels # spatial_coords = [] # labels = [] # slice_ids = [] # for i, s in enumerate(slices): # if spatial_key in s.obsm.keys(): # spatial_coords.append(s.obsm[spatial_key].copy()) # else: # raise ValueError(f"adata.obsm['{spatial_key}'] does not exist.") # if label_key is not None: # if label_key in s.obs.keys(): # labels.append(s.obs[label_key].copy()) # label_type = "cluster" # elif label_key in s.var_names: # labels.append(s[:, label_key].X.A.copy().squeeze()) # label_type = "scalar" # else: # raise ValueError(f"adata.obs['{label_key}'] or adata.var['{label_key}'] does not exist.") # assert ( # spatial_coords[-1].shape[0] == labels[-1].shape[0] # ), "The number of spatial coordinates and labels must be the same. Please check the data." # else: # label_type = "cluster" # if (slices_key is not None) and (slices_key in s.obs.keys()): # unique_id = np.unique(s.obs[slices_key]) # if len(unique_id) == 1: # slice_ids.append(unique_id[0]) # else: # raise ValueError(f"adata.obs['{slices_key}'] must have only one unique value.") # else: # slice_ids.append(str(i))
[docs]def multi_slices( slices: Union[AnnData, List[AnnData]], slices_key: Optional[str] = None, label: Optional[str] = None, spatial_key: str = "align_spatial", layer: str = "X", point_size: Optional[float] = None, font_size: Optional[float] = 20, color: Optional[str] = "skyblue", palette: Optional[str] = None, alpha: float = 1.0, ncols: int = 4, ax_height: float = 1, dpi: int = 100, show_legend: bool = True, save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", save_kwargs: Optional[dict] = None, **kwargs, ): # Check slices object. if isinstance(slices, list): adatas = [s.copy() for s in slices] for i, s in enumerate(adatas): s.X = s.layers[layer].copy() if layer != "X" else s.X.copy() s.uns = {"__type": "UMI"} if slices_key is None: slices_key = "slices" s.obs[slices_key] = f"slice_{i}" adata = integrate(adatas=adatas, batch_key=slices_key) else: assert slices_key != None, "When `slices` is an anndata object, `slices_key` cannot be None." adata = slices.copy() adata.X = adata.layers[layer].copy() if layer != "X" else adata.X.copy() # Check label data and generate plotting data. slices_data = pd.DataFrame(adata.obsm[spatial_key][:, :2], columns=["x", "y"], dtype=float) slices_data[slices_key] = adata.obs[slices_key].values if label is None: label = "spatial coordinates" slices_data[label] = label elif label in adata.obs_keys(): slices_data[label] = adata.obs[label].values elif label in adata.var_names: adata.X = to_dense_matrix(adata.X) slices_data[label] = adata[:, label].X else: raise ValueError("`label` is not a valid column names or gene name.") # Set the arrangement of subgraphs slices_id = slices_data[slices_key].unique().tolist() nrows = math.ceil(len(slices_id) / ncols) # Set the aspect ratio of each subplot spatial_coords = slices_data[["x", "y"]].values.copy() ptp_vec = spatial_coords.ptp(0) aspect_ratio = ptp_vec[0] / ptp_vec[1] ax_height = 2 if nrows == 1 and ax_height == 1 else ax_height axsize = (ax_height * aspect_ratio, ax_height * 2) # Set multi-plot grid for plotting. sns.set_theme( context="paper", style="white", font="Arial", font_scale=1, rc={ "font.size": font_size, "": ["sans-serif"], "font.sans-serif": ["Arial", "sans-serif", "Helvetica", "DejaVu Sans", "Bitstream Vera Sans"], }, ) g = sns.FacetGrid( slices_data.copy(), col=slices_key, hue=label, palette=palette, sharex=True, sharey=True, height=axsize[1] * nrows, aspect=aspect_ratio, col_wrap=ncols, despine=False, ) # Calculate the most suitable size of the point. if point_size is None: group_slices_data = slices_data.groupby(by=slices_key) min_dist_list = [] for key, data in group_slices_data: sample_num = 1000 if len(data) > 1000 else len(data) min_dist_list.append(compute_smallest_distance(coords=data[["x", "y"]].values, sample_num=sample_num)) point_size = min(min_dist_list) * axsize[0] / ptp_vec[0] * dpi point_size = point_size**2 * ncols * nrows # Draw scatter plots. g.map_dataframe(sns.scatterplot, x="x", y="y", alpha=alpha, color=color, s=point_size, legend="brief", **kwargs) # Set legend. label_values = slices_data[label].values if label_values.dtype in ["float16", "float32", "float64", "int16", "int32", "int64"]: from mpl_toolkits.axes_grid1.inset_locator import inset_axes ax = g.facet_axis(row_i=0, col_j=ncols - 1) norm = mpl.colors.Normalize(vmin=None, vmax=None) mappable =, cmap=palette) mappable.set_array(label_values) plt.colorbar( mappable, cax=inset_axes( ax, width="12%", height="100%", loc="center left", bbox_to_anchor=(1.02, 0.0, 0.5, 1.0), bbox_transform=ax.transAxes, borderpad=1.85, ), ax=ax, orientation="vertical", alpha=alpha, label=label, ) else: g.add_legend() plt.tight_layout() return save_return_show_fig_utils( save_show_or_return=save_show_or_return, show_legend=show_legend, background="white", prefix="multi_slices", save_kwargs=save_kwargs, total_panels=len(slices_id), fig=g, axes=g, return_all=False, return_all_list=None, )
# TODO: Add docstring, add multi slices plot, legend scatter plot should keep the same size of the text # def plot_clusters( # adata: Union[AnnData, np.ndarray], # spatial_key: str = 'spatial', # label_key: Union[str, List[str], pd.DataFrame] = 'clusters', # ax: Optional[mpl.axes.Axes] = None, # point_size: float = 10, # n_sampling: int = -1, # palette: Optional[dict] = None, # title: str = '', # title_kwargs: Optional[dict] = None, # # title_fontsize = 16, # show_legend: bool = True, # legend_kwargs: Optional[dict] = None, # axis_off: bool = True, # axis_kwargs: Optional[dict] = None, # ticks_off: bool = True, # color=None, # x_min = None, # x_max = None, # y_min = None, # y_max = None, # **kwargs, # ): # """ # Plots spatial data with cluster labels stored in adata.obs[col] # """ # # get spatial coords and labels # if isinstance(adata, AnnData): # if spatial_key in adata.obsm.keys(): # spatial_coords = adata.obsm[spatial_key].copy() # else: # raise ValueError(f"adata.obsm['{spatial_key}'] does not exist.") # elif isinstance(adata, np.ndarray): # if adata.shape[1] > 1: # spatial_coords = adata.copy()[:,:2] # else: # raise ValueError("the input spatial coordinates must have at least 2 columns.") # if isinstance(adata, AnnData): # if label_key in adata.obs.keys(): # label = adata.obs[label_key].copy() # else: # raise ValueError(f"adata.obs['{label_key}'] does not exist.") # elif isinstance(label_key, list): # label = pd.Series(label_key) # elif isinstance(label_key, pd.DataFrame): # label = label_key # else: # raise ValueError(f"label_key must be a string, list, or DataFrame.") # assert spatial_coords.shape[0] == label.shape[0], "The number of spatial coordinates and labels must be the same." # # downsampling if n_sampling is set # sampling_idx = np.random.choice(spatial_coords.shape[0], n_sampling, replace=False) if n_sampling > 0 and n_sampling < spatial_coords.shape[0] else np.arange(spatial_coords.shape[0]) # x = spatial_coords[sampling_idx, 0] # y = spatial_coords[sampling_idx, 1] # label = label[sampling_idx] # # get unique labels # unique_labels = np.unique(label) # # get color palette if not provided # if palette is None: # n_colors = len(unique_labels) # palette = sns.color_palette("tab20", n_colors) # # plot # scatterplot_kwargs = { # 'x': x, # 'y': y, # 'hue': label, # 'palette': palette, # 'ax': ax, # 's': point_size, # 'legend': show_legend, # } # scatterplot_kwargs.update(kwargs) # sns.scatterplot(**scatterplot_kwargs) # # adjust the legend # if show_legend: # default_legend_kwargs = { # 'loc': 'center left', # 'bbox_to_anchor': (1, 0.5), # 'prop': {'family': 'Arial', 'size': 10}, # 'fancybox': False, # 'edgecolor': 'black', # 'framealpha': 1, # 'columnspacing': 0.5, # 'handletextpad': 0.1, # } # if legend_kwargs: # default_legend_kwargs.update(legend_kwargs) # ax.legend(**default_legend_kwargs) # # set axis limits # if x_min is not None and x_max is not None: # ax.set_xlim(x_min, x_max) # if y_min is not None and y_max is not None: # ax.set_ylim(y_min, y_max) # # set other axis properties # if axis_off: # ax.axis('off') # if ticks_off: # ax.set_xticks([]) # ax.set_yticks([]) # if title_kwargs: # default_title_kwargs = { # 'label': title, # 'fontsize': 16, # } # default_title_kwargs.update(title_kwargs) # ax.set_title(title_kwargs) # ax.set_aspect('equal') # def _spatial_scatter( # spatial_x: np.ndarray, # spatial_y: np.ndarray, # color: Union[str, np.ndarray], # point_size: Union[float, np.ndarray], # alpha: Union[float, np.ndarray], # edgecolors: Optional[Union[str, np.ndarray]] = None, # linewidths: Optional[Union[float, np.ndarray]] = None, # marker: Optional[str] = "o", # palette: Optional[Union[str, dict]] = None, # ax: Optional[mpl.axes.Axes] = None, # ): # if ax is None: # _, ax = plt.subplots() # if isinstance(color, str): # color = np.array([color] * len(spatial_x)) # if isinstance(point_size, float): # point_size = np.array([point_size] * len(spatial_x)) # if isinstance(alpha, float): # alpha = np.array([alpha] * len(spatial_x)) # if edgecolors is None: # edgecolors = "none" # if linewidths is None: # linewidths = 0 # if palette is not None: # if isinstance(palette, str): # palette = sns.color_palette(palette, n_colors=len(np.unique(color))) # elif isinstance(palette, dict): # palette = [palette.get(c, "gray") for c in np.unique(color)] # else: # raise ValueError("`palette` must be a string or a dictionary.") # color = [palette[np.where(np.unique(color) == c)[0][0]] for c in color] # ax.scatter( # spatial_x, # spatial_y, # c=color, # s=point_size, # alpha=alpha, # edgecolors=edgecolors, # linewidths=linewidths, # marker=marker, # ) # ax.set_aspect('equal') # return ax
[docs]def _agenerate_palette(*labels, cmap="tab20"): if len(labels) == 1: labels = labels[0] elif len(labels) > 1: labels = np.concatenate(labels) else: raise ValueError("No labels provided.") unique_labels = np.unique(labels) n_labels = len(unique_labels) palette = {l: sns.color_palette(cmap, n_labels)[i] for i, l in enumerate(unique_labels)} return palette
[docs]def _compute_smallest_distance(spatial_coord1, spatial_coord2, direction="x", scale_factor=1.1): if direction == "x": spatial_coord1_max = np.max(spatial_coord1[:, 0]) spatial_coord2_min = -np.min(spatial_coord2[:, 0]) interval = (spatial_coord1_max + spatial_coord2_min) * scale_factor elif direction == "y": spatial_coord1_max = np.max(spatial_coord1[:, 1]) spatial_coord2_min = -np.min(spatial_coord2[:, 1]) interval = (spatial_coord1_max + spatial_coord2_min) * scale_factor else: raise ValueError("`direction` must be 'x' or 'y'.") return interval
[docs]def transform_by_min_max(x, _min, _max, interval=0.1): x = x - _min x = x / _max x = x * (1 - 2 * interval) x = x + interval return x
[docs]def get_min_max(x): _min = x.min(0) x = x - _min _max = x.max(0) return _min, _max
[docs]def transform_H(x, H, z_shift=0): x_H = np.concatenate([x, np.ones((x.shape[0], 1))], axis=1) transformed_x = (H @ x_H.T).T transformed_x = transformed_x / transformed_x[:, 2:] transformed_x[:, 1] = transformed_x[:, 1] + z_shift return transformed_x[:, :2]
[docs]def get_H(h=0.5, w=0.2): corner_points = np.array([[0, 0], [0, 1], [1, 0], [1, 1]]) # transformed_corner_points = np.array([[0,0], [w,h], [1,0], [1-w, h]]) transformed_corner_points = np.array([[w, h], [1 - w, h], [0, 0], [1, 0]]) import cv2 H, _ = cv2.findHomography(srcPoints=corner_points, dstPoints=transformed_corner_points) return H