Source code for spateo.tools.gene_expression_variance

"""
Characterizing cell-to-cell variability within spatial domains
"""
from collections import OrderedDict
from typing import Dict, List, Literal, Optional, Tuple, Union

import anndata
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
from anndata import AnnData
from matplotlib import rcParams
from tqdm import tqdm

from ..configuration import SKM, config_spateo_rcParams
from ..logging import logger_manager as lm
from ..plotting.static.utils import save_return_show_fig_utils


# ---------------------------------------------------------------------------------------------------
# Comparative statistics for gene expression between groups
# ---------------------------------------------------------------------------------------------------
[docs]def compute_gene_groups_p_val(gene: str, group1: anndata.AnnData, group2: anndata.AnnData) -> Tuple[str, float]: """Calculate the Mann-Whitney U test p-value for a gene between two groups. Args: gene: Name of the gene group1: AnnData object containing cells from the first group to compare group2: AnnData object containing cells from the second group to compare Returns: gene: Name of the gene p_val: Mann-Whitney U test p-value """ group1_gene = group1[:, gene].X group2_gene = group2[:, gene].X _, p_val = scipy.stats.mannwhitneyu(group1_gene, group2_gene, alternative="two-sided") return (gene, p_val)
# --------------------------------------------------------------------------------------------------- # Compute highly variable genes # ---------------------------------------------------------------------------------------------------
[docs]def get_highvar_genes( expression: Union[np.ndarray, scipy.sparse.csr_matrix, scipy.sparse.csc_matrix, scipy.sparse.coo_matrix], expected_fano_threshold: Optional[float] = None, numgenes: Optional[int] = None, minimal_mean: float = 0.5, ) -> Tuple[pd.DataFrame, Dict]: """Find highly-variable genes in single-cell data matrices. Args: expression: Gene expression matrix expected_fano_threshold: Optionally can be used to set a manual dispersion threshold (for definition of "highly-variable") numgenes: Optionally can be used to find the n most variable genes minimal_mean: Sets a threshold on the minimum mean expression to consider """ gene_mean = expression.mean(axis=0) gene2_mean = np.square(expression).mean(axis=0) gene_var = gene2_mean - np.square(gene_mean) gene_mean = pd.Series(gene_mean) gene_fano = gene_var / gene_mean # Find parameters for expected fano line top_genes = gene_mean.sort_values(ascending=False)[:20].index A = (np.sqrt(gene_var) / gene_mean)[top_genes].min() w_mean_low, w_mean_high = gene_mean.quantile([0.10, 0.90]) w_fano_low, w_fano_high = gene_fano.quantile([0.10, 0.90]) winsor_box = ( (gene_fano > w_fano_low) & (gene_fano < w_fano_high) & (gene_mean > w_mean_low) & (gene_mean < w_mean_high) ) fano_median = gene_fano[winsor_box].median() B = np.sqrt(fano_median) gene_expected_fano = (A**2) * gene_mean + (B**2) fano_ratio = gene_fano / gene_expected_fano # Identify high var genes if numgenes is not None: highvargenes = fano_ratio.sort_values(ascending=False).index[:numgenes] high_var_genes_ind = fano_ratio.index.isin(highvargenes) T = None else: if not expected_fano_threshold: T = 1.0 + gene_fano[winsor_box].std() else: T = expected_fano_threshold high_var_genes_ind = (fano_ratio > T) & (gene_mean > minimal_mean) gene_counts_stats = pd.DataFrame( { "mean": gene_mean, "var": gene_var, "fano": gene_fano, "expected_fano": gene_expected_fano, "high_var": high_var_genes_ind, "fano_ratio": fano_ratio, } ) gene_fano_parameters = { "A": A, "B": B, "T": T, "minimal_mean": minimal_mean, } return (gene_counts_stats, gene_fano_parameters)
[docs]def get_highvar_genes_sparse( expression: Union[ np.ndarray, scipy.sparse.csr_matrix, scipy.sparse.csc_matrix, scipy.sparse.coo_matrix, ], expected_fano_threshold: Optional[float] = None, numgenes: Optional[int] = None, minimal_mean: float = 0.5, ) -> Tuple[pd.DataFrame, Dict]: """Find highly-variable genes in sparse single-cell data matrices. Args: expression: Gene expression matrix expected_fano_threshold: Optionally can be used to set a manual dispersion threshold (for definition of "highly-variable") numgenes: Optionally can be used to find the n most variable genes minimal_mean: Sets a threshold on the minimum mean expression to consider Returns: gene_counts_stats: Results dataframe containing pertinent information for each gene gene_fano_parameters: Additional informative dictionary (w/ records of dispersion for each gene, threshold, etc.) """ gene_mean = np.array(expression.mean(axis=0)).astype(float).reshape(-1) E2 = expression.copy() E2.data **= 2 gene2_mean = np.array(E2.mean(axis=0)).reshape(-1) gene_var = pd.Series(gene2_mean - (gene_mean**2)) del E2 gene_mean = pd.Series(gene_mean) gene_fano = gene_var / gene_mean # Find parameters for expected fano line top_genes = gene_mean.sort_values(ascending=False)[:20].index A = (np.sqrt(gene_var) / gene_mean)[top_genes].min() w_mean_low, w_mean_high = gene_mean.quantile([0.10, 0.90]) w_fano_low, w_fano_high = gene_fano.quantile([0.10, 0.90]) winsor_box = ( (gene_fano > w_fano_low) & (gene_fano < w_fano_high) & (gene_mean > w_mean_low) & (gene_mean < w_mean_high) ) fano_median = gene_fano[winsor_box].median() B = np.sqrt(fano_median) gene_expected_fano = (A**2) * gene_mean + (B**2) fano_ratio = gene_fano / gene_expected_fano # Identify high var genes if numgenes is not None: highvargenes = fano_ratio.sort_values(ascending=False).index[:numgenes] high_var_genes_ind = fano_ratio.index.isin(highvargenes) T = None else: if not expected_fano_threshold: T = 1.0 + gene_fano[winsor_box].std() else: T = expected_fano_threshold high_var_genes_ind = (fano_ratio > T) & (gene_mean > minimal_mean) gene_counts_stats = pd.DataFrame( { "mean": gene_mean, "var": gene_var, "fano": gene_fano, "expected_fano": gene_expected_fano, "high_var": high_var_genes_ind, "fano_ratio": fano_ratio, } ) gene_fano_parameters = { "A": A, "B": B, "T": T, "minimal_mean": minimal_mean, } return (gene_counts_stats, gene_fano_parameters)
### ----------------------------------- Cell-to-cell variability ----------------------------------- ### @SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE, "adata")
[docs]def compute_variance_decomposition( adata: AnnData, spatial_label_id: str, celltype_label_id: str, genes: Union[None, str, List[str]] = None, figsize: Union[None, Tuple[float, float]] = None, save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", save_kwargs: Optional[dict] = {}, ): """Computes and then optionally visualizes the variance decomposition for an AnnData object. Within spatial regions, determines the proportion of the total variation that occurs within the same cell type, the proportion of the variation that occurs between cell types in the region, and the proportion of the variation that comes from baseline differences in the expression levels of the genes in the data. The within-cell type variation could potentially come from differences in cell-cell communication. Args: adata: AnnData object containing data spatial_label_id: Key in .obs containing spatial domain labels celltype_label_id: Key in .obs containing cell type labels genes: Can be used to filter to chosen subset of genes for variance computation figsize: Can be optionally used to set the size of the plotted figure save_show_or_return: Whether to save, show or return the figure. Only used if 'visualize' is True If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, displayed and the associated axis and other object will be return. save_kwargs: A dictionary that will passed to the save_fig function. Only used if 'visualize' is True. By default it is an empty dictionary and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those keys according to your needs. Returns: var_decomposition: Dataframe containing four columns, for the category label, celltype variation, inter-celltype variation and gene-level variation """ adata_copy = adata.copy() if genes is not None: if not isinstance(genes, list): genes = [genes] adata_copy = adata_copy[:, genes] # Dataframe containing gene expression, cell type labels and spatial domain labels: data = adata_copy.X.toarray() if scipy.sparse.issparse(adata_copy.X) else adata_copy.X df = pd.DataFrame(data, columns=adata_copy.var_names) df["Spatial Domain"] = pd.Series(list(adata.obs[spatial_label_id]), dtype="category") df["Cell Type"] = pd.Series(list(adata.obs[celltype_label_id]), dtype="category") domains = np.unique(df["Spatial Domain"]) var_decomposition_list = [] # For reference, within each spatial domain: # intra-cell type variance: for all cells of a given celltype, how much does each gene vary from the mean within # that cell type? # inter-cell type variance: for each gene, how much does the mean expression within each cell type vary compared # to the overall mean of the spatial domain for that gene? # gene variance: for each spatial domain, how much does the mean expression of each gene vary compared to the # overall mean of all genes? with tqdm(total=len(domains)) as pbar: for domain in domains: # For each gene, compute mean within the domain: mean_domain_genes = np.mean(df[df["Spatial Domain"] == domain][::-2], axis=0) # Compute average for all genes: mean_domain_global = np.mean(mean_domain_genes) intra_ct_var = [] inter_ct_var = [] gene_var = [] for celltype in np.unique(df["Cell Type"]): # Gene expression (take all but last two columns) for each cell type within each spatial domain domain_celltype = np.array(df[(df["Spatial Domain"] == domain) & (df["Cell Type"] == celltype)])[:, :-2] if domain_celltype.shape[0] == 0: continue # For each cell type, compute the mean expression for each gene mean_domain_celltype = np.mean(domain_celltype, axis=0) # Compute variances for each cell: for i in range(domain_celltype.shape[0]): # Within the cell type, variance for each gene from the mean of the cell type intra_ct_var.append((domain_celltype[i, :] - mean_domain_celltype) ** 2) # For each cell type, the difference in mean expression within the cell type as compared to the # mean of the domain inter_ct_var.append((mean_domain_celltype - mean_domain_genes) ** 2) # Within each domain, variance for the domain from the mean of the domain gene_var.append((mean_domain_genes - mean_domain_global) ** 2) intra_ct_var = np.sum(intra_ct_var) inter_ct_var = np.sum(inter_ct_var) gene_var = np.sum(gene_var) var_decomposition_list.append(np.array([domain, intra_ct_var, inter_ct_var, gene_var])) pbar.update(1) df = ( pd.DataFrame(var_decomposition_list, columns=["Domain", "intra_celltype_var", "inter_celltype_var", "gene_var"]) .astype( { "Domain": str, "intra_celltype_var": "float32", "inter_celltype_var": "float32", "gene_var": "float32", } ) .set_index("Domain") ) df["Total variance"] = df.intra_celltype_var + df.inter_celltype_var + df.gene_var # Normalize to sum to 1: df["Intra-cell type variance"] = df.intra_celltype_var / df["Total variance"] df["Inter-cell type variance"] = df.inter_celltype_var / df["Total variance"] df["Gene variance"] = df.gene_var / df["Total variance"] # Optionally plot with default plotting parameters if appropriate option is given to 'save_show_or_return': if len(genes) == 1: title = f"Variance Decomposition for Spatial Domains: {genes}" else: title = None plot_variance_decomposition( df, title=title, figsize=figsize, save_show_or_return=save_show_or_return, save_kwargs=save_kwargs ) return df
[docs]def genewise_variance_decomposition( adata: AnnData, celltype_label_id: str, genes: Union[str, List[str]], figsize: Union[None, Tuple[float, float]] = None, save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", save_kwargs: Optional[dict] = {}, ): """For each gene in the chosen subset, computes a variance decomposition by computing the intra-cell type variance and the inter-cell type variance. Args: adata: AnnData object containing data celltype_label_id: Key in .obs containing cell type labels genes: Can be used to filter to chosen subset of genes for variance computation figsize: Can be used to optionally set the size of the plotted figure save_show_or_return: Whether to save, show or return the figure. Only used if 'visualize' is True If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, displayed and the associated axis and other object will be return. save_kwargs: A dictionary that will passed to the save_fig function. Only used if 'visualize' is True. By default it is an empty dictionary and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those keys according to your needs. Returns: var_decomposition: Dataframe containing three columns, for the gene, intra-celltype variation and inter-celltype variation """ adata_copy = adata.copy() # Dataframe containing gene expression and cell type labels: data = adata_copy.X.toarray() if scipy.sparse.issparse(adata_copy.X) else adata_copy.X df = pd.DataFrame(data, columns=adata_copy.var_names) df["Cell Type"] = pd.Series(list(adata.obs[celltype_label_id]), dtype="category") var_decomposition_list = [] with tqdm(total=len(genes)) as pbar: for gene in genes: # For each gene, compute mean across entire sample: mean_expr = np.mean(df.loc[:, gene], axis=0) intra_ct_var = [] inter_ct_var = [] for celltype in np.unique(df["Cell Type"]): # Cell type-specific expression: celltype_expr = np.array(df.loc[df["Cell Type"] == celltype, gene]) # Mean expression within cell type: mean_celltype = np.mean(celltype_expr, axis=0) for i in range(celltype_expr.shape[0]): # Within the cell type, variance from the mean of the cell type intra_ct_var.append((celltype_expr[i] - mean_celltype) ** 2) # For each cell type, the difference in mean expression within the cell type as compared to the # mean of the whole sample inter_ct_var.append((mean_celltype - mean_expr) ** 2) intra_ct_var = np.sum(intra_ct_var) inter_ct_var = np.sum(inter_ct_var) var_decomposition_list.append(np.array([gene, intra_ct_var, inter_ct_var])) pbar.update(1) df = ( pd.DataFrame(var_decomposition_list, columns=["Gene", "intra_celltype_var", "inter_celltype_var"]) .astype( { "Gene": str, "intra_celltype_var": "float32", "inter_celltype_var": "float32", } ) .set_index("Gene") ) df["Total variance"] = df.intra_celltype_var + df.inter_celltype_var # Normalize to sum to 1: df["Intra-cell type variance"] = df.intra_celltype_var / df["Total variance"] df["Inter-cell type variance"] = df.inter_celltype_var / df["Total variance"] # Optionally plot with default plotting parameters if appropriate option is given to 'save_show_or_return': title = f"Variance Decomposition for Each Gene" plot_variance_decomposition( df, title=title, figsize=figsize, save_show_or_return=save_show_or_return, save_kwargs=save_kwargs ) return df
[docs]def plot_variance_decomposition( var_df: pd.DataFrame, figsize: Tuple[float, float] = (6, 2), cmap: str = "Blues_r", multiindex: bool = False, title: Union[None, str] = None, save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", save_kwargs: Optional[dict] = {}, ): """Visualization of the parts-wise intra-cell type variation, cell type-independent gene variation to the total variation within the data. Args: var_df: Output from :func `compute_variance_decomposition` figsize: (width, height) of the figure window cmap: Name of the matplotlib colormap to use multiindex: Specifies whether to set labels to record multi-level index information. Should only be used if var_df has a multi-index. title: Optionally, provide custom title to plot. If not given, will use default title. save_show_or_return: Whether to save, show or return the figure. If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, displayed and the associated axis and other object will be returned. save_kwargs: A dictionary that will passed to the save_fig function. By default it is an empty dictionary and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those keys according to your needs. """ logger = lm.get_main_logger() if not isinstance(var_df.index, pd.MultiIndex) and multiindex: logger.error("'var_df' index is not a multi-level index. 'Multiindex' cannot be set True.") config_spateo_rcParams() figsize = rcParams.get("figure.figsize") if figsize is None else figsize y_plot = ( ["Intra-cell type variance", "Inter-cell type variance", "Gene variance"] if "Gene variance" in var_df.columns else ["Intra-cell type variance", "Inter-cell type variance"] ) fig, ax = plt.subplots(1, 1, figsize=figsize) var_df.plot( y=y_plot, kind="bar", stacked=True, edgecolor="black", width=0.75, linewidth=0.6, figsize=figsize, ax=ax, colormap=cmap, ) if multiindex: def process_index(k): return tuple(k.split("_")) var_df["index1"], var_df["index2"] = zip(*map(process_index, var_df.index)) var_df = var_df.set_index(["index1", "index2"]) ax.set_xlabel("") xlabel_mapping = OrderedDict() for index1, index2 in var_df.index: xlabel_mapping.setdefault(index1, []) xlabel_mapping[index1].append(index2) hline = [] new_xlabels = [] for _index1, index2_list in xlabel_mapping.items(): index2_list[0] = "{}".format(index2_list[0]) new_xlabels.extend(index2_list) if hline: hline.append(len(index2_list) + hline[-1]) else: hline.append(len(index2_list)) ax.set_xticklabels(new_xlabels) # Configuring plot: ax.set_xlabel("") ax.legend(bbox_to_anchor=(1, 1), loc="upper left") if title is None: ax.set_title("Variance Decomposition for Spatial Domains") else: ax.set_title(title) ax.set_ylabel("Proportion of variance") plt.tight_layout() save_return_show_fig_utils( save_show_or_return=save_show_or_return, show_legend=True, background="white", prefix="variance_decomposition", save_kwargs=save_kwargs, total_panels=1, fig=fig, axes=ax, return_all=False, return_all_list=None, )