Source code for spateo.tools.CCI_effects_modeling.MuSIC_downstream

"""
Additional functionalities to characterize signaling patterns from spatial transcriptomics

These include:
    - prediction of the effects of spatial perturbation on gene expression- this can include the effect of perturbing
    known regulators of ligand/receptor expression or the effect of perturbing the ligand/receptor itself.
    - following spatially-aware regression (or a sequence of spatially-aware regressions), combine regression results
    with data such that each cell can be associated with region-specific coefficient(s).
    - following spatially-aware regression (or a sequence of spatially-aware regressions), overlay the directionality
    of the predicted influence of the ligand on downstream expression.
"""

import argparse
import collections
import gc
import itertools
import math
import os
from collections import Counter
from itertools import product
from typing import List, Literal, Optional, Tuple, Union

import anndata
import matplotlib as mpl
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import plotly
import plotly.graph_objs as go
import scipy.cluster.hierarchy as sch
import scipy.sparse
import scipy.stats
import seaborn as sns
from adjustText import adjust_text
from joblib import Parallel, delayed
from matplotlib import rcParams
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.stats import mannwhitneyu, pearsonr, spearmanr, ttest_1samp, ttest_ind
from sklearn.decomposition import TruncatedSVD
from sklearn.metrics import (
    confusion_matrix,
    f1_score,
    mean_squared_error,
    roc_auc_score,
)
from sklearn.preprocessing import normalize
from tqdm.auto import tqdm

from ...configuration import config_spateo_rcParams
from ...logging import logger_manager as lm
from ...plotting.static.colorlabel import godsnot_102, vega_10
from ...plotting.static.networks import plot_network
from ...plotting.static.utils import save_return_show_fig_utils
from ...tools.find_neighbors import neighbors
from ...tools.utils import filter_adata_spatial
from ..dimensionality_reduction import find_optimal_pca_components, pca_fit
from ..utils import compute_corr_ci, create_new_coordinate
from .MuSIC import MuSIC
from .regression_utils import assign_significance, multitesting_correction, wald_test
from .SWR import define_spateo_argparse


# ---------------------------------------------------------------------------------------------------
# Statistical testing, correlated differential expression analysis
# ---------------------------------------------------------------------------------------------------
[docs]class MuSIC_Interpreter(MuSIC): """ Interpretation and downstream analysis of spatially weighted regression models. Args: parser: ArgumentParser object initialized with argparse, to parse command line arguments for arguments pertinent to modeling. args_list: If parser is provided by function call, the arguments to parse must be provided as a separate list. It is recommended to use the return from :func `define_spateo_argparse()` for this. keep_coeff_threshold_proportion_cells: If provided, will threshold columns to only keep those that are nonzero in a proportion of cells greater than this threshold. For example, if this is set to 0.5, more than half of the cells must have a nonzero value for a given column for it to be retained for further inspection. Intended to be used to filter out likely false positives. """ def __init__( self, parser: argparse.ArgumentParser, args_list: Optional[List[str]] = None, keep_column_threshold_proportion_cells: Optional[float] = None, ): # Don't need to re-save the subsampling results, they have already been defined: super().__init__(parser, args_list, verbose=False, save_subsampling=False)
[docs] self.k = self.arg_retrieve.top_k_receivers
# Coefficients: if not self.set_up: self.logger.info( "Running :func `SWR._set_up_model()` to organize predictors and targets for downstream " "analysis now..." ) self._set_up_model(verbose=False) # self.logger.info("Finished preprocessing, getting fitted coefficients and standard errors.") # Dictionary containing coefficients: self.coeffs, self.standard_errors = self.return_outputs(adjust_for_subsampling=False, load_for_interpreter=True)
[docs] n_cells_expressing_targets = self.targets_expr.apply(lambda x: sum(x > 0), axis=0)
if keep_column_threshold_proportion_cells is not None: keep_column_threshold_proportion_cells = 0.01 for target, df in self.coeffs.items(): # Threshold columns to only keep those that are nonzero in a proportion of cells greater than this # threshold: threshold = int(keep_column_threshold_proportion_cells * n_cells_expressing_targets[target]) for col in df.columns: if sum(df[col] != 0) < threshold: df[col] = 0 self.standard_errors[target][col] = 0 self.coeffs[target] = df # Check for coefficients and feature names of downstream models as well: self.logger.info("Checking for coefficients of possible downstream models as well...")
[docs] downstream_parent_dir = os.path.dirname(os.path.splitext(self.output_path)[0])
[docs] id = os.path.basename(os.path.splitext(self.output_path)[0])
self.downstream_model_ligand_coeffs, self.downstream_model_ligand_standard_errors = self.return_outputs( adjust_for_subsampling=False, load_for_interpreter=True, load_from_downstream="ligand" )
[docs] dm_dir = os.path.join( downstream_parent_dir, "cci_deg_detection", "ligand_analysis", id, "downstream_design_matrix", "design_matrix.csv", )
[docs] self.downstream_model_ligand_design_matrix = ( pd.read_csv(dm_dir, index_col=0) if os.path.exists(dm_dir) else None )
self.downstream_model_receptor_coeffs, self.downstream_model_receptor_standard_errors = self.return_outputs( adjust_for_subsampling=False, load_for_interpreter=True, load_from_downstream="receptor" ) dm_dir = os.path.join( downstream_parent_dir, "cci_deg_detection", "receptor_analysis", id, "downstream_design_matrix", "design_matrix.csv", )
[docs] self.downstream_model_receptor_design_matrix = ( pd.read_csv(dm_dir, index_col=0) if os.path.exists(dm_dir) else None )
self.downstream_model_target_coeffs, self.downstream_model_target_standard_errors = self.return_outputs( adjust_for_subsampling=False, load_for_interpreter=True, load_from_downstream="target_gene" ) dm_dir = os.path.join( downstream_parent_dir, "cci_deg_detection", "target_gene_analysis", id, "downstream_design_matrix", "design_matrix.csv", )
[docs] self.downstream_model_target_design_matrix = ( pd.read_csv(dm_dir, index_col=0) if os.path.exists(dm_dir) else None )
# Design matrix:
[docs] self.design_matrix = pd.read_csv( os.path.join(os.path.splitext(self.output_path)[0], "design_matrix", "design_matrix.csv"), index_col=0 )
# If predictions of an L:R model have been computed, load these as well: if os.path.exists(os.path.join(os.path.dirname(self.output_path), "predictions.csv")): self.predictions = pd.read_csv( os.path.join(os.path.dirname(self.output_path), "predictions.csv"), index_col=0 ) # Save directory:
[docs] parent_dir = os.path.dirname(self.output_path)
if not os.path.exists(os.path.join(parent_dir, "significance")): os.makedirs(os.path.join(parent_dir, "significance")) # Arguments for cell type coupling computation:
[docs] self.filter_targets = self.arg_retrieve.filter_targets
[docs] self.filter_target_threshold = self.arg_retrieve.filter_target_threshold
# Get targets for the downstream ligand(s), receptor(s), target(s), etc. to use for analysis:
[docs] self.ligand_for_downstream = self.arg_retrieve.ligand_for_downstream
[docs] self.receptor_for_downstream = self.arg_retrieve.receptor_for_downstream
[docs] self.pathway_for_downstream = self.arg_retrieve.pathway_for_downstream
[docs] self.target_for_downstream = self.arg_retrieve.target_for_downstream
[docs] self.sender_ct_for_downstream = self.arg_retrieve.sender_ct_for_downstream
[docs] self.receiver_ct_for_downstream = self.arg_retrieve.receiver_ct_for_downstream
# Other downstream analysis-pertinent argparse arguments:
[docs] self.cci_degs_model_interactions = self.arg_retrieve.cci_degs_model_interactions
[docs] self.no_cell_type_markers = self.arg_retrieve.no_cell_type_markers
[docs] self.compute_pathway_effect = self.arg_retrieve.compute_pathway_effect
[docs] self.diff_sending_or_receiving = self.arg_retrieve.diff_sending_or_receiving
[docs] def compute_coeff_significance(self, method: str = "fdr_bh", significance_threshold: float = 0.05): """Computes local statistical significance for fitted coefficients. Args: method: Method to use for correction. Available methods can be found in the documentation for statsmodels.stats.multitest.multipletests(), and are also listed below (in correct case) for convenience: - Named methods: - bonferroni - sidak - holm-sidak - holm - simes-hochberg - hommel - Abbreviated methods: - fdr_bh: Benjamini-Hochberg correction - fdr_by: Benjamini-Yekutieli correction - fdr_tsbh: Two-stage Benjamini-Hochberg - fdr_tsbky: Two-stage Benjamini-Krieger-Yekutieli method significance_threshold: p-value (or q-value) needed to call a parameter significant. Returns: is_significant: Dataframe of identical shape to coeffs, where each element is True or False if it meets the threshold for significance pvalues: Dataframe of identical shape to coeffs, where each element is a p-value for that instance of that feature qvalues: Dataframe of identical shape to coeffs, where each element is a q-value for that instance of that feature """ self.logger.info( "Computing significance for all coefficients, note this may take a long time for large " "datasets (> 10k cells)..." ) for target in self.coeffs.keys(): # Check for existing file: parent_dir = os.path.dirname(self.output_path) if os.path.exists(os.path.join(parent_dir, "significance", f"{target}_is_significant.csv")): self.logger.info(f"Significance already computed for target {target}, moving to the next...") continue # Get coefficients and standard errors for this key coef = self.coeffs[target] columns = [col for col in coef.columns if col.startswith("b_") and "intercept" not in col] coef = coef[columns] se = self.standard_errors[target] se_feature_match = [c.replace("se_", "") for c in se.columns] def compute_p_value(cell_name, feat): return wald_test(coef.loc[cell_name, f"b_{feat}"], se.loc[cell_name, f"se_{feat}"]) filtered_tasks = [ (cell_name, feat) for cell_name, feat in product(self.sample_names, self.feature_names) if feat in se_feature_match and se.loc[cell_name, f"se_{feat}"] != 0 and coef.loc[cell_name, f"b_{feat}"] != 0 ] # Parallelize computations for filtered tasks results = Parallel(n_jobs=-1)( delayed(compute_p_value)(cell_name, feat) for cell_name, feat in tqdm(filtered_tasks, desc=f"Processing for target {target}") ) # Convert results to a DataFrame results_df = pd.DataFrame( results, index=pd.MultiIndex.from_tuples(filtered_tasks, names=["sample", "feature"]) ) p_values_all = pd.DataFrame(1, index=self.sample_names, columns=self.feature_names) p_values_all.update(results_df.unstack(level="feature").droplevel(0, axis=1)) # Multiple testing correction for each observation: qvals = np.zeros_like(p_values_all.values) for i in range(p_values_all.shape[0]): qvals[i, :] = multitesting_correction( p_values_all.iloc[i, :], method=method, alpha=significance_threshold ) q_values_df = pd.DataFrame(qvals, index=self.sample_names, columns=self.feature_names) # Significance: is_significant_df = q_values_df < significance_threshold # Save dataframes: parent_dir = os.path.dirname(self.output_path) p_values_all.to_csv(os.path.join(parent_dir, "significance", f"{target}_p_values.csv")) q_values_df.to_csv(os.path.join(parent_dir, "significance", f"{target}_q_values.csv")) is_significant_df.to_csv(os.path.join(parent_dir, "significance", f"{target}_is_significant.csv")) self.logger.info(f"Finished computing significance for target {target}.")
[docs] def filter_adata_spatial(self, instructions: List[str]): """Based on spatial coordinates, filter the adata object to only include cells that meet the criteria. Criteria provided in the form of a list of instructions of the form "x less than 0.5 and y greater than 0.5", etc., where each instruction is executed sequentially. Args: instructions: List of instructions to filter adata object by. Each instruction is a string of the form "x less than 0.5 and y greater than 0.5", etc., where each instruction is executed sequentially. """ adata_filt = filter_adata_spatial(self.adata, self.coords_key, instructions) # Cells still left post-filter self.remaining_cells = adata_filt.obs_names self.remaining_indices = np.where(self.adata.obs_names.isin(self.remaining_cells))[0]
[docs] def filter_adata_custom(self, cell_ids: List[str]): """Filter AnnData object to only the cells specified by the custom list. Args: cell_ids: List of cell IDs to keep. Each ID must be found in adata.obs_names """ self.remaining_cells = cell_ids self.remaining_indices = np.where(self.adata.obs_names.isin(self.remaining_cells))[0]
[docs] def add_interaction_effect_to_adata( self, targets: Union[str, List[str]], interactions: Union[str, List[str]], visualize: bool = False, ) -> anndata.AnnData: """For each specified interaction/list of interactions, add the predicted interaction effect to the adata object. Args: targets: Target(s) to add interaction effect for. Can be a single target or a list of targets. interactions: Interaction(s) to add interaction effect for. Can be a single interaction or a list of interactions. Should be the name of a gene for ligand models, or an L:R pair for L:R models (for example, "Igf1:Igf1r"). visualize: Whether to visualize the interaction effect for each target/interaction pair. If True, will generate spatial scatter plot and save to HTML file. Returns: adata: AnnData object with interaction effects added to .obs. """ if visualize: figure_folder = os.path.join(os.path.dirname(self.output_path), "figures") if not os.path.exists(figure_folder): os.makedirs(figure_folder) if not isinstance(targets, list): targets = [targets] if not isinstance(interactions, list): interactions = [interactions] if hasattr(self, "remaining_indices"): adata = self.adata[self.remaining_cells, :].copy() else: adata = self.adata.copy() combinations = list(product(targets, interactions)) for target, interaction in combinations: if f"b_{interaction}" not in self.coeffs[target].columns: self.logger.info( f"Information for interaction {interaction} not found for target {target}, " f"skipping..." ) continue if hasattr(self, "remaining_indices"): target_coefs = self.coeffs[target].loc[self.remaining_cells, f"b_{interaction}"] else: target_coefs = self.coeffs[target][f"b_{interaction}"] # Add to adata: adata.obs[f"{target}_{interaction}_effect"] = target_coefs if visualize: # plotly to create 3D scatter plot: spatial_coords = adata.obsm[self.coords_key] if spatial_coords.shape[1] == 2: x, y = spatial_coords[:, 0], spatial_coords[:, 1] z = np.zeros(len(x)) else: x, y, z = spatial_coords[:, 0], spatial_coords[:, 1], spatial_coords[:, 2] plot_data = adata.obs[f"{target}_{interaction}_effect"] p997 = np.percentile(plot_data.values, 99.7) plot_data[plot_data > p997] = p997 plot_vals = plot_data.values scatter = go.Scatter3d( x=x, y=y, z=z, mode="markers", marker=dict( color=plot_vals, colorscale="Magma", size=2, colorbar=dict( title=f"{interaction.title()} Effect on {target.title()}", x=0.8, titlefont=dict(size=16), tickfont=dict(size=18), ), ), showlegend=False, ) fig = go.Figure(data=scatter) title_dict = dict( text=f"{interaction.title()} Effect on {target.title()}", y=0.9, yanchor="top", x=0.5, xanchor="center", font=dict(size=28), ) # Turn off the grid fig.update_layout( showlegend=True, legend=dict(x=0.65, y=0.85, orientation="v", font=dict(size=18)), scene=dict( aspectmode="data", xaxis=dict( showgrid=False, showline=False, linewidth=2, linecolor="black", backgroundcolor="white", title="", showticklabels=False, ticks="", ), yaxis=dict( showgrid=False, showline=False, linewidth=2, linecolor="black", backgroundcolor="white", title="", showticklabels=False, ticks="", ), zaxis=dict( showgrid=False, showline=False, linewidth=2, linecolor="black", backgroundcolor="white", title="", showticklabels=False, ticks="", ), ), margin=dict(l=0, r=0, b=0, t=50), # Adjust margins to minimize spacing title=title_dict, ) path = os.path.join(figure_folder, f"{interaction}_effect_on_{target}.html") fig.write_html(path) return adata
[docs] def compute_and_visualize_diagnostics( self, type: Literal["correlations", "confusion", "rmse"], n_genes_per_plot: int = 20 ): """ For true and predicted gene expression, compute and generate either: confusion matrices, or correlations, including the Pearson correlation, Spearman correlation, or root mean-squared-error (RMSE). Args: type: Type of diagnostic to compute and visualize. Options: "correlations" for Pearson & Spearman correlation, "confusion" for confusion matrix, "rmse" for root mean-squared-error. n_genes_per_plot: Only used if "type" is "confusion". Number of genes to plot per figure. If there are more than this number of genes, multiple figures will be generated. """ # Plot title: file_name = os.path.splitext(os.path.basename(self.adata_path))[0] parent_dir = os.path.dirname(self.output_path) pred_path = os.path.join(parent_dir, "predictions.csv") predictions = pd.read_csv(pred_path, index_col=0) all_genes = predictions.columns width = 0.5 * len(all_genes) pred_vals = predictions.values if type == "correlations": # Pearson and Spearman dictionary for all cells: pearson_dict = {} spearman_dict = {} # Pearson and Spearman dictionary for only the expressing subset of cells: nz_pearson_dict = {} nz_spearman_dict = {} for i, gene in enumerate(all_genes): y = self.adata[:, gene].X.toarray().reshape(-1) music_results_target = pred_vals[:, i] # Remove index of the largest predicted value (to mitigate sensitivity of these metrics to outliers): outlier_index = np.where(np.max(music_results_target))[0] music_results_target_to_plot = np.delete(music_results_target, outlier_index) y_plot = np.delete(y, outlier_index) # Indices where target is nonzero: nonzero_indices = y_plot != 0 rp, _ = pearsonr(y_plot, music_results_target_to_plot) r, _ = spearmanr(y_plot, music_results_target_to_plot) rp_nz, _ = pearsonr(y_plot[nonzero_indices], music_results_target_to_plot[nonzero_indices]) r_nz, _ = spearmanr(y_plot[nonzero_indices], music_results_target_to_plot[nonzero_indices]) pearson_dict[gene] = rp spearman_dict[gene] = r nz_pearson_dict[gene] = rp_nz nz_spearman_dict[gene] = r_nz # Mean of diagnostic metrics: mean_pearson = sum(pearson_dict.values()) / len(pearson_dict.values()) mean_spearman = sum(spearman_dict.values()) / len(spearman_dict.values()) mean_nz_pearson = sum(nz_pearson_dict.values()) / len(nz_pearson_dict.values()) mean_nz_spearman = sum(nz_spearman_dict.values()) / len(nz_spearman_dict.values()) data = [] for gene in pearson_dict.keys(): data.append( { "Gene": gene, "Pearson coefficient": pearson_dict[gene], "Spearman coefficient": spearman_dict[gene], "Pearson coefficient (expressing cells)": nz_pearson_dict[gene], "Spearman coefficient (expressing cells)": nz_spearman_dict[gene], } ) # Color palette: colors = { "Pearson coefficient": "#FF7F00", "Spearmann coefficient": "#87CEEB", "Pearson coefficient (expressing cells)": "#0BDA51", "Spearmann coefficient (expressing cells)": "#FF6961", } df = pd.DataFrame(data) # Plot Pearson correlation barplot: sns.set(font_scale=2) sns.set_style("white") plt.figure(figsize=(width, 6)) plt.xticks(rotation="vertical") ax = sns.barplot( data=df, x="Gene", y="Pearson coefficient", palette=colors["Pearson coefficient"], edgecolor="black", dodge=True, ) # Mean line: line_style = "--" line_thickness = 2 ax.axhline(mean_pearson, color="black", linestyle=line_style, linewidth=line_thickness) # Update legend: legend_label = f"Mean: {mean_pearson}" handles, labels = ax.get_legend_handles_labels() handles.append(plt.Line2D([0], [0], color="black", linewidth=line_thickness, linestyle=line_style)) labels.append(legend_label) ax.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 0.5)) plt.title(f"Pearson correlation {file_name}") plt.tight_layout() plt.show() # Plot Spearman correlation barplot: plt.figure(figsize=(width, 6)) plt.xticks(rotation="vertical") ax = sns.barplot( data=df, x="Gene", y="Spearman coefficient", palette=colors["Spearman coefficient"], edgecolor="black", dodge=True, ) # Mean line: ax.axhline(mean_spearman, color="black", linestyle=line_style, linewidth=line_thickness) # Update legend: legend_label = f"Mean: {mean_spearman}" handles, labels = ax.get_legend_handles_labels() handles.append(plt.Line2D([0], [0], color="black", linewidth=line_thickness, linestyle=line_style)) labels.append(legend_label) ax.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 0.5)) plt.title(f"Spearman correlation {file_name}") plt.tight_layout() plt.show() # Plot Pearson correlation barplot (expressing cells): plt.figure(figsize=(width, 6)) plt.xticks(rotation="vertical") ax = sns.barplot( data=df, x="Gene", y="Pearson coefficient (expressing cells)", palette=colors["Pearson coefficient (expressing cells)"], edgecolor="black", dodge=True, ) # Mean line: ax.axhline(mean_nz_pearson, color="black", linestyle=line_style, linewidth=line_thickness) # Update legend: legend_label = f"Mean: {mean_nz_pearson}" handles, labels = ax.get_legend_handles_labels() handles.append(plt.Line2D([0], [0], color="black", linewidth=line_thickness, linestyle=line_style)) labels.append(legend_label) ax.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 0.5)) plt.title(f"Pearson correlation (expressing cells) {file_name}") plt.tight_layout() plt.show() # Plot Spearman correlation barplot (expressing cells): plt.figure(figsize=(width, 6)) plt.xticks(rotation="vertical") ax = sns.barplot( data=df, x="Gene", y="Spearman coefficient (expressing cells)", palette=colors["Spearman coefficient (expressing cells)"], edgecolor="black", dodge=True, ) # Mean line: ax.axhline(mean_nz_spearman, color="black", linestyle=line_style, linewidth=line_thickness) # Update legend: legend_label = f"Mean: {mean_nz_spearman}" handles, labels = ax.get_legend_handles_labels() handles.append(plt.Line2D([0], [0], color="black", linewidth=line_thickness, linestyle=line_style)) labels.append(legend_label) ax.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 0.5)) plt.title(f"Spearman correlation (expressing cells) {file_name}") plt.tight_layout() plt.show() elif type == "confusion": confusion_matrices = {} for i, gene in enumerate(all_genes): y = self.adata[:, gene].X.toarray().reshape(-1) music_results_target = pred_vals[:, i] predictions_binary = (music_results_target > 0).astype(int) y_binary = (y > 0).astype(int) confusion_matrices[gene] = confusion_matrix(y_binary, predictions_binary) total_figs = int(math.ceil(len(all_genes) / n_genes_per_plot)) for fig_index in range(total_figs): start_index = fig_index * n_genes_per_plot end_index = min(start_index + n_genes_per_plot, len(all_genes)) genes_to_plot = all_genes[start_index:end_index] fig, axs = plt.subplots(1, len(genes_to_plot), figsize=(width, width / 5)) axs = axs.flatten() for i, gene in enumerate(genes_to_plot): sns.heatmap( confusion_matrices[gene], annot=True, fmt="d", cmap="Blues", ax=axs[i], cbar=False, xticklabels=["Predicted \nnot expressed", "Predicted \nexpressed"], yticklabels=["Actual \nnot expressed", "Actual \nexpressed"], ) axs[i].set_title(gene) # Hide any unused subplots on the last figure if total genes don't fill up the grid for j in range(len(genes_to_plot), len(axs)): axs[j].axis("off") plt.tight_layout() # Save confusion matrices: parent_dir = os.path.dirname(self.output_path) plt.savefig(os.path.join(parent_dir, f"confusion_matrices_{fig_index}.png"), bbox_inches="tight") elif type == "rmse": rmse_dict = {} nz_rmse_dict = {} for i, gene in enumerate(all_genes): y = self.adata[:, gene].X.toarray().reshape(-1) music_results_target = pred_vals[:, i] rmse_dict[gene] = np.sqrt(mean_squared_error(y, music_results_target)) # Indices where target is nonzero: nonzero_indices = y != 0 nz_rmse_dict[gene] = np.sqrt( mean_squared_error(y[nonzero_indices], music_results_target[nonzero_indices]) ) mean_rmse = sum(rmse_dict.values()) / len(rmse_dict.values()) mean_nz_rmse = sum(nz_rmse_dict.values()) / len(nz_rmse_dict.values()) data = [] for gene in rmse_dict.keys(): data.append({"Gene": gene, "RMSE": rmse_dict[gene], "RMSE (expressing cells)": mean_nz_rmse[gene]}) # Color palette: colors = {"RMSE": "#FF7F00", "RMSE (expressing cells)": "#87CEEB"} df = pd.DataFrame(data) # Plot RMSE barplot: sns.set(font_scale=2) sns.set_style("white") plt.figure(figsize=(width, 6)) plt.xticks(rotation="vertical") ax = sns.barplot( data=df, x="Gene", y="RMSE", palette=colors["RMSE"], edgecolor="black", dodge=True, ) # Mean line: line_style = "--" line_thickness = 2 ax.axhline(mean_rmse, color="black", linestyle=line_style, linewidth=line_thickness) # Update legend: legend_label = f"Mean: {mean_rmse}" handles, labels = ax.get_legend_handles_labels() handles.append(plt.Line2D([0], [0], color="black", linewidth=line_thickness, linestyle=line_style)) labels.append(legend_label) ax.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 0.5)) plt.title(f"RMSE {file_name}") plt.tight_layout() plt.show() # Plot RMSE barplot (expressing cells): plt.figure(figsize=(width, 6)) plt.xticks(rotation="vertical") ax = sns.barplot( data=df, x="Gene", y="RMSE (expressing cells)", palette=colors["RMSE (expressing cells)"], edgecolor="black", dodge=True, ) # Mean line: ax.axhline(mean_nz_rmse, color="black", linestyle=line_style, linewidth=line_thickness) # Update legend: legend_label = f"Mean: {mean_nz_rmse}" handles, labels = ax.get_legend_handles_labels() handles.append(plt.Line2D([0], [0], color="black", linewidth=line_thickness, linestyle=line_style)) labels.append(legend_label) ax.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 0.5)) plt.title(f"RMSE (expressing cells) {file_name}") plt.tight_layout() plt.show()
[docs] def plot_interaction_effect_3D( self, target: str, interaction: str, save_path: str, pcutoff: Optional[float] = 99.7, min_value: Optional[float] = 0, zero_opacity: float = 1.0, size: float = 2.0, n_neighbors_smooth: Optional[int] = 0, ): """Quick-visualize the magnitude of the predicted effect on target for a given interaction. Args: target: Target gene to visualize interaction: Interaction to visualize (e.g. "Igf1:Igf1r" for L:R model, "Igf1" for ligand model) save_path: Path to save the figure to (will save as HTML file) pcutoff: Percentile cutoff for the colorbar. Will set all values above this percentile to this value. min_value: Minimum value to set the colorbar to. Will set all values below this value to this value. Defaults to 0. zero_opacity: Opacity of points with zero expression. Between 0.0 and 1.0. Default is 1.0. size: Size of the points in the scatter plot. Default is 2. n_neighbors_smooth: Number of neighbors to use for smoothing (to make effect patterns more apparent). If 0, no smoothing is applied. Default is 0. """ targets = pd.read_csv( os.path.join(os.path.splitext(self.output_path)[0], "design_matrix", "targets.csv"), index_col=0 ) if target not in targets.columns: raise ValueError(f"Target {target} not found in this model's directory. Please provide a valid target.") if interaction not in self.X_df.columns: raise ValueError(f"Interaction {interaction} not found in this model's directory.") if hasattr(self, "remaining_cells"): adata = self.adata[self.remaining_cells, :].copy() else: adata = self.adata.copy() coords = adata.obsm[self.coords_key] x, y, z = coords[:, 0], coords[:, 1], coords[:, 2] target_interaction_coef = self.coeffs[target].loc[adata.obs_names, f"b_{interaction}"] self.logger.info(f"{(target_interaction_coef > 0).sum()} {target}-expressing cells affected by {interaction}.") if n_neighbors_smooth > 0: from scipy.spatial import cKDTree tree = cKDTree(coords) distances, indices = tree.query(coords, k=n_neighbors_smooth + 1) smoothed_values = np.zeros(len(target_interaction_coef)) for i in range(len(smoothed_values)): neighbor_indices = indices[i, 1:] neighbor_coeffs = target_interaction_coef.iloc[neighbor_indices] # Filter to keep only nonzero values nonzero_neighbor_coeffs = neighbor_coeffs[neighbor_coeffs != 0] # Proceed only if there are at least 5 nonzero neighbors if len(nonzero_neighbor_coeffs) >= 5: # Calculate the mean of the nonzero target interaction coefficients smoothed_values[i] = np.mean(nonzero_neighbor_coeffs) target_interaction_coef = pd.Series(smoothed_values, index=target_interaction_coef.index) self.logger.info(f"{(target_interaction_coef > 0).sum()} {target} affected cells post-smoothing.") # Lenient w/ the max value cutoff so that the colored dots are more distinct from black background cutoff = np.percentile(target_interaction_coef.values, pcutoff) if pcutoff == 0: cutoff = np.percentile(target_interaction_coef.values, 99.9) target_interaction_coef[target_interaction_coef > cutoff] = cutoff target_interaction_coef[target_interaction_coef < min_value] = min_value plot_vals = target_interaction_coef.values # Separate data into zero and non-zero (keeping one zero with non-zeros) is_zero = plot_vals == 0.0 if np.any(is_zero): non_zeros = np.where(is_zero, 0, plot_vals) # Select the first zero to keep first_zero_idx = np.where(is_zero)[0][0] # Temp- to get the correct indices of nonzeros non_zeros[first_zero_idx] = 1 is_nonzero = non_zeros != 0 non_zeros[first_zero_idx] = 0 else: is_nonzero = np.ones(len(plot_vals), dtype=bool) # Two plots, one for the zeros and one for the nonzeros scatter_effect = go.Scatter3d( x=x[is_nonzero], y=y[is_nonzero], z=z[is_nonzero], mode="markers", marker=dict( color=plot_vals[is_nonzero], colorscale="Hot", size=size, colorbar=dict( title=f"{interaction.title()} Effect on {target.title()}", x=0.75, titlefont=dict(size=24), tickfont=dict(size=24), ), ), showlegend=False, ) # Plot zeros separately (if there are any): scatter_zeros = None if np.any(is_zero): scatter_zeros = go.Scatter3d( x=x[is_zero], y=y[is_zero], z=z[is_zero], mode="markers", marker=dict( color="#000000", # Use zero values for color to match the scale size=size, opacity=zero_opacity, ), showlegend=False, ) fig = go.Figure(data=[scatter_effect]) if scatter_zeros is not None: fig.add_trace(scatter_zeros) title_dict = dict( text=f"{interaction.title()} Effect on {target.title()}", y=0.9, yanchor="top", x=0.5, xanchor="center", font=dict(size=36), ) fig.update_layout( scene=dict( aspectmode="data", xaxis=dict( showgrid=False, showline=False, linewidth=2, linecolor="black", backgroundcolor="white", title="", showticklabels=False, ticks="", ), yaxis=dict( showgrid=False, showline=False, linewidth=2, linecolor="black", backgroundcolor="white", title="", showticklabels=False, ticks="", ), zaxis=dict( showgrid=False, showline=False, linewidth=2, linecolor="black", backgroundcolor="white", title="", showticklabels=False, ticks="", ), ), margin=dict(l=0, r=0, b=0, t=50), # Adjust margins to minimize spacing title=title_dict, ) fig.write_html(save_path)
[docs] def plot_multiple_interaction_effects_3D( self, effects: List[str], save_path: str, include_combos_of_two: bool = False ): """Quick-visualize the magnitude of the predicted effect on target for a given interaction. Args: effects: List of effects to visualize (e.g. ["Igf1:Igf1r", "Igf1:InsR"] for L:R model, ["Igf1"] for ligand model) save_path: Path to save the figure to (will save as HTML file) include_combos_of_two: Whether to include paired combinations of effects (e.g. "Igf1:Igf1r and Igf1:InsR") as separate categories. If False, will include these in the generic "Multiple interactions" category. """ if hasattr(self, "remaining_cells"): adata = self.adata[self.remaining_cells, :].copy() else: adata = self.adata.copy() coords = adata.obsm[self.coords_key] x, y, z = coords[:, 0], coords[:, 1], coords[:, 2] mean_values = {} adata.obs["interaction_categories"] = "Other" for effect in effects: interaction, target = effect.split(":") if target not in self.coeffs.keys(): self.logger.info( f"{target} not found in this model's directory. Skipping this interaction-target pair." ) continue if f"b_{interaction}" not in self.coeffs[target].columns: self.logger.info(f"{interaction} not found for {target}. Skipping this interaction-target pair.") continue target_interaction_coef = self.coeffs[target].loc[adata.obs_names, f"b_{interaction}"] mean_values[effect] = np.mean(target_interaction_coef[target_interaction_coef > 0]) adata.obs[f"{effect} nonzero"] = target_interaction_coef > 0 # Temporarily, the key labeled with the effect name stores whether the interaction is nonzero to a # substantial degree: adata.obs.loc[target_interaction_coef >= mean_values[effect], effect] = True # Categorize cells based on their interaction effects for idx, row in tqdm(adata.obs.iterrows(), total=len(adata.obs_names), desc="Categorizing cells..."): active_effects = [effect for effect in effects if row[f"{effect} nonzero"]] strong_active_effects = [effect for effect in effects if row[effect]] if include_combos_of_two: if len(strong_active_effects) >= 3: adata.obs.loc[idx, "interaction_categories"] = "Multiple interactions" elif len(strong_active_effects) == 2: adata.obs.loc[ idx, "interaction_categories" ] = f"{strong_active_effects[0]} and {strong_active_effects[1]}" elif len(active_effects) == 1: adata.obs.loc[idx, "interaction_categories"] = active_effects[0] else: if len(strong_active_effects) >= 2: adata.obs.loc[idx, "interaction_categories"] = "Multiple interactions" elif len(active_effects) == 1: adata.obs.loc[idx, "interaction_categories"] = active_effects[0] cat_counts = adata.obs["interaction_categories"].value_counts() # Map each category to color: if include_combos_of_two: color_mapping = dict(zip(cat_counts.index, godsnot_102)) else: color_mapping = dict(zip(cat_counts.index, vega_10)) color_mapping["Multiple interactions"] = "#71797E" color_mapping["Other"] = "#D3D3D3" traces = [] for group, color in color_mapping.items(): marker_size = 1.25 if group == "Other" else 2 mask = adata.obs["interaction_categories"] == group scatter = go.Scatter3d( x=x[mask], y=y[mask], z=z[mask], mode="markers", marker=dict(size=marker_size, color=color), showlegend=False, ) traces.append(scatter) # Invisible trace for the legend (so the colored point is larger than the plot points): legend_target = go.Scatter3d( x=[None], y=[None], z=[None], mode="markers", marker=dict(size=30, color=color), # Adjust size as needed name=group, showlegend=True, ) traces.append(legend_target) fig = go.Figure(data=traces) title = ( "L:R Interaction Effect on Target (format Ligand:Receptor-Target)" if self.mod_type == "lr" else "Ligand Effect on Target (format Ligand-Target)" ) title_dict = dict( text=title, y=0.9, yanchor="top", x=0.5, xanchor="center", font=dict(size=28), ) fig.update_layout( showlegend=True, legend=dict(x=0.7, y=0.85, orientation="v", font=dict(size=14)), scene=dict( aspectmode="data", xaxis=dict( showgrid=False, showline=False, linewidth=2, linecolor="black", backgroundcolor="white", title="", showticklabels=False, ticks="", ), yaxis=dict( showgrid=False, showline=False, linewidth=2, linecolor="black", backgroundcolor="white", title="", showticklabels=False, ticks="", ), zaxis=dict( showgrid=False, showline=False, linewidth=2, linecolor="black", backgroundcolor="white", title="", showticklabels=False, ticks="", ), ), margin=dict(l=0, r=0, b=0, t=50), # Adjust margins to minimize spacing title=title_dict, ) fig.write_html(save_path)
[docs] def plot_tf_effect_3D( self, target: str, tf: str, save_path: str, ligand_targets: bool = True, receptor_targets: bool = False, target_gene_targets: bool = False, pcutoff: float = 99.7, min_value: float = 0, zero_opacity: float = 1.0, size: float = 2.0, ): """Quick-visualize the magnitude of the predicted effect on target for a given TF. Can only find the files necessary for this if :func `CCI_deg_detection()` has been run. Args: target: Target gene of interest tf: TF of interest (e.g. "Foxo1") save_path: Path to save the figure to (will save as HTML file) ligand_targets: Set True if ligands were used as the target genes for the :func `CCI_deg_detection()` model. receptor_targets: Set True if receptors were used as the target genes for the :func `CCI_deg_detection()` model. target_gene_targets: Set True if target genes were used as the target genes for the :func `CCI_deg_detection()` model. pcutoff: Percentile cutoff for the colorbar. Will set all values above this percentile to this value. min_value: Minimum value to set the colorbar to. Will set all values below this value to this value. zero_opacity: Opacity of points with zero expression. Between 0.0 and 1.0. Default is 1.0. size: Size of the points in the scatter plot. Default is 2. """ downstream_parent_dir = os.path.dirname(os.path.splitext(self.output_path)[0]) id = os.path.splitext(os.path.basename(self.output_path))[0] if ligand_targets: target_type = "ligand" folder = "ligand_analysis" elif receptor_targets: target_type = "receptor" folder = "receptor_analysis" elif target_gene_targets: target_type = "target_gene" folder = "target_gene_analysis" else: raise ValueError( "Please set either 'ligand_targets', 'receptor_targets', or 'target_gene_targets' to True." ) targets = pd.read_csv( os.path.join( downstream_parent_dir, "cci_deg_detection", folder, id, "downstream_design_matrix", "targets.csv", ), index_col=0, ) regulators = pd.read_csv( os.path.join( downstream_parent_dir, "cci_deg_detection", folder, id, "downstream_design_matrix", "design_matrix.csv", ), index_col=0, ) regulators.columns = [col.replace("regulator_", "") for col in regulators.columns] if target not in targets.columns: raise ValueError(f"Target {target} not found in this model's directory. Please provide a valid target.") if tf not in regulators.columns: raise ValueError(f"TF {tf} not found in this model's directory.") if hasattr(self, "remaining_cells"): adata = self.adata[self.remaining_cells, :].copy() else: adata = self.adata.copy() coords = adata.obsm[self.coords_key] x, y, z = coords[:, 0], coords[:, 1], coords[:, 2] downstream_coeffs, downstream_standard_errors = self.return_outputs( adjust_for_subsampling=False, load_from_downstream=target_type, load_for_interpreter=True ) target_tf_coef = downstream_coeffs[target].loc[adata.obs_names, f"b_{tf}"] # Lenient w/ the max value cutoff so that the colored dots are more distinct from black background cutoff = np.percentile(target_tf_coef.values, pcutoff) target_tf_coef[target_tf_coef > cutoff] = cutoff target_tf_coef[target_tf_coef < min_value] = min_value plot_vals = target_tf_coef.values # Separate data into zero and non-zero (keeping one zero with non-zeros) is_zero = plot_vals == 0 if np.any(is_zero): non_zeros = np.where(is_zero, 0, plot_vals) # Select the first zero to keep first_zero_idx = np.where(is_zero)[0][0] # Temp- to get the correct indices of nonzeros non_zeros[first_zero_idx] = 1 is_nonzero = non_zeros != 0 non_zeros[first_zero_idx] = 0 else: is_nonzero = np.ones(len(plot_vals), dtype=bool) # Two plots, one for the zeros and one for the nonzeros scatter_effect = go.Scatter3d( x=x[is_nonzero], y=y[is_nonzero], z=z[is_nonzero], mode="markers", marker=dict( color=plot_vals[is_nonzero], colorscale="Hot", size=size, colorbar=dict( title=f"{tf.title()} Effect on {target.title()}", x=0.75, titlefont=dict(size=24), tickfont=dict(size=24), ), ), showlegend=False, ) # Plot zeros separately (if there are any): scatter_zeros = None if np.any(is_zero): scatter_zeros = go.Scatter3d( x=x[is_zero], y=y[is_zero], z=z[is_zero], mode="markers", marker=dict( color="#000000", # Use zero values for color to match the scale size=size, opacity=zero_opacity, ), showlegend=False, ) fig = go.Figure(data=[scatter_effect]) if scatter_zeros is not None: fig.add_trace(scatter_zeros) title_dict = dict( text=f"{tf.title()} Effect on {target.title()}", y=0.9, yanchor="top", x=0.5, xanchor="center", font=dict(size=36), ) fig.update_layout( scene=dict( aspectmode="data", xaxis=dict( showgrid=False, showline=False, linewidth=2, linecolor="black", backgroundcolor="white", title="", showticklabels=False, ticks="", ), yaxis=dict( showgrid=False, showline=False, linewidth=2, linecolor="black", backgroundcolor="white", title="", showticklabels=False, ticks="", ), zaxis=dict( showgrid=False, showline=False, linewidth=2, linecolor="black", backgroundcolor="white", title="", showticklabels=False, ticks="", ), ), margin=dict(l=0, r=0, b=0, t=50), # Adjust margins to minimize spacing title=title_dict, ) fig.write_html(save_path)
[docs] def visualize_overlap_between_interacting_components_3D( self, target: str, interaction: str, save_path: str, size: float = 2.0 ): """Visualize the spatial distribution of signaling features (ligand, receptor, or L:R field) and target gene, as well as the overlapping region. Intended for use with 3D spatial coordinates. Args: target: Target gene to visualize interaction: Interaction to visualize (e.g. "Igf1:Igf1r" for L:R model, "Igf1" for ligand model) save_path: Path to save the figure to (will save as HTML file) size: Size of the points in the plot. Defaults to 2. """ from ...plotting.static.colorlabel import godsnot_102 # Rearrange slightly: godsnot_102[1] = "#B200ED" godsnot_102[2] = "#FFA500" godsnot_102[3] = "#1CE6FF" targets = pd.read_csv( os.path.join(os.path.splitext(self.output_path)[0], "design_matrix", "targets.csv"), index_col=0 ) if target not in targets.columns: raise ValueError(f"Target {target} not found in this model's directory. Please provide a valid target.") if interaction not in self.X_df.columns: raise ValueError(f"Interaction {interaction} not found in this model's directory.") if hasattr(self, "remaining_cells"): adata = self.adata[self.remaining_cells, :].copy() else: adata = self.adata.copy() coords = adata.obsm[self.coords_key] x, y, z = coords[:, 0], coords[:, 1], coords[:, 2] # Label cells expressing target: target_expressing = adata.obs_names[adata[:, target].X.toarray().reshape(-1) != 0] # Label cells and with nonzero interaction feature value (for ligand model, cells that have expression of the # ligand in their neighborhood (in addition to other caveats incorported in model setup), for L:R model, # cells that have expression of the ligand in their neighborhood and expression of the receptor): interaction_expressing = self.X_df[self.X_df[interaction] != 0].index # Label cells expressing target and with nonzero interaction feature value: overlap = target_expressing.intersection(interaction_expressing) adata.obs[f"{interaction}_{target}"] = "Other" adata.obs.loc[ target_expressing, f"{interaction}_{target}" ] = f"{target} only (no {interaction} in neighborhood and/or receptor)" if self.mod_type == "lr": ligand, receptor = interaction.split(":") adata.obs.loc[ interaction_expressing, f"{interaction}_{target}" ] = f"{ligand.title()} in Neighborhood and {receptor}, no {target}" adata.obs.loc[ overlap, f"{interaction}_{target}" ] = f"{ligand.title()} in Neighborhood, {receptor} and {target}" elif self.mod_type == "ligand": adata.obs.loc[ interaction_expressing, f"{interaction}_{target}" ] = f"{interaction.title()} in Neighborhood and Receptor, no {target}" adata.obs.loc[ overlap, f"{interaction}_{target}" ] = f"{interaction.title()} in Neighborhood, Receptor and {target}" color_mapping = dict(zip(adata.obs[f"{interaction}_{target}"].value_counts().index, godsnot_102)) color_mapping["Other"] = "#D3D3D3" traces = [] for group, color in color_mapping.items(): marker_size = size * 0.75 if group == "Other" else size opacity = 0.5 if group == "Other" else 1.0 mask = adata.obs[f"{interaction}_{target}"] == group scatter = go.Scatter3d( x=x[mask], y=y[mask], z=z[mask], mode="markers", marker=dict(size=marker_size, color=color, opacity=opacity), showlegend=False, ) traces.append(scatter) # Invisible trace for the legend (so the colored point is larger than the plot points): legend_target = go.Scatter3d( x=[None], y=[None], z=[None], mode="markers", marker=dict(size=30, color=color), # Adjust size as needed name=group, showlegend=True, ) traces.append(legend_target) fig = go.Figure(data=traces) if self.mod_type == "lr": title = f"Distribution of interacting components: <br>{interaction} and {target}" elif self.mod_type == "ligand": title = ( f"Distribution of interacting components: <br>{interaction}, {interaction} receptor/downstream " f"components and {target}" ) title_dict = dict( text=title, y=0.9, yanchor="top", x=0.5, xanchor="center", font=dict(size=36), ) fig.update_layout( showlegend=True, legend=dict(x=0.65, y=0.85, orientation="v", font=dict(size=18)), scene=dict( aspectmode="data", xaxis=dict( showgrid=False, showline=False, linewidth=2, linecolor="black", backgroundcolor="white", title="", showticklabels=False, ticks="", ), yaxis=dict( showgrid=False, showline=False, linewidth=2, linecolor="black", backgroundcolor="white", title="", showticklabels=False, ticks="", ), zaxis=dict( showgrid=False, showline=False, linewidth=2, linecolor="black", backgroundcolor="white", title="", showticklabels=False, ticks="", ), ), margin=dict(l=0, r=0, b=0, t=50), # Adjust margins to minimize spacing title=title_dict, ) fig.write_html(save_path)
[docs] def gene_expression_heatmap( self, use_ligands: bool = False, use_receptors: bool = False, use_target_genes: bool = False, genes: Optional[List[str]] = None, position_key: str = "spatial", coord_column: Optional[Union[int, str]] = None, reprocess: bool = False, neatly_arrange_y: bool = True, window_size: int = 3, recompute: bool = False, title: Optional[str] = None, fontsize: Union[None, int] = None, figsize: Union[None, Tuple[float, float]] = None, cmap: str = "magma", save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", save_kwargs: Optional[dict] = {}, ): """Visualize the distribution of gene expression across cells in the spatial coordinates of cells; provides an idea of the simultaneous relative positions/patternings of different genes. Args: use_ligands: Set True to use ligands as the genes to visualize. If True, will ignore "genes" argument. "ligands_expr" file must be present in the model's directory. use_receptors: Set True to use receptors as the genes to visualize. If True, will ignore "genes" argument. "receptors_expr" file must be present in the model's directory. use_target_genes: Set True to use target genes as the genes to visualize. If True, will ignore "genes" argument. "targets" file must be present in the model's directory. genes: Optional list of genes to visualize. If "use_ligands", "use_receptors", and "use_target_genes" are all False, this must be given. This can also be used to visualize only a subset of the genes once processing & saving has already completed using e.g. "use_ligands", "use_receptors", etc. position_key: Key in adata.obs or adata.obsm that provides a relative indication of the position of cells. i.e. spatial coordinates. Defaults to "spatial". For each value in the position array (each coordinate, each category), multiple cells must have the same value. coord_column: Optional, only used if "position_key" points to an entry in .obsm. In this case, this is the index or name of the column to be used to provide the positional context. Can also provide "xy", "yz", "xz", "-xy", "-yz", "-xz" to draw a line between the two coordinate axes. "xy" will extend the new axis in the direction of increasing x and increasing y starting from x=0 and y=0 (or min. x/min. y), "-xy" will extend the new axis in the direction of decreasing x and increasing y starting from x=minimum x and y=maximum y, and so on. reprocess: Set to True to reprocess the data and overwrite the existing files. Use if the genes to visualize have changed compared to the saved file (if existing), e.g. if "use_ligands" is True when the initial analysis used "use_target_genes". neatly_arrange_y: Set True to order the y-axis in terms of how early along the position axis the max z-scores for each row occur in. Used for a more uniform plot where similarly patterned interaction-target pairs are grouped together. If False, will sort this axis by the identity of the interaction (i.e. all "Fgf1" rows will be grouped together). window_size: Size of window to use for smoothing. Must be an odd integer. If 1, no smoothing is applied. recompute: Set to True to recompute the data and overwrite the existing files title: Optional, can be used to provide title for plot fontsize: Size of font for x and y labels. figsize: Size of figure. cmap: Colormap to use. Options: Any divergent matplotlib colormap. save_show_or_return: Whether to save, show or return the figure. If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, displayed and the associated axis and other object will be return. save_kwargs: A dictionary that will passed to the save_fig function. By default it is an empty dictionary and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those keys according to your needs. """ logger = lm.get_main_logger() if window_size % 2 == 0: raise ValueError("Window size must be an odd integer.") if not use_ligands and not use_receptors and not use_target_genes and genes is None: raise ValueError( "Please set either 'use_ligands', 'use_receptors', or 'use_target_genes' to True, or provide a list " "of genes to visualize." ) # Check if custom genes are given: custom_genes = genes if use_ligands: if not os.path.exists( os.path.join(os.path.splitext(self.output_path)[0], "design_matrix", "ligands_expr.csv") ): raise FileNotFoundError("ligands_expr.csv not found in this model's directory.") expr_df = pd.read_csv( os.path.join(os.path.splitext(self.output_path)[0], "design_matrix", "ligands_expr.csv"), index_col=0 ) genes = expr_df.columns genes = [ g for g in genes if g not in [ "Lta4h", "Fdx1", "Tfrc", "Trf", "Lamc1", "Aldh1a2", "Dhcr24", "Rnaset2a", "Ptges3", "Nampt", "Trf", "Fdx1", "Kdr", "Apoa2", "Apoe", "Dhcr7", "Enho", "Ptgr1", "Agrp", "Akr1b3", "Daglb", "Ubash3d", ] ] file_id = "ligand_expression" elif use_receptors: if not os.path.exists( os.path.join(os.path.splitext(self.output_path)[0], "design_matrix", "receptors_expr.csv") ): raise FileNotFoundError("receptors_expr.csv not found in this model's directory.") expr_df = pd.read_csv( os.path.join(os.path.splitext(self.output_path)[0], "design_matrix", "receptors_expr.csv"), index_col=0 ) genes = expr_df.columns file_id = "receptor_expression" elif use_target_genes: if not os.path.exists(os.path.join(os.path.splitext(self.output_path)[0], "design_matrix", "targets.csv")): raise FileNotFoundError("targets.csv not found in this model's directory.") expr_df = pd.read_csv( os.path.join(os.path.splitext(self.output_path)[0], "design_matrix", "targets.csv"), index_col=0 ) genes = expr_df.columns file_id = "target_gene_expression" else: expr_df = pd.DataFrame(self.adata[:, genes].X.toarray(), index=self.adata.obs_names, columns=genes) file_id = "expression" if hasattr(self, "remaining_cells"): adata = self.adata[self.remaining_cells, :].copy() else: adata = self.adata.copy() if position_key not in self.adata.obsm.keys() and position_key not in self.adata.obs.keys(): raise ValueError( f"Position key {position_key} not found in adata.obsm or adata.obs. Please provide a valid key." ) if position_key in self.adata.obsm.keys(): if coord_column in ["xy", "yz", "xz", "-xy", "-yz", "-xz"]: self.adata = create_new_coordinate(self.adata, position_key, coord_column) pos = self.adata.obs[f"{coord_column} Coordinate"] x_label = f"Relative position along custom {coord_column} axis" if title is None: title = f"Signaling effect distribution along {coord_column} axis" save_id = f"{coord_column}_axis" else: if coord_column is not None and isinstance(coord_column, str): if not isinstance(self.adata.obsm[position_key], pd.DataFrame): raise ValueError( f"Array stored at position key {position_key} has no column names; provide the column " f"index." ) else: pos = self.adata.obsm[position_key][coord_column] elif coord_column is not None and isinstance(coord_column, int): if isinstance(self.adata.obsm[position_key], pd.DataFrame): pos = self.adata.obsm[position_key].iloc[:, coord_column] x_label = f"Relative position along {coord_column}" if title is None: title = f"Signaling effect distribution along {coord_column}" save_id = coord_column else: pos = pd.Series(self.adata.obsm[position_key][:, coord_column], index=self.adata.obs_names) if coord_column == 0: x_label = "Relative position along X" if title is None: title = "Signaling effect distribution along X" save_id = "x_axis" elif coord_column == 1: x_label = "Relative position along Y" if title is None: title = "Signaling effect distribution along Y" save_id = "y_axis" elif coord_column == 2: x_label = "Relative position along Z" if title is None: title = "Signaling effect distribution along Z" save_id = "z_axis" elif self.adata.obsm[position_key].shape[1] != 1: raise ValueError( f"Array stored at position key {position_key} has more than one column; provide the column " f"index." ) else: pos = ( pd.Series(self.adata.obsm[position_key].flatten(), index=self.adata.obs_names) if isinstance(self.adata.obsm[position_key], np.ndarray) else self.adata.obsm[position_key] ) x_label = "Relative position" if title is None: title = f"Signaling effect distribution along axis given by {position_key} key" save_id = position_key else: pos = self.adata.obs[position_key] x_label = "Relative position" if title is None: title = f"Signaling effect distribution along axis given by {position_key} key" save_id = position_key # If position array is numerical, there may not be an exact match- convert the data type to integer: if pos.dtype == float: pos = pos.astype(int) if save_show_or_return in ["save", "both", "all"]: if not os.path.exists(os.path.join(os.path.dirname(self.output_path), "figures")): os.makedirs(os.path.join(os.path.dirname(self.output_path), "figures")) figure_folder = os.path.join(os.path.dirname(self.output_path), "figures", "temp") if not os.path.exists(figure_folder): os.makedirs(figure_folder) output_folder = os.path.join(os.path.dirname(self.output_path), "analyses") if not os.path.exists(output_folder): os.makedirs(output_folder) # Use the saved name for the AnnData object to define part of the name of the saved file: base_name = os.path.basename(self.adata_path) adata_id = os.path.splitext(base_name)[0] # If divergent colormap is specified, center the colormap at 0: divergent_cmaps = [ "seismic", "coolwarm", "bwr", "RdBu", "RdGy", "PuOr", "PiYG", "PRGn", "BrBG", "RdYlBu", "RdYlGn", "Spectral", ] # Check for existing dataframe: if ( os.path.exists(os.path.join(output_folder, f"{adata_id}_distribution_{file_id}_along_{save_id}.csv")) and not recompute ): to_plot = pd.read_csv( os.path.join(output_folder, f"{adata_id}_distribution_{file_id}_along_{save_id}.csv"), index_col=0, ) # Can plot a subset once this is already processed & saved: if custom_genes is not None: custom_genes = [g for g in custom_genes if g in to_plot.index] to_plot = to_plot.loc[custom_genes] else: # For each gene, compute the mean expression: mean_expr = pd.Series(index=genes) for g in genes: mean_expr[g] = expr_df[g].mean() # For each cell, compute the fold change over the average for each combination: all_fc = pd.DataFrame(index=self.adata.obs_names, columns=genes) for g in tqdm(genes, desc="Computing fold changes for each gene..."): g_expr = expr_df[g] all_fc[g] = g_expr / mean_expr[g] # Log fold change: all_fc = np.log1p(all_fc) # z-score the fold change values: all_fc = all_fc.apply(scipy.stats.zscore, axis=0) all_fc["pos"] = pos all_fc_coord_sorted = all_fc.sort_values(by="pos") # Mean z-score at each coordinate position: all_fc_coord_sorted = all_fc_coord_sorted.groupby("pos").mean() # Smooth in the case of dropouts: all_fc_coord_sorted = all_fc_coord_sorted.rolling(window_size, center=True, min_periods=1).mean() # For each unique value in 'pos', find the top genes with the highest mean z-score top_genes = all_fc_coord_sorted.apply(lambda x: x.nlargest(30).index.tolist(), axis=1) # Find interesting interaction effects by position- get features that are in the top features for at least # five consecutive positions: consecutive_counts = {g: 0 for g in genes} genes_of_interest = set() for pos in top_genes.index: for g in top_genes[pos]: consecutive_counts[g] += 1 if consecutive_counts[g] >= 5: genes_of_interest.add(g) for g in genes: if g not in top_genes[pos]: consecutive_counts[g] = 0 genes_of_interest = list(genes_of_interest) to_plot = all_fc_coord_sorted[genes_of_interest] if to_plot.index.is_numeric(): # Minmax scale to normalize positional context: to_plot.index = (to_plot.index - to_plot.index.min()) / (to_plot.index.max() - to_plot.index.min()) to_plot = to_plot.T # so that the features are labeled along the y-axis # Sort by "heat" if applicable (i.e. in order roughly determined by how early along the relative position # the highest z-scores occur in for each interaction-target pair): if neatly_arrange_y: logger.info("Sorting by position of enrichment along axis...") column_indices = np.tile(np.arange(len(to_plot.columns)), (len(to_plot), 1)) # Column indices array # Look only at the indices corresponding to the highest changes: percentile_95 = to_plot.apply( lambda row: np.percentile(row[row > 0], 95) if row[row > 0].size > 0 else 0, axis=1 ) # Create a DataFrame that replicates the shape of to_plot weights_matrix = to_plot.gt(percentile_95, axis=0) * to_plot weighted_sum = np.sum(weights_matrix.values * column_indices, axis=1) total_weight = np.sum(weights_matrix.values, axis=1) weighted_avg = pd.Series(np.where(total_weight != 0, weighted_sum / total_weight, 0), index=to_plot.index) top_cols_sorted = weighted_avg.sort_values().index to_plot = to_plot.loc[top_cols_sorted] flattened = to_plot.values.flatten() flattened_series = pd.Series(flattened) percentile_95 = flattened_series.quantile(0.95) max_val = percentile_95 if figsize is None: m = len(to_plot) * 40 / 200 n = 8 figsize = (n, m) fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize) if fontsize is None: fontsize = rcParams.get("font.size") # Format the numerical columns for the plot: # First check whether the columns contain duplicates: if all(isinstance(name, str) for name in to_plot.columns): if any([name.count(".") > 1 for name in to_plot.columns]): to_plot.columns = [".".join(name.split(".")[:2]) for name in to_plot.columns] to_plot.columns = [float(col) for col in to_plot.columns] col_series = pd.Series(to_plot.columns) if set(col_series) != len(col_series): unique_values, counts = np.unique(col_series, return_counts=True) # Iterate through unique values for value, count in zip(unique_values, counts): if count > 1: # Find indices of the repeated value indices = col_series[col_series == value].index # Calculate step size if value == unique_values[-1]: next_value = value + (value - unique_values[-2]) else: next_index = np.where(unique_values == value)[0][0] + 1 next_value = unique_values[next_index] step = (next_value - value) / count # Update the values for i in range(count): col_series.iloc[indices[i]] = value + step * i to_plot.columns = col_series.values if all(isinstance(name, float) for name in to_plot.columns): to_plot.columns = [f"{float(col):.3f}" for col in to_plot.columns] to_plot.columns = [str(col) for col in to_plot.columns] if genes is not None and not neatly_arrange_y: to_plot = to_plot.reindex(genes) m = sns.heatmap(to_plot, vmin=-max_val, vmax=max_val, ax=ax, cmap=cmap) cbar = m.collections[0].colorbar cbar.set_label("Z-score", fontsize=fontsize * 1.5, labelpad=10) # Adjust colorbar tick font size cbar.ax.tick_params(labelsize=fontsize * 1.25) cbar.ax.set_aspect(np.min([len(to_plot), 70])) ax.set_xlabel(x_label, fontsize=fontsize * 1.25) ax.set_ylabel("Gene", fontsize=fontsize * 1.25) ax.tick_params(axis="x", labelsize=fontsize) ax.tick_params(axis="y", labelsize=fontsize) ax.set_title(title, fontsize=fontsize * 1.5, pad=20) if ( not os.path.exists(os.path.join(output_folder, f"{adata_id}_distribution_{file_id}_along_{save_id}.csv")) and not recompute ): to_plot.to_csv( os.path.join( output_folder, f"{adata_id}_distribution_{file_id}_along_{save_id}.csv", ) ) if save_show_or_return in ["save", "both", "all"]: save_kwargs["ext"] = "png" save_kwargs["dpi"] = 300 if "figure_folder" in locals(): save_kwargs["path"] = figure_folder # Save figure: save_return_show_fig_utils( save_show_or_return=save_show_or_return, show_legend=False, background="white", prefix=f"distribution_{file_id}_along_{save_id}", save_kwargs=save_kwargs, total_panels=1, fig=fig, axes=ax, return_all=False, return_all_list=None, )
[docs] def effect_distribution_heatmap( self, target_subset: Optional[List[str]] = None, interaction_subset: Optional[List[str]] = None, position_key: str = "spatial", coord_column: Optional[Union[int, str]] = None, effect_threshold: Optional[float] = None, check_downstream_ligand_effects: bool = False, check_downstream_receptor_effects: bool = False, check_downstream_target_effects: bool = False, use_significant: bool = False, sort_by_target: bool = False, neatly_arrange_y: bool = True, window_size: int = 3, recompute: bool = False, title: Optional[str] = None, fontsize: Union[None, int] = None, figsize: Union[None, Tuple[float, float]] = None, cmap: str = "magma", save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", save_kwargs: Optional[dict] = {}, ): """Visualize the distribution of interaction effects across cells in the spatial coordinates of cells; provides an idea of the simultaneous relative positions of different interaction effects. Args: target_subset: List of targets to consider. If None, will use all targets used in model fitting. interaction_subset: List of interactions to consider. If None, will use all interactions used in model. position_key: Key in adata.obs or adata.obsm that provides a relative indication of the position of cells. i.e. spatial coordinates. Defaults to "spatial". For each value in the position array (each coordinate, each category), multiple cells must have the same value. coord_column: Optional, only used if "position_key" points to an entry in .obsm. In this case, this is the index or name of the column to be used to provide the positional context. Can also provide "xy", "yz", "xz", "-xy", "-yz", "-xz" to draw a line between the two coordinate axes. "xy" will extend the new axis in the direction of increasing x and increasing y starting from x=0 and y=0 (or min. x/min. y), "-xy" will extend the new axis in the direction of decreasing x and increasing y starting from x=minimum x and y=maximum y, and so on. effect_threshold: Optional threshold minimum effect size to consider an effect for further analysis, as an absolute value. Use this to choose only the cells for which an interaction is predicted to have a strong effect. If None, use the median interaction effect. check_downstream_ligand_effects: Set True to check the coefficients of downstream ligand models instead of coefficients of the upstream CCI model. Note that this may not necessarily look nice because TF-target relationships are not spatially dependent like L:R effects are. check_downstream_receptor_effects: Set True to check the coefficients of downstream receptor models instead of coefficients of the upstream CCI model. Note that this may not necessarily look nice because TF-target relationships are not spatially dependent like L:R effects are. check_downstream_target_effects: Set True to check the coefficients of downstream target models instead of coefficients of the upstream CCI model. Note that this may not necessarily look nice because TF-target relationships are not spatially dependent like L:R effects are. use_significant: Whether to use only significant effects in computing the specificity. If True, will filter to cells + interactions where the interaction is significant for the target. Only valid if :func `compute_coeff_significance()` has been run. sort_by_target: Set True to order the y-axis in terms of the identity of the target gene. Incompatible with "neatly_arrange_y". If both this and "neatly_arrange_y" are False, will sort this axis by the identity of the interaction (i.e. all "Fgf1" rows will be grouped together). neatly_arrange_y: Set True to order the y-axis in terms of how early along the position axis the max z-scores for each row occur in. Used for a more uniform plot where similarly patterned interaction-target pairs are grouped together. If False, will sort this axis by the identity of the interaction (i.e. all "Fgf1" rows will be grouped together). window_size: Size of window to use for smoothing. Must be an odd integer. If 1, no smoothing is applied. recompute: Set to True to recompute the data and overwrite the existing files title: Optional, can be used to provide title for plot fontsize: Size of font for x and y labels. figsize: Size of figure. cmap: Colormap to use. Options: Any divergent matplotlib colormap. save_show_or_return: Whether to save, show or return the figure. If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, displayed and the associated axis and other object will be return. save_kwargs: A dictionary that will passed to the save_fig function. By default it is an empty dictionary and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those keys according to your needs. """ logger = lm.get_main_logger() if window_size % 2 == 0: raise ValueError("Window size must be an odd integer.") if position_key not in self.adata.obsm.keys() and position_key not in self.adata.obs.keys(): raise ValueError( f"Position key {position_key} not found in adata.obsm or adata.obs. Please provide a valid key." ) if position_key in self.adata.obsm.keys(): if coord_column in ["xy", "yz", "xz", "-xy", "-yz", "-xz"]: self.adata = create_new_coordinate(self.adata, position_key, coord_column) pos = self.adata.obs[f"{coord_column} Coordinate"] x_label = f"Relative position along custom {coord_column} axis" if title is None: title = f"Signaling effect distribution along {coord_column} axis" save_id = f"{coord_column}_axis" else: if coord_column is not None and isinstance(coord_column, str): if not isinstance(self.adata.obsm[position_key], pd.DataFrame): raise ValueError( f"Array stored at position key {position_key} has no column names; provide the column " f"index." ) else: pos = self.adata.obsm[position_key][coord_column] elif coord_column is not None and isinstance(coord_column, int): if isinstance(self.adata.obsm[position_key], pd.DataFrame): pos = self.adata.obsm[position_key].iloc[:, coord_column] x_label = f"Relative position along {coord_column}" if title is None: title = f"Signaling effect distribution along {coord_column}" save_id = coord_column else: pos = pd.Series(self.adata.obsm[position_key][:, coord_column], index=self.adata.obs_names) if coord_column == 0: x_label = "Relative position along X" if title is None: title = "Signaling effect distribution along X" save_id = "x_axis" elif coord_column == 1: x_label = "Relative position along Y" if title is None: title = "Signaling effect distribution along Y" save_id = "y_axis" elif coord_column == 2: x_label = "Relative position along Z" if title is None: title = "Signaling effect distribution along Z" save_id = "z_axis" elif self.adata.obsm[position_key].shape[1] != 1: raise ValueError( f"Array stored at position key {position_key} has more than one column; provide the column " f"index." ) else: pos = ( pd.Series(self.adata.obsm[position_key].flatten(), index=self.adata.obs_names) if isinstance(self.adata.obsm[position_key], np.ndarray) else self.adata.obsm[position_key] ) x_label = "Relative position" if title is None: title = f"Signaling effect distribution along axis given by {position_key} key" save_id = position_key else: pos = self.adata.obs[position_key] x_label = "Relative position" if title is None: title = f"Signaling effect distribution along axis given by {position_key} key" save_id = position_key # To ensure each coordinate has enough samples, round to the nearest 10 if the max value is on the order of # magnitude of 1000, and the nearest 1 if the max value is on the order of magnitude of 100. max_value = pos.max() if max_value < 1000: rounding_base = 10 elif max_value >= 1000: rounding_base = 100 else: rounding_base = 1000 pos = pos.round(-int(np.log10(rounding_base))) # If position array is numerical, there may not be an exact match- convert the data type to integer: if pos.dtype == float: pos = pos.astype(int) if save_show_or_return in ["save", "both", "all"]: if not os.path.exists(os.path.join(os.path.dirname(self.output_path), "figures")): os.makedirs(os.path.join(os.path.dirname(self.output_path), "figures")) figure_folder = os.path.join(os.path.dirname(self.output_path), "figures", "temp") if not os.path.exists(figure_folder): os.makedirs(figure_folder) output_folder = os.path.join(os.path.dirname(self.output_path), "analyses") if not os.path.exists(output_folder): os.makedirs(output_folder) # Use the saved name for the AnnData object to define part of the name of the saved file: base_name = os.path.basename(self.adata_path) adata_id = os.path.splitext(base_name)[0] # If divergent colormap is specified, center the colormap at 0: divergent_cmaps = [ "seismic", "coolwarm", "bwr", "RdBu", "RdGy", "PuOr", "PiYG", "PRGn", "BrBG", "RdYlBu", "RdYlGn", "Spectral", ] # Check for existing dataframe: if check_downstream_ligand_effects: logger.info( "Checking downstream TF-ligand effects...note that this may not look very nice because TF-target " "relationships are not spatially dependent like L:R effects are." ) df_path = os.path.join( output_folder, f"{adata_id}_distribution_downstream_ligand_effects_along_{save_id}.csv" ) elif check_downstream_receptor_effects: logger.info( "Checking downstream TF-receptor effects...note that this may not look very nice because TF-target " "relationships are not spatially dependent like L:R effects are." ) df_path = os.path.join( output_folder, f"{adata_id}_distribution_downstream_receptor_effects_along_{save_id}.csv" ) elif check_downstream_target_effects: logger.info( "Checking downstream TF-target effects...note that this may not look very nice because TF-target " "relationships are not spatially dependent like L:R effects are." ) df_path = os.path.join( output_folder, f"{adata_id}_distribution_downstream_target_effects_along_{save_id}.csv" ) else: df_path = os.path.join(output_folder, f"{adata_id}_distribution_interaction_effects_along_{save_id}.csv") if os.path.exists(df_path) and not recompute: to_plot = pd.read_csv(df_path, index_col=0) if interaction_subset is not None: selected_interactions = [i for i in to_plot.index if i.split("-")[1] in interaction_subset] to_plot = to_plot.loc[selected_interactions] if target_subset is not None: selected_targets = [t for t in to_plot.index if t.split("-")[0] if t in target_subset] to_plot = to_plot.loc[selected_targets] else: # Determine where to look for coefficients: if check_downstream_ligand_effects: all_coeffs = self.downstream_model_ligand_coeffs.copy() if len(all_coeffs) == 0: raise ValueError("No downstream model results found for ligands.") elif check_downstream_receptor_effects: all_coeffs = self.downstream_model_receptor_coeffs.copy() if len(all_coeffs) == 0: raise ValueError("No downstream model results found for receptors.") elif check_downstream_target_effects: all_coeffs = self.downstream_model_target_coeffs.copy() if len(all_coeffs) == 0: raise ValueError("No downstream model results found for targets.") else: all_coeffs = self.coeffs.copy() if target_subset is None: target_subset = list(all_coeffs.keys()) else: target_subset = [t for t in target_subset if t in all_coeffs.keys()] removed = [t for t in target_subset if t not in all_coeffs.keys()] if len(removed) > 0: logger.warning( f"Targets {removed} were not found in the model, and will be removed from the target subset." ) if check_downstream_ligand_effects: if self.downstream_model_ligand_design_matrix is None: raise ValueError( "No downstream model design matrix found for ligands. Run " "`CCI_deg_detection_setup` with use_ligands=True first." ) all_feature_names = [ feat.replace("regulator_", "") for feat in self.downstream_model_ligand_design_matrix.columns ] elif check_downstream_receptor_effects: if self.downstream_model_receptor_design_matrix is None: raise ValueError( "No downstream model design matrix found for receptors. Run " "`CCI_deg_detection_setup` with use_receptors=True first." ) all_feature_names = [ feat.replace("regulator_", "") for feat in self.downstream_model_receptor_design_matrix.columns ] elif check_downstream_target_effects: if self.downstream_model_target_design_matrix is None: raise ValueError( "No downstream model design matrix found for target genes. Run " "`CCI_deg_detection_setup` with use_targets=True first." ) all_feature_names = [ feat.replace("regulator_", "") for feat in self.downstream_model_target_design_matrix.columns ] else: all_feature_names = [feat for feat in self.feature_names if feat != "intercept"] if interaction_subset is None: feature_names = all_feature_names else: feature_names = [feat for feat in all_feature_names if feat in interaction_subset] removed = [feat for feat in interaction_subset if feat not in all_feature_names] if len(removed) > 0: logger.warning( f"Interactions {removed} were not found in the model, and will be removed from the interaction " f"subset." ) if use_significant: for target in target_subset: parent_dir = os.path.dirname(self.output_path) sig = pd.read_csv( os.path.join(parent_dir, "significance", f"{target}_is_significant.csv"), index_col=0 ) all_coeffs[target] *= sig if effect_threshold is not None: for target in target_subset: all_coeffs[target] = all_coeffs[target].clip(lower=effect_threshold) # For each feature-target combination, compute the mean effect across cells: combinations = list(product(target_subset, feature_names)) combinations = [ (target, feature) for target, feature in combinations if f"b_{feature}" in all_coeffs[target].columns ] # Remove combinations where the effect is hardly present (arbitrarily defined at 0.5% of cells): combinations = [ f"{target};{feature}" for target, feature in combinations if (all_coeffs[target][f"b_{feature}"] != 0).mean() >= 0.005 ] mean_effect = pd.Series(index=combinations) for combo in combinations: target, feature = combo.split(";") target_coefs = all_coeffs[target][f"b_{feature}"] mean_effect[combo] = target_coefs.mean() # For each cell, compute the fold change over the average for each combination: all_fc = pd.DataFrame(index=self.adata.obs_names, columns=combinations) for combo in tqdm(combinations, desc="Computing fold changes for interaction-target combinations..."): target, feature = combo.split(";") target_coefs = all_coeffs[target][f"b_{feature}"] all_fc[combo] = target_coefs / mean_effect[combo] # Log fold change: all_fc = np.log1p(all_fc) all_fc[np.isnan(all_fc)] = 0 # z-score the fold change values: all_fc = all_fc.apply(scipy.stats.zscore, axis=0) all_fc["pos"] = pos all_fc_coord_sorted = all_fc.sort_values(by="pos") # Mean z-score at each coordinate position: all_fc_coord_sorted = all_fc_coord_sorted.groupby("pos").mean() # Smooth in the case of dropouts: all_fc_coord_sorted = all_fc_coord_sorted.rolling(window_size, center=True, min_periods=1).mean() # For each unique value in 'pos', find the top features with the highest mean z-score top_combinations = all_fc_coord_sorted.apply(lambda x: x.nlargest(30).index.tolist(), axis=1) # Find interesting interaction effects by position- get features that are in the top features for at least # window size positions: consecutive_counts = {feature: 0 for feature in combinations} feats_of_interest = set() for pos in top_combinations.index: for feature in top_combinations[pos]: consecutive_counts[feature] += 1 if consecutive_counts[feature] >= int(window_size * 1.67): feats_of_interest.add(feature) # I am not sure what this is for. It will remove genes encountered before, and consecutive_counts is no longer used here after. # for feature in combinations: # if feature not in top_combinations[pos]: # consecutive_counts[feature] = 0 feats_of_interest = list( feats_of_interest ) # fix for set subscription, set subscription is no longer allowed to_plot = all_fc_coord_sorted[feats_of_interest] if to_plot.index.is_numeric(): # Minmax scale to normalize positional context: to_plot.index = (to_plot.index - to_plot.index.min()) / (to_plot.index.max() - to_plot.index.min()) to_plot = to_plot.T # so that the features are labeled along the y-axis if sort_by_target: logger.info("Sorting by target gene...") to_plot["temp"] = to_plot.index.to_series().apply(lambda x: x.split("-")[0]) to_plot = to_plot.sort_values(by="temp") to_plot = to_plot.drop(columns="temp") # Sort by "heat" if applicable (i.e. in order roughly determined by how early along the relative position # the highest z-scores occur in for each interaction-target pair): elif neatly_arrange_y: logger.info("Sorting by position of enrichment along axis...") column_indices = np.tile(np.arange(len(to_plot.columns)), (len(to_plot), 1)) # Column indices array # Look only at the indices corresponding to the highest changes: percentile_95 = to_plot.apply( lambda row: np.percentile(row[row > 0], 95) if row[row > 0].size > 0 else 0, axis=1 ) # Create a DataFrame that replicates the shape of to_plot weights_matrix = to_plot.gt(percentile_95, axis=0) * to_plot weighted_sum = np.sum(weights_matrix.values * column_indices, axis=1) total_weight = np.sum(weights_matrix.values, axis=1) weighted_avg = pd.Series(np.where(total_weight != 0, weighted_sum / total_weight, 0), index=to_plot.index) top_cols_sorted = weighted_avg.sort_values().index to_plot = to_plot.loc[top_cols_sorted] else: logger.info("Sorting by interaction...") to_plot["temp"] = to_plot.index.to_series().apply(lambda x: x.split("-")[-1]) to_plot = to_plot.sort_values(by="temp") to_plot = to_plot.drop(columns="temp") flattened = to_plot.values.flatten() flattened_series = pd.Series(flattened) percentile_95 = flattened_series.quantile(0.95) max_val = percentile_95 if figsize is None: m = len(to_plot) * 40 / 200 n = 8 figsize = (n, m) fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize) if fontsize is None: fontsize = rcParams.get("font.size") # Format the numerical columns for the plot: # First check whether the columns contain duplicates: if all(isinstance(name, str) for name in to_plot.columns): if any([name.count(".") > 1 for name in to_plot.columns]): to_plot.columns = [".".join(name.split(".")[:2]) for name in to_plot.columns] to_plot.columns = [float(col) for col in to_plot.columns] col_series = pd.Series(to_plot.columns) if set(col_series) != len(col_series): unique_values, counts = np.unique(col_series, return_counts=True) # Iterate through unique values for value, count in zip(unique_values, counts): if count > 1: # Find indices of the repeated value indices = col_series[col_series == value].index # Calculate step size if value == unique_values[-1]: next_value = value + (value - unique_values[-2]) else: next_index = np.where(unique_values == value)[0][0] + 1 next_value = unique_values[next_index] step = (next_value - value) / count # Update the values for i in range(count): col_series.iloc[indices[i]] = value + step * i to_plot.columns = col_series.values if all(isinstance(name, float) for name in to_plot.columns): to_plot.columns = [f"{float(col):.3f}" for col in to_plot.columns] to_plot.columns = [str(col) for col in to_plot.columns] to_plot.index = [i.replace(":", "-") for i in to_plot.index] if not sort_by_target and not neatly_arrange_y: to_plot["sort"] = to_plot.index.to_series().apply(lambda x: x.split("-")[1]) to_plot = to_plot.sort_values(by="sort") to_plot = to_plot.drop("sort", axis=1) m = sns.heatmap(to_plot, vmin=-max_val, vmax=max_val, ax=ax, cmap=cmap) cbar = m.collections[0].colorbar cbar.set_label("Z-score", fontsize=fontsize * 1.5, labelpad=10) # Adjust colorbar tick font size cbar.ax.tick_params(labelsize=fontsize * 1.25) cbar.ax.set_aspect(np.min([len(to_plot), 70])) ax.set_xlabel(x_label, fontsize=fontsize * 1.25) ax.set_ylabel("Interaction Effect on Target (formatted target-interaction)", fontsize=fontsize * 1.25) ax.tick_params(axis="x", labelsize=fontsize) ax.tick_params(axis="y", labelsize=fontsize) ax.set_title(title, fontsize=fontsize * 1.5, pad=20) if not os.path.exists(df_path): to_plot.to_csv(df_path) if save_show_or_return in ["save", "both", "all"]: save_kwargs["ext"] = "png" save_kwargs["dpi"] = 300 if "figure_folder" in locals(): save_kwargs["path"] = figure_folder # Save figure: save_return_show_fig_utils( save_show_or_return=save_show_or_return, show_legend=False, background="white", prefix=f"distribution_interaction_effects_along_{save_id}", save_kwargs=save_kwargs, total_panels=1, fig=fig, axes=ax, return_all=False, return_all_list=None, )
[docs] def effect_distribution_density( self, effect_names: List[str], position_key: str = "spatial", coord_column: Optional[Union[int, str]] = None, max_coord_val: float = 1.0, title: Optional[str] = None, x_label: Optional[str] = None, region_lower_bound: Optional[float] = None, region_upper_bound: Optional[float] = None, region_label: Optional[str] = None, fontsize: Union[None, int] = None, figsize: Union[None, Tuple[float, float]] = None, save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", save_kwargs: Optional[dict] = {}, ): """Visualize the spatial enrichment of cell-cell interaction effects using density plots over spatial coordinates. Uses existing dataframe saved by :func:`effect_distribution_heatmap()`, which must be run first. Args: effect_names: List of interaction effects to include in plot, in format "Target-Ligand:Receptor" (for L:R models) or "Target-Ligand" (for ligand models). position_key: Key in adata.obs or adata.obsm that provides a relative indication of the position of cells. i.e. spatial coordinates. Defaults to "spatial". For each value in the position array (each coordinate, each category), multiple cells must have the same value. coord_column: Optional, only used if "position_key" points to an entry in .obsm. In this case, this is the index or name of the column to be used to provide the positional context. Can also provide "xy", "yz", "xz", "-xy", "-yz", "-xz" to draw a line between the two coordinate axes. "xy" will extend the new axis in the direction of increasing x and increasing y starting from x=0 and y=0 (or min. x/min. y), "-xy" will extend the new axis in the direction of decreasing x and increasing y starting from x=minimum x and y=maximum y, and so on. max_coord_val: Optional, can be used to adjust the numbers displayed along the x-axis for the relative position along the coordinate axis. Defaults to 1.0. title: Optional, can be used to provide title for plot x_label: Optional, can be used to provide x-axis label for plot region_lower_bound: Optional, can be used to provide a lower bound for the region of interest to label on the plot- this can correspond to a spatial domain, etc. region_upper_bound: Optional, can be used to provide an upper bound for the region of interest to label on the plot- this can correspond to a spatial domain, etc. region_label: Optional, can be used to provide a label for the region of interest to label on the plot fontsize: Size of font for x and y labels. figsize: Size of figure. cmap: Colormap to use. Options: Any divergent matplotlib colormap. save_show_or_return: Whether to save, show or return the figure. If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, displayed and the associated axis and other object will be return. save_kwargs: A dictionary that will passed to the save_fig function. By default it is an empty dictionary and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those keys according to your needs. """ logger = lm.get_main_logger() if position_key not in self.adata.obsm.keys() and position_key not in self.adata.obs.keys(): raise ValueError( f"Position key {position_key} not found in adata.obsm or adata.obs. Please provide a valid key." ) if position_key in self.adata.obsm.keys(): if coord_column in ["xy", "yz", "xz", "-xy", "-yz", "-xz"]: if title is None: title = f"Signaling effect density along {coord_column} axis" if x_label is None: x_label = f"Relative position along custom {coord_column} axis" save_id = f"{coord_column}_axis" else: if coord_column is not None and not isinstance(coord_column, int): if not isinstance(self.adata.obsm[position_key], pd.DataFrame): raise ValueError( f"Array stored at position key {position_key} has no column names; provide the column " f"index." ) elif coord_column is not None and isinstance(coord_column, int): if isinstance(self.adata.obsm[position_key], pd.DataFrame): if x_label is None: x_label = f"Relative position along {coord_column}" if title is None: title = f"Signaling effect density along {coord_column}" save_id = coord_column else: if coord_column == 0: if x_label is None: x_label = "Relative position along X" if title is None: title = "Signaling effect density along X" save_id = "x_axis" elif coord_column == 1: if x_label is None: x_label = "Relative position along Y" if title is None: title = "Signaling effect density along Y" save_id = "y_axis" elif coord_column == 2: if x_label is None: x_label = "Relative position along Z" if title is None: title = "Signaling effect density along Z" save_id = "z_axis" elif self.adata.obsm[position_key].shape[1] != 1: raise ValueError( f"Array stored at position key {position_key} has more than one column; provide the column " f"index." ) else: if x_label is None: x_label = "Relative position" if title is None: title = f"Signaling effect density along axis given by {position_key} key" save_id = position_key else: if x_label is None: x_label = "Relative position" if title is None: title = f"Signaling effect density along axis given by {position_key} key" save_id = position_key # Check for existing dataframe: output_folder = os.path.join(os.path.dirname(self.output_path), "analyses") # Use the saved name for the AnnData object to define part of the name of the saved file: base_name = os.path.basename(self.adata_path) adata_id = os.path.splitext(base_name)[0] if not os.path.exists( os.path.join(output_folder, f"{adata_id}_distribution_interaction_effects_along_{save_id}.csv") ): raise ValueError( f"Could not find dataframe saved by effect_distribution_heatmap() for position key {position_key}. " f"Please run effect_distribution_heatmap() before running this function." ) to_plot = pd.read_csv( os.path.join(output_folder, f"{adata_id}_distribution_interaction_effects_along_{save_id}.csv"), index_col=0, ) # Format the numerical columns for the plot: # First check whether the columns contain duplicates: if all(isinstance(name, str) for name in to_plot.columns): if any([name.count(".") > 1 for name in to_plot.columns]): to_plot.columns = [".".join(name.split(".")[:2]) for name in to_plot.columns] to_plot.columns = [float(col) for col in to_plot.columns] col_series = pd.Series(to_plot.columns) if set(col_series) != len(col_series): unique_values, counts = np.unique(col_series, return_counts=True) # Iterate through unique values for value, count in zip(unique_values, counts): if count > 1: # Find indices of the repeated value indices = col_series[col_series == value].index # Calculate step size if value == unique_values[-1]: next_value = value + (value - unique_values[-2]) else: next_index = np.where(unique_values == value)[0][0] + 1 next_value = unique_values[next_index] step = (next_value - value) / count # Update the values for i in range(count): col_series.iloc[indices[i]] = value + step * i to_plot.columns = col_series.values if all(isinstance(name, float) for name in to_plot.columns): to_plot.columns = [f"{float(col):.3f}" for col in to_plot.columns] # Normalize to custom max value if desired: float_columns = [float(col) for col in to_plot.columns] current_min = min(float_columns) current_max = max(float_columns) normalized_columns = [ (col - current_min) / (current_max - current_min) * max_coord_val for col in float_columns ] to_plot.columns = [f"{col:.3f}" for col in normalized_columns] # Rearrange dataframe such that each interaction is its own column: to_plot = to_plot.T if not pd.api.types.is_numeric_dtype(to_plot.index): to_plot.index = pd.to_numeric(to_plot.index) # For this function, weights cannot be negative, so set all negative values to 0: to_plot[to_plot < 0] = 0 to_plot["Coord"] = to_plot.index # Check if any inputs are not included in the dataframe: missing = [name for name in effect_names if name not in to_plot.columns] if len(missing) > 0: logger.warning( f"Interactions {missing} were not found in the dataframe. They will be removed from the plot." ) effect_names = [name for name in effect_names if name in to_plot.columns] if figsize is None: figsize = (8, 6) fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize) if fontsize is None: fontsize = rcParams.get("font.size") sns.set_style("white") for effect, color in zip(effect_names, godsnot_102): sns.kdeplot(x="Coord", weights=effect, data=to_plot, color=color, label=effect, lw=2, ax=ax) if region_lower_bound is not None and region_upper_bound is not None: width = region_upper_bound - region_lower_bound region_box = mpl.patches.Rectangle( (region_lower_bound, ax.get_ylim()[0]), width, ax.get_ylim()[1] - ax.get_ylim()[0], linewidth=1, edgecolor="#1CE6FF", facecolor="#1CE6FF", alpha=0.2, ) ax.add_patch(region_box) region_box_legend = mpl.patches.Patch(color="#1CE6FF", alpha=0.2, label=region_label) handles, labels = ax.get_legend_handles_labels() handles.append(region_box_legend) labels.append(region_label) ax.legend(handles=handles, labels=labels, loc="upper left", bbox_to_anchor=(1, 1), fontsize=fontsize * 1.25) else: ax.legend(loc="upper left", bbox_to_anchor=(1, 1), fontsize=fontsize * 1.25) ax.set_xlabel(x_label, fontsize=fontsize * 1.25) ax.set_ylabel("Density", fontsize=fontsize * 1.25) ax.tick_params(axis="x", labelsize=fontsize) ax.tick_params(axis="y", labelsize=fontsize, labelleft=False, left=False) ax.set_title(title, fontsize=fontsize * 1.5, pad=20) if save_show_or_return in ["save", "both", "all"]: save_kwargs["ext"] = "png" save_kwargs["dpi"] = 300 if not os.path.exists(os.path.join(os.path.dirname(self.output_path), "figures")): os.makedirs(os.path.join(os.path.dirname(self.output_path), "figures")) figure_folder = os.path.join(os.path.dirname(self.output_path), "figures", "temp") if not os.path.exists(figure_folder): os.makedirs(figure_folder) save_kwargs["path"] = figure_folder # Save figure: save_return_show_fig_utils( save_show_or_return=save_show_or_return, show_legend=True, background="white", prefix=f"density_interaction_effects_along_{save_id}", save_kwargs=save_kwargs, total_panels=1, fig=fig, axes=ax, return_all=False, return_all_list=None, )
[docs] def visualize_effect_specificity( self, agg_method: Literal["mean", "percentage"] = "mean", plot_type: Literal["heatmap", "volcano"] = "heatmap", target_subset: Optional[List[str]] = None, interaction_subset: Optional[List[str]] = None, ct_subset: Optional[List[str]] = None, group_key: Optional[str] = None, n_anchors: Optional[int] = None, effect_threshold: Optional[float] = None, use_significant: bool = False, target_cooccurrence_threshold: float = 0.1, significance_cutoff: float = 1.3, fold_change_cutoff: float = 1.5, fold_change_cutoff_for_labels: float = 3.0, fontsize: Union[None, int] = None, figsize: Union[None, Tuple[float, float]] = None, cmap: str = "seismic", save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", save_kwargs: Optional[dict] = {}, save_df: bool = False, ): """Computes and visualizes the specificity of each interaction on each target. This is done by first separating the target-expressing cells (and their neighbors) from the rest of the cells (conditioned on predicted effect and also conditioned on receptor expression if L:R model is used). Then, computing the fold change of the average expression of the ligand in the neighborhood of the first subset vs. the neighborhoods of the second subset. Args: agg_method: Method to use for aggregating the specificity of each interaction on each target. Options: "mean" for mean ligand expression, "percentage" for the percentage of cells expressing the ligand. plot_type: Type of plot to use for visualization. Options: "heatmap" for heatmap, "volcano" for volcano plot. target_subset: List of targets to consider. If None, will use all targets used in model fitting. interaction_subset: List of interactions to consider. If None, will use all interactions used in model. ct_subset: Can be used to constrain the first group of cells (the query group) to the target-expressing cells of a particular type (conditioned on any other relevant variables). If given, will search for cell types in "group_key" attribute from model initialization. If not given, will use all cell types. group_key: Can be used to specify entry in adata.obs that contains cell type groupings. If None, will use :attr `group_key` from model initialization. n_anchors: Optional, number of target gene-expressing cells to use as anchors for analysis. Will be selected randomly from the set of target gene-expressing cells (conditioned on any other relevant values). effect_threshold: Optional threshold minimum effect size to consider an effect for further analysis, as an absolute value. Use this to choose only the cells for which an interaction is predicted to have a strong effect. If None, use the median interaction effect. use_significant: Whether to use only significant effects in computing the specificity. If True, will filter to cells + interactions where the interaction is significant for the target. Only valid if :func `compute_coeff_significance()` has been run. significance_cutoff: Cutoff for negative log-10 q-value to consider an interaction/effect significant. Only used if "plot_type" is "volcano". Defaults to 1.3 (corresponding to an approximate q-value of 0.05). fold_change_cutoff: Cutoff for fold change to consider an interaction/effect significant. Only used if "plot_type" is "volcano". Defaults to 1.5. fold_change_cutoff_for_labels: Cutoff for fold change to include the label for an interaction/effect. Only used if "plot_type" is "volcano". Defaults to 3.0. fontsize: Size of font for x and y labels. figsize: Size of figure. cmap: Colormap to use. Options: Any divergent matplotlib colormap. save_show_or_return: Whether to save, show or return the figure. If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, displayed and the associated axis and other object will be return. save_kwargs: A dictionary that will passed to the save_fig function. By default it is an empty dictionary and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those keys according to your needs. save_df: Set True to save the metric dataframe in the end """ from ..find_neighbors import neighbors logger = lm.get_main_logger() config_spateo_rcParams() # But set display DPI to 300: plt.rcParams["figure.dpi"] = 300 if self.mod_type != "lr" and self.mod_type != "ligand": raise ValueError("This function is only applicable for ligand-based models.") if save_show_or_return in ["save", "both", "all"]: if not os.path.exists(os.path.join(os.path.dirname(self.output_path), "figures")): os.makedirs(os.path.join(os.path.dirname(self.output_path), "figures")) figure_folder = os.path.join(os.path.dirname(self.output_path), "figures", "temp") if not os.path.exists(figure_folder): os.makedirs(figure_folder) if save_df: output_folder = os.path.join(os.path.dirname(self.output_path), "analyses") if not os.path.exists(output_folder): os.makedirs(output_folder) # Use the saved name for the AnnData object to define part of the name of the saved file: base_name = os.path.basename(self.adata_path) adata_id = os.path.splitext(base_name)[0] # Colormap should be divergent: divergent_cmaps = [ "seismic", "coolwarm", "bwr", "RdBu", "RdGy", "PuOr", "PiYG", "PRGn", "BrBG", "RdYlBu", "RdYlGn", "Spectral", ] if cmap not in divergent_cmaps: logger.warning( f"Colormap {cmap} is not divergent, which is recommended for this plot type. Using 'seismic' instead." ) cmap = "seismic" if target_subset is None: target_subset = list(self.coeffs.keys()) else: target_subset = [t for t in target_subset if t in self.coeffs.keys()] removed = [t for t in target_subset if t not in self.coeffs.keys()] if len(removed) > 0: logger.warning( f"Targets {removed} were not found in the model, and will be removed from the target subset." ) all_feature_names = [feat for feat in self.feature_names if feat != "intercept"] if interaction_subset is None: feature_names = all_feature_names else: feature_names = [feat for feat in all_feature_names if feat in interaction_subset] removed = [feat for feat in interaction_subset if feat not in all_feature_names] if len(removed) > 0: logger.warning( f"Interactions {removed} were not found in the model, and will be removed from the interaction " f"subset." ) if ct_subset is not None and group_key is None: group_key = self.group_key if fontsize is None: fontsize = rcParams.get("font.size") if plot_type == "heatmap": x_label = "Neighboring Ligand" if self.mod_type == "ligand" else "L:R Interaction" y_label = "Target Gene" title = "Fold Change Interaction Enrichment \n Target-Expressing Cells vs. Others" cbar_label = "$\\log_2$(Fold change Interaction Enrichment \n Target-Expressing Cells vs. Others" else: x_label = "$\\log_2$(Fold change Interaction Enrichment \n Target-Expressing Cells vs. Others" y_label = r"$-log_10$(qval)" title = "Fold Change Interaction Enrichment \n Target-Expressing Cells vs. Others" # Check for already-existing dataframe: try: output_folder = os.path.join(os.path.dirname(self.output_path), "analyses") df = pd.read_csv( os.path.join( os.path.dirname(self.output_path), "analyses", f"{plot_type}_{adata_id}_interaction_enrichment_fold_change_target_expressing_v_nonexpressing.csv", ), index_col=0, ) if interaction_subset is not None: df = df.loc[[i for i in df.index if i.split(":")[0] in interaction_subset]] if target_subset is not None: df = df.loc[[i for i in df.index if i.split(":")[1] in target_subset]] except: if plot_type == "heatmap": df = pd.DataFrame(index=target_subset, columns=feature_names) else: combinations = product(feature_names, target_subset) combinations = [f"{feature}-{target}" for feature, target in combinations] df = pd.DataFrame( index=combinations, columns=["log2FC", "p-value", "q-value", "Significance", "-log10(qval)"] ) if ( "spatial_connectivities_secreted" in self.adata.obsp.keys() and "spatial_connectivities_membrane_bound" in self.adata.obsp.keys() ): conn_secreted = self.adata.obsp["spatial_connectivities_secreted"] conn_membrane_bound = self.adata.obsp["spatial_connectivities_membrane_bound"] else: logger.info("Spatial graph not found, computing...") adata = self.adata.copy() _, adata = neighbors( adata, n_neighbors=self.n_neighbors_secreted, basis="spatial", spatial_key=self.coords_key, n_neighbors_method="ball_tree", ) conn_secreted = adata.obsp["spatial_connectivities"] adata = self.adata.copy() _, adata = neighbors( adata, n_neighbors=self.n_neighbors_membrane_bound, basis="spatial", spatial_key=self.coords_key, n_neighbors_method="ball_tree", ) conn_membrane_bound = adata.obsp["spatial_connectivities"] self.adata.obsp["spatial_connectivities_secreted"] = conn_secreted self.adata.obsp["spatial_connectivities_membrane_bound"] = conn_membrane_bound # For each target, split cells into two groups: target-expressing and all neighbors of target-expressing # cells, and the remainder. for target in target_subset: coef_target = self.coeffs[target].loc[adata.obs_names] if effect_threshold is None: nonzero_values = coef_target.values.flatten() nonzero_values = nonzero_values[nonzero_values != 0] effect_threshold = pd.Series(nonzero_values).quantile(0.75) if use_significant: parent_dir = os.path.dirname(self.output_path) sig = pd.read_csv( os.path.join(parent_dir, "significance", f"{target}_is_significant.csv"), index_col=0 ) coef_target *= sig # Taking the first group (the query group)- first subset to cell types of interest, if given: if ct_subset is not None: query_adata = self.adata[self.adata.obs[group_key].isin(ct_subset)].copy() else: query_adata = self.adata.copy() # Define masks for target expression: target_expression = query_adata[:, target].X.toarray().reshape(-1) target_expressing_mask = target_expression > 0 target_expressing_cells = query_adata.obs_names[target_expressing_mask] # Define interaction-specific masks: (optionally, if L:R model) for cells expressing receptor, # for cells predicted to be affected by an interaction and all of the neighbors of these cells: for interaction in feature_names: if f"b_{interaction}" not in coef_target.columns: # Significance for this interaction-target combination: if plot_type == "volcano": df.loc[f"{interaction}-{target}", "p-value"] = 1.0 df.loc[f"{interaction}-{target}", "log2FC"] = 0.0 else: df.loc[target, interaction] = 0.0 continue if self.mod_type == "lr": ligand, receptor = interaction.split(":") receptor_expressing_mask = np.ones(query_adata.shape[0], dtype=bool) for r in receptor.split("_"): receptor_expression = query_adata[:, r].X.toarray().reshape(-1) receptor_expressing_mask &= receptor_expression > 0 receptor_expressing_cells = query_adata.obs_names[receptor_expressing_mask] coef_interaction_target = coef_target[f"b_{interaction}"] coef_interaction_target_mask = coef_interaction_target > effect_threshold coef_interaction_target_cells = query_adata.obs_names[coef_interaction_target_mask] # Get mask for any neighbors of these cells: # Check whether to use the neighbors found for membrane-bound interaction or those for secreted # interactions: to_check = interaction.split(":")[0] if ":" in interaction else interaction if "/" in to_check: interaction_components = to_check.split("/") separator = "/" elif "_" in to_check: interaction_components = to_check.split("_") separator = "_" else: interaction_components = [to_check] separator = None matching_rows = self.lr_db[self.lr_db["from"].isin(interaction_components)] if ( matching_rows["type"].str.contains("Secreted Signaling").any() or matching_rows["type"].str.contains("ECM-Receptor").any() ): conn = conn_secreted else: conn = conn_membrane_bound if self.mod_type != "lr": # Get the intersection of cells expressing target and predicted to be affected by the interaction: adata_mask = target_expressing_cells.intersection(coef_interaction_target_cells) else: # Get the intersection of cells expressing target and receptor and are predicted to be affected by # interaction: adata_mask = target_expressing_cells.intersection(receptor_expressing_cells) adata_mask = adata_mask.intersection(coef_interaction_target_cells) # This object contains samples that can constitute the query group: query_adata_sub = query_adata[adata_mask].copy() # This object contains the other samples, that can constitute the reference: neg_mask = [ n for n in self.adata.obs_names if n not in target_expressing_cells and n not in coef_interaction_target_cells ] reference_adata_sub = self.adata[neg_mask].copy() if query_adata_sub.n_obs <= 30: logger.info( f"Insufficient query cells found for this interaction-target combination (likely based on " f"absence of strong interaction effect)- {interaction}-{target}. Skipping." ) del conn, query_adata_sub, reference_adata_sub gc.collect() if plot_type == "volcano": df.loc[f"{interaction}-{target}", "p-value"] = 1 df.loc[f"{interaction}-{target}", "log2FC"] = 0.0 else: df.loc[target, interaction] = 0.0 continue # Query group: # If applicable, select a subset of these cells to use as anchors: if n_anchors is not None: if query_adata_sub.n_obs < n_anchors: logger.warning( f"Number of anchors ({n_anchors}) is greater than number of target-expressing cells " f"({query_adata_sub.n_obs}) for target {target} and interaction {interaction}. " f"Skipping." ) del conn, query_adata_sub, reference_adata_sub gc.collect() if plot_type == "volcano": df.loc[f"{interaction}-{target}", "p-value"] = 1 df.loc[f"{interaction}-{target}", "log2FC"] = 0.0 else: df.loc[target, interaction] = 0.0 continue else: if query_adata_sub.n_obs < 200: logger.warning( f"Number of target-expressing cells ({query_adata_sub.n_obs}) is less than 100 for " f"target {target} and interaction {interaction}. Skipping." ) del conn, query_adata_sub, reference_adata_sub gc.collect() if plot_type == "volcano": df.loc[f"{interaction}-{target}", "p-value"] = 1 df.loc[f"{interaction}-{target}", "log2FC"] = 0.0 else: df.loc[target, interaction] = 0.0 continue anchors = np.random.choice(query_adata_sub.obs_names, size=n_anchors, replace=False) selected_indices = [np.where(self.adata.obs_names == string)[0][0] for string in anchors] # Get neighbors of these cells: neighbors = conn[selected_indices].nonzero()[1] neighbors = np.unique(neighbors) # Remove the anchor cells from the neighbors: neighbors = neighbors[~np.isin(neighbors, selected_indices)] neighbors_selected = self.adata.obs_names[neighbors] # The query group: anchors and their neighbors: query_group = anchors.tolist() + neighbors_selected.tolist() # Reference group: # If applicable, select a subset of these cells to use as anchors: anchors = np.random.choice(reference_adata_sub.obs_names, size=n_anchors, replace=False) selected_indices = [np.where(self.adata.obs_names == string)[0][0] for string in anchors] # Get neighbors of these cells: neighbors = conn[selected_indices].nonzero()[1] neighbors = np.unique(neighbors) # Remove the anchor cells from the neighbors: neighbors = neighbors[~np.isin(neighbors, selected_indices)] neighbors_selected = self.adata.obs_names[neighbors] # The reference group: anchors and their neighbors: reference_group = anchors.tolist() + neighbors_selected.tolist() # Ligand expression in the selected cells: ligand = interaction.split(":")[0] if ":" in interaction else interaction components = ligand.split(separator) if separator is not None else [ligand] # Compute ligand values for query + reference together before separating them for fold change # calculation: ligand_values = self.adata[query_group + reference_group, components].X.toarray() if separator == "/": # Arithmetic mean of the genes ligand_values = np.mean(ligand_values, axis=1) elif separator == "_": # Geometric mean of the genes # Replace zeros with np.nan ligand_values[ligand_values == 0] = np.nan # Compute product along the rows products = np.nanprod(ligand_values, axis=1) # Count non-nan values in each row for nth root calculation non_nan_counts = np.sum(~np.isnan(ligand_values), axis=1) # Avoid division by zero non_nan_counts[non_nan_counts == 0] = np.nan ligand_values = np.power(products, 1 / non_nan_counts) ligand_values = pd.DataFrame(ligand_values, index=query_group + reference_group, columns=[ligand]) ligand_query = ligand_values.loc[query_group, :] ligand_reference = ligand_values.loc[reference_group, :] # Significance for this interaction-target combination: if plot_type == "volcano": if (ligand_reference == 0).all().all(): df.loc[f"{interaction}-{target}", "p-value"] = 0 else: df.loc[f"{interaction}-{target}", "p-value"] = mannwhitneyu(ligand_query, ligand_reference)[ 1 ] if agg_method == "mean": ligand_query = ligand_query.mean().values ligand_reference = ligand_reference.mean().values elif agg_method == "percentage": ligand_query = (ligand_query > 0).mean().values ligand_reference = (ligand_reference > 0).mean().values if ligand_reference == 0: # Prevent division by zero, this will get set to the max threshold anyways: ligand_reference = 0.001 fold_change = np.log2(ligand_query / ligand_reference) if plot_type == "volcano": df.loc[f"{interaction}-{target}", "log2FC"] = fold_change else: df.loc[target, interaction] = fold_change del conn, query_adata_sub, reference_adata_sub gc.collect() logger.info(f"Finished computing specificity for target {target}.") # If relevant, compute adjusted p-values: if plot_type == "volcano": df["log2FC"] = df["log2FC"].apply(lambda x: x[0] if isinstance(x, np.ndarray) else x) df["p-value"] = df["p-value"].apply(lambda x: x[0] if isinstance(x, np.ndarray) else x) df["q-value"] = multitesting_correction(df["p-value"].values, method="fdr_bh") df["Significance"] = df["q-value"] < 0.05 df["-log10(qval)"] = -np.log10(df["q-value"]) # And if relevant, perform hierarchical clustering- first to group interactions w/ similar fold changes # across targets: if plot_type == "heatmap": col_linkage = sch.linkage(df.transpose(), method="ward") col_dendro = sch.dendrogram(col_linkage, no_plot=True) col_clustered_order = col_dendro["leaves"] df = df.iloc[:, col_clustered_order] # Then to group targets w/ similar fold changes across interactions: row_linkage = sch.linkage(df, method="ward") row_dendro = sch.dendrogram(row_linkage, no_plot=True) row_clustered_order = row_dendro["leaves"] df = df.iloc[row_clustered_order, :] # Plot: if figsize is None: if plot_type == "heatmap": # Set figure size based on the number of interaction features and targets: m = len(target_subset) * 50 / 200 n = len(feature_names) * 50 / 200 else: m = 6 n = 6 figsize = (n, m) fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize) cmap = mpl.colormaps[cmap] # Center colormap at 0 for heatmap: if plot_type == "heatmap": max_distance = max(abs(df.max().max()), abs(df.min().min())) norm = plt.Normalize(-max_distance, max_distance) colors = cmap(norm(df)) custom_cmap = mpl.colors.LinearSegmentedColormap.from_list("custom", colors) else: max_distance = max(abs(df["log2FC"].max()), abs(df["log2FC"].min())) if plot_type == "volcano": if len(df) > 20: size = 20 else: size = 40 significant = df["-log10(qval)"] > significance_cutoff significant_up = df["log2FC"] > fold_change_cutoff significant_down = df["log2FC"] < -fold_change_cutoff # Check if max -log10(qval) is greater than 8 if df["-log10(qval)"].max() > 8: ax.set_yscale("log", base=2) # Set y-axis to log y_label = r"$-log_10$(qval) ($log_2$ scale)" sns.scatterplot( x=df["log2FC"][significant & significant_up], y=df["-log10(qval)"][significant & significant_up], hue=df["log2FC"][significant & significant_up], palette="Reds", vmin=0, edgecolor="black", ax=ax, s=size, legend=False, ) sns.scatterplot( x=df["log2FC"][significant & significant_down], y=df["-log10(qval)"][significant & significant_down], hue=df["log2FC"][significant & significant_down], palette="Blues_r", vmax=0, edgecolor="black", ax=ax, s=size, legend=False, ) sns.scatterplot( x=df["log2FC"][~(significant & (significant_up | significant_down))], y=df["-log10(qval)"][~(significant & (significant_up | significant_down))], color="grey", edgecolor="black", ax=ax, s=size, ) # Add labels for significant interactions: high_fold_change = df[abs(df["log2FC"]) > fold_change_cutoff_for_labels] while high_fold_change.empty: fold_change_cutoff_for_labels /= 2 # Halve the cutoff high_fold_change = df[abs(df["log2FC"]) > fold_change_cutoff_for_labels] text_labels = high_fold_change.index.tolist() x_coord_text_labels = high_fold_change["log2FC"].tolist() y_coord_text_labels = high_fold_change["-log10(qval)"].tolist() text_objects = [] for i, label in enumerate(text_labels): t = ax.text( x_coord_text_labels[i], y_coord_text_labels[i], label, fontsize=fontsize * 0.75, color="black", ha="center", va="center", ) text_objects.append(t) adjust_text(text_objects, ax=ax, arrowprops=dict(arrowstyle="-", color="black", lw=0.5)) ax.axhline(y=significance_cutoff, color="grey", linestyle="--", linewidth=1.5) ax.axvline(x=fold_change_cutoff, color="grey", linestyle="--", linewidth=1.5) ax.axvline(x=-fold_change_cutoff, color="grey", linestyle="--", linewidth=1.5) ax.set_xlim(df["log2FC"].min() - 0.2, max_distance + 0.2) ax.set_xticklabels(["{:.2f}".format(x) for x in ax.get_xticks()], fontsize=fontsize) ax.set_yticklabels(["{:.2f}".format(y) for y in ax.get_yticks()], fontsize=fontsize) ax.set_xlabel(x_label, fontsize=fontsize * 1.25) ax.set_ylabel(y_label, fontsize=fontsize * 1.25) ax.set_title(title, fontsize=fontsize * 1.5) prefix = "volcano" elif plot_type == "heatmap": vmin = -max_distance vmax = max_distance thickness = 0.3 * figsize[0] / 10 mask = np.abs(df) < 0.1 m = sns.heatmap( df, square=True, linecolor="grey", linewidths=thickness, cbar_kws={"label": cbar_label, "location": "top", "pad": 0}, cmap=custom_cmap, vmin=vmin, vmax=vmax, mask=mask, ax=ax, ) # Outer frame: for _, spine in m.spines.items(): spine.set_visible(True) spine.set_linewidth(thickness * 2.5) # Adjust colorbar settings: divider = make_axes_locatable(ax) # Append axes to the top of the plot, where the colorbar will be placed if df.shape[0] > df.shape[1]: cax = divider.append_axes("top", size="30%", pad=0) else: cax = divider.append_axes("top", size="30%", pad="30%") # Create the colorbar manually in the appended axes cbar = plt.colorbar(m.collections[0], cax=cax, orientation="horizontal") cbar.set_label(cbar_label.title(), fontsize=fontsize * 1.5, labelpad=10) cbar.ax.xaxis.set_ticks_position("top") # Move ticks to the top cbar.ax.xaxis.set_label_position("top") # Move the label (title) to the top cbar.ax.tick_params(labelsize=fontsize * 1.5) cbar.ax.set_aspect(0.02) ax.set_xlabel(x_label, fontsize=fontsize * 1.25) ax.set_ylabel("Cell Type-Specific Target", fontsize=fontsize * 1.25) ax.tick_params(axis="x", labelsize=fontsize, rotation=90) ax.tick_params(axis="y", labelsize=fontsize) ax.set_title(title, fontsize=fontsize * 1.5, pad=20) prefix = "heatmap" if save_df: df.to_csv( os.path.join( output_folder, f"{prefix}_{adata_id}_interaction_enrichment_fold_change_target_expressing_v_nonexpressing.csv", ) ) if save_show_or_return in ["save", "both", "all"]: save_kwargs["ext"] = "png" save_kwargs["dpi"] = 300 if "figure_folder" in locals(): save_kwargs["path"] = figure_folder # Save figure: save_return_show_fig_utils( save_show_or_return=save_show_or_return, show_legend=False, background="white", prefix=prefix, save_kwargs=save_kwargs, total_panels=1, fig=fig, axes=ax, return_all=False, return_all_list=None, )
[docs] def visualize_neighborhood( self, target: str, interaction: str, interaction_type: Literal["secreted", "membrane-bound"], select_examples_criterion: Literal["positive", "negative"] = "positive", effect_threshold: Optional[float] = None, cell_type: Optional[str] = None, group_key: Optional[str] = None, use_significant: bool = False, n_anchors: int = 100, n_neighbors_expressing: int = 20, display_plot: bool = True, ) -> anndata.AnnData: """Sets up AnnData object for visualization of interaction effects- cells will be colored by expression of the target gene, potentially conditioned on receptor expression, and neighboring cells will be colored by ligand expression. Args: target: Target gene of interest interaction: Interaction feature to visualize, given in the same form as in the design matrix (if model is a ligand-based model or receptor-based model, this will be of form "Col4a1". If model is a ligand-receptor based model, this will be of form "Col4a1:Itgb1", for example). interaction_type: Specifies whether the chosen interaction is secreted or membrane-bound. Options: "secreted" or "membrane-bound". select_examples_criterion: Whether to select cells with positive or negative interaction effects for visualization. Defaults to "positive", which searches for cells for which the predicted interaction effect is above the given threshold. "Negative" will select cells for which the predicted interaction has no effect on the target expression. effect_threshold: Optional threshold for the effect size of an interaction/effect to be considered for analysis; only used if "to_plot" is "percentage". If not given, will use the upper quartile value among all interaction effect values to determine the threshold. cell_type: Optional, can be used to select anchor cells from only a particular cell type. If None, will select from all cells. group_key: Can be used to specify entry in adata.obs that contains cell type groupings. If None, will use :attr `group_key` from model initialization. Only used if "cell_type" is not None. use_significant: Whether to use only significant effects in computing the specificity. If True, will filter to cells + interactions where the interaction is significant for the target. Only valid if :func `compute_coeff_significance()` has been run. n_anchors: Number of target gene-expressing cells to use as anchors for visualization. Will be selected randomly from the set of target gene-expressing cells. n_neighbors_expressing: Filters the set of cells that can be selected as anchors based on the number of their neighbors that express the chosen ligand. Only used for models that incorporate ligand expression. display_plot: Whether to save a plot. If False, will return the AnnData object without doing anything else- this can then be visualized e.g. using spateo-viewer. Returns: adata: Modified AnnData object containing the expression information for the target gene and neighboring ligand expression. """ # Compute connectivity matrix if not already existing- only needed for ligand and L:R models: from ..find_neighbors import neighbors logger = lm.get_main_logger() if display_plot: figure_folder = os.path.join(os.path.dirname(self.output_path), "figures") if not os.path.exists(figure_folder): os.makedirs(figure_folder) path = os.path.join( figure_folder, f"{target}_{select_examples_criterion}_cells_example_neighborhoods_{interaction}.html" ) logger.info(f"Saving plot to {path}") if self.mod_type != "lr" and self.mod_type != "ligand": raise ValueError("This function is only applicable for ligand-based models.") if select_examples_criterion not in ["positive", "negative"]: raise ValueError("Invalid criterion for selecting examples. Options: 'positive', 'negative'.") try: membrane_bound_path = os.path.join( os.path.splitext(self.output_path)[0], "spatial_weights", "spatial_weights_membrane_bound.npz" ) secreted_path = os.path.join( os.path.splitext(self.output_path)[0], "spatial_weights", "spatial_weights_secreted.npz" ) spatial_weights_membrane_bound = scipy.sparse.load_npz(membrane_bound_path) conn_membrane_bound = spatial_weights_membrane_bound > 0 spatial_weights_secreted = scipy.sparse.load_npz(secreted_path) conn_secreted = spatial_weights_secreted > 0 except: if ( "spatial_connectivities_secreted" in self.adata.obsp.keys() and "spatial_connectivities_membrane_bound" in self.adata.obsp.keys() ): conn_secreted = self.adata.obsp["spatial_connectivities_secreted"] conn_membrane_bound = self.adata.obsp["spatial_connectivities_membrane_bound"] else: logger.info("Spatial graph not found, computing...") if interaction_type == "secreted": adata = self.adata.copy() _, adata_secreted = neighbors( adata, n_neighbors=self.n_neighbors_secreted, basis="spatial", spatial_key=self.coords_key, n_neighbors_method="ball_tree", ) conn_secreted = adata_secreted.obsp["spatial_connectivities"] self.adata.obsp["spatial_connectivities_secreted"] = conn_secreted conn = conn_secreted elif interaction_type == "membrane-bound": adata = self.adata.copy() _, adata_membrane_bound = neighbors( adata, n_neighbors=self.n_neighbors_membrane_bound, basis="spatial", spatial_key=self.coords_key, n_neighbors_method="ball_tree", ) conn_membrane_bound = adata_membrane_bound.obsp["spatial_connectivities"] self.adata.obsp["spatial_connectivities_membrane_bound"] = conn_membrane_bound conn = conn_membrane_bound else: raise ValueError("Invalid interaction type. Options: 'secreted', 'membrane-bound'.") if interaction_type == "secreted": conn = conn_secreted elif interaction_type == "membrane-bound": conn = conn_membrane_bound else: raise ValueError("Invalid interaction type. Options: 'secreted', 'membrane-bound'.") adata = self.adata.copy() if cell_type is not None: if group_key is None: group_key = self.group_key # Get the cells of the specified cell type: cell_type_mask = adata.obs[group_key] == cell_type adata_ct = adata[cell_type_mask, :].copy() adata_ct_cells = adata_ct.obs_names coef_target = self.coeffs[target].loc[adata.obs_names] if effect_threshold is None: nonzero_values = coef_target.values.flatten() nonzero_values = nonzero_values[nonzero_values != 0] effect_threshold = pd.Series(nonzero_values).quantile(0.75) if use_significant: parent_dir = os.path.dirname(self.output_path) sig = pd.read_csv(os.path.join(parent_dir, "significance", f"{target}_is_significant.csv"), index_col=0) coef_target *= sig if hasattr(self, "remaining_cells"): adata = adata[self.remaining_cells, :].copy() conn = conn[self.remaining_indices, :][:, self.remaining_indices].copy() # Compute the multiple possible masks that can be used to subset to the cells of interest: # Get the target gene expression: target_expression = adata[:, target].X.toarray().reshape(-1) # Get the interaction effect: interaction_effect = coef_target.loc[adata.obs_names, f"b_{interaction}"].values # Get the cells expressing the target gene: target_expressing_mask = target_expression > 0 target_expressing_cells = adata.obs_names[target_expressing_mask] # Cells with significant interaction effect on target: if select_examples_criterion == "positive": interaction_mask = np.abs(interaction_effect) > effect_threshold else: interaction_mask = interaction_effect == 0 interaction_cells = adata.obs_names[interaction_mask] # If applicable, split the interaction feature and get the ligand and receptor- for features w/ multiple # ligands or multiple receptors, process accordingly: to_check = interaction.split(":")[0] if ":" in interaction else interaction if "/" in to_check: genes = to_check.split("/") separator = "/" elif "_" in to_check: genes = to_check.split("_") separator = "_" else: genes = [to_check] separator = None if separator == "/": # Cells expressing any of the genes ligand_expr_mask = np.zeros(len(adata), dtype=bool) for gene in genes: ligand_expr_mask |= adata[:, gene].X.toarray().squeeze() > 0 elif separator == "_": # Cells expressing all of the genes ligand_expr_mask = np.ones(len(adata), dtype=bool) for gene in genes: ligand_expr_mask &= adata[:, gene].X.toarray().squeeze() > 0 else: # Single gene ligand_expr_mask = adata[:, to_check].X.toarray().squeeze() > 0 # Check how many cells have sufficient number of neighbors expressing the ligand: neighbor_counts = np.zeros(len(adata)) for i in range(len(adata)): # Get neighbors neighbors = conn[i].nonzero()[1] neighbor_counts[i] = np.sum(ligand_expr_mask[neighbors]) # Get the cells with sufficient number of neighbors expressing the ligand: cells_meeting_neighbor_ligand_threshold = adata.obs_names[neighbor_counts > n_neighbors_expressing] if self.mod_type == "lr": to_check = interaction.split(":")[1] if ":" in interaction else interaction if "_" in to_check: genes = to_check.split("_") separator = "_" else: genes = [to_check] separator = None if separator == "_": # Cells expressing all of the genes receptor_expr_mask = np.ones(len(adata), dtype=bool) for gene in genes: receptor_expr_mask &= adata[:, gene].X.toarray().squeeze() > 0 else: # Single gene receptor_expr_mask = adata[:, to_check].X.toarray().squeeze() > 0 # Get the cells expressing the receptor, to further subset the target-expressing cells to also : receptor_expressing_cells = adata.obs_names[receptor_expr_mask] elif self.mod_type == "ligand": # True negative examples will express the target, but not be predicted to be affected by the interaction # and either not have evidence of receptor/TF expression or not have ligand expression in the neighborhood: X_df = pd.read_csv( os.path.join(os.path.splitext(self.output_path)[0], "design_matrix", "design_matrix.csv"), index_col=0 ) if select_examples_criterion == "positive": factor_expr_mask = X_df.loc[adata.obs_names, interaction] > 0 else: factor_expr_mask = X_df.loc[adata.obs_names, interaction] == 0 factor_expr_cells = adata.obs_names[factor_expr_mask] if select_examples_criterion == "positive": if self.mod_type == "lr": # Get the intersection of cells expressing target, predicted to be affected by interaction, # with sufficient number of neighbors expressing the chosen ligand and expressing receptor: adata_mask = ( target_expressing_cells & interaction_cells & cells_meeting_neighbor_ligand_threshold & receptor_expressing_cells ) else: # Get the intersection of cells expressing target, predicted to be affected by interaction, # with sufficient number of neighbors expressing the chosen ligand and expressing the receptor or the # downstream factors of the receptor: adata_mask = ( target_expressing_cells & interaction_cells & cells_meeting_neighbor_ligand_threshold & factor_expr_cells ) else: # In this case, note that "interaction_cells" are actually those cells that are predicted not to be # affected by the interaction (and "factor_expr_cells" are actually those that don't express any of the # key downstream factors or the receptor): adata_mask = target_expressing_cells & interaction_cells & factor_expr_cells adata_sub = adata[adata_mask].copy() if cell_type is not None: adata_sub = adata_sub[adata_sub.obs[group_key] == cell_type].copy() logger.info( f"Randomly selecting {select_examples_criterion} example cells from a pool of {adata_sub.n_obs} for target" f" {target} and interaction {interaction}." ) if adata_sub.n_obs < n_anchors: logger.info( f"Given the constraints, not enough cells remain to choose {n_anchors} cells. Selecting all " f"{adata_sub.n_obs} eligible cells instead." ) n_anchors = min(n_anchors, adata_sub.n_obs) # Randomly choose a subset of target cells to use as anchors: if n_anchors == adata_sub.n_obs: target_expressing_selected = adata_sub.obs_names else: target_expressing_selected = np.random.choice(adata_sub.obs_names, size=n_anchors, replace=False) selected_indices = [np.where(adata.obs_names == string)[0][0] for string in target_expressing_selected] # Find the neighbors of these anchor cells: neighbors = conn[selected_indices].nonzero()[1] neighbors = np.unique(neighbors) # Remove the anchor cells from the neighbors: neighbors = neighbors[~np.isin(neighbors, selected_indices)] neighbors_selected = adata.obs_names[neighbors] # Also make note of the nonselected cells & their neighbors if cell type parameter was given: if cell_type is not None: selected_and_neighbors = target_expressing_selected.tolist() + neighbors_selected.tolist() ct_other_cells = [cell for cell in adata_ct_cells if cell not in selected_and_neighbors] ct_other_indices = [np.where(adata.obs_names == string)[0][0] for string in ct_other_cells] # Target expression in the selected cells: target_expression = adata_sub[target_expressing_selected, target].X.toarray().squeeze() # Ligand expression in the neighbors: ligand = interaction.split(":")[0] if ":" in interaction else interaction genes = ligand.split(separator) if separator is not None else [ligand] gene_values = adata[neighbors_selected, genes].X.toarray() if separator == "/": # Arithmetic mean of the genes ligand_expression = np.mean(gene_values, axis=1) elif separator == "_": # Geometric mean of the genes # Replace zeros with np.nan gene_values[gene_values == 0] = np.nan # Compute product along the rows products = np.nanprod(gene_values, axis=1) # Count non-nan values in each row for nth root calculation non_nan_counts = np.sum(~np.isnan(gene_values), axis=1) # Avoid division by zero non_nan_counts[non_nan_counts == 0] = np.nan ligand_expression = np.power(products, 1 / non_nan_counts) else: ligand_expression = adata[neighbors_selected, ligand].X.toarray().squeeze() adata.obs[f"{interaction}_{target}_{select_examples_criterion}_example_points"] = 0.0 adata.obs.loc[ target_expressing_selected, f"{interaction}_{target}_{select_examples_criterion}_example_points" ] = target_expression adata.obs.loc[ neighbors_selected, f"{interaction}_{target}_{select_examples_criterion}_example_points" ] = ligand_expression if display_plot: # plotly to create 3D scatter plot: spatial_coords = adata.obsm[self.coords_key] if spatial_coords.shape[1] == 2: x, y = spatial_coords[:, 0], spatial_coords[:, 1] z = np.zeros(len(x)) else: x, y, z = spatial_coords[:, 0], spatial_coords[:, 1], spatial_coords[:, 2] # Color assignment: default_color = "#D3D3D3" if cell_type is not None: ct_other_color = "#71797E" target_color = "#39FF14" # target_data = adata.obs.loc[target_expressing_selected, f"{interaction}_{target}_example_points"] # p997 = np.percentile(target_data.values, 99.7) # target_data[target_data > p997] = p997 # plot_vals = target_data.values scatter_target = go.Scatter3d( x=x[selected_indices], y=y[selected_indices], z=z[selected_indices], mode="markers", # Draw target cells larger marker=dict(color=target_color, size=6.5), showlegend=False, # marker=dict( # color=plot_vals, colorscale="Plotly3", size=6, colorbar=dict(title=f"{target} Expression", x=1.05) # ), ) nbr_data = adata.obs.loc[ neighbors_selected, f"{interaction}_{target}_{select_examples_criterion}_example_points" ] # Lenient w/ the max value cutoff so that the colored dots are more distinct from black background p95 = np.percentile(nbr_data.values, 95) nbr_data[nbr_data > p95] = p95 plot_vals = nbr_data.values scatter_ligand = go.Scatter3d( x=x[neighbors], y=y[neighbors], z=z[neighbors], mode="markers", marker=dict( color=plot_vals, colorscale="Hot", size=2.5, colorbar=dict(title=f"{ligand} Expression", x=0.8, titlefont=dict(size=16), tickfont=dict(size=18)), ), showlegend=False, ) rest_indices = list(set(range(len(x))) - set(selected_indices) - set(neighbors)) scatter_rest = go.Scatter3d( x=x[rest_indices], y=y[rest_indices], z=z[rest_indices], mode="markers", marker=dict(color=default_color, size=2), name="Other Cells", showlegend=False, ) if cell_type is not None: scatter_ct = go.Scatter3d( x=x[ct_other_indices], y=y[ct_other_indices], z=z[ct_other_indices], mode="markers", marker=dict(color=ct_other_color, size=2), name=f"Other Cells of Type {cell_type}", showlegend=False, ) # Invisible traces for the legend legend_ct = go.Scatter3d( x=[None], y=[None], z=[None], mode="markers", marker=dict(size=10, color=ct_other_color), # Adjust size as needed name=f"Other Cells of Type <br>{cell_type}", showlegend=True, ) # Invisible traces for the legend name = ( f"{target}-Expressing Cells <br>(w/ Receptor Expression)" if select_examples_criterion == "positive" else f"{target}-Expressing Cells <br>(w/o Receptor Expression)" ) legend_target = go.Scatter3d( x=[None], y=[None], z=[None], mode="markers", marker=dict(size=30, color=target_color), # Adjust size as needed name=name, showlegend=True, ) legend_rest = go.Scatter3d( x=[None], y=[None], z=[None], mode="markers", marker=dict(size=15, color=default_color), # Adjust size as needed name="Other Cells", showlegend=True, ) # Create the figure and add the scatter plots if cell_type is not None: fig = go.Figure( data=[ scatter_rest, scatter_target, scatter_ligand, scatter_ct, legend_target, legend_rest, legend_ct, ] ) else: fig = go.Figure(data=[scatter_rest, scatter_target, scatter_ligand, legend_target, legend_rest]) if cell_type is None: title_dict = dict( text=f"Target: {target}, Ligand: {ligand} " f"<br>(Example {select_examples_criterion.title()} Predicted Effects)", y=0.9, yanchor="top", x=0.5, xanchor="center", font=dict(size=28), ) else: title_dict = dict( text=f"Target: {target}, Ligand: {ligand}, <br>Cell Type: {cell_type} " f"(Example {select_examples_criterion.title()} Predicted Effects)", y=0.9, yanchor="top", x=0.5, xanchor="center", font=dict(size=28), ) # Turn off the grid fig.update_layout( showlegend=True, legend=dict(x=0.65, y=0.85, orientation="v", font=dict(size=18)), scene=dict( aspectmode="data", xaxis=dict( showgrid=False, showline=False, linewidth=2, linecolor="black", backgroundcolor="white", title="", showticklabels=False, ticks="", ), yaxis=dict( showgrid=False, showline=False, linewidth=2, linecolor="black", backgroundcolor="white", title="", showticklabels=False, ticks="", ), zaxis=dict( showgrid=False, showline=False, linewidth=2, linecolor="black", backgroundcolor="white", title="", showticklabels=False, ticks="", ), ), margin=dict(l=0, r=0, b=0, t=50), # Adjust margins to minimize spacing title=title_dict, ) fig.write_html(path) return adata
[docs] def cell_type_specific_interactions( self, to_plot: Literal["mean", "percentage"] = "mean", plot_type: Literal["heatmap", "barplot"] = "heatmap", group_key: Optional[str] = None, ct_subset: Optional[List[str]] = None, target_subset: Optional[List[str]] = None, interaction_subset: Optional[List[str]] = None, lower_threshold: float = 0.3, upper_threshold: float = 1.0, effect_threshold: Optional[float] = None, use_significant: bool = False, row_normalize: bool = False, col_normalize: bool = False, normalize_targets: bool = False, hierarchical_cluster_ct: bool = False, group_y_cell_type: bool = False, fontsize: Union[None, int] = None, figsize: Union[None, Tuple[float, float]] = None, center: Optional[float] = None, cmap: str = "Reds", save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", save_kwargs: Optional[dict] = {}, save_df: bool = False, ): """Map interactions and interaction effects that are specific to particular cell type groupings. Returns a heatmap representing the enrichment of the interaction/effect within cells of that grouping (if "to_plot" is effect, this will be enrichment of the effect on cell type-specific expression). Enrichment determined by mean effect size or expression. Args: to_plot: Whether to plot the mean effect size or the proportion of cells in a cell type w/ effect on target. Options are "mean" or "percentage". plot_type: Whether to plot the results as a heatmap or barplot. Options are "heatmap" or "barplot". If "barplot", must provide a subset of up to four interactions to visualize. group_key: Can be used to specify entry in adata.obs that contains cell type groupings. If None, will use :attr `group_key` from model initialization. ct_subset: Can be used to restrict the enrichment analysis to only cells of a particular type. If given, will search for cell types in "group_key" attribute from model initialization. Recommended to use to subset to cell types with sufficient numbers. target_subset: List of targets to consider. If None, will use all targets used in model fitting. interaction_subset: List of interactions to consider. If None, will use all interactions used in model. Is necessary if "plot_type" is "barplot", since the barplot is only designed to accomodate up to three interactions at once. lower_threshold: Lower threshold for the proportion of cells in a cell type group that must express a particular interaction/effect for it to be colored on the plot, as a proportion of the max value. Threshold will be applied to the non-normalized values (if normalization is applicable). Defaults to 0.3. upper_threshold: Upper threshold for the proportion of cells in a cell type group that must express a particular interaction/effect for it to be colored on the plot, as a proportion of the max value. Threshold will be applied to the non-normalized values (if normalization is applicable). Defaults to 1.0 (the max value). effect_threshold: Optional threshold for the effect size of an interaction/effect to be considered for analysis; only used if "to_plot" is "percentage". If not given, will use the upper quartile value among all interaction effect values to determine the threshold. use_significant: Whether to use only significant effects in computing the specificity. If True, will filter to cells + interactions where the interaction is significant for the target. Only valid if :func `compute_coeff_significance()` has been run. row_normalize: Whether to minmax scale the metric values by row (i.e. for each interaction/effect). Helps to alleviate visual differences that result from scale rather than differences in mean value across cell types. col_normalize: Whether to minmax scale the metric values by column (i.e. for each interaction/effect). Helps to alleviate visual differences that result from scale rather than differences in mean value across cell types. normalize_targets: Whether to minmax scale the metric values by column for each target (i.e. for each interaction/effect), to remove differences that occur as a result of scale of expression. Provides a clearer picture of enrichment for each target. hierarchical_cluster_ct: Whether to cluster the x-axis (target gene in cell type) using hierarchical clustering. If False, will order the x-axis by the order of the target genes for organization purposes. group_y_cell_type: Whether to group the y-axis (target gene in cell type) by cell type. If False, will group by target gene instead. Defaults to False. fontsize: Size of font for x and y labels. figsize: Size of figure. center: Optional, determines position of the colormap center. Between 0 and 1. cmap: Colormap to use for heatmap. If metric is "number", "proportion", "specificity", the bottom end of the range is 0. It is recommended to use a sequential colormap (e.g. "Reds", "Blues", "Viridis", etc.). For metric = "fc", if a divergent colormap is not provided, "seismic" will automatically be used. save_show_or_return: Whether to save, show or return the figure. If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, displayed and the associated axis and other object will be return. save_kwargs: A dictionary that will passed to the save_fig function. By default it is an empty dictionary and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those keys according to your needs. save_df: Set True to save the metric dataframe in the end """ logger = lm.get_main_logger() config_spateo_rcParams() # But set display DPI to 300: plt.rcParams["figure.dpi"] = 300 if to_plot not in ["mean", "percentage"]: raise ValueError("Unrecognized input for plotting. Options are 'mean' or 'percentage'.") if plot_type == "barplot" and interaction_subset is None: raise ValueError("Must provide a subset of interactions to visualize if 'plot_type' is 'barplot'.") if plot_type == "barplot" and len(interaction_subset) > 4: raise ValueError( "Can only visualize up to four interactions at once with 'barplot' (for practical/plot " "readability reasons)." ) if self.mod_type not in ["lr", "ligand", "receptor"]: raise ValueError("Model type must be one of 'lr', 'ligand', or 'receptor'.") if save_show_or_return in ["save", "both", "all"]: if not os.path.exists(os.path.join(os.path.dirname(self.output_path), "figures")): os.makedirs(os.path.join(os.path.dirname(self.output_path), "figures")) figure_folder = os.path.join(os.path.dirname(self.output_path), "figures", "temp") if not os.path.exists(figure_folder): os.makedirs(figure_folder) if save_df: output_folder = os.path.join(os.path.dirname(self.output_path), "analyses") if not os.path.exists(output_folder): os.makedirs(output_folder) # Colormap should be sequential: sequential_colormaps = [ "Blues", "BuGn", "BuPu", "GnBu", "Greens", "Greys", "Oranges", "OrRd", "PuBu", "PuBuGn", "PuRd", "Purples", "RdPu", "Reds", "YlGn", "YlGnBu", "YlOrBr", "YlOrRd", "afmhot", "autumn", "bone", "cool", "copper", "gist_heat", "gray", "hot", "pink", "spring", "summer", "winter", "viridis", "plasma", "inferno", "magma", "cividis", ] if cmap not in sequential_colormaps: logger.info(f"For option {to_plot}, colormap should be sequential: using 'viridis'.") cmap = "viridis" if group_key is None: group_key = self.group_key # Get appropriate adata: if isinstance(ct_subset, str): ct_subset = [ct_subset] if ct_subset is None: adata = self.adata.copy() else: adata = self.adata[self.adata.obs[group_key].isin(ct_subset)].copy() cell_types = adata.obs[group_key].unique() all_targets = list(self.coeffs.keys()) all_feature_names = [feat for feat in self.feature_names if feat != "intercept"] if isinstance(interaction_subset, str): interaction_subset = [interaction_subset] feature_names = all_feature_names if interaction_subset is None else interaction_subset if fontsize is None: fontsize = rcParams.get("font.size") if isinstance(target_subset, str): target_subset = [target_subset] targets = all_targets if target_subset is None else target_subset combinations = product(cell_types, targets) combinations = [f"{ct}-{target}" for ct, target in combinations] if figsize is None: if plot_type == "heatmap": # Set figure size based on the number of interaction features and cell type-target combos: m = len(combinations) * 50 / 200 n = len(feature_names) * 50 / 200 else: # Set figure size based on the number of cell type-target combos: n = len(combinations) * 50 / 200 m = 3 * len(feature_names) figsize = (n, m) df = pd.DataFrame(0, index=combinations, columns=feature_names) for ct in cell_types: cell_type_mask = adata.obs[group_key] == ct cell_in_ct = adata[cell_type_mask].copy() # Get appropriate coefficient arrays: for target in targets: expressing_target = pd.DataFrame( adata[:, target].X.toarray().reshape(-1) > 0, index=adata.obs_names, columns=[target] ) total_mask = cell_type_mask & expressing_target[target] total_mask = total_mask[total_mask].index if to_plot == "mean": mean_effects = [] # coef_target = self.coeffs[target].loc[adata.obs_names] coef_target = self.coeffs[target].loc[ cell_in_ct.obs_names ] # This should be cell_in_ct, since it's cell type specific effect, and not the entire adata object coef_target = coef_target[[c for c in coef_target.columns if "intercept" not in c]] if effect_threshold is None: # Cell type-specific threshold: nonzero_values = coef_target.loc[cell_type_mask].values.flatten() nonzero_values = nonzero_values[nonzero_values != 0] target_effect_threshold = pd.Series(nonzero_values).quantile(0.75) else: target_effect_threshold = effect_threshold coef_target[coef_target < target_effect_threshold] = 0 if use_significant: parent_dir = os.path.dirname(self.output_path) sig = pd.read_csv( os.path.join(parent_dir, "significance", f"{target}_is_significant.csv"), index_col=0 ) coef_target *= sig for feat in feature_names: if f"b_{feat}" in coef_target.columns: # If a given cell type does not have much expression of the target gene, mask out the # mean effect (use an arbitrary cutoff of 2% of cells): if len(total_mask) < 0.02 * cell_in_ct.n_obs: mean_effects.append(0) else: # Get mean effect size for each interaction feature in each cell type, from among the # cells that express the target gene: mean_effects.append(coef_target.loc[total_mask, f"b_{feat}"].values.mean()) else: mean_effects.append(0) df.loc[f"{ct}-{target}", :] = mean_effects elif to_plot == "percentage": percentages = [] coef_target = self.coeffs[target].loc[adata.obs_names] coef_target = coef_target[[c for c in coef_target.columns if "intercept" not in c]] if effect_threshold is None: # Cell type-specific threshold: nonzero_values = coef_target.loc[cell_type_mask].values.flatten() nonzero_values = nonzero_values[nonzero_values != 0] target_effect_threshold = pd.Series(nonzero_values).quantile(0.75) else: target_effect_threshold = effect_threshold coef_target[coef_target < target_effect_threshold] = 0 for feat in feature_names: if f"b_{feat}" in coef_target.columns: # If a given cell type does not have much expression of the target gene, mask out the # mean effect (use an arbitrary cutoff of 2% of cells): if len(total_mask) < 0.02 * cell_in_ct.n_obs: percentages.append(0) else: # Get percentage of cells in each cell type that express each interaction feature: percentages.append( (coef_target.loc[total_mask, f"b_{feat}"].values > target_effect_threshold).mean() ) else: percentages.append(0) df.loc[f"{ct}-{target}", :] = percentages # Apply metric threshold (do this in a grouped manner, for each target): # Split the index to get the targets portion of the tuples grouping_element = df.index.map(lambda x: x.split("-")[1]) # Compute the maximum (and optionally used minimum) for each group group_max = df.groupby(grouping_element).max() # Apply the threshold in a grouped fashion for group in group_max.index: # Select the rows belonging to the current group group_rows = df.index[df.index.str.contains(f"-{group}$")] # Apply the lower threshold specific to this group df.loc[group_rows] = df.loc[group_rows].where( df.loc[group_rows].ge(lower_threshold * group_max.loc[group]), 0 ) if normalize_targets: # Take 0 to be the min. value in all cases: df.loc[group_rows] = df.loc[group_rows] / group_max.loc[group] if upper_threshold != 1.0: df[df >= upper_threshold * df.max().max()] = df.max().max() # Optionally, normalize each row by minmax scaling (to get an idea of the top effects for each target), # or each column by minmax scaling: if row_normalize or col_normalize or normalize_targets: normalize = True else: normalize = False if row_normalize: # Calculate row-wise min and max row_min = df.min(axis=1).values.reshape(-1, 1) row_max = df.max(axis=1).values.reshape(-1, 1) df = (df - row_min) / (row_max - row_min) elif col_normalize: df = (df - df.min()) / (df.max() - df.min()) df.fillna(0, inplace=True) if plot_type == "heatmap": # Hierarchical clustering- first to group interactions w/ similar patterns across cell types: col_linkage = sch.linkage(df.transpose(), method="ward") col_dendro = sch.dendrogram(col_linkage, no_plot=True) col_clustered_order = col_dendro["leaves"] df = df.iloc[:, col_clustered_order] # Then to group cell types w/ similar interaction patterns, if specified: if hierarchical_cluster_ct: row_linkage = sch.linkage(df, method="ward") row_dendro = sch.dendrogram(row_linkage, no_plot=True) row_clustered_order = row_dendro["leaves"] df = df.iloc[row_clustered_order, :] else: # Sort by target: # Create a temporary MultiIndex df.index = pd.MultiIndex.from_tuples(df.index.str.split("-").map(tuple), names=["first", "second"]) if group_y_cell_type: # Sort by the first element, then the second df.sort_index(level=["first", "second"], inplace=True) else: # Sort by the second element, then the first df.sort_index(level=["second", "first"], inplace=True) # Revert to the original index format df.index = df.index.map("-".join) else: # Sort by target: # Create a temporary MultiIndex df.index = pd.MultiIndex.from_tuples(df.index.str.split("-").map(tuple), names=["first", "second"]) if group_y_cell_type: # Sort by the first element, then the second df.sort_index(level=["first", "second"], inplace=True) else: # Sort by the second element, then the first df.sort_index(level=["second", "first"], inplace=True) # Revert to the original index format df.index = df.index.map("-".join) # Delete all-zero rows and all-zero columns: df = df.loc[:, ~(df == 0).all()] logger.info(f"Final dataframe for {ct} shape: {df.shape}") if normalize and to_plot == "mean": if plot_type == "heatmap": label = ( "Normalized avg. effect per cell type for cells expressing target" if not normalize_targets else "Normalized avg. effect per cell type \nfor cells expressing target (normalized within target)" ) else: label = ( "Normalized avg. effect\n per cell type \nfor cells expressing target" if not normalize_targets else "Normalized avg. effect\n per cell type \nfor cells expressing target \n(normalized within " "target)" ) elif normalize and to_plot == "percentage": if plot_type == "heatmap": label = ( "Normalized enrichment of effect per cell type \nfor cells expressing target" if not normalize_targets else "Normalized enrichment of effect per cell type \nfor cells expressing target (normalized " "within target)" ) else: label = ( "Normalized enrichment of\n effect per cell type \nfor cells expressing target" if not normalize_targets else "Normalized enrichment \nof effect per cell type\n for cells expressing target\n(normalized " "within target)" ) elif not normalize and to_plot == "mean": label = ( "Avg. effect per cell type \nfor cells expressing target" if plot_type == "heatmap" else "Avg. effect\n per cell type \nfor cells expressing target" ) else: label = ( "Enrichment of effect per cell type \nfor cells expressing target" if plot_type == "heatmap" else "Enrichment of effect\n per cell type \nfor cells expressing target" ) if self.mod_type == "lr": x_label = "Interaction" title = "Enrichment of L:R interaction in each cell type" elif self.mod_type == "ligand": x_label = "Neighboring ligand expression" title = "Enrichment of neighboring ligand expression in each cell type for each target" elif self.mod_type == "receptor": x_label = "Receptor expression" title = "Enrichment of receptor expression in each cell type" # Formatting color legend: if group_y_cell_type: group_labels = [idx.split("-")[0] for idx in df.index] else: group_labels = [idx.split("-")[1] for idx in df.index] target_colors = mpl.colormaps["tab20"].colors if group_y_cell_type: color_mapping = { annotation: target_colors[i % len(target_colors)] for i, annotation in enumerate(set(cell_types)) } else: color_mapping = { annotation: target_colors[i % len(target_colors)] for i, annotation in enumerate(set(targets)) } max_annotation_length = max([len(annotation) for annotation in color_mapping.keys()]) if max_annotation_length > 30: ax2_size = "30%" elif max_annotation_length > 20: ax2_size = "20%" else: ax2_size = "10%" if plot_type == "heatmap": # Plot heatmap: vmin = 0 vmax = 1 if normalize else df.max().max() fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize) divider = make_axes_locatable(ax) ax2 = divider.append_axes("right", size=ax2_size, pad=0) # Keep track of groups: current_group = None group_start = None # Color legend: for i, annotation in enumerate(group_labels): if annotation != current_group: if current_group is not None: group_center = len(df) - ((group_start + i - 1) / 2) - 1 ax2.text(0.22, group_center, current_group, va="center", ha="left", fontsize=fontsize) current_group = annotation group_start = i color = color_mapping[annotation] ax2.add_patch(plt.Rectangle((0, i), 0.2, 1, color=color)) # Add label for the last group: group_center = len(df) - ((group_start + len(df) - 1) / 2) - 1 ax2.text(0.22, group_center, current_group, va="center", ha="left", fontsize=fontsize) ax2.set_ylim(0, len(df.index)) ax2.axis("off") thickness = 0.3 * figsize[0] / 10 mask = df == 0 m = sns.heatmap( df, square=True, linecolor="grey", linewidths=thickness, cbar_kws={"label": label, "location": "top", "pad": 0}, cmap=cmap, center=center, vmin=vmin, vmax=vmax, mask=mask, ax=ax, ) # Outer frame: for _, spine in m.spines.items(): spine.set_visible(True) spine.set_linewidth(thickness * 2.5) # Adjust colorbar settings: divider = make_axes_locatable(ax) # Append axes to the top of the plot, where the colorbar will be placed if df.shape[0] > df.shape[1]: cax = divider.append_axes("top", size="30%", pad=0) else: cax = divider.append_axes("top", size="30%", pad="30%") # Create the colorbar manually in the appended axes cbar = plt.colorbar(m.collections[0], cax=cax, orientation="horizontal") cbar.set_label(to_plot.title(), fontsize=fontsize * 1.5, labelpad=10) cbar.ax.xaxis.set_ticks_position("top") # Move ticks to the top cbar.ax.xaxis.set_label_position("top") # Move the label (title) to the top cbar.ax.tick_params(labelsize=fontsize * 1.5) cbar.ax.set_aspect(0.02) ax.set_xlabel(x_label, fontsize=fontsize * 1.25) ax.set_ylabel("Cell Type-Specific Target", fontsize=fontsize * 1.25) ax.tick_params(axis="x", labelsize=fontsize, rotation=90) ax.tick_params(axis="y", labelsize=fontsize) ax.set_title(title, fontsize=fontsize * 1.5, pad=20) # Use the saved name for the AnnData object to define part of the name of the saved file: base_name = os.path.basename(self.adata_path) adata_id = os.path.splitext(base_name)[0] prefix = f"{adata_id}_{to_plot}_enrichment_cell_type" else: rem_interactions = [i for i in interaction_subset if i in df.columns] fig, axes = plt.subplots(nrows=len(rem_interactions), ncols=1, figsize=figsize) fig.subplots_adjust(hspace=0.4) colormap = mpl.colormaps[cmap] # Determine the order of the plot based on averaging over the chosen interactions (if there is more than # one): df_sub = df[rem_interactions] df_sub["Group"] = group_labels # Ranks within each group: grouped_ranked_df = df_sub.groupby("Group").rank(ascending=False) # Average rank across groups: avg_ranked_df = grouped_ranked_df.mean() # Sort by average rank: sorted_features = avg_ranked_df.sort_values().index.tolist() df = df[sorted_features] # Color legend: if not isinstance(axes, (list, np.ndarray)): divider = make_axes_locatable(axes) else: # If 'axes' is an array, and we want to apply to the first one if len(axes) > 0: divider = make_axes_locatable(axes[0]) else: raise ValueError("No axes found in the 'axes' array") ax2 = divider.append_axes("top", size=ax2_size, pad=0) current_group = None group_start = None for i, annotation in enumerate(group_labels): if annotation != current_group: if current_group is not None: group_center = (group_start + i - 1) / 2 ax2.text(group_center, 0.42, current_group, va="bottom", ha="center", fontsize=fontsize) current_group = annotation group_start = i color = color_mapping[annotation] ax2.add_patch(plt.Rectangle((i, 0), 1, 0.4, color=color)) # Add label for the last group: group_center = (group_start + len(df) - 1) / 2 ax2.text(group_center, 0.42, current_group, va="bottom", ha="center", fontsize=fontsize) ax2.set_xlim(0, len(df.index)) ax2.axis("off") if not isinstance(axes, (list, np.ndarray)): vmin = 0 vmax = 1 if normalize else df[interaction_subset].max().values norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) colors = [colormap(norm(val)) for val in df[interaction_subset].values] sns.barplot( x=df[interaction_subset].index, y=df[interaction_subset].values.flatten(), edgecolor="black", linewidth=1, palette=colors, ax=axes, ) axes.set_title(interaction_subset[0], fontsize=fontsize * 1.5, pad=35) axes.set_xlabel("Cell Type-Specific Target", fontsize=fontsize) axes.set_ylabel(label, fontsize=fontsize) axes.tick_params(axis="y", labelsize=fontsize * 1.1) axes.tick_params(axis="x", labelsize=fontsize * 0.9, rotation=90) else: for i, ax in enumerate(axes): # From the larger dataframe, get the column for the chosen interaction as a series: interaction = interaction_subset[i] interaction_series = df[interaction] vmin = 0 vmax = 1 if normalize else interaction_series.max() norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) colors = [colormap(norm(val)) for val in interaction_series] sns.barplot( x=interaction_series.index, y=interaction_series.values, edgecolor="black", linewidth=1, palette=colors, ax=ax, ) if ax is axes[0]: ax.set_title(interaction, fontsize=fontsize * 1.5, pad=35) else: ax.set_title(interaction, fontsize=fontsize * 1.5, pad=10) ax.set_xlabel("Cell Type-Specific Target", fontsize=fontsize) ax.set_ylabel(label, fontsize=fontsize) ax.tick_params(axis="y", labelsize=fontsize * 1.1) if ax is axes[-1]: ax.tick_params(axis="x", labelsize=fontsize * 0.9, rotation=90) else: ax.tick_params(axis="x", labelbottom=False) # Use the saved name for the AnnData object to define part of the name of the saved file: base_name = os.path.basename(self.adata_path) adata_id = os.path.splitext(base_name)[0] prefix = f"{adata_id}_{to_plot}_enrichment_cell_type" # Save figure: save_kwargs["ext"] = "png" save_kwargs["dpi"] = 300 if "figure_folder" in locals(): save_kwargs["path"] = figure_folder save_return_show_fig_utils( save_show_or_return=save_show_or_return, show_legend=False, background="white", prefix=prefix, save_kwargs=save_kwargs, total_panels=1, fig=fig, axes=axes, return_all=False, return_all_list=None, ) if save_df: df.to_csv(os.path.join(output_folder, f"{prefix}.csv"))
[docs] def cell_type_interaction_fold_change( self, ref_ct: str, query_ct: str, group_key: Optional[str] = None, target_subset: Optional[List[str]] = None, interaction_subset: Optional[List[str]] = None, to_plot: Literal["mean", "percentage"] = "mean", plot_type: Literal["volcano", "barplot"] = "barplot", source_data: Literal["interaction", "effect", "target"] = "effect", top_n_to_plot: Optional[int] = None, significance_cutoff: float = 1.3, fold_change_cutoff: float = 1.5, fold_change_cutoff_for_labels: float = 3.0, plot_query_over_ref: bool = False, plot_ref_over_query: bool = False, plot_only_significant: bool = False, fontsize: Union[None, int] = None, figsize: Union[None, Tuple[float, float]] = None, cmap: str = "seismic", save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", save_kwargs: Optional[dict] = {}, save_df: bool = False, ): """Computes fold change in predicted interaction effects between two cell types, and visualizes result. Args: ref_ct: Label of the first cell type to consider. Fold change will be computed with respect to the level in this cell type. query_ct: Label of the second cell type to consider group_key: Name of the key in .obs containing cell type information. If not given, will use :attr `group_key` from model initialization. target_subset: List of targets to consider. If None, will use all targets used in model fitting. interaction_subset: List of interactions to consider. If None, will use all interactions used in model. to_plot: Whether to plot the mean effect size or the proportion of cells in a cell type w/ effect on target. Options are "mean" or "percentage". plot_type: Whether to plot the results as a volcano plot or barplot. Options are "volcano" or "barplot". source_data: Selects what to use in computing fold changes. Options: - "interaction": will use the design matrix (e.g. neighboring ligand expression or L:R mapping) - "effect": will use the coefficient arrays for each target - "target": will use the target gene expression top_n_to_plot: If given, will only include the top n features in the visualization. Recommended if "source_data" is "effect", as all combinations of interaction and target will be considered in this case. significance_cutoff: Cutoff for negative log-10 q-value to consider an interaction/effect significant. Only used if "plot_type" is "volcano". Defaults to 1.3 (corresponding to an approximate q-value of 0.05). fold_change_cutoff: Cutoff for fold change to consider an interaction/effect significant. Only used if "plot_type" is "volcano". Defaults to 1.5. fold_change_cutoff_for_labels: Cutoff for fold change to include the label for an interaction/effect. Only used if "plot_type" is "volcano". Defaults to 3.0. plot_query_over_ref: Whether to plot/visualize only the portion that corresponds to the fold change of the query cell type over the reference cell type (and the portion that is significant). If False (and "plot_ref_over_query" is False), will plot the entire volcano plot. Only used if "plot_type" is "volcano". plot_ref_over_query: Whether to plot/visualize only the portion that corresponds to the fold change of the reference cell type over the query cell type (and the portion that is significant). If False (and "plot_query_over_ref" is False), will plot the entire volcano plot. Only used if "plot_type" is "volcano". plot_only_significant: Whether to plot/visualize only the portion that passes the "significance_cutoff" p-value threshold. Only used if "plot_type" is "volcano". fontsize: Size of font for x and y labels. figsize: Size of figure. cmap: Colormap to use for heatmap. If metric is "number", "proportion", "specificity", the bottom end of the range is 0. It is recommended to use a sequential colormap (e.g. "Reds", "Blues", "Viridis", etc.). For metric = "fc", if a divergent colormap is not provided, "seismic" will automatically be used. save_show_or_return: Whether to save, show or return the figure. If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, displayed and the associated axis and other object will be return. save_kwargs: A dictionary that will passed to the save_fig function. By default it is an empty dictionary and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those keys according to your needs. save_df: Set True to save the metric dataframe in the end """ config_spateo_rcParams() # But set display DPI to 300: plt.rcParams["figure.dpi"] = 300 parent_dir = os.path.dirname(self.output_path) if save_show_or_return in ["save", "both", "all"]: if not os.path.exists(os.path.join(os.path.dirname(self.output_path), "figures")): os.makedirs(os.path.join(os.path.dirname(self.output_path), "figures")) figure_folder = os.path.join(os.path.dirname(self.output_path), "figures", "temp") if not os.path.exists(figure_folder): os.makedirs(figure_folder) # Use the saved name for the AnnData object to define part of the name of the saved figure & file (if # applicable): base_name = os.path.basename(self.adata_path) adata_id = os.path.splitext(base_name)[0] prefix = f"{adata_id}_fold_changes_{source_data}_{ref_ct}_{query_ct}" output_folder = os.path.join(os.path.dirname(self.output_path), "analyses") if not os.path.exists(output_folder): os.makedirs(output_folder) if fontsize is None: fontsize = rcParams.get("font.size") if group_key is None: group_key = self.group_key if target_subset is None: target_subset = self.targets_expr.columns if interaction_subset is None: interaction_subset = self.feature_names # Formatting: if source_data == "effect": x_label = f"$\\log_2$(Fold change effect on target- \n{ref_ct} and {query_ct})" title = f"Fold change effect on target \n{ref_ct} and {query_ct}" if self.mod_type == "lr": y_label = "L:R effect on target" elif self.mod_type == "ligand": y_label = "Ligand effect on target" elif source_data == "interaction": x_label = f"$\\log_2$(Fold change interaction enrichment \n {ref_ct} and {query_ct})" title = f"Fold change interaction enrichment \n{ref_ct} and {query_ct}" if self.mod_type == "lr": y_label = "L:R interaction" elif self.mod_type == "ligand": y_label = "Ligand" elif source_data == "target": x_label = f"$\\log_2$(Fold change target expression \n {ref_ct} and {query_ct})" title = f"Fold change target expression \n {ref_ct} and {query_ct}" y_label = "Target" # Check for already-existing dataframe: if os.path.exists(os.path.join(parent_dir, output_folder, f"{prefix}.csv")): results = pd.read_csv(os.path.join(parent_dir, output_folder, f"{prefix}.csv"), index_col=0) else: ref_names = self.adata[self.adata.obs[group_key] == ref_ct].obs_names query_names = self.adata[self.adata.obs[group_key] == query_ct].obs_names # Series/dataframes for each group: if source_data == "interaction": ref_data = self.X_df.loc[ref_names, interaction_subset] query_data = self.X_df.loc[query_names, interaction_subset] elif source_data == "effect": # Coefficients for all targets in subset: for target in target_subset: if target not in self.coeffs.keys(): raise ValueError(f"Target {target} not found in model.") else: coef_target = self.coeffs[target].loc[self.adata.obs_names] coef_target.columns = coef_target.columns.str.replace("b_", "") coef_target = coef_target[[col for col in coef_target.columns if col != "intercept"]] coef_target.columns = [replace_col_with_collagens(col) for col in coef_target.columns] coef_target.columns = [f"{col}-> target {target}" for col in coef_target.columns] duplicates = coef_target.columns[coef_target.columns.duplicated(keep=False)] for item in duplicates.unique(): # Calculate mean for collagens: mean_series = coef_target.filter(like=item).mean(axis=1) coef_target.drop(columns=coef_target.filter(like=item).columns, inplace=True) coef_target[item] = mean_series target_interaction_subset = [replace_col_with_collagens(i) for i in interaction_subset] target_interaction_subset = list( set([f"{i}-> target {target}" for i in target_interaction_subset]) ) target_interaction_subset = [i for i in target_interaction_subset if i in coef_target.columns] if "effect_df" not in locals(): effect_df = coef_target.loc[:, target_interaction_subset] else: effect_df = pd.concat([effect_df, coef_target.loc[:, target_interaction_subset]], axis=1) ref_data = effect_df.loc[ref_names, :] query_data = effect_df.loc[query_names, :] elif source_data == "target": ref_data = self.targets_expr.loc[ref_names, target_subset] query_data = self.targets_expr.loc[query_names, target_subset] else: raise ValueError( f"Unrecognized input for source_data: {source_data}. Options are 'interaction', 'effect', or " f"'target'." ) # Compute significance for each column: pvals = [] for col in tqdm(ref_data.columns, desc="Computing significance..."): if source_data == "effect" or source_data == "interaction": pvals.append(ttest_ind(ref_data[col], query_data[col])[1]) elif source_data == "target": pvals.append(mannwhitneyu(ref_data[col], query_data[col])[1]) # Correct for multiple hypothesis testing: qvals = multitesting_correction(pvals, method="fdr_bh") results = pd.DataFrame(qvals, index=ref_data.columns, columns=["qval"]) results["qval"] = results["qval"].apply(lambda x: x[0] if isinstance(x, np.ndarray) else x) results["Significance"] = results.apply(assign_significance, axis=1) # Negative log q-value (in the case of volcano plot): results["-log10(qval)"] = -np.log10(qvals) # Threshold at the highest non-infinity q-value: max_non_inf = results[results["-log10(qval)"] != np.inf]["-log10(qval)"].max() results["-log10(qval)"] = results["-log10(qval)"].apply(lambda x: x if x != np.inf else max_non_inf) if to_plot == "mean": ref_data = ref_data.mean(axis=0) query_data = query_data.mean(axis=0) elif to_plot == "percentage": ref_data = (ref_data > 0).mean(axis=0) query_data = (query_data > 0).mean(axis=0) # Add small offset to ensure reference value is not 0: ref_data += 1e-3 query_data += 1e-3 # Compute fold change: fold_change = query_data / ref_data results["Fold Change"] = fold_change results["Fold Change"] = results["Fold Change"].apply(lambda x: x[0] if isinstance(x, np.ndarray) else x) # Take the log of the fold change: results["Fold Change"] = np.log2(results["Fold Change"]) # Remove NaNs: results = results[~results["Fold Change"].isna()] results = results.sort_values("Fold Change") if top_n_to_plot is not None: results = results.iloc[:top_n_to_plot, :] # Plot: if figsize is None: # Set figure size based on the number of interaction features and targets: if plot_type == "barplot": m = len(results) / 2 n = m / 2 elif plot_only_significant or plot_query_over_ref or plot_ref_over_query: m = 7 n = m * 2 else: m = 10 n = m figsize = (n, m) fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize) cmap = mpl.colormaps[cmap] # Center colormap at 0: max_distance = max(abs(results["Fold Change"]).max(), abs(results["Fold Change"]).min()) max_pos = results["Fold Change"].max() max_neg = results["Fold Change"].min() norm = plt.Normalize(-max_distance, max_distance) colors = cmap(norm(results["Fold Change"])) if plot_type == "barplot": barplot = sns.barplot( x="Fold Change", y=results.index, data=results, orient="h", palette=colors, edgecolor="black", linewidth=1, ax=ax, ) for index, row in results.iterrows(): ax.text(row["Fold Change"], index, f"{row['Significance']}", color="black", ha="right") ax.axvline(x=0, color="grey", linestyle="--", linewidth=2) ax.set_xlim(max_neg * 1.1, max_pos * 1.1) ax.set_xticklabels(results.index, fontsize=fontsize) ax.set_yticklabels(["{:.2f}".format(y) for y in ax.get_yticks()], fontsize=fontsize) ax.set_xlabel(x_label, fontsize=fontsize * 1.25) ax.set_ylabel(y_label, fontsize=fontsize * 1.25) ax.set_title(title, fontsize=fontsize * 1.5) elif plot_type == "volcano": if len(results) > 20: size = 20 else: size = 40 # Check if max -log10(qval) is greater than 8 if results["-log10(qval)"].max() > 8: ax.set_yscale("log", base=2) # Set y-axis to log y_label = r"$-log_{10}$(qval) ($log_2$ scale)" else: y_label = r"$-log_{10}$(qval)" significant = results["-log10(qval)"] > significance_cutoff significant_up = results["Fold Change"] > fold_change_cutoff significant_down = results["Fold Change"] < -fold_change_cutoff if plot_only_significant: results = results[significant] size *= 1.5 positive_fold_change = results["Fold Change"] > 0 negative_fold_change = results["Fold Change"] < 0 # Check if only plotting query over ref or ref over query: if plot_query_over_ref: size *= 1.5 fc_up = ax.scatter( x=results["Fold Change"][significant & significant_up & positive_fold_change], y=results["-log10(qval)"][significant & significant_up & positive_fold_change], c=results["Fold Change"][significant & significant_up & positive_fold_change], cmap="Reds", edgecolor="black", s=size, ) elif plot_ref_over_query: size *= 1.5 fc_down = ax.scatter( x=results["Fold Change"][significant & significant_down & negative_fold_change], y=results["-log10(qval)"][significant & significant_down & negative_fold_change], c=results["Fold Change"][significant & significant_down & negative_fold_change], cmap="Blues_r", edgecolor="black", s=size, ) else: fc_up = ax.scatter( x=results["Fold Change"][significant & significant_up], y=results["-log10(qval)"][significant & significant_up], c=results["Fold Change"][significant & significant_up], cmap="Reds", edgecolor="black", s=size, ) fc_down = ax.scatter( x=results["Fold Change"][significant & significant_down], y=results["-log10(qval)"][significant & significant_down], c=results["Fold Change"][significant & significant_down], cmap="Blues_r", edgecolor="black", s=size, ) ax.scatter( x=results["Fold Change"][~(significant & (significant_up | significant_down))], y=results["-log10(qval)"][~(significant & (significant_up | significant_down))], color="grey", edgecolor="black", s=size, ) # Add color bars if "fc_up" in locals(): cbar_red = fig.colorbar(fc_up, ax=ax, orientation="vertical", pad=0.0, aspect=40) cbar_red.ax.set_ylabel( f"Fold Changes- {query_ct} over {ref_ct}", rotation=90, labelpad=15, fontsize=fontsize ) cbar_red.ax.yaxis.set_label_position("left") cbar_red.ax.yaxis.label.set_horizontalalignment("right") cbar_red.ax.yaxis.label.set_position((0, 1.0)) for label in cbar_red.ax.get_yticklabels(): label.set_fontsize(fontsize) if "fc_down" in locals(): cbar_blue = fig.colorbar(fc_down, ax=ax, orientation="vertical", pad=0.1, aspect=40) cbar_blue.ax.set_ylabel( f"Fold Changes- {ref_ct} over {query_ct}", rotation=90, labelpad=15, fontsize=fontsize ) cbar_blue.ax.yaxis.set_label_position("left") cbar_blue.ax.yaxis.label.set_horizontalalignment("right") cbar_blue.ax.yaxis.label.set_position((0, 1.0)) for label in cbar_blue.ax.get_yticklabels(): label.set_fontsize(fontsize) # Add text for most significant interactions: # Get the highest fold changes: high_fold_change = results[abs(results["Fold Change"]) > fold_change_cutoff_for_labels] while high_fold_change.empty: fold_change_cutoff_for_labels /= 2 high_fold_change = results[abs(results["Fold Change"]) > fold_change_cutoff_for_labels] # Take only the top few (it is impossible to view all at once clearly): if len(high_fold_change) > 3: high_fold_change = high_fold_change.sort_values(by="Fold Change", ascending=False) high_fold_change_selected = high_fold_change.iloc[:3, :] else: high_fold_change_selected = high_fold_change # And a few more from not as high but still significant q-values: max_log10_qval = high_fold_change["-log10(qval)"].max() log10_qval_steps = [] i = 0 current_value = max_log10_qval while current_value >= 10: log10_qval_steps.append(current_value) i += 1 # Add to labels in descending half steps, with smaller steps taken if only visualizing the positive # or the negative fold changes: if plot_query_over_ref or plot_ref_over_query or plot_only_significant: step_size = 1.25 else: step_size = 1.5 current_value = max_log10_qval / (step_size**i) selected_rows = [] for value in log10_qval_steps: # Find the row closest to the current value without duplicates closest_index = abs(high_fold_change["-log10(qval)"] - value).idxmin() if closest_index not in high_fold_change_selected.index: selected_rows.append(high_fold_change.loc[closest_index]) high_fold_change_log10_qval = pd.DataFrame(selected_rows) # Combine with high_fold_change and remove duplicates high_fold_change_selected = pd.concat( [high_fold_change_selected, high_fold_change_log10_qval] ).drop_duplicates() text_labels = high_fold_change_selected.index.tolist() x_coord_text_labels = high_fold_change_selected["Fold Change"].tolist() y_coord_text_labels = high_fold_change_selected["-log10(qval)"].tolist() text_objects = [] for i, label in enumerate(text_labels): t = ax.text( x_coord_text_labels[i], y_coord_text_labels[i], label, fontsize=fontsize * 0.75, color="black", ha="center", va="center", ) text_objects.append(t) adjust_text(text_objects, ax=ax, arrowprops=dict(arrowstyle="<|-", color="black", lw=1.0)) y_label = r"$-log_{10}$(qval)" if "y_label" not in locals() else y_label ax.axhline(y=significance_cutoff, color="grey", linestyle="--", linewidth=1.5) ax.axvline(x=fold_change_cutoff, color="grey", linestyle="--", linewidth=1.5) ax.axvline(x=-fold_change_cutoff, color="grey", linestyle="--", linewidth=1.5) ax.set_xlim(max_neg * 1.1, max_pos * 1.1) ax.set_xticklabels(["{:.2f}".format(x) for x in ax.get_xticks()], fontsize=fontsize) ax.set_yticklabels(["{:.2f}".format(y) for y in ax.get_yticks()], fontsize=fontsize) ax.set_xlabel(x_label, fontsize=fontsize * 1.25) ax.set_ylabel(y_label, fontsize=fontsize * 1.25) ax.set_title(title, fontsize=fontsize * 1.5) save_kwargs["ext"] = "png" save_kwargs["dpi"] = 300 if "figure_folder" in locals(): save_kwargs["path"] = figure_folder # Save figure: save_return_show_fig_utils( save_show_or_return=save_show_or_return, show_legend=False, background="white", prefix=prefix, save_kwargs=save_kwargs, total_panels=1, fig=fig, axes=ax, return_all=False, return_all_list=None, ) if save_df: results.to_csv(os.path.join(output_folder, f"{prefix}.csv"))
[docs] def enriched_interactions_barplot( self, interactions: Optional[Union[str, List[str]]] = None, targets: Optional[Union[str, List[str]]] = None, plot_type: Literal["average", "proportion"] = "average", effect_size_threshold: float = 0.0, fontsize: Union[None, int] = None, figsize: Union[None, Tuple[float, float]] = None, cmap: str = "Reds", top_n: Optional[int] = None, save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", save_kwargs: Optional[dict] = {}, ): """Visualize the top predicted effect sizes for each interaction on particular target gene(s). Args: interactions: Optional subset of interactions to focus on, given in the form ligand(s):receptor(s), following the formatting in the design matrix. If not given, will consider all interactions that were specified in model fitting. targets: Can optionally specify a subset of the targets to compute this on. If not given, will use all targets that were specified in model fitting. If multiple targets are given, "save_show_or_return" should be "save" (and provide appropriate keyword arguments for saving using "save_kwargs"), otherwise only the last target will be shown. plot_type: Options: "average" or "proportion". Whether to plot the average effect size or the proportion of cells expressing the target predicted to be affected by the interaction. effect_size_threshold: Lower bound for average effect size to include a particular interaction in the barplot fontsize: Size of font for x and y labels figsize: Size of figure cmap: Colormap to use for barplot. It is recommended to use a sequential colormap (e.g. "Reds", "Blues", "Viridis", etc.). top_n: If given, will only include the top n features in the visualization. If not given, will include all features that pass the "effect_size_threshold". save_show_or_return: Whether to save, show or return the figure If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, displayed and the associated axis and other object will be return. save_kwargs: A dictionary that will passed to the save_fig function By default it is an empty dictionary and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those keys according to your needs. """ config_spateo_rcParams() # But set display DPI to 300: plt.rcParams["figure.dpi"] = 300 if fontsize is None: fontsize = rcParams.get("font.size") if interactions is None: interactions = self.X_df.columns.tolist() elif isinstance(interactions, str): interactions = [interactions] elif not isinstance(interactions, list): raise TypeError(f"Interactions must be a list or string, not {type(interactions)}.") # Predictions: parent_dir = os.path.dirname(self.output_path) pred_path = os.path.join(parent_dir, "predictions.csv") predictions = pd.read_csv(pred_path, index_col=0) if save_show_or_return in ["save", "both", "all"]: if not os.path.exists(os.path.join(os.path.dirname(self.output_path), "figures")): os.makedirs(os.path.join(os.path.dirname(self.output_path), "figures")) figure_folder = os.path.join(os.path.dirname(self.output_path), "figures", "temp") if not os.path.exists(figure_folder): os.makedirs(figure_folder) if targets is None: targets = self.targets_expr.columns elif isinstance(targets, str): targets = [targets] elif not isinstance(targets, list): raise TypeError(f"targets must be a list or string, not {type(targets)}.") for target in targets: # Get coefficients for this key coef = self.coeffs[target] effects = coef[[col for col in coef.columns if col.startswith("b_") and "intercept" not in col]] effects.columns = [col.split("_")[1] for col in effects.columns] # Subset to only the interactions of interest: interactions = [interaction for interaction in interactions if interaction in effects.columns] effects = effects[interactions] target_expr = self.adata[:, target].X.toarray().squeeze() > 0 target_pred = predictions[target] target_pred_np = target_pred.values.astype(bool) if plot_type == "average": # Subset to cells expressing the target that are predicted to be expressing the target: target_true_pos_indices = np.where(target_expr & target_pred_np)[0] target_expr_sub = self.adata[target_true_pos_indices, :].copy() # Subset effects dataframe to same subset: effects_sub = effects.loc[target_expr_sub.obs_names, :] # Compute average for each column: to_plot = effects_sub.mean(axis=0) elif plot_type == "proportion": # Get proportion of cells expressing the target that are predicted to be affected by particular # molecule: effects_sub = effects.loc[target_expr == 1, :] to_plot = (effects_sub > 0).mean(axis=0) else: raise ValueError(f"Unrecognized input for to_plot: {plot_type}. Options are 'average' or 'proportion'.") # Filter based on the threshold to_plot = to_plot[to_plot > effect_size_threshold] # Sort the average_expression in descending order to_plot = to_plot.sort_values(ascending=False) if self.mod_type == "ligand": to_plot.index = [replace_col_with_collagens(idx) for idx in to_plot.index] to_plot.index = [replace_hla_with_hlas(idx) for idx in to_plot.index] if top_n is not None: to_plot = to_plot.iloc[:top_n] # Plot: if figsize is None: # Set figure size based on the number of interaction features and targets: m = len(to_plot) / 2 figsize = (m, 5) fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize) palette = sns.color_palette(cmap, n_colors=len(to_plot)) sns.barplot( x=to_plot.index, y=to_plot.values, edgecolor="black", linewidth=1, palette=palette, ax=ax, ) ax.set_xticklabels(to_plot.index, rotation=90, fontsize=fontsize) ax.set_yticklabels(["{:.2f}".format(y) for y in ax.get_yticks()], fontsize=fontsize) ax.set_xlabel("Interaction (ligand(s):receptor(s))", fontsize=fontsize) if plot_type == "average": title = f"Average Predicted Interaction Effects on {target}" ylabel = "Mean Coefficient \nMagnitude" elif plot_type == "proportion": if top_n is not None and top_n < 20: title = f"Proportion of {target}-Expressing Cells \nPredicted to be Affected by Interaction" else: title = f"Proportion of {target}-Expressing Cells Predicted to be Affected by Interaction" ylabel = "Proportion of Cells" ax.set_ylabel("Mean Coefficient \nMagnitude", fontsize=fontsize) ax.set_title(title, fontsize=fontsize) # Use the saved name for the AnnData object to define part of the name of the saved file: base_name = os.path.basename(self.adata_path) adata_id = os.path.splitext(base_name)[0] prefix = f"{adata_id}_interaction_barplot_{target}" save_kwargs["ext"] = "png" save_kwargs["dpi"] = 300 if "figure_folder" in locals(): save_kwargs["path"] = figure_folder # Save figure: save_return_show_fig_utils( save_show_or_return=save_show_or_return, show_legend=False, background="white", prefix=prefix, save_kwargs=save_kwargs, total_panels=1, fig=fig, axes=ax, return_all=False, return_all_list=None, )
[docs] def summarize_interaction_effects( self, interactions: Optional[Union[str, List[str]]] = None, targets: Optional[Union[str, List[str]]] = None, effect_size_threshold: float = 0.0, ): """Summarize the interaction effects for each target gene in dataframe format. Each element will be the average effect size for a particular interaction on a particular target gene. Args: interactions: Optional subset of interactions to focus on. If not given, will consider all interactions. targets: Can optionally specify a subset of the targets. If not given, will use all targets. effect_size_threshold: Lower bound for average effect size to include a particular interaction. Returns: effects_df: Dataframe with the average effect size for each interaction (rows) on each target gene ( columns). """ if interactions is None: interactions = self.X_df.columns.tolist() elif isinstance(interactions, str): interactions = [interactions] elif not isinstance(interactions, list): raise TypeError(f"Interactions must be a list or string, not {type(interactions)}.") if targets is None: targets = self.targets_expr.columns elif isinstance(targets, str): targets = [targets] elif not isinstance(targets, list): raise TypeError(f"Targets must be a list or string, not {type(targets)}.") # Predictions: parent_dir = os.path.dirname(self.output_path) pred_path = os.path.join(parent_dir, "predictions.csv") predictions = pd.read_csv(pred_path, index_col=0) # Initialize the DataFrame to store the results effects_df = pd.DataFrame(0.0, index=interactions, columns=targets) for target in targets: coef = self.coeffs[target] effects = coef[[col for col in coef.columns if col.startswith("b_") and "intercept" not in col]] effects.columns = [col.split("_")[1] for col in effects.columns] # Subset to only the interactions of interest: interactions = [interaction for interaction in interactions if interaction in effects.columns] effects = effects[interactions] target_expr = self.adata[:, target].X.toarray().squeeze() > 0 target_pred = predictions[target] target_pred_np = target_pred.values.astype(bool) # Subset to cells expressing the target that are predicted to be expressing the target: target_true_pos_indices = np.where(target_expr & target_pred_np)[0] target_expr_sub = self.adata[target_true_pos_indices, :].copy() # Subset effects dataframe to same subset: effects_sub = effects.loc[target_expr_sub.obs_names, :] average_effect_size = effects_sub.mean(axis=0) # Filter interactions based on threshold and add to results DataFrame filtered_effect_sizes = average_effect_size[average_effect_size > effect_size_threshold] effects_df[target] = filtered_effect_sizes effects_df = effects_df.replace(np.nan, 0.0) return effects_df
[docs] def enriched_tfs_barplot( self, tfs: Optional[Union[str, List[str]]] = None, targets: Optional[Union[str, List[str]]] = None, target_type: Literal["ligand", "receptor", "target_gene"] = "target_gene", plot_type: Literal["average", "proportion"] = "average", effect_size_threshold: float = 0.0, fontsize: Union[None, int] = None, figsize: Union[None, Tuple[float, float]] = None, cmap: str = "Reds", top_n: Optional[int] = None, save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", save_kwargs: Optional[dict] = {}, ): """Visualize the top predicted effect sizes for each transcription factor on particular target gene(s). Args: tfs: Optional subset of transcription factors to focus on. If not given, will consider all transcription factors that were specified in model fitting. targets: Can optionally specify a subset of the targets to compute this on. If not given, will use all targets that were specified in model fitting. If multiple targets are given, "save_show_or_return" should be "save" (and provide appropriate keyword arguments for saving using "save_kwargs"), otherwise only the last target will be shown. target_type: Set whether the given targets are ligands, receptors or target genes. Used to determine which folder to check for outputs. plot_type: Options: "average" or "proportion". Whether to plot the average effect size or the proportion of cells expressing the target predicted to be affected by the interaction. effect_size_threshold: Lower bound for average effect size to include a particular interaction in the barplot fontsize: Size of font for x and y labels figsize: Size of figure cmap: Colormap to use for barplot. It is recommended to use a sequential colormap (e.g. "Reds", "Blues", "Viridis", etc.). top_n: If given, will only include the top n features in the visualization. If not given, will include all features that pass the "effect_size_threshold". save_show_or_return: Whether to save, show or return the figure If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, displayed and the associated axis and other object will be return. save_kwargs: A dictionary that will passed to the save_fig function By default it is an empty dictionary and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those keys according to your needs. """ config_spateo_rcParams() # But set display DPI to 300: plt.rcParams["figure.dpi"] = 300 if fontsize is None: fontsize = rcParams.get("font.size") if target_type == "ligand": coeffs = self.downstream_model_ligand_coeffs if tfs is None: tfs = self.downstream_model_ligand_design_matrix.columns.tolist() elif target_type == "receptor": coeffs = self.downstream_model_receptor_coeffs if tfs is None: tfs = self.downstream_model_receptor_design_matrix.columns.tolist() elif target_type == "target_gene": coeffs = self.downstream_model_target_coeffs if tfs is None: tfs = self.downstream_model_target_design_matrix.columns.tolist() else: raise ValueError( f"Unrecognized input for target_type: {target_type}. Options are 'ligand', 'receptor', " f"or 'target_gene'." ) tfs = [tf.replace("regulator_", "") for tf in tfs] if isinstance(tfs, str): interactions = [tfs] elif not isinstance(tfs, list): raise TypeError(f"TFs must be a list or string, not {type(tfs)}.") # Predictions: downstream_parent_dir = os.path.dirname(os.path.splitext(self.output_path)[0]) id = os.path.splitext(os.path.basename(self.output_path))[0] if target_type == "ligand": folder = "ligand_analysis" elif target_type == "receptor": folder = "receptor_analysis" elif target_type == "target_gene": folder = "target_gene_analysis" pred_path = os.path.join(downstream_parent_dir, "cci_deg_detection", folder, "downstream/predictions.csv") predictions = pd.read_csv(pred_path, index_col=0) if save_show_or_return in ["save", "both", "all"]: if not os.path.exists(os.path.join(os.path.dirname(self.output_path), "figures")): os.makedirs(os.path.join(os.path.dirname(self.output_path), "figures")) figure_folder = os.path.join(os.path.dirname(self.output_path), "figures", "temp") if not os.path.exists(figure_folder): os.makedirs(figure_folder) if targets is None: targets = coeffs.keys() elif isinstance(targets, str): targets = [targets] elif not isinstance(targets, list): raise TypeError(f"targets must be a list or string, not {type(targets)}.") for target in targets: # Get coefficients for this key coef = coeffs[target] effects = coef[[col for col in coef.columns if col.startswith("b_") and "intercept" not in col]] effects.columns = [col.split("_")[1] for col in effects.columns] # Subset to only the TFs of interest: tfs = [tf for tf in tfs if tf in effects.columns] effects = effects[tfs] target_expr = self.adata[:, target].X.toarray().squeeze() > 0 target_pred = predictions[target] target_pred_np = target_pred.values.astype(bool) if plot_type == "average": # Subset to cells expressing the target that are predicted to be expressing the target: target_true_pos_indices = np.where(target_expr & target_pred_np)[0] target_expr_sub = self.adata[target_true_pos_indices, :].copy() # Subset effects dataframe to same subset: effects_sub = effects.loc[target_expr_sub.obs_names, :] # Compute average for each column: to_plot = effects_sub.mean(axis=0) elif plot_type == "proportion": # Get proportion of cells expressing the target that are predicted to be affected by particular # molecule: effects_sub = effects.loc[target_expr == 1, :] to_plot = (effects_sub > 0).mean(axis=0) else: raise ValueError(f"Unrecognized input for to_plot: {plot_type}. Options are 'average' or 'proportion'.") # Filter based on the threshold to_plot = to_plot[to_plot > effect_size_threshold] # Sort the average_expression in descending order to_plot = to_plot.sort_values(ascendin