Source code for spateo.plotting.static.utils

# code adapted from https://github.com/aristoteleo/dynamo-release/blob/master/dynamo/plot/utils.py
import copy
import math
import os
import warnings
from inspect import signature
from typing import Any, Collection, Dict, List, Optional, Tuple, Union

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    import geopandas as gpd

import matplotlib
import matplotlib.patheffects as PathEffects
import matplotlib.pyplot as plt
import mpl_toolkits
import numba
import numpy as np
import pandas as pd
import scipy
from anndata import AnnData
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from pandas.api.types import is_categorical_dtype
from scipy.cluster import hierarchy as sch
from scipy.spatial import distance
from shapely.wkb import loads
from sklearn.decomposition import PCA
from typing_extensions import Literal

from ...configuration import SKM, _themes
from ...logging import logger_manager as lm


# ---------------------------------------------------------------------------------------------------
# variable checking utilities
[docs]def is_gene_name(adata, var): if type(var) in [str, np.str_]: return var in adata.var.index else: return False
[docs]def is_cell_anno_column(adata, var): if type(var) in [str, np.str_]: return var in adata.obs.columns else: return False
[docs]def is_layer_keys(adata, var): if type(var) in [str, np.str_]: return var in adata.layers.keys() else: return False
[docs]def is_list_of_lists(list_of_lists): all(isinstance(elem, list) for elem in list_of_lists)
[docs]def _get_adata_color_vec(adata, layer, col): if layer in ["protein", "X_protein"]: _color = adata.obsm[layer].loc[col, :] elif layer == "X": _color = adata.obs_vector(col, layer=None) else: _color = adata.obs_vector(col, layer=layer) return np.array(_color).flatten()
# --------------------------------------------------------------------------------------------------- # plotting utilities that borrowed from umap # link: https://github.com/lmcinnes/umap/blob/7e051d8f3c4adca90ca81eb45f6a9d1372c076cf/umap/plot.py
[docs]def map2color(val, min=None, max=None, cmap="viridis"): import matplotlib import matplotlib.cm as cm import matplotlib.pyplot as plt minima = np.min(val) if min is None else min maxima = np.max(val) if max is None else max norm = matplotlib.colors.Normalize(vmin=minima, vmax=maxima, clip=True) mapper = cm.ScalarMappable(norm=norm, cmap=plt.get_cmap(cmap)) cols = [mapper.to_rgba(v) for v in val] return cols
[docs]def _to_hex(arr): return [matplotlib.colors.to_hex(c) for c in arr]
# https://stackoverflow.com/questions/8468855/convert-a-rgb-colour-value-to-decimal """ Convert RGB color to decimal RGB integers are typically treated as three distinct bytes where \ the left-most (highest-order) byte is red, the middle byte is green and the right-most (lowest-order) byte is blue. \ """ @numba.vectorize(["uint8(uint32)", "uint8(uint32)"])
[docs]def _red(x): return (x & 0xFF0000) >> 16
@numba.vectorize(["uint8(uint32)", "uint8(uint32)"])
[docs]def _green(x): return (x & 0x00FF00) >> 8
@numba.vectorize(["uint8(uint32)", "uint8(uint32)"])
[docs]def _blue(x): return x & 0x0000FF
[docs]def _embed_datashader_in_an_axis(datashader_image, ax): img_rev = datashader_image.data[::-1] mpl_img = np.dstack([_blue(img_rev), _green(img_rev), _red(img_rev)]) ax.imshow(mpl_img) return ax
[docs]def _get_extent(points): """Compute bounds on a space with appropriate padding""" min_x = np.min(points[:, 0]) max_x = np.max(points[:, 0]) min_y = np.min(points[:, 1]) max_y = np.max(points[:, 1]) extent = ( np.round(min_x - 0.05 * (max_x - min_x)), np.round(max_x + 0.05 * (max_x - min_x)), np.round(min_y - 0.05 * (max_y - min_y)), np.round(max_y + 0.05 * (max_y - min_y)), ) return extent
[docs]def _select_font_color(background): if background in ["k", "black"]: font_color = "white" elif background in ["w", "white"]: font_color = "black" elif background.startswith("#"): mean_val = np.mean( # specify 0 as the base in order to invoke this prefix-guessing behavior; # omitting it means to assume base-10 [int("0x" + c, 0) for c in (background[1:3], background[3:5], background[5:7])] ) if mean_val > 126: font_color = "black" else: font_color = "white" else: font_color = "black" return font_color
[docs]def _scatter_projection(ax, points, projection, **kwargs): if projection == "3d": ax.scatter(points[:, 0], points[:, 1], points[:, 2], **kwargs) else: ax.scatter(points[:, 0], points[:, 1], **kwargs)
[docs]def _vector_projection( ax, points: np.ndarray, vectors: np.ndarray, projection: str = "2d", geo: bool = False, **kwargs ): """Plot a 2D field of arrows over spatial transcriptomics data Args: ax: Matplotlib axis object points: Point coordinates of shape [n_samples, 2], either grid coordinates (for grid or streamlines plots) or coordinates of the cells themselves (for cell plots) vectors: Array of shape [n_samples, 2] or [n_samples, 3] containing the vector field projection: Either '2d' or '3d' to indicate if plot is 2D or 3D geo: Set True if plotting atop geometrical objects. If only generating a scatterplot, set False. **kwargs: Additional keyword arguments provided to :func `ax.quiver()` """ if geo and not isinstance(points, np.ndarray): centroids = np.array([polygon.centroid.coords[0] for polygon in points]) if projection == "3d": ax.quiver( centroids[:, 0], centroids[:, 1], centroids[:, 2], vectors[:, 0], vectors[:, 1], vectors[:, 2], **kwargs ) else: ax.quiver(centroids[:, 0], centroids[:, 1], vectors[:, 0], vectors[:, 1], **kwargs) else: if projection == "3d": ax.quiver(points[:, 0], points[:, 1], points[:, 2], vectors[:, 0], vectors[:, 1], vectors[:, 2], **kwargs) else: ax.quiver(points[:, 0], points[:, 1], vectors[:, 0], vectors[:, 1], **kwargs)
[docs]def _streamlines_projection( ax, points: np.ndarray, vectors: np.ndarray, projection: str = "2d", geo: bool = False, **kwargs ): """Plot streamlines over spatial transcriptomics data Args: ax: Matplotlib axis object points: Point coordinates of shape [n_samples, 2], either grid coordinates (for grid or streamlines plots) or coordinates of the cells themselves (for cell plots) vectors: Array of shape [n_samples, 2] or [n_samples, 3] containing the vector field projection: Either '2d' or '3d' to indicate if plot is 2D or 3D geo: Set True if plotting atop geometrical objects. If only generating a scatterplot, set False. **kwargs: Additional keyword arguments provided to :func `ax.streamplot()` """ if geo and not isinstance(points, np.ndarray): centroids = np.array([polygon.centroid.coords[0] for polygon in points]) if projection == "3d": raise NotImplementedError("Streamlines are not supported in 3D") else: ax.streamplot(centroids[:, 0], centroids[:, 1], vectors[:, 0], vectors[:, 1], **kwargs) else: if projection == "3d": raise NotImplementedError("Streamlines are not supported in 3D") else: ax.streamplot(points[:, 0], points[:, 1], vectors[:, 0], vectors[:, 1], **kwargs)
[docs]def _geo_projection(ax, points, **kwargs): linecolor = kwargs.pop("linecolor") if "values" in kwargs: # using value gdf = gpd.GeoDataFrame(data={"values": kwargs.pop("values"), "points": points}, geometry="points") ax = gdf.plot("values", ax=ax, **kwargs) else: # using color gdf = gpd.GeoDataFrame(geometry=points) ax = gdf.plot(ax=ax, **kwargs) # clean args for boundary plotting if "color" in kwargs: kwargs.pop("color") if "cmap" in kwargs: kwargs.pop("cmap") gdf.boundary.plot(ax=ax, color=linecolor, **kwargs)
[docs]def plot_vectors( ax: Union[plt.Axes, mpl_toolkits.mplot3d.Axes3D], points: np.ndarray, V: np.ndarray, vf_plot_method: str = "cell", projection: str = "2d", geo: bool = False, **kwargs, ): """Wrapper for plotting vector fields. Args: ax: Matplotlib axis object points: Point coordinates of shape [n_samples, 2], either grid coordinates (for grid or streamlines plots) or coordinates of the cells themselves (for cell plots) V: Array of shape [n_samples, 2] or [n_samples, 3] containing the vector field vf_plot_method: 'grid' or 'streamplot' to indicate if plot should be vectors coming from the cells themselves ('cell'), vectors coming from vertices of a grid ('grid') or streamlines coming from vertices of a grid ( 'streamplot') geo: Set True if plotting atop geometrical objects. If only generating a scatterplot, set False. """ if vf_plot_method == "grid" or vf_plot_method == "cell": _vector_projection( ax, points, V, projection, geo, **kwargs, ) elif vf_plot_method == "streamplot": _streamlines_projection( ax, points, V, projection, geo, **kwargs, ) else: raise ValueError("vf_plot_method must be either 'grid' or 'streamplot', got {}".format(vf_plot_method))
[docs]def _matplotlib_points( points, ax=None, labels=None, values=None, highlights=None, cmap: str = "Blues", color_key: Optional[str] = None, color_key_cmap: str = "Spectral", background: str = "white", width: int = 7, height: int = 5, show_legend: bool = True, vmin: float = 2, vmax: float = 98, sort: str = "raw", frontier: bool = False, contour: bool = False, ccmap=None, calpha: float = 0.4, sym_c: bool = False, inset_dict={}, show_colorbar: bool = True, projection=None, # default in matplotlib geo: bool = False, X_grid: Optional[np.ndarray] = None, V: Optional[np.ndarray] = None, vf_plot_method: str = "cell", vf_kwargs: Optional[Dict] = None, **kwargs, ): import matplotlib.pyplot as plt from matplotlib.ticker import MaxNLocator dpi = plt.rcParams["figure.dpi"] width, height = width * dpi, height * dpi rasterized = kwargs["rasterized"] if "rasterized" in kwargs.keys() else None # """Use matplotlib to plot points""" # point_size = 500.0 / np.sqrt(points.shape[0]) legend_elements = None if ax is None: dpi = plt.rcParams["figure.dpi"] fig = plt.figure(figsize=(width / dpi, height / dpi)) ax = fig.add_subplot(111, projection=projection) ax.set_facecolor(background) # Color by labels unique_labels = [] # Separate keyword arguments used for scatter and quiver plots (if the latter is applicable): if V is not None: if vf_plot_method == "grid" or vf_plot_method == "cell": quiver_params = [ "scale", "scale_units", "angles", "width", "color", "pivot", "headwidth", "headlength", "headaxislength", "minshaft", "minlength", "linewidth", "edgecolor", "norm", "cmap", ] quiver_kwargs = {} for key, value in vf_kwargs.items(): if key in quiver_params: quiver_kwargs[key] = value elif vf_plot_method == "stream": streamplot_params = [ "density", "linewidth", "color", "cmap", "norm", "arrowsize", "arrowstyle", "minlength", "start_points", "zorder", "maxlength", "integration_direction", "data", ] streamplot_kwargs = {} for key, value in vf_kwargs.items(): if key in streamplot_params: streamplot_kwargs[key] = value if labels is not None: # main_debug("labels are not None, drawing by labels") if labels.shape[0] != points.shape[0]: raise ValueError( "Labels must have a label for " "each sample (size mismatch: {} {})".format(labels.shape[0], points.shape[0]) ) if color_key is None: # main_debug("color_key is None") cmap = copy.copy(matplotlib.cm.get_cmap(color_key_cmap)) cmap.set_bad("lightgray") colors = None if highlights is None: unique_labels = np.unique(labels) num_labels = unique_labels.shape[0] color_key = plt.get_cmap(color_key_cmap)(np.linspace(0, 1, num_labels)) else: if type(highlights) is str: highlights = [highlights] highlights.append("other") unique_labels = np.array(highlights) num_labels = unique_labels.shape[0] color_key = _to_hex(plt.get_cmap(color_key_cmap)(np.linspace(0, 1, num_labels))) color_key[-1] = "#bdbdbd" # lightgray hex code https://www.color-hex.com/color/d3d3d3 labels[[i not in highlights[:-1] for i in labels]] = "other" points = pd.DataFrame(points) points["label"] = pd.Categorical(labels) # reorder data so that highlighting points will be on top of background points highlight_ids, background_ids = ( points["label"] != "other", points["label"] == "other", ) if V is not None and vf_plot_method == "cell": # Get the indices that would sort the DataFrame in ascending order sorted_indices = np.argsort(np.concatenate((background_ids.values, highlight_ids.values))) V = V[sorted_indices, :] # reorder_data = points.copy(deep=True) # ( # reorder_data.loc[:(sum(background_ids) - 1), :], # reorder_data.loc[sum(background_ids):, :], # ) = (points.loc[background_ids, :].values, points.loc[highlight_ids, :].values) points = pd.concat( ( points.loc[background_ids, :], points.loc[highlight_ids, :], ) ).values # labels = points[:, 2] labels = points["label"] # WARNING: do not change the following line to "elif" during refactor # This if-else branch is not logically parallel to the previous one. The following branch sets `colors`. if isinstance(color_key, dict): # main_debug("color_key is a dict") colors = pd.Series(labels).map(color_key).values unique_labels = np.unique(labels) legend_elements = [ # Patch(facecolor=color_key[k], label=k) for k in unique_labels Line2D( [0], [0], marker="o", color=color_key[k], label=k, linestyle="None", ) for k in unique_labels ] else: # main_debug("color_key is not None and not a dict") unique_labels = np.unique(labels) if len(color_key) < unique_labels.shape[0]: raise ValueError("Color key must have enough colors for the number of labels") new_color_key = {k: color_key[i] for i, k in enumerate(unique_labels)} legend_elements = [ # Patch(facecolor=color_key[i], label=k) Line2D( [0], [0], marker="o", color=color_key[i], label=k, linestyle="None", ) for i, k in enumerate(unique_labels) ] colors = pd.Series(labels).map(new_color_key) if frontier: # main_debug("drawing frontier") _scatter_projection( ax, points, projection, s=kwargs["s"] * 2, c="0.0", lw=2, rasterized=rasterized, ) _scatter_projection( ax, points, projection, s=kwargs["s"] * 2, c="1.0", lw=0, rasterized=rasterized, ) _scatter_projection( ax, points, projection, c=colors, plotnonfinite=True, **kwargs, ) if V is not None: vf_kwargs = streamplot_kwargs if vf_plot_method == "stream" else quiver_kwargs vec_points = points if vf_plot_method == "cell" else X_grid if len(vf_kwargs) == 0: plot_vectors(ax, vec_points, V, vf_plot_method=vf_plot_method) else: plot_vectors(ax, vec_points, V, vf_plot_method=vf_plot_method, **vf_kwargs) elif contour: import seaborn as sns ccmap = "viridis" if ccmap is None else ccmap df = pd.DataFrame(points, columns=["x", "y", "z"][: points.shape[1]]) ax = sns.kdeplot( data=df.iloc[:, :2], x="x", y="y", fill=True, alpha=calpha, palette=ccmap, ax=ax, thresh=0, levels=100, ) x, y = points[:, :2].T _scatter_projection( ax, points, projection, c=colors, plotnonfinite=True, zorder=21, **kwargs, ) if V is not None: vf_kwargs = streamplot_kwargs if vf_plot_method == "stream" else quiver_kwargs vec_points = points if vf_plot_method == "cell" else X_grid if len(vf_kwargs) == 0: plot_vectors(ax, vec_points, V, vf_plot_method=vf_plot_method) else: plot_vectors(ax, vec_points, V, vf_plot_method=vf_plot_method, **vf_kwargs) else: # main_debug("drawing without frontiers and contour") if geo: _geo_projection( ax, points, color=colors, **kwargs, ) if V is not None: vf_kwargs = streamplot_kwargs if vf_plot_method == "stream" else quiver_kwargs vec_points = points if vf_plot_method == "cell" else X_grid if len(vf_kwargs) == 0: plot_vectors(ax, vec_points, V, vf_plot_method=vf_plot_method, geo=True) else: plot_vectors(ax, vec_points, V, vf_plot_method=vf_plot_method, geo=True, **vf_kwargs) else: _scatter_projection( ax, points, projection, c=colors, plotnonfinite=True, **kwargs, ) if V is not None: vf_kwargs = streamplot_kwargs if vf_plot_method == "stream" else quiver_kwargs vec_points = points if vf_plot_method == "cell" else X_grid if len(vf_kwargs) == 0: plot_vectors(ax, vec_points, V, vf_plot_method=vf_plot_method) else: plot_vectors(ax, vec_points, V, vf_plot_method=vf_plot_method, **vf_kwargs) # Color by values elif values is not None: # main_debug("drawing points by values") cmap_ = copy.copy(matplotlib.cm.get_cmap(cmap)) cmap_.set_bad("lightgray") with warnings.catch_warnings(): warnings.simplefilter("ignore") matplotlib.cm.register_cmap(name=cmap_.name, cmap=cmap_, override_builtin=True) if values.shape[0] != points.shape[0]: raise ValueError( "Values must have a value for " "each sample (size mismatch: {} {})".format(values.shape[0], points.shape[0]) ) # reorder data so that high values points will be on top of background points sorted_id = ( np.argsort(abs(values)) if sort == "abs" else np.argsort(-values) if sort == "neg" else np.argsort(values) ) values, points = values[sorted_id], points[sorted_id] if V is not None and vf_plot_method == "cell": V = V[sorted_id] # if there are very few cells have expression, set the vmin/vmax only based on positive values to # get rid of outliers if np.nanmin(values) == 0: n_pos_cells = sum(values > 0) if 0 < n_pos_cells / len(values) < 0.02: vmin = 0 if n_pos_cells == 1 else np.percentile(values[values > 0], 2) vmax = np.nanmax(values) if n_pos_cells == 1 else np.percentile(values[values > 0], 98) if vmin + vmax in [1, 100]: vmin += 1e-12 vmax += 1e-12 # if None: min/max from data # if positive and sum up to 1, take fraction # if positive and sum up to 100, take percentage # otherwise take the data _vmin = ( np.nanmin(values) if vmin is None else np.nanpercentile(values, vmin * 100) if (vmin + vmax == 1 and 0 <= vmin < vmax) else np.nanpercentile(values, vmin) if (vmin + vmax == 100 and 0 <= vmin < vmax) else vmin ) _vmax = ( np.nanmax(values) if vmax is None else np.nanpercentile(values, vmax * 100) if (vmin + vmax == 1 and 0 <= vmin < vmax) else np.nanpercentile(values, vmax) if (vmin + vmax == 100 and 0 <= vmin < vmax) else vmax ) if sym_c and _vmin < 0 and _vmax > 0: bounds = np.nanmax([np.abs(_vmin), _vmax]) bounds = bounds * np.array([-1, 1]) _vmin, _vmax = bounds if frontier: # main_debug("drawing frontier") _scatter_projection( ax, points, projection, s=kwargs["s"] * 2, c="0.0", lw=2, rasterized=rasterized, ) _scatter_projection( ax, points, projection, s=kwargs["s"] * 2, c="1.0", lw=0, rasterized=rasterized, ) _scatter_projection( ax, points, projection, c=values, cmap=cmap, vmin=_vmin, vmax=_vmax, plotnonfinite=True, **kwargs, ) if V is not None: vf_kwargs = streamplot_kwargs if vf_plot_method == "stream" else quiver_kwargs vec_points = points if vf_plot_method == "cell" else X_grid if len(vf_kwargs) == 0: plot_vectors(ax, vec_points, V, vf_plot_method=vf_plot_method) else: plot_vectors(ax, vec_points, V, vf_plot_method=vf_plot_method, **vf_kwargs) elif contour: ccmap = "viridis" if ccmap is None else ccmap # # ax.tricontourf(triang, values, cmap=ccmap) # _scatter_projection(x, y, # c=values, # cmap=cmap, # plotnonfinite=True, # **kwargs, ) import seaborn as sns df = pd.DataFrame(points, columns=["x", "y", "z"][: points.shape[1]]) ax = sns.kdeplot( data=df.iloc[:, :2], x="x", y="y", fill=True, alpha=calpha, palette=ccmap, ax=ax, thresh=0, levels=100, ) _scatter_projection( ax, points, projection, c=values, cmap=cmap, vmin=_vmin, vmax=_vmax, plotnonfinite=True, **kwargs, ) if V is not None: vf_kwargs = streamplot_kwargs if vf_plot_method == "stream" else quiver_kwargs vec_points = points if vf_plot_method == "cell" else X_grid if len(vf_kwargs) == 0: plot_vectors(ax, vec_points, V, vf_plot_method=vf_plot_method) else: plot_vectors(ax, vec_points, V, vf_plot_method=vf_plot_method, **vf_kwargs) else: # main_debug("drawing without frontiers and contour") # main_debug("using cmap: %s" % (str(cmap))) if geo: _geo_projection( ax, points, values=values, cmap=cmap, vmin=_vmin, vmax=_vmax, **kwargs, ) if V is not None: vf_kwargs = streamplot_kwargs if vf_plot_method == "stream" else quiver_kwargs vec_points = points if vf_plot_method == "cell" else X_grid if len(vf_kwargs) == 0: plot_vectors(ax, vec_points, V, vf_plot_method=vf_plot_method, geo=True) else: plot_vectors(ax, vec_points, V, vf_plot_method=vf_plot_method, geo=True, **vf_kwargs) else: _scatter_projection( ax, points, projection, c=values, cmap=cmap, vmin=_vmin, vmax=_vmax, **kwargs, ) if V is not None: vf_kwargs = streamplot_kwargs if vf_plot_method == "stream" else quiver_kwargs vec_points = points if vf_plot_method == "cell" else X_grid if len(vf_kwargs) == 0: plot_vectors(ax, vec_points, V, vf_plot_method=vf_plot_method) else: plot_vectors(ax, vec_points, V, vf_plot_method=vf_plot_method, **vf_kwargs) if "norm" in kwargs: norm = kwargs["norm"] else: norm = matplotlib.colors.Normalize(vmin=_vmin, vmax=_vmax) mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap) mappable.set_array(values) if show_colorbar: cb = plt.colorbar(mappable, cax=set_colorbar(ax, inset_dict), ax=ax) cb.set_alpha(1) cb.draw_all() cb.locator = MaxNLocator(nbins=3, integer=True) cb.update_ticks() cmap = matplotlib.cm.get_cmap(cmap) colors = cmap(values) # No color (just pick the midpoint of the cmap) else: # main_debug("drawing points without color passed in args, using midpoint of the cmap") colors = plt.get_cmap(cmap)(0.5) if geo: _geo_projection(ax, points, color=colors, **kwargs) if V is not None: vf_kwargs = streamplot_kwargs if vf_plot_method == "stream" else quiver_kwargs vec_points = points if vf_plot_method == "cell" else X_grid if len(vf_kwargs) == 0: plot_vectors(ax, vec_points, V, vf_plot_method=vf_plot_method, geo=True) else: plot_vectors(ax, vec_points, V, vf_plot_method=vf_plot_method, geo=True, **vf_kwargs) else: _scatter_projection(ax, points, projection, c=colors, **kwargs) if V is not None: vf_kwargs = streamplot_kwargs if vf_plot_method == "stream" else quiver_kwargs vec_points = points if vf_plot_method == "cell" else X_grid if len(vf_kwargs) == 0: plot_vectors(ax, vec_points, V, vf_plot_method=vf_plot_method) else: plot_vectors(ax, vec_points, V, vf_plot_method=vf_plot_method, **vf_kwargs) if show_legend and legend_elements is not None: if len(unique_labels) == 1 and show_legend == "on data": ax.legend( handles=legend_elements, bbox_to_anchor=(1.04, 1), loc=matplotlib.rcParams["legend.loc"], ncol=len(unique_labels) // 20 + 1, prop=dict(size=8), ) elif len(unique_labels) > 1 and show_legend == "on data": font_color = "white" if background in ["black", "#ffffff"] else "black" for i in unique_labels: if i == "other": continue if not geo: color_cnt_x, color_cnt_y = np.nanmedian(points[np.where(labels == i)[0], :2].astype("float"), 0) else: color_cnt_x = np.nanmedian(points[np.where(labels == i)[0]].centroid.x.astype("float"), 0) color_cnt_y = np.nanmedian(points[np.where(labels == i)[0]].centroid.x.astype("float"), 0) txt = plt.text( color_cnt_x, color_cnt_y, str(i), color=_select_font_color(font_color), zorder=1000, verticalalignment="center", horizontalalignment="center", weight="bold", ) # txt.set_path_effects( [ PathEffects.Stroke(linewidth=1.5, foreground=font_color, alpha=0.8), PathEffects.Normal(), ] ) else: show_legend = "best" if show_legend == "on data" else show_legend ax.legend( handles=legend_elements, bbox_to_anchor=(1.04, 1), loc=show_legend, ncol=len(unique_labels) // 20 + 1, ) else: # main_debug("hiding legend") pass return ax, colors
# --------------------------------------------------------------------------------------------------- # plotting utilities borrow from velocyto # link - https://github.com/velocyto-team/velocyto-notebooks/blob/master/python/DentateGyrus.ipynb
[docs]def despline(ax=None): import matplotlib.pyplot as plt ax = plt.gca() if ax is None else ax # Hide the right and top spines ax.spines["right"].set_visible(False) ax.spines["top"].set_visible(False) # Only show ticks on the left and bottom spines ax.yaxis.set_ticks_position("left") ax.xaxis.set_ticks_position("bottom")
[docs]def despline_all(ax=None, sides=None): # removing the default axis on all sides: import matplotlib.pyplot as plt ax = plt.gca() if ax is None else ax if sides is None: sides = ["bottom", "right", "top", "left"] for side in sides: ax.spines[side].set_visible(False)
[docs]def deaxis_all(ax=None): # removing the axis ticks import matplotlib.pyplot as plt ax = plt.gca() if ax is None else ax ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False)
[docs]def minimal_xticks(start, end): import matplotlib.pyplot as plt end_ = np.around(end, -int(np.log10(end)) + 1) xlims = np.linspace(start, end_, 5) xlims_tx = [""] * len(xlims) xlims_tx[0], xlims_tx[-1] = f"{xlims[0]:.0f}", f"{xlims[-1]:.02f}" plt.xticks(xlims, xlims_tx)
[docs]def minimal_yticks(start, end): import matplotlib.pyplot as plt end_ = np.around(end, -int(np.log10(end)) + 1) ylims = np.linspace(start, end_, 5) ylims_tx = [""] * len(ylims) ylims_tx[0], ylims_tx[-1] = f"{ylims[0]:.0f}", f"{ylims[-1]:.02f}" plt.yticks(ylims, ylims_tx)
[docs]def set_spine_linewidth(ax, lw): for axis in ["top", "bottom", "left", "right"]: ax.spines[axis].set_linewidth(lw) return ax
# --------------------------------------------------------------------------------------------------- # scatter plot utilities
[docs]def scatter_with_colorbar(fig, ax, x, y, c, cmap, **kwargs): # https://stackoverflow.com/questions/32462881/add-colorbar-to-existing-axis from mpl_toolkits.axes_grid1 import make_axes_locatable divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) g = ax.scatter(x, y, c=c, cmap=cmap, **kwargs) fig.colorbar(g, cax=cax, orientation="vertical") return fig, ax
[docs]def scatter_with_legend(fig, ax, df, font_color, x, y, c, cmap, legend, **kwargs): import matplotlib.patheffects as PathEffects import seaborn as sns unique_labels = np.unique(c) if legend == "on data": _ = sns.scatterplot(x, y, hue=c, palette=cmap, ax=ax, legend=False, **kwargs) for i in unique_labels: color_cnt = np.nanmedian(df.iloc[np.where(c == i)[0], :2], 0) txt = ax.text( color_cnt[0], color_cnt[1], str(i), color=font_color, zorder=1000, verticalalignment="center", horizontalalignment="center", weight="bold", ) # c txt.set_path_effects( [ PathEffects.Stroke(linewidth=1.5, foreground=font_color, alpha=0.8), PathEffects.Normal(), ] # 'w' ) else: _ = sns.scatterplot(x, y, hue=c, palette=cmap, ax=ax, legend="full", **kwargs) ax.legend(loc=legend, ncol=unique_labels // 15) return fig, ax
[docs]def set_colorbar(ax, inset_dict={}): """https://matplotlib.org/3.1.0/gallery/axes_grid1/demo_colorbar_with_inset_locator.html""" from mpl_toolkits.axes_grid1.inset_locator import inset_axes if len(inset_dict) == 0: # see more at https://matplotlib.org/gallery/axes_grid1/inset_locator_demo.html axins = inset_axes( ax, width="12%", # width = 5% of parent_bbox width height="100%", # height : 50% loc="upper right", bbox_to_anchor=(0.85, 0.97, 0.145, 0.17), bbox_transform=ax.transAxes, borderpad=1.85, ) else: axins = inset_axes(ax, bbox_transform=ax.transAxes, **inset_dict) return axins
[docs]def arrowed_spines(ax, columns, background="white"): """https://stackoverflow.com/questions/33737736/matplotlib-axis-arrow-tip modified based on Answer 6 """ if type(columns) == str: columns = [columns.upper() + " 0", columns.upper() + " 1"] import matplotlib.pyplot as plt fig = plt.gcf() xmin, xmax = ax.get_xlim() ymin, ymax = ax.get_ylim() # removing the default axis on all sides: despline_all(ax) # removing the axis ticks deaxis_all(ax) # get width and height of axes object to compute # matching arrowhead length and width dps = fig.dpi_scale_trans.inverted() bbox = ax.get_window_extent().transformed(dps) width, height = bbox.width, bbox.height # manual arrowhead width and length (x-axis) hw = 1.0 / 20.0 * (ymax - ymin) hl = 1.0 / 20.0 * (xmax - xmin) lw = 1.0 # axis line width ohg = 0.2 # arrow overhang # compute matching arrowhead length and width (y-axis) yhw = hw / (ymax - ymin) * (xmax - xmin) * height / width yhl = hl / (xmax - xmin) * (ymax - ymin) * width / height # draw x and y axis fc, ec = ("w", "w") if background in ["black", "#ffffff"] else ("k", "k") ax.arrow( xmin, ymin, hl * 5 / 2, 0, fc=fc, ec=ec, lw=lw, head_width=hw / 2, head_length=hl / 2, overhang=ohg / 2, length_includes_head=True, clip_on=False, ) ax.arrow( xmin, ymin, 0, hw * 5 / 2, fc=fc, ec=ec, lw=lw, head_width=yhw / 2, head_length=yhl / 2, overhang=ohg / 2, length_includes_head=True, clip_on=False, ) ax.text( xmin + hl * 2.5 / 2, ymin - 1.5 * hw / 2, columns[0], ha="center", va="center", rotation=0, # size=hl * 5 / (2 * len(str(columns[0]))) * 20, # size=matplotlib.rcParams['axes.titlesize'], size=np.clip((hl + yhw) * 8 / 2, 6, 18), ) ax.text( xmin - 1.5 * yhw / 2, ymin + hw * 2.5 / 2, columns[1], ha="center", va="center", rotation=90, # size=hw * 5 / (2 * len(str(columns[1]))) * 20, # size=matplotlib.rcParams['axes.titlesize'], size=np.clip((hl + yhw) * 8 / 2, 6, 18), ) return ax
# --------------------------------------------------------------------------------------------------- # vector field plot related utilities
[docs]def quiver_autoscaler(X_emb, V_emb): """Function to automatically calculate the value for the scale parameter of quiver plot, adapted from scVelo Parameters ---------- X_emb: `np.ndarray` X, Y-axis coordinates V_emb: `np.ndarray` Velocity (U, V) values on the X, Y-axis Returns ------- The scale for quiver plot """ import matplotlib.pyplot as plt fig, ax = plt.subplots() scale_factor = np.ptp(X_emb, 0).mean() X_emb = X_emb - X_emb.min(0) if len(V_emb.shape) == 3: Q = ax.quiver( X_emb[0] / scale_factor, X_emb[1] / scale_factor, V_emb[0], V_emb[1], angles="xy", scale_units="xy", scale=None, ) else: Q = ax.quiver( X_emb[:, 0] / scale_factor, X_emb[:, 1] / scale_factor, V_emb[:, 0], V_emb[:, 1], angles="xy", scale_units="xy", scale=None, ) Q._init() fig.clf() plt.close(fig) return Q.scale / scale_factor * 2
[docs]def default_quiver_args(arrow_size, arrow_len=None): if isinstance(arrow_size, (list, tuple)) and len(arrow_size) == 3: head_w, head_l, ax_l = arrow_size elif type(arrow_size) in [int, float]: head_w, head_l, ax_l = 10 * arrow_size, 12 * arrow_size, 8 * arrow_size else: head_w, head_l, ax_l = 10, 12, 8 scale = 1 / arrow_len if arrow_len is not None else 1 / arrow_size return head_w, head_l, ax_l, scale
# ---------------------------------------------------------------------------------------------------
[docs]def _plot_traj(y0, t, args, integration_direction, ax, color, lw, f): from dynamo.tools.utils import integrate_vf _, y = integrate_vf(y0, t, args, integration_direction, f) # integrate_vf_ivp ax.plot(*y.transpose(), color=color, lw=lw, linestyle="dashed", alpha=0.5) ax.scatter(*y0.transpose(), color=color, marker="*") return ax
# --------------------------------------------------------------------------------------------------- # streamline related aesthetics # ---------------------------------------------------------------------------------------------------
[docs]def set_arrow_alpha(ax=None, alpha=1): from matplotlib import patches ax = plt.gca() if ax is None else ax # iterate through the children of ax for art in ax.get_children(): # we are only interested in FancyArrowPatches if not isinstance(art, patches.FancyArrowPatch): continue art.set_alpha(alpha)
[docs]def set_stream_line_alpha(s=None, alpha=1): """s has to be a StreamplotSet""" s.lines.set_alpha(alpha)
# --------------------------------------------------------------------------------------------------- # save_fig figure related # ---------------------------------------------------------------------------------------------------
[docs]def save_fig( path=None, prefix=None, dpi=None, ext="pdf", transparent=True, close=True, verbose=True, ): """Save a figure from pyplot. code adapated from http://www.jesshamrick.com/2012/09/03/saving-figures-from-pyplot/ Parameters ---------- path: `string` The path (and filename, without the extension) to save_fig the figure to. prefix: `str` or `None` The prefix added to the figure name. This will be automatically set accordingly to the plotting function used. dpi: [ None | scalar > 0 | 'figure' ] The resolution in dots per inch. If None, defaults to rcParams["savefig.dpi"]. If 'figure', uses the figure's dpi value. ext: `string` (default='pdf') The file extension. This must be supported by the active matplotlib backend (see matplotlib.backends module). Most backends support 'png', 'pdf', 'ps', 'eps', and 'svg'. close: `boolean` (default=True) Whether to close the figure after saving. If you want to save_fig the figure multiple times (e.g., to multiple formats), you should NOT close it in between saves or you will have to re-plot it. verbose: boolean (default=True) Whether to print information about when and where the image has been saved. """ import matplotlib.pyplot as plt if path is None: path = os.getcwd() + "/" # Extract the directory and filename from the given path directory = os.path.split(path)[0] filename = os.path.split(path)[1] if directory == "": directory = "." if filename == "": filename = "spateo_savefig" # If the directory does not exist, create it if not os.path.exists(directory): os.makedirs(directory) # The final path to save_fig to savepath = ( os.path.join(directory, filename + "." + ext) if prefix is None else os.path.join(directory, prefix + "_" + filename + "." + ext) ) if verbose: print(f"Saving figure to {savepath}...") # Actually save the figure plt.savefig( savepath, dpi=300 if dpi is None else dpi, transparent=transparent, format=ext, bbox_inches="tight", ) # Close it if close: plt.close() if verbose: print("Done")
# ---------------------------------------------------------------------------------------------------
[docs]def alpha_shape(x, y, alpha): # Start Using SHAPELY try: import shapely.geometry as geometry from shapely.geometry import MultiPoint from shapely.ops import cascaded_union, polygonize except ImportError: raise ImportError( "If you want to use the tricontourf in plotting function, you need to install `shapely` " "package via `pip install shapely` see more details at https://pypi.org/project/Shapely/," ) from scipy.spatial import Delaunay crds = np.array([x.flatten(), y.flatten()]).transpose() points = MultiPoint(crds) if len(points) < 4: # When you have a triangle, there is no sense # in computing an alpha shape. return geometry.MultiPoint(list(points)).convex_hull def add_edge(edges, edge_points, coords, i, j): """ Add a line between the i-th and j-th points, if not in the list already """ if (i, j) in edges or (j, i) in edges: # already added return edges.add((i, j)) edge_points.append(coords[[i, j]]) coords = np.array([point.coords[0] for point in points]) tri = Delaunay(coords) edges = set() edge_points = [] # loop over triangles: # ia, ib, ic = indices of corner points of the triangle for ia, ib, ic in tri.vertices: pa = coords[ia] pb = coords[ib] pc = coords[ic] # Lengths of sides of triangle a = math.sqrt((pa[0] - pb[0]) ** 2 + (pa[1] - pb[1]) ** 2) b = math.sqrt((pb[0] - pc[0]) ** 2 + (pb[1] - pc[1]) ** 2) c = math.sqrt((pc[0] - pa[0]) ** 2 + (pc[1] - pa[1]) ** 2) # Semiperimeter of triangle s = (a + b + c) / 2.0 # Area of triangle by Heron's formula area = math.sqrt(s * (s - a) * (s - b) * (s - c)) circum_r = a * b * c / (4.0 * area) # Here's the radius filter. if circum_r < 1.0 / alpha: add_edge(edges, edge_points, coords, ia, ib) add_edge(edges, edge_points, coords, ib, ic) add_edge(edges, edge_points, coords, ic, ia) m = geometry.MultiLineString(edge_points) triangles = list(polygonize(m)) return cascaded_union(triangles), edge_points
# View the polygon and adjust alpha if needed
[docs]def plot_polygon(polygon, margin=1, fc="#999999", ec="#000000", fill=True, ax=None, **kwargs): try: from descartes.patch import PolygonPatch except ImportError: raise ImportError( "If you want to use the tricontourf in plotting function, you need to install `descartes` " "package via `pip install descartes` see more details at https://pypi.org/project/descartes/," ) if ax is None: fig = plt.figure() ax = fig.add_subplot(111) margin = margin x_min, y_min, x_max, y_max = polygon.bounds ax.set_xlim([x_min - margin, x_max + margin]) ax.set_ylim([y_min - margin, y_max + margin]) patch = PolygonPatch(polygon, fc=fc, ec=ec, fill=fill, zorder=-1, lw=3, alpha=0.4, **kwargs) ax.add_patch(patch) return ax
# --------------------------------------------------------------------------------------------------- # the following Loess class is taken from: # link: https://github.com/joaofig/pyloess/blob/master/pyloess/Loess.py
[docs]def tricubic(x): y = np.zeros_like(x) idx = (x >= -1) & (x <= 1) y[idx] = np.power(1.0 - np.power(np.abs(x[idx]), 3), 3) return y
[docs]class Loess(object): @staticmethod
[docs] def normalize_array(array): min_val = np.min(array) max_val = np.max(array) return (array - min_val) / (max_val - min_val), min_val, max_val
def __init__(self, xx, yy, degree=1): self.n_xx, self.min_xx, self.max_xx = self.normalize_array(xx) self.n_yy, self.min_yy, self.max_yy = self.normalize_array(yy) self.degree = degree @staticmethod
[docs] def get_min_range(distances, window): min_idx = np.argmin(distances) n = len(distances) if min_idx == 0: return np.arange(0, window) if min_idx == n - 1: return np.arange(n - window, n) min_range = [min_idx] while len(min_range) < window: i0 = min_range[0] i1 = min_range[-1] if i0 == 0: min_range.append(i1 + 1) elif i1 == n - 1: min_range.insert(0, i0 - 1) elif distances[i0 - 1] < distances[i1 + 1]: min_range.insert(0, i0 - 1) else: min_range.append(i1 + 1) return np.array(min_range)
@staticmethod
[docs] def get_weights(distances, min_range): max_distance = np.max(distances[min_range]) weights = tricubic(distances[min_range] / max_distance) return weights
[docs] def normalize_x(self, value): return (value - self.min_xx) / (self.max_xx - self.min_xx)
[docs] def denormalize_y(self, value): return value * (self.max_yy - self.min_yy) + self.min_yy
[docs] def estimate(self, x, window, use_matrix=False, degree=1): n_x = self.normalize_x(x) distances = np.abs(self.n_xx - n_x) min_range = self.get_min_range(distances, window) weights = self.get_weights(distances, min_range) if use_matrix or degree > 1: wm = np.multiply(np.eye(window), weights) xm = np.ones((window, degree + 1)) xp = np.array([[math.pow(n_x, p)] for p in range(degree + 1)]) for i in range(1, degree + 1): xm[:, i] = np.power(self.n_xx[min_range], i) ym = self.n_yy[min_range] xmt_wm = np.transpose(xm) @ wm beta = np.linalg.pinv(xmt_wm @ xm) @ xmt_wm @ ym y = (beta @ xp)[0] else: xx = self.n_xx[min_range] yy = self.n_yy[min_range] sum_weight = np.sum(weights) sum_weight_x = np.dot(xx, weights) sum_weight_y = np.dot(yy, weights) sum_weight_x2 = np.dot(np.multiply(xx, xx), weights) sum_weight_xy = np.dot(np.multiply(xx, yy), weights) mean_x = sum_weight_x / sum_weight mean_y = sum_weight_y / sum_weight b = (sum_weight_xy - mean_x * mean_y * sum_weight) / (sum_weight_x2 - mean_x * mean_x * sum_weight) a = mean_y - b * mean_x y = a + b * n_x return self.denormalize_y(y)
[docs]def _convert_to_geo_dataframe(adata, basis): # convert to AnnData with GeoDataFrame as obs adata.obs[basis] = pd.Series(adata.obsm[basis]).apply(loads, hex=True).values adata.obs = gpd.GeoDataFrame(adata.obs, geometry=basis) return adata
[docs]def save_return_show_fig_utils( save_show_or_return: Literal["save", "show", "return", "both", "all"], show_legend: bool, background: str, prefix: str, save_kwargs: Dict, total_panels: int, fig: matplotlib.figure.Figure, axes: matplotlib.axes.Axes, return_all: bool, return_all_list: Union[List, Tuple, None], ) -> Optional[Tuple]: from ...configuration import reset_rcParams from ...tools.utils import update_dict if show_legend: plt.subplots_adjust(right=0.85) if save_show_or_return in ["save", "both", "all"]: s_kwargs = { "path": None, "prefix": prefix, "dpi": None, "ext": "pdf", "transparent": True, "close": True if save_show_or_return == "save" else False, "verbose": True, } s_kwargs = update_dict(s_kwargs, save_kwargs) save_fig(**s_kwargs) if background is not None: reset_rcParams() if save_show_or_return in ["show", "both", "all"]: # with warnings.catch_warnings(): # warnings.simplefilter("ignore") # plt.tight_layout() plt.show() if background is not None: reset_rcParams() if save_show_or_return in ["return", "all"]: if background is not None: reset_rcParams() if return_all: return (fig, *return_all_list) if total_panels > 1 else (fig, *return_all_list) else: return (fig, axes) if total_panels > 1 else (fig, axes)
# --------------------------------------------------------------------------------------------------- # for plotting: subset and reorder data array # ---------------------------------------------------------------------------------------------------
[docs]def _get_array_values( X: Union[np.ndarray, scipy.sparse.base.spmatrix], dim_names: pd.Index, keys: List[str], axis: Literal[0, 1], backed: bool, ): """ Subset and reorder data array, given array and corresponding array index. Args: X : np.ndarray or scipy sparse matrix dim_names : pd.Index Names of keys : list of str Index names to subset axis : int, 0 or 1 Subset rows or columns of 'X' (0 for rows, 1 for columns) backed : bool Interfaces w/ AnnData objects; is True if AnnData is backed to disk Returns: matrix : np.ndarray """ mutable_idxer = [slice(None), slice(None)] idx = dim_names.get_indexer(keys) if backed: idx_order = np.argsort(idx) rev_idxer = mutable_idxer.copy() mutable_idxer[axis] = idx[idx_order] rev_idxer[axis] = np.argsort(idx_order) matrix = X[tuple(mutable_idxer)][tuple(rev_idxer)] else: mutable_idxer[axis] = idx matrix = X[tuple(mutable_idxer)] from scipy.sparse import issparse if issparse(matrix): matrix = matrix.toarray() return matrix
# --------------------------------------------------------------------------------------------------- # for plotting: generating object to map from feature magnitudes to color intensities # ---------------------------------------------------------------------------------------------------
[docs]def check_colornorm( vmin: Union[None, float] = None, vmax: Union[None, float] = None, vcenter: Union[None, float] = None, norm: Union[None, matplotlib.colors.Normalize] = None, ): """ When plotting continuous variables, configure a normalizer object for the purposes of mapping the data to varying color intensities. Args: vmin : optional float The data value that defines 0.0 in the normalization. Defaults to the min value of the dataset. vmax : optional float The data value that defines 1.0 in the normalization. Defaults to the the max value of the dataset. vcenter : optional float The data value that defines 0.5 in the normalization norm : optional `matplotlib.colors.Normalize` object Optional already-initialized normalizing object that scales data, typically into the interval [0, 1], for the purposes of mapping to color intensities for plotting. Do not pass both 'norm' and 'vmin'/'vmax', etc. Returns: normalize : `matplotlib.colors.Normalize` object The normalizing object that scales data, typically into the interval [0, 1], for the purposes of mapping to color intensities for plotting. """ from matplotlib.colors import Normalize try: from matplotlib.colors import TwoSlopeNorm as DivNorm except ImportError: from matplotlib.colors import DivergingNorm as DivNorm if norm is not None: if (vmin is not None) or (vmax is not None) or (vcenter is not None): raise ValueError("Passing both norm and vmin/vmax/vcenter is not allowed.") else: if vcenter is not None: norm = DivNorm(vmin=vmin, vmax=vmax, vcenter=vcenter) else: norm = Normalize(vmin=vmin, vmax=vmax) return norm
# --------------------------------------------------------------------------------------------------- # for plotting: ensure no duplicate keyword arguments # ---------------------------------------------------------------------------------------------------
[docs]def deduplicate_kwargs(kwargs_dict, **kwargs): """ Given a dictionary of plot parameters (kwargs_dict) and any number of additional keyword arguments, merge the parameters into a single consolidated dictionary to avoid argument duplication errors. If kwargs_dict contains a key that matches any of the additional keyword arguments, only the value in kwargs_dict is kept. Args: kwargs_dict : dict Each key is a variable name and each value is the value of that variable kwargs : Any additional keyword arguments, the keywords of which may or may not already be in 'kwargs_dict' """ kwargs.update(kwargs_dict) return kwargs
# --------------------------------------------------------------------------------------------------- # Dendrogram and utilities for dendrogram generation # ---------------------------------------------------------------------------------------------------
[docs]def _dendrogram_sig(data: np.ndarray, method: str, **kwargs) -> Tuple[List[int], List[int], List[int], List[int]]: sch_linkage_params = {k for k in signature(sch.linkage).parameters.keys()} sch_dendro_params = {k for k in signature(sch.dendrogram).parameters.keys()} # Extract the kwargs that correspond to each function: link_kwargs = {k: v for k, v in kwargs.items() if k in sch_linkage_params} dendro_kwargs = {k: v for k, v in kwargs.items() if k in sch_dendro_params} # Row cluster: row_link = sch.linkage(data, method=method, **link_kwargs) row_dendro = sch.dendrogram(row_link, no_plot=True, **dendro_kwargs) row_order = row_dendro["leaves"] # Column cluster: col_link = sch.linkage(np.transpose(data), method=method, **link_kwargs) col_dendro = sch.dendrogram(col_link, no_plot=True, **dendro_kwargs) col_order = col_dendro["leaves"] return row_order, col_order, row_link, col_link
@SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE, "adata")
[docs]def dendrogram( adata: AnnData, cat_key: str, n_pcs: int = 30, use_rep: Union[None, str] = None, var_names: Union[None, List[str]] = None, cor_method: str = "pearson", linkage_method: str = "complete", optimal_ordering: bool = False, key_added: Union[None, str] = None, inplace: bool = True, ) -> Optional[Dict[str, Any]]: """ Computes a hierarchical clustering for the categories given by 'cat_key'. By default, the PCA representation is used unless `.X` has less than 50 variables. Alternatively, a list of `var_names` (e.g. genes) can be given. If this is the case, will subset to these features and use them for the dendrogram. Args: adata: object of class `anndata.AnnData` cat_key: Name of key in .obs specifying group labels for each sample n_pcs: Number of principal components to use in computing hierarchical clustering use_rep: Entry in .obsm to use for computing hierarchical clustering var_names: List of genes to define a subset of 'adata' to compute hierarchical clustering directly on expression values. cor_method: Correlation method to use. Options are 'pearson', 'kendall', and 'spearman' linkage_method: Linkage method to use. See :func:`scipy.cluster.hierarchy.linkage` for more information. optimal_ordering: Same as the optimal_ordering argument of :func:`scipy.cluster.hierarchy.linkage` which reorders the linkage matrix so that the distance between successive leaves is minimal. key_added: Sets key in .uns in which dendrogram information is saved. By default, the dendrogram information is added to `.uns[f'dendrogram_{cat_key}']`. inplace: If `True`, adds dendrogram information to `adata.uns[key_added]`, else this function returns the information. Returns: If `inplace=False`, returns dendrogram information, else adata object is updated in place with information stored in `adata.uns[key_added]`. """ logger = lm.get_main_logger() if not isinstance(cat_key, list): cat_key = [cat_key] # For each category label given in 'cat_key': for cat in cat_key: if cat not in adata.obs_keys(): logger.error( "'cat_key' has to be a valid observation. " f"Given value: {cat}, valid observations: {adata.obs_keys()}" ) if not is_categorical_dtype(adata.obs[cat_key]): logger.error( "'cat_key' has to be a categorical observation. " f"Given value: {cat}, Column type: {adata.obs[cat].dtype}" ) if var_names is None: # Choose representation to use for hierarchical clustering: if use_rep is None and n_pcs == 0: use_rep = "X" if use_rep is None: if adata.n_vars > n_pcs: if "X_pca" in adata.obsm.keys(): if n_pcs is not None and n_pcs > adata.obsm["X_pca"].shape[1]: logger.error("Existing 'X_pca' does not have enough PCs.") X = adata.obsm["X_pca"][:, :n_pcs] logger.info(f"Using 'X_pca' with n_pcs = {X.shape[1]} to compute dendrogram...") else: logger.warning( "'n_pcs' was provided, but 'X_pca' does not already exist. If you meant to use " "gene expression, set 'use_rep' = 'X' or 'n_pcs' = 0. For now, will proceed with " "computing PCA representation and using rep 'X_pca'." ) pca = PCA( n_components=min(n_pcs, adata.X.shape[1] - 1), svd_solver="arpack", random_state=0, ) fit = pca.fit(adata.X.toarray()) if scipy.sparse.issparse(adata.X) else pca.fit(adata.X) X_pca = ( fit.transform(adata.X.toarray()) if scipy.sparse.issparse(adata.X) else fit.transform(adata.X) ) adata.obsm["X_pca"] = X_pca else: logger.info("Using data matrix X directly") X = adata.X.toarray() if scipy.sparse.issparse(adata.X) else adata.X else: if use_rep in adata.obsm.keys() and n_pcs is not None: if n_pcs > adata.obsm[use_rep].shape[1]: logger.error( f"{use_rep} does not have enough dimensions. Provide a representation with equal or more " f"dimensions than 'n_pcs' or lower 'n_pcs'." ) X = adata.obsm[use_rep][:, :n_pcs] elif use_rep in adata.obsm.keys() and n_pcs is None: X = adata.obsm[use_rep] elif use_rep == "X": X = adata.X.toarray() if scipy.sparse.issparse(adata.X) else adata.X else: logger.error("Did not find {} in `.obsm.keys()`. Needs to be compute first.".format(use_rep)) rep_df = pd.DataFrame(X) categorical = adata.obs[cat_key[0]] # If multiple category keys are given, create new categories by merging their combinations: if len(cat_key) > 1: for cat in cat_key[1:]: categorical = (categorical.astype(str) + "_" + adata.obs[cat].astype(str)).astype("category") categorical.name = "_".join(cat_key) rep_df.set_index(categorical, inplace=True) categories = rep_df.index.categories else: gene_names = adata.var_names from .dotplot import adata_to_frame categories, rep_df = adata_to_frame(adata, gene_names, cat_key) # Aggregate values within categories using "mean": mean_df = rep_df.groupby(level=0).mean() corr_matrix = mean_df.T.corr(method=cor_method) corr_condensed = distance.squareform(1 - corr_matrix) z_var = sch.linkage(corr_condensed, method=linkage_method, optimal_ordering=optimal_ordering) dendro_info = sch.dendrogram(z_var, labels=list(categories), no_plot=True) dat = dict( linkage=z_var, cat_key=cat_key, use_rep=use_rep, cor_method=cor_method, linkage_method=linkage_method, categories_ordered=dendro_info["ivl"], categories_idx_ordered=dendro_info["leaves"], dendrogram_info=dendro_info, correlation_matrix=corr_matrix.values, ) if inplace: if key_added is None: key_added = f'dendrogram_{"_".join(cat_key)}' logger.info_insert_adata(key_added, adata_attr="uns") adata.uns[key_added] = dat else: return dat
[docs]def plot_dendrogram( dendro_ax: matplotlib.axes.Axes, adata: AnnData, cat_key: str, dendrogram_key: Union[None, str] = None, orientation: Literal["top", "bottom", "left", "right"] = "right", remove_labels: bool = True, ticks: Union[None, Collection[float]] = None, ): """ Plots dendrogram on the provided Axes, using the dendrogram information stored in `.uns[dendrogram_key]` Args: dendro_ax: object of class `matplotlib.axes.Axes` adata: object of class `anndata.AnnData` Contains dendrogram information as well as the data that will be plotted (and was used to hierarchically cluster) cat_key: Key in .obs containing category labels for all samples dendrogram_key: orientation: Specifies dendrogram placement relative to the plotting window. Options: 'top', 'bottom', 'left', 'right' remove_labels: Removes labels along the side that dendrogram is on, if any ticks: Assumes original ticks come from `scipy.cluster.hierarchy.dendrogram`, but if not can also pass a list of custom tick values. """ logger = lm.get_main_logger() # Get dendrogram key: if not isinstance(dendrogram_key, str): if isinstance(cat_key, str): dendrogram_key = f"dendrogram_{cat_key}" elif isinstance(cat_key, list): dendrogram_key = f'dendrogram_{"_".join(cat_key)}' if dendrogram_key not in adata.uns: logger.warning( f"Dendrogram data not found (using key={dendrogram_key}). Running :func `st.pl.dendrogram` with " f"default parameters. For fine tuning it is recommended to run `st.pl.dendrogram` independently." ) dendrogram(adata, cat_key, key_added=dendrogram_key) if "dendrogram_info" not in adata.uns[dendrogram_key]: raise ValueError( f"The given dendrogram key ({dendrogram_key!r}) does not contain valid dendrogram information." ) def translate_pos(pos_list: List[float], new_ticks: List[int], old_ticks: Union[np.ndarray, List[int]]): """ Transforms the dendrogram coordinates to a given new position. The xlabel_pos ('pos_list') and orig_ticks ('old_ticks') should be of the same length. This is mostly done for the heatmap case, where the position of the dendrogram leaves needs to be adjusted depending on the category size. Args: pos_list: List of dendrogram positions that should be translated new_ticks: Sorted list of desired tick positions (e.g. [0, 1, 2, 3]) old_ticks: sorted list of original tick positions (e.g. [5, 15, 25, 35]) This list is usually the default position used by `scipy.cluster.hierarchy.dendrogram`. Returns: new_xs: Translated list of positions """ if not isinstance(old_ticks, list): old_ticks = old_ticks.tolist() new_xs = [] for x_val in pos_list: if x_val in old_ticks: new_x_val = new_ticks[old_ticks.index(x_val)] else: # find smaller and bigger indices idx_next = np.searchsorted(old_ticks, x_val, side="left") idx_prev = idx_next - 1 old_min = old_ticks[idx_prev] old_max = old_ticks[idx_next] new_min = new_ticks[idx_prev] new_max = new_ticks[idx_next] new_x_val = ((x_val - old_min) / (old_max - old_min)) * (new_max - new_min) + new_min new_xs.append(new_x_val) return new_xs dendro_info = adata.uns[dendrogram_key]["dendrogram_info"] leaves = dendro_info["ivl"] icoord = np.array(dendro_info["icoord"]) dcoord = np.array(dendro_info["dcoord"]) orig_ticks = np.arange(5, len(leaves) * 10 + 5, 10).astype(float) # check that ticks has the same length as orig_ticks if ticks is not None and len(orig_ticks) != len(ticks): logger.warning("'ticks' argument does not have the same size as orig_ticks. The argument will be ignored.") ticks = None for xs, ys in zip(icoord, dcoord): if ticks is not None: xs = translate_pos(xs, ticks, orig_ticks) if orientation in ["right", "left"]: xs, ys = ys, xs dendro_ax.plot(xs, ys, color="#555555") dendro_ax.tick_params(bottom=False, top=False, left=False, right=False) ticks = ticks if ticks is not None else orig_ticks if orientation in ["right", "left"]: dendro_ax.set_yticks(ticks) dendro_ax.set_yticklabels(leaves, fontsize="small", rotation=0) dendro_ax.tick_params(labelbottom=False, labeltop=False) if orientation == "left": xmin, xmax = dendro_ax.get_xlim() dendro_ax.set_xlim(xmax, xmin) dendro_ax.tick_params(labelleft=False, labelright=True) else: dendro_ax.set_xticks(ticks) dendro_ax.set_xticklabels(leaves, fontsize="small", rotation=90) dendro_ax.tick_params(labelleft=False, labelright=False) if orientation == "bottom": ymin, ymax = dendro_ax.get_ylim() dendro_ax.set_ylim(ymax, ymin) dendro_ax.tick_params(labeltop=True, labelbottom=False) if remove_labels: dendro_ax.tick_params(labelbottom=False, labeltop=False, labelleft=False, labelright=False) dendro_ax.grid(False) dendro_ax.spines["right"].set_visible(False) dendro_ax.spines["top"].set_visible(False) dendro_ax.spines["left"].set_visible(False) dendro_ax.spines["bottom"].set_visible(False)