Source code for spateo.plotting.static.interactions

"""
Plots to visualize results from cell-cell colocalization based analyses, as well as cell-cell communication
inference-based analyses. Makes use of dotplot-generating functions
"""
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union

from matplotlib.collections import PolyCollection
from matplotlib.ticker import StrMethodFormatter
from mpl_toolkits.axes_grid1 import make_axes_locatable

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal

from inspect import signature

import matplotlib as mpl
import matplotlib.patheffects as PathEffects
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
from anndata import AnnData
from matplotlib import rcParams
from scipy.cluster import hierarchy as sch

from ...configuration import SKM, config_spateo_rcParams, set_pub_style
from ...logging import logger_manager as lm
from ...plotting.static.dotplot import CCDotplot
from ...tools.find_neighbors import neighbors
from ...tools.labels import Label, interlabel_connections
from .utils import _dendrogram_sig, save_return_show_fig_utils


@SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE, "adata")
[docs]def ligrec( adata: AnnData, dict_key: str, source_groups: Union[None, str, List[str]] = None, target_groups: Union[None, str, List[str]] = None, means_range: Tuple[float, float] = (-np.inf, np.inf), pvalue_threshold: float = 1.0, remove_empty_interactions: bool = True, remove_nonsig_interactions: bool = False, dendrogram: Union[None, str] = None, alpha: float = 0.001, swap_axes: bool = False, title: Union[None, str] = None, figsize: Union[None, Tuple[float, float]] = None, save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", save_kwargs: Optional[dict] = {}, **kwargs, ): """ Dotplot for visualizing results of ligand-receptor interaction analysis For each L:R pair on the dotplot, molecule 1 is sent from the cluster(s) labeled on the top of the plot (or on the right, if 'swap_axes' is True), whereas molecule 2 is the receptor on the cluster(s) labeled on the bottom. Args: adata: Object of :class `anndata.AnnData` dict_key: Key in .uns to dictionary containing cell-cell communication information. Should contain keys labeled "means" and "pvalues", with values being dataframes for the mean cell type-cell type L:R product and significance values. source_groups: Source interaction clusters. If `None`, select all clusters. target_groups: Target interaction clusters. If `None`, select all clusters. means_range: Only show interactions whose means are within this **closed** interval pvalue_threshold: Only show interactions with p-value <= `pvalue_threshold` remove_empty_interactions: Remove rows and columns that contain NaN values remove_nonsig_interactions: Remove rows and columns that only contain interactions that are larger than `alpha` dendrogram: How to cluster based on the p-values. Valid options are: - None (no input) - do not perform clustering. - `'interacting_molecules'` - cluster the interacting molecules. - `'interacting_clusters'` - cluster the interacting clusters. - `'both'` - cluster both rows and columns. Note that in this case, the dendrogram is not shown. alpha: Significance threshold. All elements with p-values <= `alpha` will be marked by tori instead of dots. swap_axes: Whether to show the cluster combinations as rows and the interacting pairs as columns title: Title of the plot figsize: The width and height of a figure save_show_or_return: Options: "save", "show", "return", "both", "all" - "both" for save and show save_kwargs: A dictionary that will passed to the save_fig function. By default it is an empty dictionary and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. But to change any of these parameters, this dictionary can be used to do so. kwargs : Keyword arguments for :func `style` or :func `legend` of :class `Dotplot` """ logger = lm.get_main_logger() config_spateo_rcParams() set_pub_style() if figsize is None: figsize = rcParams.get("figure.figsize") if title is None: title = "Ligand-Receptor Inference" dict = adata.uns[dict_key] def filter_values( pvals: pd.DataFrame, means: pd.DataFrame, *, mask: pd.DataFrame, kind: str ) -> Tuple[pd.DataFrame, pd.DataFrame]: mask_rows = mask.any(axis=1) pvals = pvals.loc[mask_rows] means = means.loc[mask_rows] if pvals.empty: raise ValueError(f"After removing rows with only {kind} interactions, none remain.") mask_cols = mask.any(axis=0) pvals = pvals.loc[:, mask_cols] means = means.loc[:, mask_cols] if pvals.empty: raise ValueError(f"After removing columns with only {kind} interactions, none remain.") return pvals, means def get_dendrogram(adata: AnnData, linkage: str = "complete") -> Mapping[str, Any]: z_var = sch.linkage( adata.X, metric="correlation", method=linkage, # Unlikely to ever be profiling this many LR pairings, but cap at 1500 optimal_ordering=adata.n_obs <= 1500, ) dendro_info = sch.dendrogram(z_var, labels=adata.obs_names.values, no_plot=True) # this is what the DotPlot requires return { "linkage": z_var, "cat_key": ["groups"], "cor_method": "pearson", "use_rep": None, "linkage_method": linkage, "categories_ordered": dendro_info["ivl"], "categories_idx_ordered": dendro_info["leaves"], "dendrogram_info": dendro_info, } if len(means_range) != 2: logger.error(f"Expected `means_range` to be a sequence of size `2`, found `{len(means_range)}`.") means_range = tuple(sorted(means_range)) if alpha is not None and not (0 <= alpha <= 1): logger.error(f"Expected `alpha` to be in range `[0, 1]`, found `{alpha}`.") if source_groups is None: source_groups = dict["pvalues"].columns.get_level_values(0) elif isinstance(source_groups, str): source_groups = (source_groups,) if target_groups is None: target_groups = dict["pvalues"].columns.get_level_values(1) if isinstance(target_groups, str): target_groups = (target_groups,) # Get specified source and target groups from the dictionary: pvals: pd.DataFrame = dict["pvalues"].loc[:, (source_groups, target_groups)] means: pd.DataFrame = dict["means"].loc[:, (source_groups, target_groups)] if pvals.empty: raise ValueError("No valid clusters have been selected.") means = means[(means >= means_range[0]) & (means <= means_range[1])] pvals = pvals[pvals <= pvalue_threshold] if remove_empty_interactions: pvals, means = filter_values(pvals, means, mask=~(pd.isnull(means) | pd.isnull(pvals)), kind="NaN") if remove_nonsig_interactions and alpha is not None: pvals, means = filter_values(pvals, means, mask=pvals <= alpha, kind="non-significant") start, label_ranges = 0, {} if dendrogram == "interacting_clusters": # Set rows to be cluster combinations, not LR pairs: pvals = pvals.T means = means.T for cls, size in (pvals.groupby(level=0, axis=1)).size().to_dict().items(): label_ranges[cls] = (start, start + size - 1) start += size label_ranges = {k: label_ranges[k] for k in sorted(label_ranges.keys())} pvals = pvals[label_ranges.keys()].astype("float") # Add minimum value to p-values to avoid value error- 3.0 will be the largest possible value: pvals = -np.log10(pvals + min(1e-3, alpha if alpha is not None else 1e-3)).fillna(0) pvals.columns = map(" | ".join, pvals.columns.to_flat_index()) pvals.index = map(" | ".join, pvals.index.to_flat_index()) means = means[label_ranges.keys()].fillna(0) means.columns = map(" | ".join, means.columns.to_flat_index()) means.index = map(" | ".join, means.index.to_flat_index()) means = np.log2(means + 1) var = pd.DataFrame(pvals.columns) var = var.set_index(var.columns[0]) # Instantiate new AnnData object containing plot values: adata = AnnData(pvals.values, obs={"groups": pd.Categorical(pvals.index)}, var=var, dtype=pvals.values.dtype) adata.obs_names = pvals.index minn = np.nanmin(adata.X) delta = np.nanmax(adata.X) - minn adata.X = (adata.X - minn) / delta # To satisfy conditional check that happens on instantiating dotplot: adata.uns["__type"] = "UMI" try: if dendrogram == "both": row_order, col_order, _, _ = _dendrogram_sig( adata.X, method="complete", metric="correlation", optimal_ordering=adata.n_obs <= 1500 ) adata = adata[row_order, :][:, col_order] pvals = pvals.iloc[row_order, :].iloc[:, col_order] means = means.iloc[row_order, :].iloc[:, col_order] elif dendrogram is not None: adata.uns["dendrogram"] = get_dendrogram(adata) except Exception as e: logger.warning(f"Unable to create a dendrogram. Reason: `{e}`. Will display without one.") dendrogram = None kwargs["dot_edge_lw"] = 0 kwargs.setdefault("cmap", "magma") kwargs.setdefault("grid", True) kwargs.pop("color_on", None) # Set style and legend kwargs: dotplot_style_params = {k for k in signature(CCDotplot.style).parameters.keys()} dotplot_style_kwargs = {k: v for k, v in kwargs.items() if k in dotplot_style_params} dotplot_legend_params = {k for k in signature(CCDotplot.legend).parameters.keys()} dotplot_legend_kwargs = {k: v for k, v in kwargs.items() if k in dotplot_legend_params} dp = ( CCDotplot( delta=delta, minn=minn, alpha=alpha, adata=adata, var_names=adata.var_names, cat_key="groups", dot_color_df=means, dot_size_df=pvals, title=title, var_group_labels=None if dendrogram == "both" else list(label_ranges.keys()), var_group_positions=None if dendrogram == "both" else list(label_ranges.values()), standard_scale=None, figsize=figsize, ) .style(**dotplot_style_kwargs) .legend( size_title=r"$-\log_{10} ~ P$", colorbar_title=r"$log_2(molecule_1 * molecule_2 + 1)$", **dotplot_legend_kwargs, ) ) if dendrogram in ["interacting_molecules", "interacting_clusters"]: dp.add_dendrogram(size=1.6, dendrogram_key="dendrogram") if swap_axes: dp.swap_axes() dp.make_figure() if dendrogram != "both": # Remove the target part in: source | target labs = dp.ax_dict["mainplot_ax"].get_yticklabels() if swap_axes else dp.ax_dict["mainplot_ax"].get_xticklabels() for text in labs: text.set_text(text.get_text().split(" | ")[1]) if swap_axes: dp.ax_dict["mainplot_ax"].set_yticklabels(labs) else: dp.ax_dict["mainplot_ax"].set_xticklabels(labs) if alpha is not None: yy, xx = np.where((pvals.values + alpha) >= -np.log10(alpha)) if len(xx) and len(yy): # for dendrogram='both', they are already re-ordered mapper = ( np.argsort(adata.uns["dendrogram"]["categories_idx_ordered"]) if "dendrogram" in adata.uns else np.arange(len(pvals)) ) logger.info(f"Found `{len(yy)}` significant interactions at level `{alpha}`") ss = 0.33 * (adata.X[yy, xx] * (dp.largest_dot - dp.smallest_dot) + dp.smallest_dot) yy = np.array([mapper[y] for y in yy]) if swap_axes: xx, yy = yy, xx dp.ax_dict["mainplot_ax"].scatter( xx + 0.5, yy + 0.5, color="white", edgecolor=kwargs["dot_edge_color"], linewidth=kwargs["dot_edge_lw"], s=ss, lw=0, ) # Save, show or return figures: return save_return_show_fig_utils( save_show_or_return=save_show_or_return, # Doesn't matter what show_legend is for this plotting function show_legend=False, background="white", prefix="dotplot", save_kwargs=save_kwargs, total_panels=1, fig=dp.fig, axes=dp.ax_dict, # Return all parameters are for returning multiple values for 'axes', but this function uses a single dictionary return_all=False, return_all_list=None, )
@SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE, "adata")
[docs]def plot_connections( adata: AnnData, cat_key: str, spatial_key: str = "spatial", n_spatial_neighbors: Union[None, int] = 6, spatial_weights_matrix: Union[None, scipy.sparse.csr_matrix, np.ndarray] = None, expr_weights_matrix: Union[None, scipy.sparse.csr_matrix, np.ndarray] = None, reverse_expr_plot_orientation: bool = True, ax: Union[None, mpl.axes.Axes] = None, figsize: tuple = (3, 3), zero_self_connections: bool = True, normalize_by_self_connections: bool = False, shapes_style: bool = True, label_outline: bool = False, max_scale: float = 0.46, colormap: Union[str, dict, "mpl.colormap"] = "Spectral", title_str: Union[None, str] = None, title_fontsize: Union[None, float] = None, label_fontsize: Union[None, float] = None, save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", save_kwargs: Optional[dict] = {}, ): """Plot spatial_connections between labels- visualization of how closely labels are colocalized Args: adata: AnnData object cat_key: Key in .obs containing categorical grouping labels. Colocalization will be assessed for pairwise combinations of these labels. spatial_key: Key in .obsm containing coordinates in the physical space. Not used unless 'spatial_weights_matrix' is None, in which case this is required. Defaults to "spatial". n_spatial_neighbors: Optional, number of neighbors in the physical space for each cell. Not used unless 'spatial_weights_matrix' is None. spatial_weights_matrix: Spatial distance matrix, weighted by distance between spots. If not given, will compute at runtime. expr_weights_matrix: Gene expression distance matrix, weighted by distance in transcriptomic or PCA space. If not given, only the spatial distance matrix will be plotted. If given, will plot the spatial distance matrix in the left plot and the gene expression distance matrix in the right plot. reverse_expr_plot_orientation: If True, plot the gene expression connections in the form of a lower right triangle. If False, gene expression connections will be an upper left triangle just like the spatial connections. ax: Existing axes object, if applicable figsize: Width x height of desired figure window in inches zero_self_connections: If True, ignores intra-label interactions normalize_by_self_connections: Only used if 'zero_self_connections' is False. If True, normalize intra-label connections by the number of spots of that label shapes_style: If True plots squares, if False plots heatmap label_outline: If True, gives dark outline to axis tick label text max_scale: Only used for the case that 'shape_style' is True, gives maximum size of square colormap: Specifies colors to use for plotting. If dictionary, keys should be numerical labels corresponding to those of the Label object. title_str: Optionally used to give plot a title title_fontsize: Size of plot title- only used if 'title_str' is given. label_fontsize: Size of labels along the axes of the graph save_show_or_return: Whether to save, show or return the figure. If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, displayed and the associated axis and other object will be return. save_kwargs: A dictionary that will passed to the save_fig function. By default it is an empty dictionary and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those keys according to your needs. Returns: (fig, ax): Returns plot and axis object if 'save_show_or_return' is "all" """ from ...plotting.static.utils import save_fig from ...tools.utils import update_dict logger = lm.get_main_logger() config_spateo_rcParams() title_fontsize = rcParams.get("axes.titlesize") if title_fontsize is None else title_fontsize label_fontsize = rcParams.get("axes.labelsize") if label_fontsize is None else label_fontsize if ax is None: if expr_weights_matrix is not None: figsize = (figsize[0] * 2.25, figsize[1]) fig, axes = plt.subplots(1, 2, figsize=figsize) ax_sp, ax_expr = axes[0], axes[1] if reverse_expr_plot_orientation: # Allow subplot boundaries to technically be partially overlapping (for better visual) box = ax_expr.get_position() box.x0 = box.x0 - 0.4 box.x1 = box.x1 - 0.3 ax_expr.set_position(box) else: fig, ax_sp = plt.subplots(1, 1, figsize=figsize) else: ax = ax if len(ax) > 1: ax_sp, ax_expr = ax[0], ax[1] else: ax_sp = ax fig = ax.get_figure() # Convert cell type labels to numerical using Label object: # Remove cell types with fewer than 30 cells: logger.info("Filtering out cell types with fewer than 30 cells...") categories_str = adata.obs[cat_key] # Count occurrences for each category category_counts = categories_str.value_counts() # Filter categories with at least 30 occurrences filtered_categories = category_counts[category_counts >= 30].index # Update the AnnData object to include only the filtered categories adata = adata[categories_str.isin(filtered_categories)].copy() categories_str_cat = np.unique(adata.obs[cat_key].values) categories_num_cat = range(len(categories_str_cat)) map_dict = dict(zip(categories_num_cat, categories_str_cat)) categories_num = adata.obs[cat_key].replace(categories_str_cat, categories_num_cat) # Update expression weights matrix if applicable to only include filtered categories: if expr_weights_matrix is not None: mask = categories_str.isin(filtered_categories) indices_to_retain = np.where(mask)[0] expr_weights_matrix = expr_weights_matrix[indices_to_retain, :][:, indices_to_retain] label = Label(categories_num.to_numpy(), str_map=map_dict) # If spatial weights matrix is not given, compute it. 'spatial_key' needs to be present in the AnnData object: if spatial_weights_matrix is None: if spatial_key not in adata.obsm_keys(): logger.error( f"Given 'spatial_key' {spatial_key} does not exist as key in adata.obsm. Options: " f"{adata.obsm_keys()}." ) _, adata = neighbors(adata, basis="spatial", spatial_key=spatial_key, n_neighbors=n_spatial_neighbors) spatial_weights_matrix = adata.obsp["connectivities"] # Compute spatial connections array: spatial_connections = interlabel_connections(label, spatial_weights_matrix) if zero_self_connections: np.fill_diagonal(spatial_connections, 0) elif normalize_by_self_connections: spatial_connections /= spatial_connections.diagonal()[:, np.newaxis] spatial_connections_max = np.amax(spatial_connections) # Optionally, compute gene expression connections array: if expr_weights_matrix is not None: expr_connections = interlabel_connections(label, expr_weights_matrix) if zero_self_connections: np.fill_diagonal(expr_connections, 0) elif normalize_by_self_connections: expr_connections /= expr_connections.diagonal()[:, np.newaxis] expr_connections_max = np.amax(expr_connections) # Set label colors: if isinstance(colormap, str): cmap = mpl.colormaps[colormap] else: cmap = colormap # If colormap is given, map label ID to points along the colormap. If dictionary is given, instead map each label # to a color using the dictionary keys as guides. if isinstance(cmap, dict): if type(list(cmap.keys())[0]) == str: id_colors = {n_id: cmap[id] for n_id, id in zip(label.ids, label.str_ids)} else: id_colors = {id: cmap[id] for id in label.ids} else: id_colors = {id: cmap(id / label.max_id) for id in label.ids} # -------------------------------- Spatial Connections Plot- Setup -------------------------------- # if shapes_style: # Cell types/labels will be represented using triangles: left_triangle = np.array( ( (-1.0, 1.0), # (1., 1.), (1.0, -1.0), (-1.0, -1.0), ) ) right_triangle = np.array( ( (-1.0, 1.0), (1.0, 1.0), (1.0, -1.0), # (-1., -1.) ) ) polygon_list = [] color_list = [] ax_sp.set_ylim(-0.55, label.num_labels - 0.45) ax_sp.set_xlim(-0.55, label.num_labels - 0.45) for label_1 in range(spatial_connections.shape[0]): for label_2 in range(spatial_connections.shape[1]): if label_1 <= label_2: for triangle in [left_triangle, right_triangle]: center = np.array((label_1, label_2))[np.newaxis, :] scale_factor = np.sqrt(spatial_connections[label_1, label_2] / spatial_connections_max) offsets = triangle * max_scale * scale_factor polygon_list.append(center + offsets) color_list += (id_colors[label.ids[label_2]], id_colors[label.ids[label_1]]) collection = PolyCollection(polygon_list, facecolors=color_list, edgecolors="face", linewidths=0) ax_sp.add_collection(collection) # Remove ticks ax_sp.tick_params(labelbottom=False, labeltop=True, top=False, bottom=False, left=False) ax_sp.xaxis.set_tick_params(pad=-2) else: # Heatmap of connection strengths heatmap = ax_sp.imshow(spatial_connections, cmap=colormap, interpolation="nearest") divider = make_axes_locatable(ax_sp) cax = divider.append_axes("right", size="5%", pad=0.1) fig.colorbar(heatmap, cax=cax) cax.tick_params(axis="both", which="major", labelsize=6, rotation=-45) # Change formatting if values too small if spatial_connections_max < 0.001: cax.yaxis.set_major_formatter(StrMethodFormatter("{x:,.1e}")) # Formatting adjustments ax_sp.set_aspect("equal") ax_sp.set_xticks( np.arange(label.num_labels), ) text_outline = [PathEffects.Stroke(linewidth=0.5, foreground="black", alpha=0.8)] if label_outline else None # If label has categorical labels associated, use those to label the axes instead: if label.str_map is not None: ax_sp.set_xticklabels( label.str_ids, fontsize=label_fontsize, fontweight="bold", rotation=90, path_effects=text_outline, ) else: ax_sp.set_xticklabels( label.ids, fontsize=label_fontsize, fontweight="bold", rotation=0, path_effects=text_outline, ) ax_sp.set_yticks(np.arange(label.num_labels)) if label.str_map is not None: ax_sp.set_yticklabels( label.str_ids, fontsize=label_fontsize, fontweight="bold", path_effects=text_outline, ) else: ax_sp.set_yticklabels( label.ids, fontsize=label_fontsize, fontweight="bold", path_effects=text_outline, ) for ticklabels in [ax_sp.get_xticklabels(), ax_sp.get_yticklabels()]: for n, id in enumerate(label.ids): ticklabels[n].set_color(id_colors[id]) title_str_sp = "Spatial Connections" if title_str is None else title_str ax_sp.set_title(title_str_sp, fontsize=title_fontsize, fontweight="bold") # ------------------------------ Optional Gene Expression Connections Plot- Setup ------------------------------ # if expr_weights_matrix is not None: if shapes_style: polygon_list = [] color_list = [] ax_expr.set_ylim(-0.55, label.num_labels - 0.45) ax_expr.set_xlim(-0.55, label.num_labels - 0.45) for label_1 in range(expr_connections.shape[0]): for label_2 in range(expr_connections.shape[1]): if label_1 <= label_2: for triangle in [left_triangle, right_triangle]: center = np.array((label_1, label_2))[np.newaxis, :] scale_factor = np.sqrt(expr_connections[label_1, label_2] / expr_connections_max) offsets = triangle * max_scale * scale_factor polygon_list.append(center + offsets) color_list += (id_colors[label.ids[label_2]], id_colors[label.ids[label_1]]) # Remove ticks if reverse_expr_plot_orientation: ax_expr.tick_params( labelbottom=True, labeltop=False, labelleft=False, labelright=True, top=False, bottom=False, left=False, ) # Flip x- and y-axes of the expression plot: ax_expr.invert_xaxis() ax_expr.invert_yaxis() else: ax_expr.tick_params(labelbottom=False, labeltop=True, top=False, bottom=False, left=False) ax_expr.xaxis.set_tick_params(pad=-2) collection = PolyCollection(polygon_list, facecolors=color_list, edgecolors="face", linewidths=0) ax_expr.add_collection(collection) else: # Heatmap of connection strengths heatmap = ax_expr.imshow(expr_connections, cmap=colormap, interpolation="nearest") divider = make_axes_locatable(ax_expr) cax = divider.append_axes("right", size="5%", pad=0.1) fig.colorbar(heatmap, cax=cax) cax.tick_params(axis="both", which="major", labelsize=6, rotation=-45) # Change formatting if values too small if spatial_connections_max < 0.001: cax.yaxis.set_major_formatter(StrMethodFormatter("{x:,.1e}")) # Formatting adjustments ax_expr.set_facecolor("none") ax_expr.set_aspect("equal") ax_expr.set_xticks( np.arange(label.num_labels), ) if reverse_expr_plot_orientation: # Despine both spatial connections & gene expression connections plots: ax_sp.spines["right"].set_visible(False) ax_sp.spines["top"].set_visible(False) ax_sp.spines["left"].set_visible(False) ax_sp.spines["bottom"].set_visible(False) ax_expr.spines["right"].set_visible(False) ax_expr.spines["top"].set_visible(False) ax_expr.spines["left"].set_visible(False) ax_expr.spines["bottom"].set_visible(False) text_outline = [PathEffects.Stroke(linewidth=0.5, foreground="black", alpha=0.8)] if label_outline else None # If label has categorical labels associated, use those to label the axes instead: if label.str_map is not None: ax_expr.set_xticklabels( label.str_ids, fontsize=label_fontsize, fontweight="bold", rotation=90, path_effects=text_outline, ) else: ax_expr.set_xticklabels( label.ids, fontsize=label_fontsize, fontweight="bold", rotation=0, path_effects=text_outline, ) ax_expr.set_yticks(np.arange(label.num_labels)) if label.str_map is not None: ax_expr.set_yticklabels( label.str_ids, fontsize=label_fontsize, fontweight="bold", path_effects=text_outline, ) else: ax_expr.set_yticklabels( label.ids, fontsize=label_fontsize, fontweight="bold", path_effects=text_outline, ) for ticklabels in [ax_expr.get_xticklabels(), ax_expr.get_yticklabels()]: for n, id in enumerate(label.ids): ticklabels[n].set_color(id_colors[id]) title_str_expr = "Gene Expression Similarity" if title_str is None else title_str if reverse_expr_plot_orientation: if label_fontsize <= 8: y = -0.3 elif label_fontsize > 8: y = -0.35 else: y = None ax_expr.set_title(title_str_expr, fontsize=title_fontsize, fontweight="bold", y=y) prefix = "spatial_connections" if expr_weights_matrix is None else "spatial_and_expr_connections" if save_show_or_return in ["save", "both", "all"]: s_kwargs = { "path": None, "prefix": prefix, "dpi": None, "ext": "pdf", "transparent": True, "close": True, "verbose": True, } s_kwargs = update_dict(s_kwargs, save_kwargs) save_fig(**s_kwargs) elif save_show_or_return in ["show", "both", "all"]: plt.show() elif save_show_or_return in ["return", "all"]: if expr_weights_matrix is not None: ax = axes else: ax = ax_sp return (fig, ax)