import math
from typing import List, Optional, Union
import matplotlib as mpl
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from anndata import AnnData
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
import numpy as np
from scipy.sparse import issparse
from ...tools.cluster.utils import integrate, to_dense_matrix
from ...tools.utils import compute_smallest_distance
from .utils import save_return_show_fig_utils
[docs]def slices_2d(
slices: Union[AnnData, List[AnnData]],
slices_key: Optional[Union[bool, str]] = None,
label_key: Optional[str] = None,
label_type: Optional[str] = None,
spatial_key: str = "spatial",
point_size: Optional[float] = None,
n_sampling: int = -1,
palette: Optional[dict] = None,
ncols: int = 4,
title: str = "",
title_kwargs: Optional[dict] = None,
show_legend: bool = True,
legend_kwargs: Optional[dict] = None,
axis_off: bool = False,
axis_kwargs: Optional[dict] = None,
ticks_off: bool = True,
x_min=None,
x_max=None,
y_min=None,
y_max=None,
height: float = 2,
alpha: float = 1.0, # TODO: alpha to be a key in adata
cmap="tab20",
center_coordinate: bool = False,
gridspec_kws: Optional[dict] = None,
return_palette: bool = False,
save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show",
save_kwargs: Optional[dict] = None,
sort_values: bool = True,
sort_ascending: bool = True,
**kwargs,
):
"""Plot multiple 2D spatial transcriptomics slices for both categorical clusters and scalar values.
Parameters
----------
slices
A single :class:`~anndata.AnnData` object **or** a list of ``AnnData``,
each element represents one spatial slice.
slices_key
Column in ``.obs`` that identifies each slice. Required when a
single ``AnnData`` containing multiple batches is supplied.
label_key
Column name in ``.obs`` **or** gene name in ``.var`` whose values will
be visualised. If the value is numeric, it will be treated as a scalar value, otherwise it will be treated as a cluster.
label_type
Force interpretation of ``label_key`` as ``'cluster'`` or
``'scalar'``. If *None* the type is inferred automatically. If the value is numeric, it will be treated as a scalar value, otherwise it will be treated as a cluster.
spatial_key
Key in ``.obsm`` that stores the spatial coordinates (default
``'spatial'``).
point_size
Size of scatter points in points². If *None* an empirical value is
computed based on the overall point density.
n_sampling
Down-sample each slice to *n_sampling* points (without replacement).
Supply ``-1`` to disable down-sampling.
palette
Mapping from category → color for cluster plots **or** a colormap name
for scalar plots. If *None* a palette/colormap is generated
automatically.
ncols
Number of subplot columns. The value is automatically clipped so as
not to exceed the number of slices.
title, title_kwargs
Common title string and optional ``Axes.set_title`` keyword arguments.
show_legend, legend_kwargs
Whether to draw a legend (clusters) or colour-bar (scalar) and optional
keyword overrides.
axis_off, axis_kwargs
Whether to hide the axis frame/ticks and optional overrides to
``Axes.axis``.
ticks_off
If *True* x/y-tick labels are removed.
x_min, x_max, y_min, y_max
Explicit axis limits. When *None* the limits are derived from the
data and padded by 5 % so that points at the border are still visible.
height
Height of **one** subplot in inches.
alpha
Global alpha (transparency) for all points.
cmap
Name of a matplotlib colormap to use when ``label_type == 'scalar'``.
center_coordinate
If *True* centre the coordinates of every slice at (0, 0).
gridspec_kws
Extra keyword arguments passed to ``plt.subplots`` via the
``gridspec_kw`` parameter.
return_palette
If *True* the colour mapping used for plotting is returned alongside
the figure.
save_show_or_return
One of ``'save'``, ``'show'``, ``'return'`` or ``'both'`` / ``'all'`` –
forwarded to :func:`save_return_show_fig_utils`.
save_kwargs
Extra keyword arguments for saving.
sort_values
Only relevant for scalar plots. If *True* the points are rendered in
order of ``label_key`` (after sorting) so that points with higher (or
lower) values are drawn on top.
sort_ascending
Sort direction used when ``sort_values`` is *True* (ascending ⇢ smaller
values on top).
**kwargs
Additional keyword arguments forwarded to :func:`matplotlib.pyplot.scatter`.
Returns
-------
Depending on ``save_show_or_return`` either a :class:`matplotlib.figure.Figure`
object, ``None`` (when the figure is saved/shown directly), or a tuple
containing both the figure and the palette.
"""
# Check slices object.
if isinstance(slices, AnnData):
slices = [slices]
# get spatial coords and labels
spatial_coords = []
labels = []
slice_ids = []
for i, s in enumerate(slices):
if spatial_key in s.obsm.keys():
spatial_coords.append(s.obsm[spatial_key].copy())
else:
raise ValueError(f"adata.obsm['{spatial_key}'] does not exist.")
if label_key in s.obs.keys():
labels.append(s.obs[label_key].copy())
# label_type = "cluster"
elif label_key in s.var_names:
X_data = s[:, label_key].X
if issparse(X_data): # Check if it's a sparse matrix
labels.append(X_data.A.copy().squeeze())
else:
labels.append(X_data.copy().squeeze())
# label_type = "scalar"
else:
raise ValueError(f"adata.obs['{label_key}'] or adata.var['{label_key}'] does not exist.")
if (slices_key is not None) and (slices_key in s.obs.keys()):
unique_id = np.unique(s.obs[slices_key])
if len(unique_id) == 1:
slice_ids.append(unique_id[0])
else:
raise ValueError(f"adata.obs['{slices_key}'] must have only one unique value.")
else:
slice_ids.append(str(i))
assert (
spatial_coords[-1].shape[0] == labels[-1].shape[0]
), "The number of spatial coordinates and labels must be the same. Please check the data."
# infer the label_type if not specified
if label_type is None:
if labels[0].dtype in ["float16", "float32", "float64", "int16", "int32", "int64"]:
label_type = "scalar"
else:
label_type = "cluster"
# downsampling if n_sampling is set
for i in range(len(slices)):
sampling_idx = (
np.random.choice(spatial_coords[i].shape[0], n_sampling, replace=False)
if n_sampling > 0 and n_sampling < spatial_coords[i].shape[0]
else np.arange(spatial_coords[i].shape[0])
)
spatial_coords[i] = spatial_coords[i][sampling_idx]
labels[i] = labels[i][sampling_idx]
# center the coordinates
if center_coordinate:
for i in range(len(slices)):
spatial_coords[i] = spatial_coords[i] - np.mean(spatial_coords[i], axis=0)
# calculate global spatial bounds for consistent scaling
all_coords = np.vstack(spatial_coords)
x_global_min, x_global_max = all_coords[:, 0].min(), all_coords[:, 0].max()
y_global_min, y_global_max = all_coords[:, 1].min(), all_coords[:, 1].max()
# expand bounds by 2% to prevent edge points from being cut off
x_range = x_global_max - x_global_min
y_range = y_global_max - y_global_min
expand_factor = 0.05
x_global_min -= x_range * expand_factor
x_global_max += x_range * expand_factor
y_global_min -= y_range * expand_factor
y_global_max += y_range * expand_factor
# use provided limits if available (override expanded bounds)
if x_min is not None:
x_global_min = x_min
if x_max is not None:
x_global_max = x_max
if y_min is not None:
y_global_min = y_min
if y_max is not None:
y_global_max = y_max
# set the aspect ratio
ptp_vec = all_coords.ptp(0)
aspect_ratio = ptp_vec[0] / ptp_vec[1]
# Set the arrangement of subgraphs
# Adjust ncols if it's larger than the number of slices
ncols = min(ncols, len(slices))
nrows = math.ceil(len(slices) / ncols)
# generate palette
if (palette is None) and (label_type == "cluster"):
palette = _agenerate_palette(*labels, cmap=cmap)
elif label_type == "scalar":
palette = cmap
# determine the pointsize if not specified
if point_size is None:
total_points = sum(len(coords) for coords in spatial_coords)
point_size = 500 * height**2 * aspect_ratio / (total_points / len(slices))
# Set plot theme
sns.set_theme(
context="paper",
style="white",
font="Arial",
font_scale=1,
rc={
# "font.size": font_size,
"font.family": ["sans-serif"],
"font.sans-serif": ["Arial", "sans-serif", "Helvetica", "DejaVu Sans", "Bitstream Vera Sans"],
},
)
# adjust the gridspec
_gridspec_kws = {"wspace": 0.1, "hspace": 0.2}
if slices_key is False:
_gridspec_kws["hspace"] = _gridspec_kws["wspace"] * aspect_ratio
if gridspec_kws is not None:
_gridspec_kws.update(gridspec_kws)
# Create subplots
fig, axes = plt.subplots(
nrows,
ncols,
figsize=(ncols * height * aspect_ratio, nrows * height),
gridspec_kw=_gridspec_kws,
)
# Handle single subplot case
if nrows == 1 and ncols == 1:
axes = [axes]
elif nrows == 1 or ncols == 1:
axes = axes.flatten()
else:
axes = axes.flatten()
if label_type == "scalar":
all_scalar_vals = np.concatenate([np.asarray(l) for l in labels])
scalar_norm = mpl.colors.Normalize(vmin=all_scalar_vals.min(), vmax=all_scalar_vals.max())
else:
scalar_norm = None
# Plot each slice
for i in range(len(slices)):
ax = axes[i]
if label_type == "cluster":
# For categorical data, use scatter with color mapping
unique_labels = np.unique(labels[i])
for label in unique_labels:
mask = labels[i] == label
if np.any(mask):
ax.scatter(
spatial_coords[i][mask, 0],
spatial_coords[i][mask, 1],
c=[palette[label]],
s=point_size,
alpha=alpha,
label=label,
edgecolors="none",
**kwargs,
)
else:
# For scalar data, use continuous color mapping
cur_vals = np.asarray(labels[i])
if sort_values:
# Sort points by values to control plotting order
sort_idx = np.argsort(cur_vals)
if not sort_ascending:
sort_idx = sort_idx[::-1] # Reverse for descending order
cur_vals = cur_vals[sort_idx]
cur_coords = spatial_coords[i][sort_idx]
else:
cur_coords = spatial_coords[i]
scatter = ax.scatter(
cur_coords[:, 0],
cur_coords[:, 1],
c=cur_vals,
s=point_size,
alpha=alpha,
cmap=palette,
norm=scalar_norm,
edgecolors="none",
**kwargs,
)
# Set axis properties
ax.set_xlim(x_global_min, x_global_max)
ax.set_ylim(y_global_min, y_global_max)
ax.set_aspect("equal")
if axis_off:
ax.axis("off")
if ticks_off:
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel("")
ax.set_ylabel("")
# Set title
if slices_key is False:
ax.set_title("")
else:
title_text = f"Slice {slice_ids[i]}" if title == "" else title
if title_kwargs:
ax.set_title(title_text, **title_kwargs)
else:
ax.set_title(title_text)
# Hide unused subplots
for i in range(len(slices), len(axes)):
axes[i].axis("off")
axes[i].set_xticks([])
axes[i].set_yticks([])
# Create legend / colorbar
if show_legend:
if label_type == "cluster":
_legend_kwargs = {
"loc": "center left",
"bbox_to_anchor": (1, 0.5),
"prop": {"family": "Arial", "size": 10},
"fancybox": False,
"edgecolor": "black",
"framealpha": 1,
"columnspacing": 0.8,
"handletextpad": 0.5,
"frameon": True,
}
if legend_kwargs:
_legend_kwargs.update(legend_kwargs)
legend_elements = [
mpl.lines.Line2D(
[0], [0], marker="o", color="w", label=k, markerfacecolor=v, markersize=6, markeredgecolor="k"
)
for k, v in palette.items()
]
fig.legend(handles=legend_elements, **_legend_kwargs)
else:
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
# For scalar data, add a global colorbar using all scalar values
sm = mpl.cm.ScalarMappable(cmap=palette, norm=scalar_norm)
sm.set_array([])
# Place below all panels
cbar = fig.colorbar(
sm,
ax=axes.tolist(),
orientation="vertical",
fraction=0.05,
pad=0.07,
shrink=0.5,
cax=inset_axes(
ax,
width="15%",
height="75%",
loc="center left",
# **_legend_kwargs,
bbox_to_anchor=(1.02, 0.0, 0.5, 1.0),
bbox_transform=ax.transAxes,
),
)
cbar.set_label(label_key if label_key else "Value", fontsize=10)
plt.tight_layout()
if return_palette:
return (
save_return_show_fig_utils(
save_show_or_return=save_show_or_return,
show_legend=show_legend,
background="white",
prefix="multi_slices_scatter",
save_kwargs=save_kwargs,
total_panels=len(slice_ids),
fig=fig,
axes=axes,
return_all=False,
return_all_list=None,
),
palette,
)
else:
return save_return_show_fig_utils(
save_show_or_return=save_show_or_return,
show_legend=show_legend,
background="white",
prefix="multi_slices_scatter",
save_kwargs=save_kwargs,
total_panels=len(slice_ids),
fig=fig,
axes=axes,
return_all=False,
return_all_list=None,
)
[docs]def overlay_slices_2d(
slices: Union[AnnData, List[AnnData]],
slices_key: Optional[Union[bool, str]] = None,
label_key: Optional[str] = None,
overlay_type: Literal["forward", "backward", "both"] = "both",
spatial_key: str = "spatial",
point_size: Optional[float] = None,
n_sampling: int = -1,
palette: Optional[dict] = None,
ncols: int = 4,
title: str = "",
title_kwargs: Optional[dict] = None,
show_legend: bool = True,
legend_kwargs: Optional[dict] = None,
axis_off: bool = False,
axis_kwargs: Optional[dict] = None,
ticks_off: bool = True,
x_min=None,
x_max=None,
y_min=None,
y_max=None,
height: float = 2,
alpha: float = 1.0, # TODO: alpha to be a key in adata
cmap="tab20",
center_coordinate: bool = False, # different from slices_2d
gridspec_kws: Optional[dict] = None,
save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show",
save_kwargs: Optional[dict] = None,
**kwargs,
):
# Check slices object.
if isinstance(slices, AnnData):
slices = [slices]
# get spatial coords and labels
spatial_coords = []
labels = []
slice_ids = []
for i, s in enumerate(slices):
if spatial_key in s.obsm.keys():
spatial_coords.append(s.obsm[spatial_key].copy())
else:
raise ValueError(f"adata.obsm['{spatial_key}'] does not exist.")
if label_key is not None:
if label_key in s.obs.keys():
labels.append(s.obs[label_key].copy())
label_type = "cluster"
elif label_key in s.var_names:
X_data = s[:, label_key].X
if issparse(X_data): # Check if it's a sparse matrix
labels.append(X_data.A.copy().squeeze())
else:
labels.append(X_data.copy().squeeze())
label_type = "scalar"
else:
raise ValueError(f"adata.obs['{label_key}'] or adata.var['{label_key}'] does not exist.")
assert (
spatial_coords[-1].shape[0] == labels[-1].shape[0]
), "The number of spatial coordinates and labels must be the same. Please check the data."
else:
label_type = "cluster"
if (slices_key is not None) and (slices_key in s.obs.keys()):
unique_id = np.unique(s.obs[slices_key])
if len(unique_id) == 1:
slice_ids.append(unique_id[0])
else:
raise ValueError(f"adata.obs['{slices_key}'] must have only one unique value.")
else:
slice_ids.append(str(i))
# downsampling if n_sampling is set
for i in range(len(slices)):
sampling_idx = (
np.random.choice(spatial_coords[i].shape[0], n_sampling, replace=False)
if n_sampling > 0 and n_sampling < spatial_coords[i].shape[0]
else np.arange(spatial_coords[i].shape[0])
)
spatial_coords[i] = spatial_coords[i][sampling_idx]
if label_key is not None:
labels[i] = labels[i][sampling_idx]
# center the coordinates
if center_coordinate:
for i in range(len(slices)):
spatial_coords[i] = spatial_coords[i] - np.mean(spatial_coords[i], axis=0)
# Set the arrangement of subgraphs
nrows = math.ceil(len(slices) / ncols)
# create dataframe for ploting
slices_spatial_data = pd.DataFrame(columns=["x", "y", "label", "overlay_id", "slice_id", "col", "row"])
for i in range(len(slices)):
if (
(overlay_type == "both")
or ((overlay_type == "backward") and (i < len(slices) - 1))
or ((overlay_type == "forward") and (i > 0))
):
slices_spatial_data = pd.concat(
[
slices_spatial_data,
pd.DataFrame(
{
"x": spatial_coords[i][:, 0],
"y": spatial_coords[i][:, 1],
"label": labels[i] if label_key is not None else "unknow",
"overlay_id": "current",
"slice_id": slice_ids[i],
"col": i % ncols,
"row": i // ncols,
}
),
],
axis=0,
)
if (i > 0) and ((overlay_type == "forward") or (overlay_type == "both")):
slices_spatial_data = pd.concat(
[
slices_spatial_data,
pd.DataFrame(
{
"x": spatial_coords[i][:, 0],
"y": spatial_coords[i][:, 1],
"label": labels[i] if label_key is not None else "unknow",
"overlay_id": "current",
"slice_id": slice_ids[i],
"col": i % ncols,
"row": i // ncols,
}
),
],
axis=0,
)
slices_spatial_data = pd.concat(
[
slices_spatial_data,
pd.DataFrame(
{
"x": spatial_coords[i - 1][:, 0],
"y": spatial_coords[i - 1][:, 1],
"label": labels[i - 1] if label_key is not None else "unknow",
"overlay_id": "forward",
"slice_id": slice_ids[i - 1],
"col": i % ncols,
"row": i // ncols,
}
),
],
axis=0,
)
if (i < len(slices) - 1) and ((overlay_type == "backward") or (overlay_type == "both")):
slices_spatial_data = pd.concat(
[
slices_spatial_data,
pd.DataFrame(
{
"x": spatial_coords[i][:, 0],
"y": spatial_coords[i][:, 1],
"label": labels[i] if label_key is not None else "unknow",
"overlay_id": "current",
"slice_id": slice_ids[i],
"col": i % ncols,
"row": i // ncols,
}
),
],
axis=0,
)
slices_spatial_data = pd.concat(
[
slices_spatial_data,
pd.DataFrame(
{
"x": spatial_coords[i + 1][:, 0],
"y": spatial_coords[i + 1][:, 1],
"label": labels[i + 1] if label_key is not None else "unknow",
"overlay_id": "backward",
"slice_id": slice_ids[i + 1],
"col": i % ncols,
"row": i // ncols,
}
),
],
axis=0,
)
# set the aspect ratio of each subplot
ptp_vec = slices_spatial_data[["x", "y"]].values.ptp(0)
aspect_ratio = ptp_vec[0] / ptp_vec[1]
# Set multi-plot grid for plotting.
sns.set_theme(
context="paper",
style="white",
font="Arial",
font_scale=1,
rc={
# "font.size": font_size,
"font.family": ["sans-serif"],
"font.sans-serif": ["Arial", "sans-serif", "Helvetica", "DejaVu Sans", "Bitstream Vera Sans"],
},
)
# generate palette
if label_key is not None:
if (palette is None) and (label_type == "cluster"):
palette = _agenerate_palette(*labels, cmap=cmap)
else:
palette = cmap
else:
if overlay_type == "both":
palette = {
"current": "red",
"forward": "green",
"backward": "blue",
}
elif overlay_type == "forward":
palette = {
"current": "red",
"forward": "green",
}
elif overlay_type == "backward":
palette = {
"current": "red",
"backward": "blue",
}
# adjust the gridspec
_gridspec_kws = {"wspace": 0.1, "hspace": 0.2}
if gridspec_kws is not None:
_gridspec_kws.update(gridspec_kws)
if slices_key is False:
_gridspec_kws["hspace"] = _gridspec_kws["wspace"] * aspect_ratio
# determine the pointsize if not specified
if point_size is None:
point_size = 500 * height**2 * aspect_ratio / (slices_spatial_data.shape[0] / len(slices))
# plotting
g = sns.FacetGrid(
slices_spatial_data,
col="col",
row="row",
hue="label" if label_key is not None else "overlay_id",
palette=palette,
sharex=True,
sharey=True,
height=height,
aspect=aspect_ratio,
despine=False,
gridspec_kws=_gridspec_kws,
xlim=(x_min, x_max) if x_min is not None and x_max is not None else None,
ylim=(y_min, y_max) if y_min is not None and y_max is not None else None,
)
scatterplot_kwargs = {"x": "x", "y": "y", "alpha": alpha, "s": point_size, "legend": False, "edgecolor": None}
scatterplot_kwargs.update(kwargs)
g.map_dataframe(sns.scatterplot, **scatterplot_kwargs)
for i, (col_val, ax) in enumerate(g.axes_dict.items()):
if i < len(slices):
if slices_key is False:
ax.set_title(title)
else:
ax.set_title(f"Slice {slice_ids[i]}", title_kwargs)
else:
ax.set_title("")
ax.set_xticks([])
ax.set_yticks([])
ax.axis("off")
ax.set_aspect("equal")
if axis_off:
ax.axis("off")
if ticks_off:
ax.set_xticks([])
ax.set_yticks([])
# if x_max is not None and x_min is not None:
# ax.set_xlim(x_min, x_max)
# if y_max is not None and y_min is not None:
# ax.set_ylim(y_min, y_max)
ax.set_xlabel("")
ax.set_ylabel("")
# create legend
if show_legend:
if label_type == "cluster":
_legend_kwargs = {
"loc": "upper center",
"bbox_to_anchor": (0.5, 0),
"prop": {"family": "Arial", "size": 10},
"fancybox": False,
"edgecolor": "black",
"framealpha": 1,
"columnspacing": 0.8,
"handletextpad": 0.5,
"frameon": True,
"ncol": 8,
"borderaxespad": -4,
"frameon": False,
}
if legend_kwargs:
_legend_kwargs.update(legend_kwargs)
# if legend_kwargs.get('loc', None) == 'upper center':
# _legend_kwargs['bbox_to_anchor'] = (0.5, 0)
legend_elements = [
mpl.lines.Line2D(
[0], [0], marker="o", color="w", label=k, markerfacecolor=v, markersize=6, markeredgecolor="k"
)
for k, v in palette.items()
]
g.figure.legend(handles=legend_elements, **_legend_kwargs)
else:
_legend_kwargs = {
"loc": "center left",
# 'bbox_to_anchor': (1, 0.5, 0.5, 1.0),
# 'prop': {'family': 'Arial', 'size': 10},
# 'fancybox': False,
# 'edgecolor': 'black',
# 'framealpha': 1,
# 'columnspacing': 0.5,
# 'handletextpad': 0.1,
# 'frameon': True,
}
if legend_kwargs:
_legend_kwargs.update(legend_kwargs)
# if legend_kwargs.get('loc', None) == 'upper center':
# _legend_kwargs['bbox_to_anchor'] = (0.5, 0, 0.5, 1.0)
# TODO: add colorbar for scalar value input
label_values = slices_spatial_data["label"].values
norm = mpl.colors.Normalize(vmin=None, vmax=None)
mappable = mpl.cm.ScalarMappable(norm=norm, cmap=palette)
mappable.set_array(label_values)
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
g.figure.colorbar(
mappable,
use_gridspec=False,
shrink=0.5,
cax=inset_axes(
ax,
width="15%",
height="75%",
loc="center left",
# **_legend_kwargs,
bbox_to_anchor=(1.02, 0.0, 0.5, 1.0),
bbox_transform=ax.transAxes,
),
)
# TODO: add save_return_show_fig_utils
# plt.tight_layout()
return save_return_show_fig_utils(
save_show_or_return=save_show_or_return,
show_legend=show_legend,
background="white",
prefix="multi_slices",
save_kwargs=save_kwargs,
total_panels=len(slice_ids),
fig=g,
axes=g,
return_all=False,
return_all_list=None,
)
[docs]def optimization_animation(
aligned_slices: List[AnnData],
label_key: Optional[str] = None,
spatial_key: str = "spatial",
key_added: str = "align_spatial",
iter_key_added: Optional[str] = "iter_spatial",
filename: Optional[str] = "Visualization2D",
fps: int = 10,
stepsize: int = 10,
cmap="Set1",
palette: Optional[dict] = None,
point_size: Optional[float] = None,
n_sampling: int = -1,
):
assert len(aligned_slices) == 2, "Input aligned_slices must be 2 slices!"
if label_key is not None:
if [label_key in s.obs.keys() for s in aligned_slices]:
labels = [s.obs[label_key] for s in aligned_slices]
label1 = aligned_slices[0].obs[label_key]
label2 = aligned_slices[1].obs[label_key]
else:
label1 = np.zeros((aligned_slices[0].shape[0],), dtype=np.int32)
label2 = np.ones((aligned_slices[1].shape[0],), dtype=np.int32)
if n_sampling > 0:
sampling_idx1 = (
np.random.choice(aligned_slices[0].shape[0], n_sampling, replace=False)
if n_sampling < aligned_slices[0].shape[0]
else np.arange(aligned_slices[0].shape[0])
)
sampling_idx2 = (
np.random.choice(aligned_slices[1].shape[0], n_sampling, replace=False)
if n_sampling < aligned_slices[1].shape[0]
else np.arange(aligned_slices[1].shape[0])
)
else:
sampling_idx1 = np.arange(aligned_slices[0].shape[0])
sampling_idx2 = np.arange(aligned_slices[1].shape[0])
# generate palette
if (palette is None) and (label_key is not None):
palette = _agenerate_palette(*labels, cmap=cmap)
if label_key is not None:
label1_colors = [palette[cat] for cat in label1[sampling_idx1]]
label2_colors = [palette[cat] for cat in label2[sampling_idx2]]
else:
label1_colors = ["#e41a1c" for cat in label1[sampling_idx1]]
label2_colors = ["#377eb8" for cat in label2[sampling_idx2]]
if point_size is None:
point_size = 500 * 10 / (len(sampling_idx1) + len(sampling_idx2))
coordsB = aligned_slices[0].obsm[spatial_key]
plt.ioff()
fig, ax = plt.subplots(figsize=(10, 5))
ax.axis("equal")
ax.set_xticks([])
ax.set_yticks([])
artists = []
iter_dict = aligned_slices[1].uns[iter_key_added]
iter = len(iter_dict[key_added])
iteration = range(0, iter, stepsize)
ax.scatter(
coordsB[sampling_idx1, 0], coordsB[sampling_idx1, 1], marker="o", s=point_size, c=label1_colors, edgecolors=None
)
for i in iteration:
frame = ax.scatter(
iter_dict[key_added][i][sampling_idx2, 0],
iter_dict[key_added][i][sampling_idx2, 1],
marker="o",
s=point_size,
c=label2_colors,
edgecolors=None,
)
title_text = "Iter: {}, sigma2: {:.3f}.".format(i, iter_dict["sigma2"][i])
tit = ax.text(0.5, 1.02, title_text, ha="center", va="bottom", size=16, weight="bold", transform=ax.transAxes)
artists.append([frame, tit])
ani = animation.ArtistAnimation(fig=fig, artists=artists, interval=4, blit=False)
ani.save(filename + ".gif", fps=fps, dpi=100)
plt.close()
# def plot_align_correspondence_2d(
# slices: List[AnnData],
# mapping: List[np.ndarray],
# label_key: Optional[str] = None,
# spatial_key: str = "spatial",
# point_size: Optional[float] = None,
# linewidth: Optional[float] = None,
# n_sampling: int = -1,
# mapping_sampling: int = -1,
# robust_threshold: Optional[float] = None,
# palette: Optional[dict] = None,
# show_legend: bool = True,
# legend_kwargs: Optional[dict] = None,
# axis_off: bool = False,
# axis_kwargs: Optional[dict] = None,
# ticks_off: bool = True,
# x_min=None,
# x_max=None,
# y_min=None,
# y_max=None,
# height: float = 2,
# alpha: float = 1.0, # TODO: alpha to be a key in adata
# cmap="tab20",
# center_coordinate: bool = False, # different from slices_2d
# ):
# assert len(mapping) == len(slices) - 1, "The length of mapping should be len(slices) - 1."
# # get spatial coords and labels
# spatial_coords = []
# labels = []
# for i, s in enumerate(slices):
# if spatial_key in s.obsm.keys():
# spatial_coords.append(s.obsm[spatial_key].copy())
# else:
# raise ValueError(f"adata.obsm['{spatial_key}'] does not exist.")
# if label_key in s.obs.keys():
# labels.append(s.obs[label_key].copy())
# # label_type = "cluster"
# elif label_key in s.var_names:
# labels.append(s[:, label_key].X.A.copy().squeeze())
# # label_type = "scalar"
# else:
# raise ValueError(f"adata.obs['{label_key}'] or adata.var['{label_key}'] does not exist.")
# assert (
# spatial_coords[-1].shape[0] == labels[-1].shape[0]
# ), "The number of spatial coordinates and labels must be the same. Please check the data."
# # add mapping
# correspondences = []
# for i in range(len(mapping)):
# if mapping[i].shape[1] == 2:
# correspondences.append(mapping[i])
# elif (mapping[i].shape[1] == slices[i+1].shape[0]) and (mapping[i].shape[0] == slices[i].shape[0]):
# sampling_idx = (
# np.random.choice(mapping[i].shape[0], mapping_sampling, replace=False)
# if mapping_sampling > 0 and mapping_sampling < mapping[i].shape[0]
# else np.arange(mapping[i].shape[0])
# )
# mapping_argmax = np.argmax(mapping[i][sampling_idx], axis=1)
# mapping_valmax = mapping[i][sampling_idx][np.arange(mapping_sampling), mapping_argmax]
# mask = np.arange(sampling_idx.shape[0]) if robust_threshold is None else mapping_valmax > robust_threshold
# correspondence = np.array([sampling_idx[mask], mapping_argmax[mask]]).T
# correspondences.append(correspondence)
# else:
# raise ValueError("The shape of mapping is not correct.")
# # infer the label_type if not specified
# if label_type is None:
# if labels[0].values.dtype in ["float16", "float32", "float64", "int16", "int32", "int64"]:
# label_type = "scalar"
# else:
# label_type = "cluster"
# # downsampling if n_sampling is set
# # TODO: implement downsampling
# # center the coordinates
# if center_coordinate:
# for i in range(len(slices)):
# spatial_coords[i] = spatial_coords[i] - np.mean(spatial_coords[i], axis=0)
# # determine the interval of slices
# slices_interval = []
# for i in range(len(slices) - 1):
# slices_interval.append(
# _compute_smallest_distance(spatial_coords[i], spatial_coords[i+1])
# )
# # Update the spatial coordinates
# cur_pos = 0
# for i in range(len(slices) - 1):
# cur_pos += slices_interval[i]
# spatial_coords[i+1] += cur_pos
# # determine the mapping line position and label
# mapping_lines = []
# mapping_labels = []
# for i in range(len(correspondences)):
# mapping_lines.append(
# np.concatenate([spatial_coords[i][correspondences[i][:, 0]], spatial_coords[i+1][correspondences[i][:, 1]]], axis=1)
# )
# mapping_labels.append(
# labels[i][correspondences[i][:, 0]]
# )
# # Set plot theme
# sns.set_theme(
# context="paper",
# style="white",
# font="Arial",
# font_scale=1,
# rc={
# # "font.size": font_size,
# "font.family": ["sans-serif"],
# "font.sans-serif": ["Arial", "sans-serif", "Helvetica", "DejaVu Sans", "Bitstream Vera Sans"],
# },
# )
# # generate palette
# if (palette is None) and (label_type == "cluster"):
# palette = _agenerate_palette(*labels, cmap=cmap)
# else:
# palette = cmap
# # generate figure
# fig, ax = plt.subplots(1, 1, figsize=(height * aspect_ratio, height))
# # determine the pointsize if not specified
# if point_size is None:
# point_size = 500 * height**2 * aspect_ratio / (slices_spatial_data.shape[0] / len(slices))
# # plotting
# sns.scatterplot(x=x, y=y, hue=label, legend=legend, palette = palette, ax=ax, s=point_size, edgecolor=edgecolor)
# TODO: finish this
# def plot_align_correspondence_3d(
# slices: List[AnnData],
# label_key: Optional[str] = None,
# spatial_key: str = "spatial",
# point_size: Optional[float] = None,
# n_sampling: int = -1,
# palette: Optional[dict] = None,
# show_legend: bool = True,
# legend_kwargs: Optional[dict] = None,
# axis_off: bool = False,
# axis_kwargs: Optional[dict] = None,
# ticks_off: bool = True,
# x_min=None,
# x_max=None,
# y_min=None,
# y_max=None,
# height: float = 2,
# alpha: float = 1.0, # TODO: alpha to be a key in adata
# cmap="tab20",
# center_coordinate: bool = False, # different from slices_2d
# ):
# # get spatial coords and labels
# spatial_coords = []
# labels = []
# slice_ids = []
# for i, s in enumerate(slices):
# if spatial_key in s.obsm.keys():
# spatial_coords.append(s.obsm[spatial_key].copy())
# else:
# raise ValueError(f"adata.obsm['{spatial_key}'] does not exist.")
# if label_key is not None:
# if label_key in s.obs.keys():
# labels.append(s.obs[label_key].copy())
# label_type = "cluster"
# elif label_key in s.var_names:
# labels.append(s[:, label_key].X.A.copy().squeeze())
# label_type = "scalar"
# else:
# raise ValueError(f"adata.obs['{label_key}'] or adata.var['{label_key}'] does not exist.")
# assert (
# spatial_coords[-1].shape[0] == labels[-1].shape[0]
# ), "The number of spatial coordinates and labels must be the same. Please check the data."
# else:
# label_type = "cluster"
# if (slices_key is not None) and (slices_key in s.obs.keys()):
# unique_id = np.unique(s.obs[slices_key])
# if len(unique_id) == 1:
# slice_ids.append(unique_id[0])
# else:
# raise ValueError(f"adata.obs['{slices_key}'] must have only one unique value.")
# else:
# slice_ids.append(str(i))
[docs]def multi_slices(
slices: Union[AnnData, List[AnnData]],
slices_key: Optional[str] = None,
label: Optional[str] = None,
spatial_key: str = "align_spatial",
layer: str = "X",
point_size: Optional[float] = None,
font_size: Optional[float] = 20,
color: Optional[str] = "skyblue",
palette: Optional[str] = None,
alpha: float = 1.0,
ncols: int = 4,
ax_height: float = 1,
dpi: int = 100,
show_legend: bool = True,
save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show",
save_kwargs: Optional[dict] = None,
**kwargs,
):
# Check slices object.
if isinstance(slices, list):
adatas = [s.copy() for s in slices]
for i, s in enumerate(adatas):
s.X = s.layers[layer].copy() if layer != "X" else s.X.copy()
s.uns = {"__type": "UMI"}
if slices_key is None:
slices_key = "slices"
s.obs[slices_key] = f"slice_{i}"
adata = integrate(adatas=adatas, batch_key=slices_key)
else:
assert slices_key != None, "When `slices` is an anndata object, `slices_key` cannot be None."
adata = slices.copy()
adata.X = adata.layers[layer].copy() if layer != "X" else adata.X.copy()
# Check label data and generate plotting data.
slices_data = pd.DataFrame(adata.obsm[spatial_key][:, :2], columns=["x", "y"], dtype=float)
slices_data[slices_key] = adata.obs[slices_key].values
if label is None:
label = "spatial coordinates"
slices_data[label] = label
elif label in adata.obs_keys():
slices_data[label] = adata.obs[label].values
elif label in adata.var_names:
adata.X = to_dense_matrix(adata.X)
slices_data[label] = adata[:, label].X
else:
raise ValueError("`label` is not a valid column names or gene name.")
# Set the arrangement of subgraphs
slices_id = slices_data[slices_key].unique().tolist()
nrows = math.ceil(len(slices_id) / ncols)
# Set the aspect ratio of each subplot
spatial_coords = slices_data[["x", "y"]].values.copy()
ptp_vec = spatial_coords.ptp(0)
aspect_ratio = ptp_vec[0] / ptp_vec[1]
ax_height = 2 if nrows == 1 and ax_height == 1 else ax_height
axsize = (ax_height * aspect_ratio, ax_height * 2)
# Set multi-plot grid for plotting.
sns.set_theme(
context="paper",
style="white",
font="Arial",
font_scale=1,
rc={
"font.size": font_size,
"font.family": ["sans-serif"],
"font.sans-serif": ["Arial", "sans-serif", "Helvetica", "DejaVu Sans", "Bitstream Vera Sans"],
},
)
g = sns.FacetGrid(
slices_data.copy(),
col=slices_key,
hue=label,
palette=palette,
sharex=True,
sharey=True,
height=axsize[1] * nrows,
aspect=aspect_ratio,
col_wrap=ncols,
despine=False,
)
# Calculate the most suitable size of the point.
if point_size is None:
group_slices_data = slices_data.groupby(by=slices_key)
min_dist_list = []
for key, data in group_slices_data:
sample_num = 1000 if len(data) > 1000 else len(data)
min_dist_list.append(compute_smallest_distance(coords=data[["x", "y"]].values, sample_num=sample_num))
point_size = min(min_dist_list) * axsize[0] / ptp_vec[0] * dpi
point_size = point_size**2 * ncols * nrows
# Draw scatter plots.
g.map_dataframe(sns.scatterplot, x="x", y="y", alpha=alpha, color=color, s=point_size, legend="brief", **kwargs)
# Set legend.
label_values = slices_data[label].values
if label_values.dtype in ["float16", "float32", "float64", "int16", "int32", "int64"]:
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
ax = g.facet_axis(row_i=0, col_j=ncols - 1)
norm = mpl.colors.Normalize(vmin=None, vmax=None)
mappable = mpl.cm.ScalarMappable(norm=norm, cmap=palette)
mappable.set_array(label_values)
plt.colorbar(
mappable,
cax=inset_axes(
ax,
width="12%",
height="100%",
loc="center left",
bbox_to_anchor=(1.02, 0.0, 0.5, 1.0),
bbox_transform=ax.transAxes,
borderpad=1.85,
),
ax=ax,
orientation="vertical",
alpha=alpha,
label=label,
)
else:
g.add_legend()
plt.tight_layout()
return save_return_show_fig_utils(
save_show_or_return=save_show_or_return,
show_legend=show_legend,
background="white",
prefix="multi_slices",
save_kwargs=save_kwargs,
total_panels=len(slices_id),
fig=g,
axes=g,
return_all=False,
return_all_list=None,
)
# TODO: Add docstring, add multi slices plot, legend scatter plot should keep the same size of the text
# def plot_clusters(
# adata: Union[AnnData, np.ndarray],
# spatial_key: str = 'spatial',
# label_key: Union[str, List[str], pd.DataFrame] = 'clusters',
# ax: Optional[mpl.axes.Axes] = None,
# point_size: float = 10,
# n_sampling: int = -1,
# palette: Optional[dict] = None,
# title: str = '',
# title_kwargs: Optional[dict] = None,
# # title_fontsize = 16,
# show_legend: bool = True,
# legend_kwargs: Optional[dict] = None,
# axis_off: bool = True,
# axis_kwargs: Optional[dict] = None,
# ticks_off: bool = True,
# color=None,
# x_min = None,
# x_max = None,
# y_min = None,
# y_max = None,
# **kwargs,
# ):
# """
# Plots spatial data with cluster labels stored in adata.obs[col]
# """
# # get spatial coords and labels
# if isinstance(adata, AnnData):
# if spatial_key in adata.obsm.keys():
# spatial_coords = adata.obsm[spatial_key].copy()
# else:
# raise ValueError(f"adata.obsm['{spatial_key}'] does not exist.")
# elif isinstance(adata, np.ndarray):
# if adata.shape[1] > 1:
# spatial_coords = adata.copy()[:,:2]
# else:
# raise ValueError("the input spatial coordinates must have at least 2 columns.")
# if isinstance(adata, AnnData):
# if label_key in adata.obs.keys():
# label = adata.obs[label_key].copy()
# else:
# raise ValueError(f"adata.obs['{label_key}'] does not exist.")
# elif isinstance(label_key, list):
# label = pd.Series(label_key)
# elif isinstance(label_key, pd.DataFrame):
# label = label_key
# else:
# raise ValueError(f"label_key must be a string, list, or DataFrame.")
# assert spatial_coords.shape[0] == label.shape[0], "The number of spatial coordinates and labels must be the same."
# # downsampling if n_sampling is set
# sampling_idx = np.random.choice(spatial_coords.shape[0], n_sampling, replace=False) if n_sampling > 0 and n_sampling < spatial_coords.shape[0] else np.arange(spatial_coords.shape[0])
# x = spatial_coords[sampling_idx, 0]
# y = spatial_coords[sampling_idx, 1]
# label = label[sampling_idx]
# # get unique labels
# unique_labels = np.unique(label)
# # get color palette if not provided
# if palette is None:
# n_colors = len(unique_labels)
# palette = sns.color_palette("tab20", n_colors)
# # plot
# scatterplot_kwargs = {
# 'x': x,
# 'y': y,
# 'hue': label,
# 'palette': palette,
# 'ax': ax,
# 's': point_size,
# 'legend': show_legend,
# }
# scatterplot_kwargs.update(kwargs)
# sns.scatterplot(**scatterplot_kwargs)
# # adjust the legend
# if show_legend:
# default_legend_kwargs = {
# 'loc': 'center left',
# 'bbox_to_anchor': (1, 0.5),
# 'prop': {'family': 'Arial', 'size': 10},
# 'fancybox': False,
# 'edgecolor': 'black',
# 'framealpha': 1,
# 'columnspacing': 0.5,
# 'handletextpad': 0.1,
# }
# if legend_kwargs:
# default_legend_kwargs.update(legend_kwargs)
# ax.legend(**default_legend_kwargs)
# # set axis limits
# if x_min is not None and x_max is not None:
# ax.set_xlim(x_min, x_max)
# if y_min is not None and y_max is not None:
# ax.set_ylim(y_min, y_max)
# # set other axis properties
# if axis_off:
# ax.axis('off')
# if ticks_off:
# ax.set_xticks([])
# ax.set_yticks([])
# if title_kwargs:
# default_title_kwargs = {
# 'label': title,
# 'fontsize': 16,
# }
# default_title_kwargs.update(title_kwargs)
# ax.set_title(title_kwargs)
# ax.set_aspect('equal')
# def _spatial_scatter(
# spatial_x: np.ndarray,
# spatial_y: np.ndarray,
# color: Union[str, np.ndarray],
# point_size: Union[float, np.ndarray],
# alpha: Union[float, np.ndarray],
# edgecolors: Optional[Union[str, np.ndarray]] = None,
# linewidths: Optional[Union[float, np.ndarray]] = None,
# marker: Optional[str] = "o",
# palette: Optional[Union[str, dict]] = None,
# ax: Optional[mpl.axes.Axes] = None,
# ):
# if ax is None:
# _, ax = plt.subplots()
# if isinstance(color, str):
# color = np.array([color] * len(spatial_x))
# if isinstance(point_size, float):
# point_size = np.array([point_size] * len(spatial_x))
# if isinstance(alpha, float):
# alpha = np.array([alpha] * len(spatial_x))
# if edgecolors is None:
# edgecolors = "none"
# if linewidths is None:
# linewidths = 0
# if palette is not None:
# if isinstance(palette, str):
# palette = sns.color_palette(palette, n_colors=len(np.unique(color)))
# elif isinstance(palette, dict):
# palette = [palette.get(c, "gray") for c in np.unique(color)]
# else:
# raise ValueError("`palette` must be a string or a dictionary.")
# color = [palette[np.where(np.unique(color) == c)[0][0]] for c in color]
# ax.scatter(
# spatial_x,
# spatial_y,
# c=color,
# s=point_size,
# alpha=alpha,
# edgecolors=edgecolors,
# linewidths=linewidths,
# marker=marker,
# )
# ax.set_aspect('equal')
# return ax
[docs]def _agenerate_palette(*labels, cmap="tab20"):
if len(labels) == 1:
labels = labels[0]
elif len(labels) > 1:
labels = np.concatenate(labels)
else:
raise ValueError("No labels provided.")
unique_labels = np.unique(labels)
n_labels = len(unique_labels)
palette = {l: sns.color_palette(cmap, n_labels)[i] for i, l in enumerate(unique_labels)}
return palette
[docs]def _compute_smallest_distance(spatial_coord1, spatial_coord2, direction="x", scale_factor=1.1):
if direction == "x":
spatial_coord1_max = np.max(spatial_coord1[:, 0])
spatial_coord2_min = -np.min(spatial_coord2[:, 0])
interval = (spatial_coord1_max + spatial_coord2_min) * scale_factor
elif direction == "y":
spatial_coord1_max = np.max(spatial_coord1[:, 1])
spatial_coord2_min = -np.min(spatial_coord2[:, 1])
interval = (spatial_coord1_max + spatial_coord2_min) * scale_factor
else:
raise ValueError("`direction` must be 'x' or 'y'.")
return interval
[docs]def get_min_max(x):
_min = x.min(0)
x = x - _min
_max = x.max(0)
return _min, _max
[docs]def get_H(h=0.5, w=0.2):
corner_points = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
# transformed_corner_points = np.array([[0,0], [w,h], [1,0], [1-w, h]])
transformed_corner_points = np.array([[w, h], [1 - w, h], [0, 0], [1, 0]])
import cv2
H, _ = cv2.findHomography(srcPoints=corner_points, dstPoints=transformed_corner_points)
return H