"""
Dotplot class adapted from https://github.com/scverse/scanpy with modifications for suitability to cell-cell
communication and interaction analyses
Development notes: some of the methods mention dendrograms/other extra plots and there is currently no capability to
generate those- coming in future update...additions that will have to be made: functions for plot_dendrogram,
plot_totals, additional if condition in make_figure()...
"""
import collections.abc as cabc
from collections import namedtuple
from itertools import product
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
from typing import Iterable, Optional, Sequence, Tuple, Union
import matplotlib as mpl
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from anndata import AnnData
from matplotlib import gridspec, rcParams
from matplotlib.colorbar import ColorbarBase
from matplotlib.colors import Normalize
from matplotlib.figure import Figure
from matplotlib.path import Path
from pandas.api.types import is_numeric_dtype
from ...configuration import SKM, config_spateo_rcParams, set_pub_style
from ...logging import logger_manager as lm
from .utils import (
_get_array_values,
check_colornorm,
deduplicate_kwargs,
plot_dendrogram,
save_return_show_fig_utils,
)
# --------------------------------------- Data conversion for plotting --------------------------------------- #
@SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE)
[docs]def adata_to_frame(
adata: AnnData,
var_names: Sequence[str],
cat_key: Union[str, Sequence[str]],
num_categories: int = 7,
layer: Union[None, str] = None,
gene_symbols_key: Union[None, str] = None,
):
"""
For the purposes of dot plotting, converts the information given in AnnData object to a dataframe in which the
row names are categories defined by groups and column names correspond to variable names.
Args:
adata: class `anndata.AnnData`
var_names: Should be a subset of adata.var_names
cat_key: The key(s) in .obs of the grouping to consider. Should be a categorical observation; if not,
will be subdivided into 'num_categories'.
num_categories: Only used if groupby observation is not categorical. This value determines the number of groups into
which the groupby observation should be subdivided.
layer: Key in .layers specifying layer to use. If not given, will use .X.
gene_symbols_key: Key in .var containing gene symbols
"""
logger = lm.get_main_logger()
if isinstance(var_names, str):
var_names = [var_names]
# Can group by either .obs key or index. Set this flag to group by adata index
cat_index = None
if cat_key is not None:
if isinstance(cat_key, str):
cat_key = [cat_key]
for group in cat_key:
if group not in list(adata.obs_keys()) + [adata.obs.index.name]:
if adata.obs.index.name is not None:
msg = f' or index name "{adata.obs.index.name}"'
else:
msg = ""
logger.error(
f"Grouping key cannot be found. Given: {group}, could not be found in {adata.obs_keys()}" + msg
)
if group in adata.obs.keys() and group == adata.obs.index.name:
logger.error(f"Given group {group} is both and index and a column level, which is ambiguous.")
if group == adata.obs.index.name:
cat_index = group
if cat_index is not None:
# Downstream operations will already cover the index, so it does not need to be given:
cat_key = cat_key.copy() # copy to not modify user passed parameter
cat_key.remove(cat_index)
keys = list(cat_key) + list(np.unique(var_names))
# Convert chosen .obs entries to dataframe:
if gene_symbols_key is not None:
alias_index = pd.Index(adata.var[gene_symbols_key])
else:
alias_index = None
# Check indices and return warnings in the case of duplicate names:
if alias_index is not None:
# Map from current var_names to gene symbols:
alt_names = pd.Series(adata.var.index, index=alias_index)
alias_name = alias_index.name
alt_search_repr = f"var['{alias_name}']"
else:
alt_names = pd.Series(adata.var.index, index=adata.var.index)
alt_search_repr = "var_names"
# Looking for keys within AnnData- store based on found location:
obs_cols = []
var_idx_keys = []
var_symbols = []
not_found = []
# Check that adata.obs does not contain duplicated columns.
# (if duplicated columns names are present, they will be further duplicated when selecting them)
if not adata.obs.columns.is_unique:
dup_cols = adata.obs.columns[adata.obs.columns.duplicated()].tolist()
logger.error(
f"adata.obs contains duplicated columns. Please rename or remove these columns first.\n`"
f"Duplicated columns: {dup_cols}"
)
if not adata.var.index.is_unique:
logger.error(
f"adata.var_names contains duplicated items. \n"
f"Rename variable names first for example using `adata.var_names_make_unique()`."
)
# Use only unique keys:
for key in np.unique(keys):
if key in adata.obs.columns:
obs_cols.append(key)
if key in alt_names.index:
logger.error(f"The key '{key}' is found in both adata.obs and adata.{alt_search_repr}.")
elif key in alt_names.index:
val = alt_names[key]
if isinstance(val, pd.Series):
# while var_names must be unique, adata.var[gene_symbols] does not
# It's still ambiguous to refer to a duplicated entry though.
assert alias_index is not None
raise KeyError(f"Found duplicate entries for '{key}' in adata.{alt_search_repr}.")
var_idx_keys.append(val)
var_symbols.append(key)
else:
not_found.append(key)
if len(not_found) > 0:
raise KeyError(f"Could not find keys '{not_found}' in columns of `adata.obs` or in adata.{alt_search_repr}.")
adata_tidy_df = pd.DataFrame(index=adata.obs_names)
# Adding var values to DataFrame:
if len(var_idx_keys) > 0:
adata_arr = adata.X if layer is None else adata.layers[layer]
matrix = _get_array_values(
adata_arr, dim_names=adata.var.index, keys=var_idx_keys, axis=1, backed=adata.isbacked
)
adata_tidy_df = pd.concat(
[adata_tidy_df, pd.DataFrame(matrix, columns=var_symbols, index=adata.obs_names)],
axis=1,
)
# Adding obs values to DataFrame:
if len(obs_cols) > 0:
adata_tidy_df = pd.concat([adata_tidy_df, adata.obs[obs_cols]], axis=1)
# Reorder columns to given order (including duplicates keys if present)
if keys:
adata_tidy_df = adata_tidy_df[keys]
assert np.all(np.array(keys) == np.array(adata_tidy_df.columns))
if cat_index is not None:
adata_tidy_df.reset_index(inplace=True)
cat_key.append(cat_index)
if cat_key is None:
categorical = pd.Series(np.repeat("", len(adata_tidy_df))).astype("category")
elif len(cat_key) == 1 and is_numeric_dtype(adata_tidy_df[cat_key[0]]):
# If category column is not categorical, turn it into one by subdividing ranges of values into 'num_categories'
# categories:
categorical = pd.cut(adata_tidy_df[cat_key[0]], num_categories)
elif len(cat_key) == 1:
categorical = adata_tidy_df[cat_key[0]].astype("category")
categorical.name = cat_key[0]
else:
# Join the category values using "_" to make a new 'category' consisting of both categorical columns:
categorical = adata_tidy_df[cat_key].apply("_".join, axis=1).astype("category")
categorical.name = "_".join(cat_key)
# Preserve category order as it appears in adata_tidy_df:
order = {"_".join(k): idx for idx, k in enumerate(product(*(adata_tidy_df[g].cat.categories for g in cat_key)))}
categorical = categorical.cat.reorder_categories(sorted(categorical.cat.categories, key=lambda x: order[x]))
adata_tidy_df = adata_tidy_df[var_names].set_index(categorical)
categories = adata_tidy_df.index.categories
return categories, adata_tidy_df
# --------------------------------------- Initialize plotting grid --------------------------------------- #
# For multi-component plots to plot within the same plotting window- will use multiple Axes objects, but not define
# separate subplots
[docs]def make_grid_spec(
ax_or_figsize: Union[Tuple[int, int], mpl.axes.Axes],
nrows: int,
ncols: int,
wspace: Optional[float] = None,
hspace: Optional[float] = None,
width_ratios: Optional[Sequence[float]] = None,
height_ratios: Optional[Sequence[float]] = None,
) -> Tuple[Figure, gridspec.GridSpecBase]:
"""
Initialize grid layout to place subplots within a figure environment
Args:
ax_or_figsize: Either already-existing ax object or the width and height to create a figure window
nrows: Number of rows in the grid
ncols: Number of columns in the grid
wspace: The amount of width reserved for space between subplots, expressed as a fraction of the average axis width
hspace: The amount of height reserved for space between subplots, expressed as a fraction of the average axis height
width_ratios: Defines the relative widths of the columns. Each column gets a relative width of width_ratios[i] / sum(
width_ratios). If not given, all columns will have the same width.
height_ratios: Defines the relative heights of the rows. Each row gets a relative width of height_ratios[i] / sum(
height_ratios). If not given, all columns will have the same width.
Returns:
fig: Instantiated Figure object
gs: Instantiated gridspec object
"""
kw = dict(
wspace=wspace,
hspace=hspace,
width_ratios=width_ratios,
height_ratios=height_ratios,
)
if isinstance(ax_or_figsize, tuple):
fig = plt.figure(figsize=ax_or_figsize)
gs = gridspec.GridSpec(nrows, ncols, **kw)
return fig, gs
else:
ax = ax_or_figsize
ax.axis("off")
ax.set_frame_on(False)
ax.set_xticks([])
ax.set_yticks([])
fig = ax.figure
gs = ax.get_subplotspec().subgridspec(nrows, ncols, **kw)
return fig, gs
# --------------------------------------- Dotplot class --------------------------------------- #
[docs]class Dotplot:
"""
Simultaneous visualization of two variates that are encoded by the dot size and the dot color. Size usually
represents the fraction of samples that have non-zero values, and color usually represents the magnitude of the
value.
Args:
adata: class `anndata.AnnData`
var_names: Should be a subset of adata.var_names
cat_key: The key(s) in .obs of the grouping to consider. Should be a categorical observation; if not,
will be subdivided into 'num_categories'.
num_categories: Only used if groupby observation is not categorical. This value determines the number of
groups into which the groupby observation should be subdivided.
categories_order: Sets order of categories given by 'cat_key' along the plotting axis
title: Sets title for figure window
figsize: The width and height of a figure
gene_symbols_key: Key in .var containing gene symbols
var_group_positions: Each item in the list should contain the start and end position that the bracket
should cover. Eg. [(0, 4), (5, 8)] means that there are two brackets, one for the var_names in positions 0-4
and other for positions 5-8.
var_group_labels: List of group labels for the variable names (e.g. can group var_names in positions 0-4
as being "group A")
var_group_rotation: Rotation in degrees of the variable name labels. If not given, small labels (<4
characters) are not rotated, but otherwise labels are rotated 90 degrees.
layer: Key in .layers specifying layer to use. If not given, will use .X.
expression_cutoff: Used for binarizing feature expression- feature is considered to be expressed only if
the expression value is greater than this threshold
mean_only_expressed: If True, gene expression is averaged only over the cells expressing the given features
standard_scale: Whether or not to standardize that dimension between 0 and 1, meaning for each variable or
group, subtract the minimum and divide each by its maximum. 'val' or 'group' is used to specify whether this
should be done over variables or groups.
dot_color_df: Pre-prepared dataframe with features as indices, categories as columns, and indices
corresponding to color intensities
dot_size_df: Pre-prepared dataframe with features as indices, categories as columns, and indices
corresponding to dot sizes
ax: Can be used to provide pre-existing plotting axis
vmin: The data value that defines 0.0 in the normalization. Defaults to the min value of the dataset.
vmax: The data value that defines 1.0 in the normalization. Defaults to the the max value of the dataset.
vcenter: The data value that defines 0.5 in the normalization
norm: Optional already-initialized normalizing object that scales data, typically into the interval [0, 1],
for the purposes of mapping to color intensities for plotting. Do not pass both 'norm' and
'vmin'/'vmax', etc.
**kwargs:
Additional arguments passed to `matplotlib.pyplot.scatter()`
"""
# Default parameters- visualization:
[docs] default_colormap = "winter"
[docs] default_color_on = "dot"
[docs] default_smallest_dot = 0.0
[docs] default_largest_dot = 200.0
[docs] default_dot_edgecolor = "black"
[docs] default_dot_edgelw = 0.2
[docs] default_size_exponent = 1.5
[docs] default_size_legend_title = "Fraction of cells\nin group (%)"
[docs] default_color_legend_title = "Mean expression\nin group"
[docs] default_num_colorbar_ticks = 5
[docs] default_num_size_legend_dots = 5
[docs] default_legends_width = 1.5 # inches
[docs] default_plot_x_padding = 0.8 # a unit is equivalent to the distance between two x-axis ticks
[docs] default_plot_y_padding = 1.0 # a unit is equivalent to the distance between two y-axis ticks
# Default parameters- spacing:
[docs] default_category_height = 0.35
[docs] default_category_width = 0.37
# Space between main plot, dendrogram and legend:
@SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE)
def __init__(
self,
adata: AnnData,
var_names: Sequence[str],
cat_key: Union[str, Sequence[str]],
num_categories: int = 7,
prescale_adata: bool = False,
categories_order: Union[None, Sequence[str]] = None,
title: Union[None, str] = None,
figsize: Union[None, Tuple[float, float]] = None,
gene_symbols_key: Union[None, str] = None,
var_group_positions: Union[None, Sequence[Tuple[int, int]]] = None,
var_group_labels: Union[None, Sequence[str]] = None,
var_group_rotation: Union[None, float] = None,
layer: Union[None, str] = None,
expression_cutoff: float = 0.0,
mean_only_expressed: bool = False,
standard_scale: Literal["var", "group"] = None,
dot_color_df: Optional[pd.DataFrame] = None,
dot_size_df: Optional[pd.DataFrame] = None,
ax: Union[None, mpl.axes.Axes] = None,
vmin: Union[None, float] = None,
vmax: Union[None, float] = None,
vcenter: Union[None, float] = None,
norm: Optional[Normalize] = None,
**kwargs,
):
# Default plotting parameters:
config_spateo_rcParams()
set_pub_style()
[docs] self.logger = lm.get_main_logger()
[docs] self.cat_key = [cat_key] if isinstance(cat_key, str) else cat_key
[docs] self.var_names = var_names
[docs] self.var_group_labels = var_group_labels
[docs] self.var_group_positions = var_group_positions
[docs] self.var_group_rotation = var_group_rotation
[docs] self.has_var_groups = True if var_group_positions is not None and len(var_group_positions) > 0 else False
# Update variable names if given as a dictionary:
self._update_var_groups()
# Figure formatting:
if figsize is None:
self.figsize = rcParams.get("figure.figsize")
else:
self.figsize = figsize
self.width, self.height = self.figsize
# Limit for the number of categories that are allowed to be plotted:
[docs] self.max_num_categories = 100
self.categories, self.adata_tidy_df = adata_to_frame(
adata,
self.var_names,
cat_key=cat_key,
num_categories=num_categories,
layer=layer,
gene_symbols_key=gene_symbols_key,
)
# Check categories:
if len(self.categories) > self.max_num_categories:
self.logger.warning(f"Over {self.max_num_categories} categories found. Plot would be very large.")
if categories_order is not None:
if set(self.adata_tidy_df.index.categories) != set(categories_order):
self.logger.error(
"Please check that the categories given by the `order` parameter match the categories to be "
"reordered. \n\n"
"Mismatch: "
f"{set(self.adata_tidy_df.index.categories).difference(categories_order)}\n\n"
f"Given order categories: {categories_order}\n\n"
f"{cat_key} categories: {list(self.adata_tidy_df.index.categories)}\n"
)
return
# Compute fraction of cells having value > chosen expression cutoff, and transform into Boolean matrix using
# the expression cutoff:
[docs] obs_bool = self.adata_tidy_df > expression_cutoff
# If dot size-specifying dataframe is not provided:
if dot_size_df is None:
# Compute the sum per group (for the Boolean matrix, the number of values > expression cutoff), divide the
# result by the total number of cells in the group:
dot_size_df = obs_bool.groupby(level=0).sum() / obs_bool.groupby(level=0).count()
# If dot color-specifying dataframe is not provided:
if dot_color_df is None:
# Compute mean expression value, either only of cells that are expressing or of all cells:
if mean_only_expressed:
dot_color_df = self.adata_tidy_df.mask(~obs_bool).groupby(level=0).mean().fillna(0)
else:
dot_color_df = self.adata_tidy_df.groupby(level=0).mean()
if standard_scale == "group":
dot_color_df = dot_color_df.sub(dot_color_df.min(1), axis=0)
dot_color_df = dot_color_df.div(dot_color_df.max(1), axis=0).fillna(0)
elif standard_scale == "var":
dot_color_df -= dot_color_df.min(0)
dot_color_df = (dot_color_df / dot_color_df.max(0)).fillna(0)
elif standard_scale is None:
pass
else:
self.logger.warning(
"Unknown input given for 'standard_scale', proceeding without further " "processing array"
)
else:
# check that both matrices have the same shape
if dot_color_df.shape != dot_size_df.shape:
self.logger.error(
"The given dot_color_df data frame has a different shape than the data frame used for the dot "
"size. Both data frames need to have the same index and columns."
)
# Remove duplicated features (can occur e.g. if the same gene is a marker for two groups)
unique_var_names, unique_idx = np.unique(dot_color_df.columns, return_index=True)
if len(unique_var_names) != len(self.var_names):
dot_color_df = dot_color_df.iloc[:, unique_idx]
# Use the same order for rows and columns in the color and size dataframes:
[docs] dot_color_df = dot_color_df.loc[dot_size_df.index][dot_size_df.columns]
self.dot_color_df = dot_color_df
[docs] self.dot_size_df = dot_size_df
# Initialize all style parameters to the default:
[docs] self.cmap = self.default_colormap
[docs] self.dot_max = self.default_dot_max
[docs] self.dot_min = self.default_dot_min
[docs] self.smallest_dot = self.default_smallest_dot
[docs] self.largest_dot = self.default_largest_dot
[docs] self.color_on = self.default_color_on
[docs] self.size_exponent = self.default_size_exponent
[docs] self.plot_x_padding = self.default_plot_x_padding
[docs] self.plot_y_padding = self.default_plot_y_padding
[docs] self.dot_edge_color = self.default_dot_edgecolor
[docs] self.dot_edge_lw = self.default_dot_edgelw
# Set legend defaults:
[docs] self.color_legend_title = self.default_color_legend_title
[docs] self.size_title = self.default_size_legend_title
[docs] self.num_colorbar_ticks = self.default_num_colorbar_ticks
[docs] self.num_size_legend_dots = self.default_num_size_legend_dots
[docs] self.base = self.default_base
[docs] self.legends_width = self.default_legends_width
[docs] self.show_size_legend = True
[docs] self.show_colorbar = True
[docs] self.wspace = self.default_wspace
# For plotting:
[docs] VBoundNorm = namedtuple("VBoundNorm", ["vmin", "vmax", "vcenter", "norm"])
[docs] self.vboundnorm = VBoundNorm(vmin=vmin, vmax=vmax, vcenter=vcenter, norm=norm)
# Label order:
[docs] self.are_axes_swapped = False
[docs] self.categories_order = categories_order
[docs] self.var_names_idx_order = None
# For creating extra plots:
# Instantiate plotting variables- ax_dict will contain a dictionary of axes used in the plot:
[docs] def swap_axes(self):
"""
Modifies variables to flip x- and y-axes of dotplot.
By default, the x axis contains 'var_names' (e.g. genes) and the y axis the groupby categories. By setting
'swap_axes' the x-axis becomes the categories and the y-axis becomes the variable names.
"""
self.default_category_height, self.default_category_width = (
self.default_category_width,
self.default_category_height,
)
self.are_axes_swapped = True
return self
[docs] def add_dendrogram(self, show: bool = True, dendrogram_key: Union[None, str] = None, size: float = 0.8):
"""
Show dendrogram based on the hierarchical clustering between the `cat_key` categories. Categories are
reordered to match the dendrogram order.
The dendrogram information is computed using :func:`utils.dendrogram` within Spateo.
If `utils.dendrogram` has not been called previously the function is called with default parameters here.
The dendrogram is by default shown on the right side of the plot or on top if the axes are swapped.
Args:
show: Boolean to turn on (True) or off (False) 'add_dendrogram'
dendrogram_key: Needed if :func `utils.dendrogram` saved the dendrogram using a key different than the
default name.
size: Size of the dendrogram. Corresponds to width when dendrogram shown on the right of the plot,
or height when shown on top. The unit is the same as in matplotlib (inches).
"""
if not show:
self.plot_group_extra = None
return self
if self.cat_key is None or len(self.categories) <= 2:
# dendrogram can only be computed between groupby categories
self.logger.warning(
"Too few categories for dendrogram. Dendrogram is added only when the number of categories to plot > 2"
)
return self
self.group_extra_size = size
# To correctly plot dataframe, categories need to be reordered according to the dendrogram ordering:
self.reorder_categories_after_dendrogram(dendrogram_key)
# So that dendrogram "spines" are aligned with dotplot labels:
dendro_ticks = np.arange(len(self.categories)) + 0.5
self.group_extra_size = size
self.plot_group_extra = {
"kind": "dendrogram",
"width": size,
"dendrogram_key": dendrogram_key,
"dendrogram_ticks": dendro_ticks,
}
return self
# To modify the style of the plot:
[docs] def style(
self,
cmap: str = default_colormap,
color_on: Optional[Literal["dot", "square"]] = default_color_on,
dot_max: Optional[float] = default_dot_max,
dot_min: Optional[float] = default_dot_min,
smallest_dot: Optional[float] = default_smallest_dot,
largest_dot: Optional[float] = default_largest_dot,
dot_edge_color: Optional[float] = default_dot_edgecolor,
dot_edge_lw: Optional[float] = default_dot_edgelw,
size_exponent: Optional[float] = default_size_exponent,
grid: Optional[float] = False,
x_padding: Optional[float] = default_plot_x_padding,
y_padding: Optional[float] = default_plot_y_padding,
):
"""
Modifies visual aspects of the dot plot
Args:
cmap: Name of Matplotlib color map to use
color_on: Options are 'dot' or 'square'. By default the colormap is applied to the color of the dot,
but 'square' changes this to be applied to a square region behind the dot, in which case the dot becomes
transparent with only the edge of the dot shown.
dot_max: If none, the maximum dot size is set to the maximum fraction value found (e.g. 0.6). If given,
the value should be a number between 0 and 1. All fractions larger than dot_max are clipped to this value.
dot_min: If none, the minimum dot size is set to 0. If given, the value should be a number between 0 and 1.
All fractions smaller than dot_min are clipped to this value.
smallest_dot: If none, the smallest dot has size 0. All expression fractions with `dot_min` are plotted with this size.
largest_dot: If none, the largest dot has size 200. All expression fractions with `dot_max` are plotted with this size.
dot_edge_color: Only used if 'color_on' is 'square'. Sets dot edge color
dot_edge_lw: Only used if 'color_on' is 'square'. Sets dot edge line width
size_exponent: Dot size is computed as:
fraction ** size exponent
and is afterwards scaled to match the 'smallest_dot' and 'largest_dot' size parameters.
Using a different size exponent changes the relative sizes of the dots to each other.
grid: Set to true to show grid lines. By default grid lines are not shown. Further configuration of the grid
lines can be achieved directly on the returned ax.
x_padding: Space between the plot left/right borders and the dots center. A unit is the distance between the x
ticks. Only applied when 'color_on' = 'dot'
y_padding: Space between the plot top/bottom borders and the dots center. A unit is the distance between the x
ticks. Only applied when 'color_on' = 'dot'
Returns:
self (instance of class DotPlot)
Example:
Creating a modified dot plot (w/ a loaded AnnData object given name 'adata'):
markers = ['C1QA', 'PSAP', 'CD79A', 'CD79B', 'CST3', 'LYZ']
st.pl.DotPlot(adata, var_names=markers, cat_key='Celltype').style(cmap='RdBu_r', color_on='square').show()
"""
# All variables initialized to their default value, check if any of them were selected to change by the user:
if cmap != self.cmap:
self.cmap = cmap
if dot_max != self.dot_max:
self.dot_max = dot_max
if dot_min != self.dot_min:
self.dot_min = dot_min
if smallest_dot != self.smallest_dot:
self.smallest_dot = smallest_dot
if largest_dot != self.largest_dot:
self.largest_dot = largest_dot
if color_on != self.color_on:
self.color_on = color_on
if size_exponent != self.size_exponent:
self.size_exponent = size_exponent
if dot_edge_color != self.dot_edge_color:
self.dot_edge_color = dot_edge_color
if dot_edge_lw != self.dot_edge_lw:
self.dot_edge_lw = dot_edge_lw
if grid != self.grid:
self.grid = grid
if x_padding != self.plot_x_padding:
self.plot_x_padding = x_padding
if y_padding != self.plot_y_padding:
self.plot_y_padding = y_padding
return self
# Working with the plot legends:
[docs] def legend(
self,
show: bool = True,
show_size_legend: bool = True,
show_colorbar: bool = True,
size_title: Optional[str] = default_size_legend_title,
colorbar_title: Optional[str] = default_color_legend_title,
base: Optional[int] = default_base,
num_colorbar_ticks: Optional[int] = default_num_colorbar_ticks,
num_size_legend_dots: Optional[int] = default_num_size_legend_dots,
width: Optional[float] = default_legends_width,
):
"""
Configures colorbar and other legends for dotplot
Args:
show: Set to `False` to hide the default plot of the legends. This sets the legend width to zero,
which will result in a wider main plot.
show_size_legend: Set to `False` to hide the dot size legend
show_colorbar: Set to `False` to hide the colorbar legend
size_title: Title for the dot size legend. Use '\\n' to add line breaks. Will be shown at the top of
the dot size legend box
colorbar_title: Title for the color bar. Use '\\n' to add line breaks. Will be shown at the top of the
color bar.
base: To determine the size of each "benchmark" dot in the size legend, will use a logscale; this parameter
sets the base of that scale.
num_colorbar_ticks: Number of ticks for the colorbar
num_size_legend_dots: Number of "benchmark" dots to include in the dot size legend
width: Width of the legends area. The unit is the same as in matplotlib (inches)
Returns:
self (instance of class DotPlot)
Example:
Setting the colorbar title (w/ a loaded AnnData object given name 'adata'):
markers = {{'T-cell': 'CD3D', 'B-cell': 'CD79A', 'myeloid': 'CST3'}}
dp = st.pl.DotPlot(adata, markers, groupby='Celltype')
dp.legend(colorbar_title='log(UMI counts + 1)').show()
"""
if not show:
# Turn off legends by setting width to 0
self.legends_width = 0
else:
self.color_legend_title = colorbar_title
self.size_title = size_title
self.base = base
self.num_colorbar_ticks = num_colorbar_ticks
self.num_size_legend_dots = num_size_legend_dots
self.legends_width = width
self.show_size_legend = show_size_legend
self.show_colorbar = show_colorbar
return self
[docs] def get_axes(self):
if self.ax_dict is None:
self.make_figure()
return self.ax_dict
[docs] def _plot_size_legend(self, size_legend_ax: mpl.axes.Axes):
"""
Given axis object, generates dot size legend and displays on plot
For the dot size "benchmarks" on the legend, adjust the difference in size between consecutive benchmarks
based on how different 'self.dot_max' and 'self.dot_min' are.
"""
diff = self.dot_max - self.dot_min
if 0.3 < diff <= 0.6:
step = 0.1
elif diff <= 0.3:
step = 0.05
else:
step = 0.2
# Want the max size to be part of the legend- min size doesn't matter as much (and it's often going to be
# zero anyways)- so set size scale to be inverted:
size_range = np.arange(self.dot_max, self.dot_min, step * -1)[::-1]
if self.dot_min != 0 or self.dot_max != 1:
dot_range = self.dot_max - self.dot_min
size_values = (size_range - self.dot_min) / dot_range
else:
size_values = size_range
size = size_values**self.size_exponent
size = size * (self.largest_dot - self.smallest_dot) + self.smallest_dot
# Plot size bar
size_legend_ax.scatter(
np.arange(len(size)) + 0.5,
np.repeat(0, len(size)),
s=size,
color="gray",
edgecolor="black",
linewidth=self.dot_edge_lw,
zorder=100,
)
size_legend_ax.set_xticks(np.arange(len(size)) + 0.5)
labels = ["{}".format(np.round((x * 100), decimals=0).astype(int)) for x in size_range]
size_legend_ax.set_xticklabels(labels, fontsize=8)
# Remove y ticks and labels
size_legend_ax.tick_params(axis="y", left=False, labelleft=False, labelright=False)
# Remove surrounding lines
size_legend_ax.spines["right"].set_visible(False)
size_legend_ax.spines["top"].set_visible(False)
size_legend_ax.spines["left"].set_visible(False)
size_legend_ax.spines["bottom"].set_visible(False)
size_legend_ax.grid(False)
ymax = size_legend_ax.get_ylim()[1]
size_legend_ax.set_ylim(-1.05 - self.largest_dot * 0.003, 4)
size_legend_ax.set_title(self.size_title, y=ymax + 0.45, size=6)
xmin, xmax = size_legend_ax.get_xlim()
size_legend_ax.set_xlim(xmin - 0.15, xmax + 0.5)
[docs] def _plot_colorbar(self, color_legend_ax: mpl.axes.Axes, normalize: Union[None, mpl.colors.Normalize] = None):
"""
Given axis object, plots a horizontal colorbar
Args:
color_legend_ax : `mpl.axes.Axes` object
Matplotlib axis object to plot onto
normalize : `mpl.colors.Normalize` object
The normalizing object that scales data, typically into the interval [0, 1], for the purposes of
mapping to color intensities for plotting. If None, norm defaults to a colors.Normalize object and
automatically scales based on min/max values in the data.
"""
cmap = plt.get_cmap(self.cmap)
ColorbarBase(
color_legend_ax,
orientation="horizontal",
cmap=cmap,
norm=normalize,
ticks=np.linspace(
np.nanmin(self.dot_color_df.values),
np.nanmax(self.dot_color_df.values),
self.default_num_colorbar_ticks,
),
format="%.2f",
)
color_legend_ax.set_title(self.color_legend_title, fontsize=7)
color_legend_ax.xaxis.set_tick_params(labelsize=7)
[docs] def _plot_legend(
self, legend_ax: mpl.axes.Axes, return_ax_dict: dict, normalize: Union[None, mpl.colors.Normalize] = None
):
"""
Organizes the size legend and color legend.
The structure for the legends is:
First row: Empty space of variable size to control the size of the other rows
Second row: Dot size legend
Third row: Spacer to prevent titles/labels of the color and dot size legends overlapping
Fourth row: Colorbar
Args:
legend_ax : mpl.axes.Axes
Matplotlib axis object to plot onto
return_ax_dict :
"""
cbar_legend_height = self.min_figure_height * 0.08
size_legend_height = self.min_figure_height * 0.27
spacer_height = self.min_figure_height * 0.3
height_ratios = [
self.height - size_legend_height - cbar_legend_height - spacer_height,
size_legend_height,
spacer_height,
cbar_legend_height,
]
fig, legend_gs = make_grid_spec(legend_ax, nrows=4, ncols=1, height_ratios=height_ratios)
if self.show_size_legend:
size_legend_ax = fig.add_subplot(legend_gs[1])
self._plot_size_legend(size_legend_ax)
return_ax_dict["size_legend_ax"] = size_legend_ax
if self.show_colorbar:
color_legend_ax = fig.add_subplot(legend_gs[3])
self._plot_colorbar(color_legend_ax, normalize)
return_ax_dict["color_legend_ax"] = color_legend_ax
# Working with the main body of the plot:
[docs] def _mainplot(self, ax: mpl.axes.Axes):
# Work on a copy of the dataframes. This is to avoid changes on the original data frames after repetitive
# calls to the DotPlot object.
_color_df = self.dot_color_df.copy()
_size_df = self.dot_size_df.copy()
if self.var_names_idx_order is not None:
_color_df = _color_df.iloc[:, self.var_names_idx_order]
_size_df = _size_df.iloc[:, self.var_names_idx_order]
if self.categories_order is not None:
_color_df = _color_df.loc[self.categories_order, :]
_size_df = _size_df.loc[self.categories_order, :]
if self.are_axes_swapped:
_size_df = _size_df.T
_color_df = _color_df.T
self.cmap = self.kwargs.get("cmap", self.cmap)
if "cmap" in self.kwargs:
del self.kwargs["cmap"]
# Put dotplot together!
normalize, dot_min, dot_max = self._dotplot(
_size_df,
_color_df,
ax,
cmap=self.cmap,
dot_max=self.dot_max,
dot_min=self.dot_min,
color_on=self.color_on,
edge_color=self.dot_edge_color,
edge_lw=self.dot_edge_lw,
smallest_dot=self.smallest_dot,
largest_dot=self.largest_dot,
size_exponent=self.size_exponent,
grid=self.grid,
x_padding=self.plot_x_padding,
y_padding=self.plot_y_padding,
vmin=self.vboundnorm.vmin,
vmax=self.vboundnorm.vmax,
vcenter=self.vboundnorm.vcenter,
norm=self.vboundnorm.norm,
**self.kwargs,
)
self.dot_min, self.dot_max = dot_min, dot_max
return normalize
@staticmethod
[docs] def _dotplot(
dot_size: pd.DataFrame,
dot_color: pd.DataFrame,
dot_ax: mpl.axes.Axes,
cmap: str = "Reds",
color_on: str = "dot",
y_label: Union[None, str] = None,
dot_max: Union[None, float] = None,
dot_min: Union[None, float] = None,
standard_scale: Union[None, Literal["var", "group"]] = None,
smallest_dot: float = 0.0,
largest_dot: float = 200,
size_exponent: float = 2,
edge_color: Union[None, str] = None,
edge_lw: Union[None, float] = None,
grid: bool = False,
x_padding: float = 0.8,
y_padding: float = 1.0,
vmin: Union[None, float] = None,
vmax: Union[None, float] = None,
vcenter: Union[None, float] = None,
norm: Union[None, Normalize] = None,
**kwargs,
):
"""
Generate a dotplot given the axis object and two dataframes containing the dot size and dot color. The
indices and columns of the dataframes are used to label the resultant image.
The dots are plotted using :func:`matplotlib.pyplot.scatter()`. Thus, additional
arguments can be passed.
Args:
dot_size : pd.DataFrame
Data frame containing the dot_size.
dot_color : pd.DataFrame
Data frame containing the dot_color, should have the same shape, columns and indices as dot_size.
dot_ax : matplotlib Axes object
Axis to plot figure onto
cmap : str, default 'Reds'
String denoting matplotlib color map
color_on : str, default 'dot'
Options: 'dot' or 'square'. By default the colormap is applied to the color of the dot. Optionally,
the colormap can be applied to an square behind the dot, in which case the dot is transparent and only
the edge is shown.
y_label : optional str
Label for y-axis
dot_max : optional float
If none, the maximum dot size is set to the maximum fraction value found (e.g. 0.6). If given,
the value should be a number between 0 and 1. All fractions larger than dot_max are clipped to this value.
dot_min : optional float
If none, the minimum dot size is set to 0. If given, the value should be a number between 0 and 1.
All fractions smaller than dot_min are clipped to this value.
standard_scale : 'None', 'val', or 'group'
Whether or not to standardize that dimension between 0 and 1, meaning for each variable or group,
subtract the minimum and divide each by its maximum. 'val' or 'group' is used to specify whether this
should be done over variables or groups.
smallest_dot : optional float
If none, the smallest dot has size 0. All expression fractions with `dot_min` are plotted with this size.
largest_dot : optional float
If none, the largest dot has size 200. All expression fractions with `dot_max` are plotted with this size.
size_exponent : float, default 1.5
Dot size is computed as:
fraction ** size exponent
and is afterwards scaled to match the 'smallest_dot' and 'largest_dot' size parameters.
Using a different size exponent changes the relative sizes of the dots to each other.
edge_color : str, default 'black'
Only used if 'color_on' is 'square'. Sets dot edge color
edge_lw : float, default 0.2
Only used if 'color_on' is 'square'. Sets dot edge line width
grid : bool, default False
Set to true to show grid lines. By default grid lines are not shown. Further configuration of the grid
lines can be achieved directly on the returned ax.
x_padding : float, default 0.8
Space between the plot left/right borders and the dots center. A unit is the distance between the x
ticks. Only applied when 'color_on' = 'dot'
y_padding : float, default 1.0
Space between the plot top/bottom borders and the dots center. A unit is the distance between the x
ticks. Only applied when 'color_on' = 'dot'
vmin : optional float
The data value that defines 0.0 in the normalization. Defaults to the min value of the dataset.
vmax : optional float
The data value that defines 1.0 in the normalization. Defaults to the the max value of the dataset.
vcenter : optional float
The data value that defines 0.5 in the normalization
norm : optional `matplotlib.colors.Normalize` object
Optional already-initialized normalizing object that scales data, typically into the interval [0, 1],
for the purposes of mapping to color intensities for plotting. Do not pass both 'norm' and
'vmin'/'vmax', etc.
**kwargs :
Additional arguments passed to `matplotlib.pyplot.scatter`
Returns:
normalize : `matplotlib.colors.Normalize` object
The normalizing object that scales data, typically into the interval [0, 1], for the purposes of
mapping to color intensities for plotting.
dot_min : float
The minimum dot size represented on the plot, given as a fration of the maximum value in the data
dot_max : float
The maximum dot size represented on the plot, given as a fraction of the maximum value in the data
"""
logger = lm.get_main_logger()
if dot_size.shape != dot_color.shape:
logger.error("Dot size and dot color dataframes are not the same size.")
if list(dot_size.index) != list(dot_color.index):
logger.error("Dot size and dot color dataframes do not have the same features.")
if list(dot_size.columns) != list(dot_color.columns):
logger.error("Dot size and dot color dataframes do not have the same categories.")
if standard_scale == "group":
dot_color = dot_color.sub(dot_color.min(1), axis=0)
dot_color = dot_color.div(dot_color.max(1), axis=0).fillna(0)
elif standard_scale == "var":
dot_color -= dot_color.min(0)
dot_color = (dot_color / dot_color.max(0)).fillna(0)
elif standard_scale is None:
pass
# Set the center of each first dot at 0.5 to more easily line up dotplot w/ possible dendrograms:
y, x = np.indices(dot_color.shape)
y = y.flatten() + 0.5
x = x.flatten() + 0.5
frac = dot_size.values.flatten()
mean_flat = dot_color.values.flatten()
cmap = plt.get_cmap(kwargs.get("cmap", cmap))
if "cmap" in kwargs:
del kwargs["cmap"]
if dot_max is None:
dot_max = np.ceil(max(frac) * 10) / 10
else:
if dot_max < 0 or dot_max > 1:
raise ValueError("`dot_max` value has to be between 0 and 1")
if dot_min is None:
dot_min = 0
else:
if dot_min < 0 or dot_min > 1:
raise ValueError("`dot_min` value has to be between 0 and 1")
if dot_min != 0 or dot_max != 1:
# clip frac between dot_min and dot_max
frac = np.clip(frac, dot_min, dot_max)
old_range = dot_max - dot_min
# re-scale frac between 0 and 1
frac = (frac - dot_min) / old_range
size = frac**size_exponent
# rescale size to match smallest_dot and largest_dot
size = size * (largest_dot - smallest_dot) + smallest_dot
normalize = check_colornorm(vmin, vmax, vcenter, norm)
if color_on == "square":
if edge_color is None:
from seaborn.utils import relative_luminance
# Use either black or white for the edge color depending on the luminance of the background
# square color
edge_color = []
for color_value in cmap(normalize(mean_flat)):
lum = relative_luminance(color_value)
edge_color.append(".15" if lum > 0.408 else "w")
edge_lw = 1.5 if edge_lw is None else edge_lw
# Create heatmap with squares, then create circles and plot them over the top:
dot_ax.pcolor(dot_color.values, cmap=cmap, norm=normalize)
for axis in ["top", "bottom", "left", "right"]:
dot_ax.spines[axis].set_linewidth(1.5)
# A few created variables will be used as keyword args to ax.scatter...ensure that they aren't already
# given as keyword args to this function:
kwargs = deduplicate_kwargs(
kwargs,
s=size,
cmap=cmap,
linewidth=edge_lw,
facecolor="none",
edgecolor=edge_color,
norm=normalize,
)
dot_ax.scatter(x, y, **kwargs)
else:
edge_color = "none" if edge_color is None else edge_color
edge_lw = 0.0 if edge_lw is None else edge_lw
color = cmap(normalize(mean_flat))
# A few created variables will be used as keyword args to ax.scatter...ensure that they aren't already
# given as keyword args to this function:
kwargs = deduplicate_kwargs(
kwargs,
s=size,
cmap=cmap,
color=color,
linewidth=edge_lw,
edgecolor=edge_color,
norm=normalize,
)
dot_ax.scatter(x, y, **kwargs)
y_ticks = np.arange(dot_color.shape[0]) + 0.5
dot_ax.set_yticks(y_ticks)
dot_ax.set_yticklabels([dot_color.index[idx] for idx, _ in enumerate(y_ticks)], minor=False)
x_ticks = np.arange(dot_color.shape[1]) + 0.5
dot_ax.set_xticks(x_ticks)
dot_ax.set_xticklabels(
[dot_color.columns[idx] for idx, _ in enumerate(x_ticks)],
rotation=90,
ha="center",
minor=False,
)
dot_ax.tick_params(axis="both", labelsize=6)
dot_ax.grid(False)
dot_ax.set_ylabel(y_label)
# To be consistent with the heatmap plot, is better to invert the order of the y-axis, such that the first
# group is on top
dot_ax.set_ylim(dot_color.shape[0], 0)
dot_ax.set_xlim(0, dot_color.shape[1])
if color_on == "dot":
# Add padding to the x and y lims when the color is not in the square
# Default y range goes from 0.5 to num cols + 0.5 and default x range goes from 0.5 to num rows + 0.5
x_padding = x_padding - 0.5
y_padding = y_padding - 0.5
dot_ax.set_ylim(dot_color.shape[0] + y_padding, -y_padding)
dot_ax.set_xlim(-x_padding, dot_color.shape[1] + x_padding)
if grid:
dot_ax.grid(True, color="lightgray", linewidth=0.1)
dot_ax.set_axisbelow(True)
return normalize, dot_min, dot_max
[docs] def reorder_categories_after_dendrogram(self, dendrogram_key):
"""
Reorders categorical observations along plot axis based on dendrogram results.
The function checks if a dendrogram has already been precomputed. If not, `utils.dendrogram` is run with
default parameters.
The results found in `.uns[dendrogram_key]` are used to reorder `var_group_labels` and `var_group_positions`.
"""
def _format_first_three_categories(_categories):
"""used to clean up warning message"""
_categories = list(_categories)
if len(_categories) > 3:
_categories = _categories[:3] + ["etc."]
return ", ".join(_categories)
# Get dendrogram key:
if not isinstance(dendrogram_key, str):
if isinstance(self.cat_key, str):
dendrogram_key = f"dendrogram_{self.cat_key}"
elif isinstance(self.cat_key, list):
dendrogram_key = f'dendrogram_{"_".join(self.cat_key)}'
if dendrogram_key not in self.adata.uns:
from .utils import dendrogram
self.logger.warning(
f"Dendrogram data not found (using key={dendrogram_key}). Running :func `st.pl.dendrogram` with "
f"default parameters. For fine tuning it is recommended to run `st.pl.dendrogram` independently."
)
dendrogram(self.adata, self.cat_key, key_added=dendrogram_key)
if "dendrogram_info" not in self.adata.uns[dendrogram_key]:
raise ValueError(
f"The given dendrogram key ({dendrogram_key!r}) does not contain valid dendrogram information."
)
dendro_info = self.adata.uns[dendrogram_key]
if self.cat_key != dendro_info["cat_key"]:
raise ValueError(
"Incompatible observations. The precomputed dendrogram contains information for the "
f"observation: '{self.cat_key}' while the plot is made for the observation: '{dendro_info['cat_key']}. "
"Please run :func `st.pl.dendrogram` using the right observation.'"
)
# Category order:
categories_idx_ordered = dendro_info["categories_idx_ordered"]
categories_ordered = dendro_info["categories_ordered"]
if len(self.categories) != len(categories_idx_ordered):
raise ValueError(
f"Incompatible observations. Dendrogram data has {len(categories_idx_ordered)} categories but current "
f"groupby observation {self.cat_key} contains {len(self.categories)} categories. Most likely the "
"underlying groupby observation changed after the initial computation of :func `st.pl.dendrogram`. "
"Please run `st.pl.dendrogram` again.'"
)
# Reorder var_groups (if any)
if self.var_names is not None:
var_names_idx_ordered = list(range(len(self.var_names)))
if self.has_var_groups:
if set(self.var_group_labels) == set(self.categories):
positions_ordered = []
labels_ordered = []
position_start = 0
var_names_idx_ordered = []
for cat_name in categories_ordered:
idx = self.var_group_labels.index(cat_name)
position = self.var_group_positions[idx]
_var_names = self.var_names[position[0] : position[1] + 1]
var_names_idx_ordered.extend(range(position[0], position[1] + 1))
positions_ordered.append((position_start, position_start + len(_var_names) - 1))
position_start += len(_var_names)
labels_ordered.append(self.var_group_labels[idx])
self.var_group_labels = labels_ordered
self.var_group_positions = positions_ordered
else:
self.logger.warning(
"Groups are not reordered because the `groupby` categories and the `var_group_labels` are "
f"different.\n"
f"categories: {_format_first_three_categories(self.categories)}\n"
"var_group_labels: "
f"{_format_first_three_categories(self.var_group_labels)}"
)
if var_names_idx_ordered is not None:
var_names_ordered = [self.var_names[x] for x in var_names_idx_ordered]
else:
var_names_ordered = None
self.categories_idx_ordered = categories_idx_ordered
self.categories_order = dendro_info["categories_ordered"]
self.var_names_idx_order = var_names_idx_ordered
self.var_names_ordered = var_names_ordered
@staticmethod
[docs] def _plot_var_groups_brackets(
gene_groups_ax: mpl.axes.Axes,
group_positions: Iterable[Tuple[int, int]],
group_labels: Sequence[str],
left_adjustment: float = -0.3,
right_adjustment: float = 0.3,
rotation: Optional[float] = None,
orientation: Literal["top", "right"] = "top",
):
"""
Draws brackets that represent groups of features on the given axis.
The 'gene_groups_ax' Axes object should share the x-axis/y-axis (depending on the axis along which the
features are plotted) with the main plot axis. For example, in instantiation:
gene_groups_ax = fig.add_subplot(axs[0,0], sharex=dot_ax)
Args:
gene_groups_ax : `matplotlib.axes.Axes` object
Axis to plot on, should correspond to the axis of the main plot on which the feature names/feature
ticks are drawn
group_positions : list of tuples of form (int, int)
Each item in the list, should contain the start and end position that the bracket should cover.
Eg. [(0, 4), (5, 8)] means that there are two brackets, one for the var_names (eg genes) in
positions 0-4 and the second for positions 5-8.
group_labels : list of str
List of labels for the feature groups
left_adjustment : float, default -0.3
Adjustment to plot the bracket start slightly before or after the first feature position.
If the value is negative the start is moved before.
right_adjustment : float, default 0.3
Adjustment to plot the bracket end slightly before or after the first feature position.
If the value is negative the end is moved before, if positive the end is moved after.
rotation : optional float
In degrees, angle of rotation for the labels. If not given, small labels (<4 characters) are not
rotated, otherwise, they are rotated 90 degrees
orientation : str
Options: 'top' or 'right' to set the location of the brackets
"""
# Get the 'brackets' coordinates as lists of start and end positions
left = [x[0] + left_adjustment for x in group_positions]
right = [x[1] + right_adjustment for x in group_positions]
# verts and codes are used by PathPatch to make the brackets
verts = []
codes = []
# If brackets are specified to be placed along the top of the figure:
if orientation == "top":
# If custom rotation is not specified, rotate labels if any of them is longer than 4 characters
if rotation is None and group_labels:
if max([len(x) for x in group_labels]) > 4:
rotation = 90
else:
rotation = 0
for idx, (left_coor, right_coor) in enumerate(zip(left, right)):
verts.append((left_coor, 0)) # lower-left
verts.append((left_coor, 0.6)) # upper-left
verts.append((right_coor, 0.6)) # upper-right
verts.append((right_coor, 0)) # lower-right
codes.append(Path.MOVETO)
codes.append(Path.LINETO)
codes.append(Path.LINETO)
codes.append(Path.LINETO)
group_x_center = left[idx] + float(right[idx] - left[idx]) / 2
gene_groups_ax.text(
group_x_center,
1.1,
group_labels[idx],
ha="center",
va="bottom",
rotation=rotation,
)
# Else, the brackets will be placed to the right of the figure:
else:
top = left
bottom = right
for idx, (top_coor, bottom_coor) in enumerate(zip(top, bottom)):
verts.append((0, top_coor)) # upper-left
verts.append((0.4, top_coor)) # upper-right
verts.append((0.4, bottom_coor)) # lower-right
verts.append((0, bottom_coor)) # lower-left
codes.append(Path.MOVETO)
codes.append(Path.LINETO)
codes.append(Path.LINETO)
codes.append(Path.LINETO)
diff = bottom[idx] - top[idx]
group_y_center = top[idx] + float(diff) / 2
if diff * 2 < len(group_labels[idx]):
# cut label to fit available space
group_labels[idx] = group_labels[idx][: int(diff * 2)] + "."
gene_groups_ax.text(
1.1,
group_y_center,
group_labels[idx],
ha="right",
va="center",
rotation=270,
fontsize=8,
)
path = Path(verts, codes)
patch = patches.PathPatch(path, facecolor="none", lw=1.5)
gene_groups_ax.add_patch(patch)
gene_groups_ax.grid(False)
gene_groups_ax.axis("off")
# Remove all ticks from the bracket plot Axes object:
gene_groups_ax.tick_params(axis="y", left=False, labelleft=False)
gene_groups_ax.tick_params(axis="x", bottom=False, labelbottom=False, labeltop=False)
[docs] def _update_var_groups(self):
"""
Checks if var_names is a dict. Is this is the cases, then set the
correct values for var_group_labels and var_group_positions
Updates var_names, var_group_labels, var_group_positions
"""
if isinstance(self.var_names, cabc.Mapping):
if self.has_var_groups:
self.logger.warning(
"Given `var_names` is a dictionary. This will reset the current values of `var_group_labels` "
"and `var_group_positions`."
)
var_group_labels = []
_var_names = []
var_group_positions = []
start = 0
for label, vars_list in self.var_names.items():
if isinstance(vars_list, str):
vars_list = [vars_list]
# use list() in case var_list is a numpy array or pandas series
_var_names.extend(list(vars_list))
var_group_labels.append(label)
var_group_positions.append((start, start + len(vars_list) - 1))
start += len(vars_list)
self.var_names = _var_names
self.var_group_labels = var_group_labels
self.var_group_positions = var_group_positions
self.has_var_groups = True
elif isinstance(self.var_names, str):
self.var_names = [self.var_names]
[docs]class CCDotplot(Dotplot):
"""
Because of the often much smaller values dealt with in cell-cell communication inference, this class creates a
modified legend.
Args:
delta : optional float
Distance between the largest value to consider and the smallest value to consider (see 'minn'
parameter below)
minn : optional float
For the dot size legend, sets the value corresponding to the smallest dot on the legend
alpha : optional float
Significance threshold. If given, all elements w/ p-values <= 'alpha' will be marked by rings instead of
dots.
*args :
Positional arguments to initialize :class `Dotplot`
**kwargs :
Keyword arguments to initialize :class `Dotplot`
"""
[docs] default_largest_dot = 50.0
def __init__(self, minn: float, delta: float, alpha: float, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs] self.largest_dot = self.default_largest_dot
[docs] def _plot_size_legend(self, size_legend_ax: mpl.axes.Axes):
"""
Given axis object, generates dot size legend and displays on plot
Overwrites the default :func `plot_size_legend` for :class `Dotplot`
"""
# Ending point:
y = self.base ** -((self.dot_max * self.delta) + self.minn)
# Starting point:
x = self.base ** -((self.dot_min * self.delta) + self.minn)
size_range = -(np.logspace(x, y, self.num_size_legend_dots + 1, base=10).astype(np.float64))
size_range = (size_range - np.min(size_range)) / (np.max(size_range) - np.min(size_range))
# no point in showing dot of size 0
size_range = size_range[1:]
# See documentation for 'style()'- matching the methodology for plotting the actual dots
size = size_range**self.size_exponent
mult = (self.largest_dot - self.smallest_dot) + self.smallest_dot
size = size * mult
# Plot size legend
ymin = -self.largest_dot * 0.003
ymax = 0.65
size_legend_ax.scatter(
np.arange(len(size)) + 0.5,
np.repeat(0, len(size)),
s=size,
color="gray",
edgecolor="black",
linewidth=self.dot_edge_lw,
zorder=100,
)
size_legend_ax.set_xticks(np.arange(len(size)) + 0.5)
labels = [f"{(x * self.delta) + self.minn:.1f}" for x in size_range]
size_legend_ax.set_xticklabels(labels, fontsize=9)
# Remove y ticks and labels
size_legend_ax.tick_params(axis="y", left=False, labelleft=False, labelright=False)
# Remove surrounding lines
size_legend_ax.spines["right"].set_visible(False)
size_legend_ax.spines["top"].set_visible(False)
size_legend_ax.spines["left"].set_visible(False)
size_legend_ax.spines["bottom"].set_visible(False)
size_legend_ax.grid(False)
size_legend_ax.set_ylim(ymin, ymax)
size_legend_ax.set_title(self.size_title, y=ymax + 0.05, size=9)
xmin, xmax = size_legend_ax.get_xlim()
size_legend_ax.set_xlim(xmin - 0.15, xmax + 0.5)
# If significance check is involved, a separate legend panel will be used to indicate significance w/
# closed/open circles:
if self.alpha is not None:
# Attribute will be created/set to not-None upon calling 'make_figure()' during the process of creating the
# outer plotting class
ax = self.fig.add_subplot()
ax.scatter(
[0.35, 0.65],
[0, 0],
s=size[-1],
color="black",
edgecolor="black",
linewidth=self.dot_edge_lw,
zorder=100,
)
ax.scatter(
[0.65], [0], s=0.33 * mult, color="white", edgecolor="black", linewidth=self.dot_edge_lw, zorder=100
)
ax.set_xlim([0, 1])
ax.set_xticks([0.35, 0.65])
ax.set_xticklabels(["false", "true"])
ax.set_yticks([])
ax.set_title(f"significant\n$p={self.alpha}$", y=ymax + 0.05, size=9)
ax.set(frame_on=False)
l, b, w, h = size_legend_ax.get_position().bounds
ax.set_position([l, b + h + 0.2, w, h])
# --------------------------------------- Dotplot wrapper --------------------------------------- #
@SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE)
[docs]def dotplot(
adata: AnnData,
var_names: Sequence[str],
cat_key: Union[str, Sequence[str]],
num_categories: int = 7,
cell_cell_dp: bool = False,
delta: Union[None, float] = None,
minn: Union[None, float] = None,
alpha: Union[None, float] = None,
prescale_adata: bool = False,
expression_cutoff: float = 0.0,
mean_only_expressed: bool = False,
cmap: str = "Reds",
dot_max: float = Dotplot.default_dot_max,
dot_min: float = Dotplot.default_dot_min,
standard_scale: Literal["var", "group"] = None,
smallest_dot: float = Dotplot.default_smallest_dot,
largest_dot: float = Dotplot.default_largest_dot,
title: str = None,
colorbar_title: str = Dotplot.default_color_legend_title,
size_title: str = Dotplot.default_size_legend_title,
figsize: Union[None, Tuple[float, float]] = None,
dendrogram: Union[bool, str] = False,
gene_symbols_key: Union[None, str] = None,
var_group_positions: Union[None, Sequence[Tuple[int, int]]] = None,
var_group_labels: Union[None, Sequence[str]] = None,
var_group_rotation: Union[None, float] = None,
layer: Union[None, str] = None,
swap_axes: bool = False,
dot_color_df: Union[None, pd.DataFrame] = None,
save_show_or_return: Literal["save", "show", "return", "both", "all"] = "save",
save_kwargs: dict = {},
ax: Union[None, mpl.axes.Axes] = None,
vmin: Union[None, float] = None,
vmax: Union[None, float] = None,
vcenter: Union[None, float] = None,
norm: Union[None, Normalize] = None,
**kwargs,
):
"""
Makes a dot plot of the expression values of `var_names`. For each var_name and each `groupby` category a dot
is plotted.
Each dot represents two values: mean expression within each category (visualized by color) and fraction of cells
expressing the `var_name` in the category (visualized by the size of the dot). If `groupby` is not given,
the dotplot assumes that all data belongs to a single category.
Args:
adata: object of class `anndata.AnnData`
var_names: Should be a subset of adata.var_names
cat_key: The key(s) in .obs of the grouping to consider. Should be a categorical observation; if not,
will be subdivided into 'num_categories'.
num_categories: Only used if groupby observation is not categorical. This value determines the number of
groups into which the groupby observation should be subdivided.
cell_cell_dp: Set True to initialize specialized cell-cell dotplot instead of gene expression dotplot
delta: Only used if 'cell_cell_dp' is True- distance between the largest value to consider and the smallest
value to consider (see 'minn' parameter below)
minn: Only used if 'cell_cell_dp' is True- for the dot size legend, sets the value corresponding to the
smallest dot on the legend
alpha: Only used if 'cell_cell_dp' is True- significance threshold. If given, all elements w/ p-values <=
'alpha' will be marked by rings instead of dots.
prescale_adata: Set True to indicate that AnnData object should be scaled- if so, will use 'delta' and
'minn' to do so. If False, will proceed as though adata has already been processed as needed.
expression_cutoff: Used for binarizing feature expression- feature is considered to be expressed only if
the expression value is greater than this threshold
mean_only_expressed: If True, gene expression is averaged only over the cells expressing the given features
cmap: Name of Matplotlib color map to use
dot_max: If none, the maximum dot size is set to the maximum fraction value found (e.g. 0.6). If given,
the value should be a number between 0 and 1. All fractions larger than dot_max are clipped to this value.
dot_min: If none, the minimum dot size is set to 0. If given, the value should be a number between 0 and 1.
All fractions smaller than dot_min are clipped to this value.
standard_scale: Whether or not to standardize that dimension between 0 and 1, meaning for each variable or
group, subtract the minimum and divide each by its maximum. 'val' or 'group' is used to specify whether this
should be done over variables or groups.
smallest_dot: If None, the smallest dot has size 0. All expression fractions with `dot_min` are plotted with
this size.
largest_dot: If None, the largest dot has size 200. All expression fractions with `dot_max` are plotted with
this size.
title: Title for the entire plot
colorbar_title: Title for the color legend. If None will use generic default title
size_title: Title for the dot size legend. If None will use generic default title
figsize: Sets width and height of figure window
dendrogram: If True, adds dendrogram to plot. Will do the same thing if string is given here,
but will recompute dendrogram and save using this argument to set key in .uns.
gene_symbols_key: Key in .var containing gene symbols
var_group_positions: Each item in the list should contain the start and end position that the bracket
should cover. Eg. [(0, 4), (5, 8)] means that there are two brackets, one for the var_names in positions
0-4 and other for positions 5-8
var_group_labels: List of group labels for the variable names (e.g. can group var_names in positions 0-4 as
being "group A")
var_group_rotation: Rotation in degrees of the variable name labels. If not given, small labels (<4
characters) are not rotated, but otherwise labels are rotated 90 degrees.
layer: Key in .layers specifying layer to use. If not given, will use .X.
swap_axes: Set True to switch what is plotted on the x- and y-axes
dot_color_df: Pre-prepared dataframe with features as indices, categories as columns, and indices
corresponding to color intensities
save_show_or_return: Options: "save", "show", "return", "both", "all"
- "both" for save and show
save_kwargs: A dictionary that will passed to the save_fig function. By default it is an empty dictionary
and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf',
"transparent": True, "close": True, "verbose": True} as its parameters. But to change any of these
parameters, this dictionary can be used to do so.
ax: Pre-initialized axis object to plot on
vmin: The data value that defines 0.0 in the normalization. Defaults to the min value of the dataset.
vmax: The data value that defines 1.0 in the normalization. Defaults to the the max value of the dataset.
vcenter: The data value that defines 0.5 in the normalization
norm: Optional already-initialized normalizing object that scales data, typically into the interval [0, 1],
for the purposes of mapping to color intensities for plotting. Do not pass both 'norm' and
'vmin'/'vmax', etc.
kwargs: Additional keyword arguments passed to :func:`matplotlib.pyplot.scatter`
Returns:
fig: Instantiated Figure object- only if 'return' is True
axes: Instantiated Axes object- only if 'return' is True
"""
if cell_cell_dp:
dp = CCDotplot(
adata,
var_names,
cat_key,
delta=delta,
minn=minn,
alpha=alpha,
prescale_adata=prescale_adata,
num_categories=num_categories,
expression_cutoff=expression_cutoff,
mean_only_expressed=mean_only_expressed,
standard_scale=standard_scale,
title=title,
figsize=figsize,
gene_symbols_key=gene_symbols_key,
var_group_positions=var_group_positions,
var_group_labels=var_group_labels,
var_group_rotation=var_group_rotation,
layer=layer,
dot_color_df=dot_color_df,
ax=ax,
vmin=vmin,
vmax=vmax,
vcenter=vcenter,
norm=norm,
**kwargs,
)
else:
dp = Dotplot(
adata,
var_names,
cat_key,
prescale_adata=prescale_adata,
num_categories=num_categories,
expression_cutoff=expression_cutoff,
mean_only_expressed=mean_only_expressed,
standard_scale=standard_scale,
title=title,
figsize=figsize,
gene_symbols_key=gene_symbols_key,
var_group_positions=var_group_positions,
var_group_labels=var_group_labels,
var_group_rotation=var_group_rotation,
layer=layer,
dot_color_df=dot_color_df,
ax=ax,
vmin=vmin,
vmax=vmax,
vcenter=vcenter,
norm=norm,
**kwargs,
)
if dendrogram or isinstance(dendrogram, str):
dp.add_dendrogram(dendrogram_key=dendrogram)
if swap_axes:
dp.swap_axes()
dp = dp.style(
cmap=cmap,
dot_max=dot_max,
dot_min=dot_min,
smallest_dot=smallest_dot,
largest_dot=largest_dot,
dot_edge_lw=kwargs.pop("linewidth", Dotplot.default_dot_edgelw),
).legend(colorbar_title=colorbar_title, size_title=size_title)
dp.make_figure()
# Save, show or return figures:
return save_return_show_fig_utils(
save_show_or_return=save_show_or_return,
# Doesn't matter what show_legend is for this plotting function
show_legend=False,
background="white",
prefix="dotplot",
save_kwargs=save_kwargs,
total_panels=1,
fig=dp.fig,
axes=dp.ax_dict,
# Return all parameters are for returning multiple values for 'axes', but this function uses a single dictionary
return_all=False,
return_all_list=None,
)