"""
Additional functionalities to characterize signaling patterns from spatial transcriptomics
These include:
- prediction of the effects of spatial perturbation on gene expression- this can include the effect of perturbing
known regulators of ligand/receptor expression or the effect of perturbing the ligand/receptor itself.
- following spatially-aware regression (or a sequence of spatially-aware regressions), combine regression results
with data such that each cell can be associated with region-specific coefficient(s).
- following spatially-aware regression (or a sequence of spatially-aware regressions), overlay the directionality
of the predicted influence of the ligand on downstream expression.
"""
import argparse
import collections
import gc
import itertools
import math
import os
from collections import Counter
from itertools import product
from typing import List, Literal, Optional, Tuple, Union
import anndata
import matplotlib as mpl
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import plotly
import plotly.graph_objs as go
import scipy.cluster.hierarchy as sch
import scipy.sparse
import scipy.stats
import seaborn as sns
from adjustText import adjust_text
from joblib import Parallel, delayed
from matplotlib import rcParams
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.stats import mannwhitneyu, pearsonr, spearmanr, ttest_1samp, ttest_ind
from sklearn.decomposition import TruncatedSVD
from sklearn.metrics import (
confusion_matrix,
f1_score,
mean_squared_error,
roc_auc_score,
)
from sklearn.preprocessing import normalize
from tqdm.auto import tqdm
from ...configuration import config_spateo_rcParams
from ...logging import logger_manager as lm
from ...plotting.static.colorlabel import godsnot_102, vega_10
from ...plotting.static.networks import plot_network
from ...plotting.static.utils import save_return_show_fig_utils
from ...tools.find_neighbors import neighbors
from ...tools.utils import filter_adata_spatial
from ..dimensionality_reduction import find_optimal_pca_components, pca_fit
from ..utils import compute_corr_ci, create_new_coordinate
from .MuSIC import MuSIC
from .regression_utils import assign_significance, multitesting_correction, wald_test
from .SWR import define_spateo_argparse
# ---------------------------------------------------------------------------------------------------
# Statistical testing, correlated differential expression analysis
# ---------------------------------------------------------------------------------------------------
[docs]class MuSIC_Interpreter(MuSIC):
"""
Interpretation and downstream analysis of spatially weighted regression models.
Args:
parser: ArgumentParser object initialized with argparse, to parse command line arguments for arguments
pertinent to modeling.
args_list: If parser is provided by function call, the arguments to parse must be provided as a separate
list. It is recommended to use the return from :func `define_spateo_argparse()` for this.
keep_coeff_threshold_proportion_cells: If provided, will threshold columns to only keep those that are
nonzero in a proportion of cells greater than this threshold. For example, if this is set to 0.5,
more than half of the cells must have a nonzero value for a given column for it to be retained for
further inspection. Intended to be used to filter out likely false positives.
"""
def __init__(
self,
parser: argparse.ArgumentParser,
args_list: Optional[List[str]] = None,
keep_column_threshold_proportion_cells: Optional[float] = None,
):
# Don't need to re-save the subsampling results, they have already been defined:
super().__init__(parser, args_list, verbose=False, save_subsampling=False)
[docs] self.k = self.arg_retrieve.top_k_receivers
# Coefficients:
if not self.set_up:
self.logger.info(
"Running :func `SWR._set_up_model()` to organize predictors and targets for downstream "
"analysis now..."
)
self._set_up_model(verbose=False)
# self.logger.info("Finished preprocessing, getting fitted coefficients and standard errors.")
# Dictionary containing coefficients:
self.coeffs, self.standard_errors = self.return_outputs(adjust_for_subsampling=False, load_for_interpreter=True)
[docs] n_cells_expressing_targets = self.targets_expr.apply(lambda x: sum(x > 0), axis=0)
if keep_column_threshold_proportion_cells is not None:
keep_column_threshold_proportion_cells = 0.01
for target, df in self.coeffs.items():
# Threshold columns to only keep those that are nonzero in a proportion of cells greater than this
# threshold:
threshold = int(keep_column_threshold_proportion_cells * n_cells_expressing_targets[target])
for col in df.columns:
if sum(df[col] != 0) < threshold:
df[col] = 0
self.standard_errors[target][col] = 0
self.coeffs[target] = df
# Check for coefficients and feature names of downstream models as well:
self.logger.info("Checking for coefficients of possible downstream models as well...")
[docs] downstream_parent_dir = os.path.dirname(os.path.splitext(self.output_path)[0])
[docs] id = os.path.basename(os.path.splitext(self.output_path)[0])
self.downstream_model_ligand_coeffs, self.downstream_model_ligand_standard_errors = self.return_outputs(
adjust_for_subsampling=False, load_for_interpreter=True, load_from_downstream="ligand"
)
[docs] dm_dir = os.path.join(
downstream_parent_dir,
"cci_deg_detection",
"ligand_analysis",
id,
"downstream_design_matrix",
"design_matrix.csv",
)
[docs] self.downstream_model_ligand_design_matrix = (
pd.read_csv(dm_dir, index_col=0) if os.path.exists(dm_dir) else None
)
self.downstream_model_receptor_coeffs, self.downstream_model_receptor_standard_errors = self.return_outputs(
adjust_for_subsampling=False, load_for_interpreter=True, load_from_downstream="receptor"
)
dm_dir = os.path.join(
downstream_parent_dir,
"cci_deg_detection",
"receptor_analysis",
id,
"downstream_design_matrix",
"design_matrix.csv",
)
[docs] self.downstream_model_receptor_design_matrix = (
pd.read_csv(dm_dir, index_col=0) if os.path.exists(dm_dir) else None
)
self.downstream_model_target_coeffs, self.downstream_model_target_standard_errors = self.return_outputs(
adjust_for_subsampling=False, load_for_interpreter=True, load_from_downstream="target_gene"
)
dm_dir = os.path.join(
downstream_parent_dir,
"cci_deg_detection",
"target_gene_analysis",
id,
"downstream_design_matrix",
"design_matrix.csv",
)
[docs] self.downstream_model_target_design_matrix = (
pd.read_csv(dm_dir, index_col=0) if os.path.exists(dm_dir) else None
)
# Design matrix:
[docs] self.design_matrix = pd.read_csv(
os.path.join(os.path.splitext(self.output_path)[0], "design_matrix", "design_matrix.csv"), index_col=0
)
# If predictions of an L:R model have been computed, load these as well:
if os.path.exists(os.path.join(os.path.dirname(self.output_path), "predictions.csv")):
self.predictions = pd.read_csv(
os.path.join(os.path.dirname(self.output_path), "predictions.csv"), index_col=0
)
# Save directory:
[docs] parent_dir = os.path.dirname(self.output_path)
if not os.path.exists(os.path.join(parent_dir, "significance")):
os.makedirs(os.path.join(parent_dir, "significance"))
# Arguments for cell type coupling computation:
[docs] self.filter_targets = self.arg_retrieve.filter_targets
[docs] self.filter_target_threshold = self.arg_retrieve.filter_target_threshold
# Get targets for the downstream ligand(s), receptor(s), target(s), etc. to use for analysis:
[docs] self.ligand_for_downstream = self.arg_retrieve.ligand_for_downstream
[docs] self.receptor_for_downstream = self.arg_retrieve.receptor_for_downstream
[docs] self.pathway_for_downstream = self.arg_retrieve.pathway_for_downstream
[docs] self.target_for_downstream = self.arg_retrieve.target_for_downstream
[docs] self.sender_ct_for_downstream = self.arg_retrieve.sender_ct_for_downstream
[docs] self.receiver_ct_for_downstream = self.arg_retrieve.receiver_ct_for_downstream
# Other downstream analysis-pertinent argparse arguments:
[docs] self.cci_degs_model_interactions = self.arg_retrieve.cci_degs_model_interactions
[docs] self.no_cell_type_markers = self.arg_retrieve.no_cell_type_markers
[docs] self.compute_pathway_effect = self.arg_retrieve.compute_pathway_effect
[docs] self.diff_sending_or_receiving = self.arg_retrieve.diff_sending_or_receiving
[docs] def compute_coeff_significance(self, method: str = "fdr_bh", significance_threshold: float = 0.05):
"""Computes local statistical significance for fitted coefficients.
Args:
method: Method to use for correction. Available methods can be found in the documentation for
statsmodels.stats.multitest.multipletests(), and are also listed below (in correct case) for
convenience:
- Named methods:
- bonferroni
- sidak
- holm-sidak
- holm
- simes-hochberg
- hommel
- Abbreviated methods:
- fdr_bh: Benjamini-Hochberg correction
- fdr_by: Benjamini-Yekutieli correction
- fdr_tsbh: Two-stage Benjamini-Hochberg
- fdr_tsbky: Two-stage Benjamini-Krieger-Yekutieli method
significance_threshold: p-value (or q-value) needed to call a parameter significant.
Returns:
is_significant: Dataframe of identical shape to coeffs, where each element is True or False if it meets the
threshold for significance
pvalues: Dataframe of identical shape to coeffs, where each element is a p-value for that instance of that
feature
qvalues: Dataframe of identical shape to coeffs, where each element is a q-value for that instance of that
feature
"""
self.logger.info(
"Computing significance for all coefficients, note this may take a long time for large "
"datasets (> 10k cells)..."
)
for target in self.coeffs.keys():
# Check for existing file:
parent_dir = os.path.dirname(self.output_path)
if os.path.exists(os.path.join(parent_dir, "significance", f"{target}_is_significant.csv")):
self.logger.info(f"Significance already computed for target {target}, moving to the next...")
continue
# Get coefficients and standard errors for this key
coef = self.coeffs[target]
columns = [col for col in coef.columns if col.startswith("b_") and "intercept" not in col]
coef = coef[columns]
se = self.standard_errors[target]
se_feature_match = [c.replace("se_", "") for c in se.columns]
def compute_p_value(cell_name, feat):
return wald_test(coef.loc[cell_name, f"b_{feat}"], se.loc[cell_name, f"se_{feat}"])
filtered_tasks = [
(cell_name, feat)
for cell_name, feat in product(self.sample_names, self.feature_names)
if feat in se_feature_match
and se.loc[cell_name, f"se_{feat}"] != 0
and coef.loc[cell_name, f"b_{feat}"] != 0
]
# Parallelize computations for filtered tasks
results = Parallel(n_jobs=-1)(
delayed(compute_p_value)(cell_name, feat)
for cell_name, feat in tqdm(filtered_tasks, desc=f"Processing for target {target}")
)
# Convert results to a DataFrame
results_df = pd.DataFrame(
results, index=pd.MultiIndex.from_tuples(filtered_tasks, names=["sample", "feature"])
)
p_values_all = pd.DataFrame(1, index=self.sample_names, columns=self.feature_names)
p_values_all.update(results_df.unstack(level="feature").droplevel(0, axis=1))
# Multiple testing correction for each observation:
qvals = np.zeros_like(p_values_all.values)
for i in range(p_values_all.shape[0]):
qvals[i, :] = multitesting_correction(
p_values_all.iloc[i, :], method=method, alpha=significance_threshold
)
q_values_df = pd.DataFrame(qvals, index=self.sample_names, columns=self.feature_names)
# Significance:
is_significant_df = q_values_df < significance_threshold
# Save dataframes:
parent_dir = os.path.dirname(self.output_path)
p_values_all.to_csv(os.path.join(parent_dir, "significance", f"{target}_p_values.csv"))
q_values_df.to_csv(os.path.join(parent_dir, "significance", f"{target}_q_values.csv"))
is_significant_df.to_csv(os.path.join(parent_dir, "significance", f"{target}_is_significant.csv"))
self.logger.info(f"Finished computing significance for target {target}.")
[docs] def filter_adata_spatial(self, instructions: List[str]):
"""Based on spatial coordinates, filter the adata object to only include cells that meet the criteria.
Criteria provided in the form of a list of instructions of the form "x less than 0.5 and y greater than 0.5",
etc., where each instruction is executed sequentially.
Args:
instructions: List of instructions to filter adata object by. Each instruction is a string of the form
"x less than 0.5 and y greater than 0.5", etc., where each instruction is executed sequentially.
"""
adata_filt = filter_adata_spatial(self.adata, self.coords_key, instructions)
# Cells still left post-filter
self.remaining_cells = adata_filt.obs_names
self.remaining_indices = np.where(self.adata.obs_names.isin(self.remaining_cells))[0]
[docs] def filter_adata_custom(self, cell_ids: List[str]):
"""Filter AnnData object to only the cells specified by the custom list.
Args:
cell_ids: List of cell IDs to keep. Each ID must be found in adata.obs_names
"""
self.remaining_cells = cell_ids
self.remaining_indices = np.where(self.adata.obs_names.isin(self.remaining_cells))[0]
[docs] def add_interaction_effect_to_adata(
self,
targets: Union[str, List[str]],
interactions: Union[str, List[str]],
visualize: bool = False,
) -> anndata.AnnData:
"""For each specified interaction/list of interactions, add the predicted interaction effect to the adata
object.
Args:
targets: Target(s) to add interaction effect for. Can be a single target or a list of targets.
interactions: Interaction(s) to add interaction effect for. Can be a single interaction or a list of
interactions. Should be the name of a gene for ligand models, or an L:R pair for L:R models (for
example, "Igf1:Igf1r").
visualize: Whether to visualize the interaction effect for each target/interaction pair. If True,
will generate spatial scatter plot and save to HTML file.
Returns:
adata: AnnData object with interaction effects added to .obs.
"""
if visualize:
figure_folder = os.path.join(os.path.dirname(self.output_path), "figures")
if not os.path.exists(figure_folder):
os.makedirs(figure_folder)
if not isinstance(targets, list):
targets = [targets]
if not isinstance(interactions, list):
interactions = [interactions]
if hasattr(self, "remaining_indices"):
adata = self.adata[self.remaining_cells, :].copy()
else:
adata = self.adata.copy()
combinations = list(product(targets, interactions))
for target, interaction in combinations:
if f"b_{interaction}" not in self.coeffs[target].columns:
self.logger.info(
f"Information for interaction {interaction} not found for target {target}, " f"skipping..."
)
continue
if hasattr(self, "remaining_indices"):
target_coefs = self.coeffs[target].loc[self.remaining_cells, f"b_{interaction}"]
else:
target_coefs = self.coeffs[target][f"b_{interaction}"]
# Add to adata:
adata.obs[f"{target}_{interaction}_effect"] = target_coefs
if visualize:
# plotly to create 3D scatter plot:
spatial_coords = adata.obsm[self.coords_key]
if spatial_coords.shape[1] == 2:
x, y = spatial_coords[:, 0], spatial_coords[:, 1]
z = np.zeros(len(x))
else:
x, y, z = spatial_coords[:, 0], spatial_coords[:, 1], spatial_coords[:, 2]
plot_data = adata.obs[f"{target}_{interaction}_effect"]
p997 = np.percentile(plot_data.values, 99.7)
plot_data[plot_data > p997] = p997
plot_vals = plot_data.values
scatter = go.Scatter3d(
x=x,
y=y,
z=z,
mode="markers",
marker=dict(
color=plot_vals,
colorscale="Magma",
size=2,
colorbar=dict(
title=f"{interaction.title()} Effect on {target.title()}",
x=0.8,
titlefont=dict(size=16),
tickfont=dict(size=18),
),
),
showlegend=False,
)
fig = go.Figure(data=scatter)
title_dict = dict(
text=f"{interaction.title()} Effect on {target.title()}",
y=0.9,
yanchor="top",
x=0.5,
xanchor="center",
font=dict(size=28),
)
# Turn off the grid
fig.update_layout(
showlegend=True,
legend=dict(x=0.65, y=0.85, orientation="v", font=dict(size=18)),
scene=dict(
aspectmode="data",
xaxis=dict(
showgrid=False,
showline=False,
linewidth=2,
linecolor="black",
backgroundcolor="white",
title="",
showticklabels=False,
ticks="",
),
yaxis=dict(
showgrid=False,
showline=False,
linewidth=2,
linecolor="black",
backgroundcolor="white",
title="",
showticklabels=False,
ticks="",
),
zaxis=dict(
showgrid=False,
showline=False,
linewidth=2,
linecolor="black",
backgroundcolor="white",
title="",
showticklabels=False,
ticks="",
),
),
margin=dict(l=0, r=0, b=0, t=50), # Adjust margins to minimize spacing
title=title_dict,
)
path = os.path.join(figure_folder, f"{interaction}_effect_on_{target}.html")
fig.write_html(path)
return adata
[docs] def compute_and_visualize_diagnostics(
self, type: Literal["correlations", "confusion", "rmse"], n_genes_per_plot: int = 20
):
"""
For true and predicted gene expression, compute and generate either: confusion matrices,
or correlations, including the Pearson correlation, Spearman correlation, or root mean-squared-error (RMSE).
Args:
type: Type of diagnostic to compute and visualize. Options: "correlations" for Pearson & Spearman
correlation, "confusion" for confusion matrix, "rmse" for root mean-squared-error.
n_genes_per_plot: Only used if "type" is "confusion". Number of genes to plot per figure. If there are
more than this number of genes, multiple figures will be generated.
"""
# Plot title:
file_name = os.path.splitext(os.path.basename(self.adata_path))[0]
parent_dir = os.path.dirname(self.output_path)
pred_path = os.path.join(parent_dir, "predictions.csv")
predictions = pd.read_csv(pred_path, index_col=0)
all_genes = predictions.columns
width = 0.5 * len(all_genes)
pred_vals = predictions.values
if type == "correlations":
# Pearson and Spearman dictionary for all cells:
pearson_dict = {}
spearman_dict = {}
# Pearson and Spearman dictionary for only the expressing subset of cells:
nz_pearson_dict = {}
nz_spearman_dict = {}
for i, gene in enumerate(all_genes):
y = self.adata[:, gene].X.toarray().reshape(-1)
music_results_target = pred_vals[:, i]
# Remove index of the largest predicted value (to mitigate sensitivity of these metrics to outliers):
outlier_index = np.where(np.max(music_results_target))[0]
music_results_target_to_plot = np.delete(music_results_target, outlier_index)
y_plot = np.delete(y, outlier_index)
# Indices where target is nonzero:
nonzero_indices = y_plot != 0
rp, _ = pearsonr(y_plot, music_results_target_to_plot)
r, _ = spearmanr(y_plot, music_results_target_to_plot)
rp_nz, _ = pearsonr(y_plot[nonzero_indices], music_results_target_to_plot[nonzero_indices])
r_nz, _ = spearmanr(y_plot[nonzero_indices], music_results_target_to_plot[nonzero_indices])
pearson_dict[gene] = rp
spearman_dict[gene] = r
nz_pearson_dict[gene] = rp_nz
nz_spearman_dict[gene] = r_nz
# Mean of diagnostic metrics:
mean_pearson = sum(pearson_dict.values()) / len(pearson_dict.values())
mean_spearman = sum(spearman_dict.values()) / len(spearman_dict.values())
mean_nz_pearson = sum(nz_pearson_dict.values()) / len(nz_pearson_dict.values())
mean_nz_spearman = sum(nz_spearman_dict.values()) / len(nz_spearman_dict.values())
data = []
for gene in pearson_dict.keys():
data.append(
{
"Gene": gene,
"Pearson coefficient": pearson_dict[gene],
"Spearman coefficient": spearman_dict[gene],
"Pearson coefficient (expressing cells)": nz_pearson_dict[gene],
"Spearman coefficient (expressing cells)": nz_spearman_dict[gene],
}
)
# Color palette:
colors = {
"Pearson coefficient": "#FF7F00",
"Spearmann coefficient": "#87CEEB",
"Pearson coefficient (expressing cells)": "#0BDA51",
"Spearmann coefficient (expressing cells)": "#FF6961",
}
df = pd.DataFrame(data)
# Plot Pearson correlation barplot:
sns.set(font_scale=2)
sns.set_style("white")
plt.figure(figsize=(width, 6))
plt.xticks(rotation="vertical")
ax = sns.barplot(
data=df,
x="Gene",
y="Pearson coefficient",
palette=colors["Pearson coefficient"],
edgecolor="black",
dodge=True,
)
# Mean line:
line_style = "--"
line_thickness = 2
ax.axhline(mean_pearson, color="black", linestyle=line_style, linewidth=line_thickness)
# Update legend:
legend_label = f"Mean: {mean_pearson}"
handles, labels = ax.get_legend_handles_labels()
handles.append(plt.Line2D([0], [0], color="black", linewidth=line_thickness, linestyle=line_style))
labels.append(legend_label)
ax.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 0.5))
plt.title(f"Pearson correlation {file_name}")
plt.tight_layout()
plt.show()
# Plot Spearman correlation barplot:
plt.figure(figsize=(width, 6))
plt.xticks(rotation="vertical")
ax = sns.barplot(
data=df,
x="Gene",
y="Spearman coefficient",
palette=colors["Spearman coefficient"],
edgecolor="black",
dodge=True,
)
# Mean line:
ax.axhline(mean_spearman, color="black", linestyle=line_style, linewidth=line_thickness)
# Update legend:
legend_label = f"Mean: {mean_spearman}"
handles, labels = ax.get_legend_handles_labels()
handles.append(plt.Line2D([0], [0], color="black", linewidth=line_thickness, linestyle=line_style))
labels.append(legend_label)
ax.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 0.5))
plt.title(f"Spearman correlation {file_name}")
plt.tight_layout()
plt.show()
# Plot Pearson correlation barplot (expressing cells):
plt.figure(figsize=(width, 6))
plt.xticks(rotation="vertical")
ax = sns.barplot(
data=df,
x="Gene",
y="Pearson coefficient (expressing cells)",
palette=colors["Pearson coefficient (expressing cells)"],
edgecolor="black",
dodge=True,
)
# Mean line:
ax.axhline(mean_nz_pearson, color="black", linestyle=line_style, linewidth=line_thickness)
# Update legend:
legend_label = f"Mean: {mean_nz_pearson}"
handles, labels = ax.get_legend_handles_labels()
handles.append(plt.Line2D([0], [0], color="black", linewidth=line_thickness, linestyle=line_style))
labels.append(legend_label)
ax.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 0.5))
plt.title(f"Pearson correlation (expressing cells) {file_name}")
plt.tight_layout()
plt.show()
# Plot Spearman correlation barplot (expressing cells):
plt.figure(figsize=(width, 6))
plt.xticks(rotation="vertical")
ax = sns.barplot(
data=df,
x="Gene",
y="Spearman coefficient (expressing cells)",
palette=colors["Spearman coefficient (expressing cells)"],
edgecolor="black",
dodge=True,
)
# Mean line:
ax.axhline(mean_nz_spearman, color="black", linestyle=line_style, linewidth=line_thickness)
# Update legend:
legend_label = f"Mean: {mean_nz_spearman}"
handles, labels = ax.get_legend_handles_labels()
handles.append(plt.Line2D([0], [0], color="black", linewidth=line_thickness, linestyle=line_style))
labels.append(legend_label)
ax.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 0.5))
plt.title(f"Spearman correlation (expressing cells) {file_name}")
plt.tight_layout()
plt.show()
elif type == "confusion":
confusion_matrices = {}
for i, gene in enumerate(all_genes):
y = self.adata[:, gene].X.toarray().reshape(-1)
music_results_target = pred_vals[:, i]
predictions_binary = (music_results_target > 0).astype(int)
y_binary = (y > 0).astype(int)
confusion_matrices[gene] = confusion_matrix(y_binary, predictions_binary)
total_figs = int(math.ceil(len(all_genes) / n_genes_per_plot))
for fig_index in range(total_figs):
start_index = fig_index * n_genes_per_plot
end_index = min(start_index + n_genes_per_plot, len(all_genes))
genes_to_plot = all_genes[start_index:end_index]
fig, axs = plt.subplots(1, len(genes_to_plot), figsize=(width, width / 5))
axs = axs.flatten()
for i, gene in enumerate(genes_to_plot):
sns.heatmap(
confusion_matrices[gene],
annot=True,
fmt="d",
cmap="Blues",
ax=axs[i],
cbar=False,
xticklabels=["Predicted \nnot expressed", "Predicted \nexpressed"],
yticklabels=["Actual \nnot expressed", "Actual \nexpressed"],
)
axs[i].set_title(gene)
# Hide any unused subplots on the last figure if total genes don't fill up the grid
for j in range(len(genes_to_plot), len(axs)):
axs[j].axis("off")
plt.tight_layout()
# Save confusion matrices:
parent_dir = os.path.dirname(self.output_path)
plt.savefig(os.path.join(parent_dir, f"confusion_matrices_{fig_index}.png"), bbox_inches="tight")
elif type == "rmse":
rmse_dict = {}
nz_rmse_dict = {}
for i, gene in enumerate(all_genes):
y = self.adata[:, gene].X.toarray().reshape(-1)
music_results_target = pred_vals[:, i]
rmse_dict[gene] = np.sqrt(mean_squared_error(y, music_results_target))
# Indices where target is nonzero:
nonzero_indices = y != 0
nz_rmse_dict[gene] = np.sqrt(
mean_squared_error(y[nonzero_indices], music_results_target[nonzero_indices])
)
mean_rmse = sum(rmse_dict.values()) / len(rmse_dict.values())
mean_nz_rmse = sum(nz_rmse_dict.values()) / len(nz_rmse_dict.values())
data = []
for gene in rmse_dict.keys():
data.append({"Gene": gene, "RMSE": rmse_dict[gene], "RMSE (expressing cells)": mean_nz_rmse[gene]})
# Color palette:
colors = {"RMSE": "#FF7F00", "RMSE (expressing cells)": "#87CEEB"}
df = pd.DataFrame(data)
# Plot RMSE barplot:
sns.set(font_scale=2)
sns.set_style("white")
plt.figure(figsize=(width, 6))
plt.xticks(rotation="vertical")
ax = sns.barplot(
data=df,
x="Gene",
y="RMSE",
palette=colors["RMSE"],
edgecolor="black",
dodge=True,
)
# Mean line:
line_style = "--"
line_thickness = 2
ax.axhline(mean_rmse, color="black", linestyle=line_style, linewidth=line_thickness)
# Update legend:
legend_label = f"Mean: {mean_rmse}"
handles, labels = ax.get_legend_handles_labels()
handles.append(plt.Line2D([0], [0], color="black", linewidth=line_thickness, linestyle=line_style))
labels.append(legend_label)
ax.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 0.5))
plt.title(f"RMSE {file_name}")
plt.tight_layout()
plt.show()
# Plot RMSE barplot (expressing cells):
plt.figure(figsize=(width, 6))
plt.xticks(rotation="vertical")
ax = sns.barplot(
data=df,
x="Gene",
y="RMSE (expressing cells)",
palette=colors["RMSE (expressing cells)"],
edgecolor="black",
dodge=True,
)
# Mean line:
ax.axhline(mean_nz_rmse, color="black", linestyle=line_style, linewidth=line_thickness)
# Update legend:
legend_label = f"Mean: {mean_nz_rmse}"
handles, labels = ax.get_legend_handles_labels()
handles.append(plt.Line2D([0], [0], color="black", linewidth=line_thickness, linestyle=line_style))
labels.append(legend_label)
ax.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 0.5))
plt.title(f"RMSE (expressing cells) {file_name}")
plt.tight_layout()
plt.show()
[docs] def plot_interaction_effect_3D(
self,
target: str,
interaction: str,
save_path: str,
pcutoff: Optional[float] = 99.7,
min_value: Optional[float] = 0,
zero_opacity: float = 1.0,
size: float = 2.0,
n_neighbors_smooth: Optional[int] = 0,
):
"""Quick-visualize the magnitude of the predicted effect on target for a given interaction.
Args:
target: Target gene to visualize
interaction: Interaction to visualize (e.g. "Igf1:Igf1r" for L:R model, "Igf1" for ligand model)
save_path: Path to save the figure to (will save as HTML file)
pcutoff: Percentile cutoff for the colorbar. Will set all values above this percentile to this value.
min_value: Minimum value to set the colorbar to. Will set all values below this value to this value.
Defaults to 0.
zero_opacity: Opacity of points with zero expression. Between 0.0 and 1.0. Default is 1.0.
size: Size of the points in the scatter plot. Default is 2.
n_neighbors_smooth: Number of neighbors to use for smoothing (to make effect patterns more apparent). If 0,
no smoothing is applied. Default is 0.
"""
targets = pd.read_csv(
os.path.join(os.path.splitext(self.output_path)[0], "design_matrix", "targets.csv"), index_col=0
)
if target not in targets.columns:
raise ValueError(f"Target {target} not found in this model's directory. Please provide a valid target.")
if interaction not in self.X_df.columns:
raise ValueError(f"Interaction {interaction} not found in this model's directory.")
if hasattr(self, "remaining_cells"):
adata = self.adata[self.remaining_cells, :].copy()
else:
adata = self.adata.copy()
coords = adata.obsm[self.coords_key]
x, y, z = coords[:, 0], coords[:, 1], coords[:, 2]
target_interaction_coef = self.coeffs[target].loc[adata.obs_names, f"b_{interaction}"]
self.logger.info(f"{(target_interaction_coef > 0).sum()} {target}-expressing cells affected by {interaction}.")
if n_neighbors_smooth > 0:
from scipy.spatial import cKDTree
tree = cKDTree(coords)
distances, indices = tree.query(coords, k=n_neighbors_smooth + 1)
smoothed_values = np.zeros(len(target_interaction_coef))
for i in range(len(smoothed_values)):
neighbor_indices = indices[i, 1:]
neighbor_coeffs = target_interaction_coef.iloc[neighbor_indices]
# Filter to keep only nonzero values
nonzero_neighbor_coeffs = neighbor_coeffs[neighbor_coeffs != 0]
# Proceed only if there are at least 5 nonzero neighbors
if len(nonzero_neighbor_coeffs) >= 5:
# Calculate the mean of the nonzero target interaction coefficients
smoothed_values[i] = np.mean(nonzero_neighbor_coeffs)
target_interaction_coef = pd.Series(smoothed_values, index=target_interaction_coef.index)
self.logger.info(f"{(target_interaction_coef > 0).sum()} {target} affected cells post-smoothing.")
# Lenient w/ the max value cutoff so that the colored dots are more distinct from black background
cutoff = np.percentile(target_interaction_coef.values, pcutoff)
if pcutoff == 0:
cutoff = np.percentile(target_interaction_coef.values, 99.9)
target_interaction_coef[target_interaction_coef > cutoff] = cutoff
target_interaction_coef[target_interaction_coef < min_value] = min_value
plot_vals = target_interaction_coef.values
# Separate data into zero and non-zero (keeping one zero with non-zeros)
is_zero = plot_vals == 0.0
if np.any(is_zero):
non_zeros = np.where(is_zero, 0, plot_vals)
# Select the first zero to keep
first_zero_idx = np.where(is_zero)[0][0]
# Temp- to get the correct indices of nonzeros
non_zeros[first_zero_idx] = 1
is_nonzero = non_zeros != 0
non_zeros[first_zero_idx] = 0
else:
is_nonzero = np.ones(len(plot_vals), dtype=bool)
# Two plots, one for the zeros and one for the nonzeros
scatter_effect = go.Scatter3d(
x=x[is_nonzero],
y=y[is_nonzero],
z=z[is_nonzero],
mode="markers",
marker=dict(
color=plot_vals[is_nonzero],
colorscale="Hot",
size=size,
colorbar=dict(
title=f"{interaction.title()} Effect on {target.title()}",
x=0.75,
titlefont=dict(size=24),
tickfont=dict(size=24),
),
),
showlegend=False,
)
# Plot zeros separately (if there are any):
scatter_zeros = None
if np.any(is_zero):
scatter_zeros = go.Scatter3d(
x=x[is_zero],
y=y[is_zero],
z=z[is_zero],
mode="markers",
marker=dict(
color="#000000", # Use zero values for color to match the scale
size=size,
opacity=zero_opacity,
),
showlegend=False,
)
fig = go.Figure(data=[scatter_effect])
if scatter_zeros is not None:
fig.add_trace(scatter_zeros)
title_dict = dict(
text=f"{interaction.title()} Effect on {target.title()}",
y=0.9,
yanchor="top",
x=0.5,
xanchor="center",
font=dict(size=36),
)
fig.update_layout(
scene=dict(
aspectmode="data",
xaxis=dict(
showgrid=False,
showline=False,
linewidth=2,
linecolor="black",
backgroundcolor="white",
title="",
showticklabels=False,
ticks="",
),
yaxis=dict(
showgrid=False,
showline=False,
linewidth=2,
linecolor="black",
backgroundcolor="white",
title="",
showticklabels=False,
ticks="",
),
zaxis=dict(
showgrid=False,
showline=False,
linewidth=2,
linecolor="black",
backgroundcolor="white",
title="",
showticklabels=False,
ticks="",
),
),
margin=dict(l=0, r=0, b=0, t=50), # Adjust margins to minimize spacing
title=title_dict,
)
fig.write_html(save_path)
[docs] def plot_multiple_interaction_effects_3D(
self, effects: List[str], save_path: str, include_combos_of_two: bool = False
):
"""Quick-visualize the magnitude of the predicted effect on target for a given interaction.
Args:
effects: List of effects to visualize (e.g. ["Igf1:Igf1r", "Igf1:InsR"] for L:R model,
["Igf1"] for ligand model)
save_path: Path to save the figure to (will save as HTML file)
include_combos_of_two: Whether to include paired combinations of effects (e.g. "Igf1:Igf1r and
Igf1:InsR") as separate categories. If False, will include these in the generic "Multiple interactions"
category.
"""
if hasattr(self, "remaining_cells"):
adata = self.adata[self.remaining_cells, :].copy()
else:
adata = self.adata.copy()
coords = adata.obsm[self.coords_key]
x, y, z = coords[:, 0], coords[:, 1], coords[:, 2]
mean_values = {}
adata.obs["interaction_categories"] = "Other"
for effect in effects:
interaction, target = effect.split(":")
if target not in self.coeffs.keys():
self.logger.info(
f"{target} not found in this model's directory. Skipping this interaction-target pair."
)
continue
if f"b_{interaction}" not in self.coeffs[target].columns:
self.logger.info(f"{interaction} not found for {target}. Skipping this interaction-target pair.")
continue
target_interaction_coef = self.coeffs[target].loc[adata.obs_names, f"b_{interaction}"]
mean_values[effect] = np.mean(target_interaction_coef[target_interaction_coef > 0])
adata.obs[f"{effect} nonzero"] = target_interaction_coef > 0
# Temporarily, the key labeled with the effect name stores whether the interaction is nonzero to a
# substantial degree:
adata.obs.loc[target_interaction_coef >= mean_values[effect], effect] = True
# Categorize cells based on their interaction effects
for idx, row in tqdm(adata.obs.iterrows(), total=len(adata.obs_names), desc="Categorizing cells..."):
active_effects = [effect for effect in effects if row[f"{effect} nonzero"]]
strong_active_effects = [effect for effect in effects if row[effect]]
if include_combos_of_two:
if len(strong_active_effects) >= 3:
adata.obs.loc[idx, "interaction_categories"] = "Multiple interactions"
elif len(strong_active_effects) == 2:
adata.obs.loc[
idx, "interaction_categories"
] = f"{strong_active_effects[0]} and {strong_active_effects[1]}"
elif len(active_effects) == 1:
adata.obs.loc[idx, "interaction_categories"] = active_effects[0]
else:
if len(strong_active_effects) >= 2:
adata.obs.loc[idx, "interaction_categories"] = "Multiple interactions"
elif len(active_effects) == 1:
adata.obs.loc[idx, "interaction_categories"] = active_effects[0]
cat_counts = adata.obs["interaction_categories"].value_counts()
# Map each category to color:
if include_combos_of_two:
color_mapping = dict(zip(cat_counts.index, godsnot_102))
else:
color_mapping = dict(zip(cat_counts.index, vega_10))
color_mapping["Multiple interactions"] = "#71797E"
color_mapping["Other"] = "#D3D3D3"
traces = []
for group, color in color_mapping.items():
marker_size = 1.25 if group == "Other" else 2
mask = adata.obs["interaction_categories"] == group
scatter = go.Scatter3d(
x=x[mask],
y=y[mask],
z=z[mask],
mode="markers",
marker=dict(size=marker_size, color=color),
showlegend=False,
)
traces.append(scatter)
# Invisible trace for the legend (so the colored point is larger than the plot points):
legend_target = go.Scatter3d(
x=[None],
y=[None],
z=[None],
mode="markers",
marker=dict(size=30, color=color), # Adjust size as needed
name=group,
showlegend=True,
)
traces.append(legend_target)
fig = go.Figure(data=traces)
title = (
"L:R Interaction Effect on Target (format Ligand:Receptor-Target)"
if self.mod_type == "lr"
else "Ligand Effect on Target (format Ligand-Target)"
)
title_dict = dict(
text=title,
y=0.9,
yanchor="top",
x=0.5,
xanchor="center",
font=dict(size=28),
)
fig.update_layout(
showlegend=True,
legend=dict(x=0.7, y=0.85, orientation="v", font=dict(size=14)),
scene=dict(
aspectmode="data",
xaxis=dict(
showgrid=False,
showline=False,
linewidth=2,
linecolor="black",
backgroundcolor="white",
title="",
showticklabels=False,
ticks="",
),
yaxis=dict(
showgrid=False,
showline=False,
linewidth=2,
linecolor="black",
backgroundcolor="white",
title="",
showticklabels=False,
ticks="",
),
zaxis=dict(
showgrid=False,
showline=False,
linewidth=2,
linecolor="black",
backgroundcolor="white",
title="",
showticklabels=False,
ticks="",
),
),
margin=dict(l=0, r=0, b=0, t=50), # Adjust margins to minimize spacing
title=title_dict,
)
fig.write_html(save_path)
[docs] def plot_tf_effect_3D(
self,
target: str,
tf: str,
save_path: str,
ligand_targets: bool = True,
receptor_targets: bool = False,
target_gene_targets: bool = False,
pcutoff: float = 99.7,
min_value: float = 0,
zero_opacity: float = 1.0,
size: float = 2.0,
):
"""Quick-visualize the magnitude of the predicted effect on target for a given TF. Can only find the files
necessary for this if :func `CCI_deg_detection()` has been run.
Args:
target: Target gene of interest
tf: TF of interest (e.g. "Foxo1")
save_path: Path to save the figure to (will save as HTML file)
ligand_targets: Set True if ligands were used as the target genes for the :func `CCI_deg_detection()`
model.
receptor_targets: Set True if receptors were used as the target genes for the :func `CCI_deg_detection()`
model.
target_gene_targets: Set True if target genes were used as the target genes for the :func
`CCI_deg_detection()` model.
pcutoff: Percentile cutoff for the colorbar. Will set all values above this percentile to this value.
min_value: Minimum value to set the colorbar to. Will set all values below this value to this value.
zero_opacity: Opacity of points with zero expression. Between 0.0 and 1.0. Default is 1.0.
size: Size of the points in the scatter plot. Default is 2.
"""
downstream_parent_dir = os.path.dirname(os.path.splitext(self.output_path)[0])
id = os.path.splitext(os.path.basename(self.output_path))[0]
if ligand_targets:
target_type = "ligand"
folder = "ligand_analysis"
elif receptor_targets:
target_type = "receptor"
folder = "receptor_analysis"
elif target_gene_targets:
target_type = "target_gene"
folder = "target_gene_analysis"
else:
raise ValueError(
"Please set either 'ligand_targets', 'receptor_targets', or 'target_gene_targets' to True."
)
targets = pd.read_csv(
os.path.join(
downstream_parent_dir,
"cci_deg_detection",
folder,
id,
"downstream_design_matrix",
"targets.csv",
),
index_col=0,
)
regulators = pd.read_csv(
os.path.join(
downstream_parent_dir,
"cci_deg_detection",
folder,
id,
"downstream_design_matrix",
"design_matrix.csv",
),
index_col=0,
)
regulators.columns = [col.replace("regulator_", "") for col in regulators.columns]
if target not in targets.columns:
raise ValueError(f"Target {target} not found in this model's directory. Please provide a valid target.")
if tf not in regulators.columns:
raise ValueError(f"TF {tf} not found in this model's directory.")
if hasattr(self, "remaining_cells"):
adata = self.adata[self.remaining_cells, :].copy()
else:
adata = self.adata.copy()
coords = adata.obsm[self.coords_key]
x, y, z = coords[:, 0], coords[:, 1], coords[:, 2]
downstream_coeffs, downstream_standard_errors = self.return_outputs(
adjust_for_subsampling=False, load_from_downstream=target_type, load_for_interpreter=True
)
target_tf_coef = downstream_coeffs[target].loc[adata.obs_names, f"b_{tf}"]
# Lenient w/ the max value cutoff so that the colored dots are more distinct from black background
cutoff = np.percentile(target_tf_coef.values, pcutoff)
target_tf_coef[target_tf_coef > cutoff] = cutoff
target_tf_coef[target_tf_coef < min_value] = min_value
plot_vals = target_tf_coef.values
# Separate data into zero and non-zero (keeping one zero with non-zeros)
is_zero = plot_vals == 0
if np.any(is_zero):
non_zeros = np.where(is_zero, 0, plot_vals)
# Select the first zero to keep
first_zero_idx = np.where(is_zero)[0][0]
# Temp- to get the correct indices of nonzeros
non_zeros[first_zero_idx] = 1
is_nonzero = non_zeros != 0
non_zeros[first_zero_idx] = 0
else:
is_nonzero = np.ones(len(plot_vals), dtype=bool)
# Two plots, one for the zeros and one for the nonzeros
scatter_effect = go.Scatter3d(
x=x[is_nonzero],
y=y[is_nonzero],
z=z[is_nonzero],
mode="markers",
marker=dict(
color=plot_vals[is_nonzero],
colorscale="Hot",
size=size,
colorbar=dict(
title=f"{tf.title()} Effect on {target.title()}",
x=0.75,
titlefont=dict(size=24),
tickfont=dict(size=24),
),
),
showlegend=False,
)
# Plot zeros separately (if there are any):
scatter_zeros = None
if np.any(is_zero):
scatter_zeros = go.Scatter3d(
x=x[is_zero],
y=y[is_zero],
z=z[is_zero],
mode="markers",
marker=dict(
color="#000000", # Use zero values for color to match the scale
size=size,
opacity=zero_opacity,
),
showlegend=False,
)
fig = go.Figure(data=[scatter_effect])
if scatter_zeros is not None:
fig.add_trace(scatter_zeros)
title_dict = dict(
text=f"{tf.title()} Effect on {target.title()}",
y=0.9,
yanchor="top",
x=0.5,
xanchor="center",
font=dict(size=36),
)
fig.update_layout(
scene=dict(
aspectmode="data",
xaxis=dict(
showgrid=False,
showline=False,
linewidth=2,
linecolor="black",
backgroundcolor="white",
title="",
showticklabels=False,
ticks="",
),
yaxis=dict(
showgrid=False,
showline=False,
linewidth=2,
linecolor="black",
backgroundcolor="white",
title="",
showticklabels=False,
ticks="",
),
zaxis=dict(
showgrid=False,
showline=False,
linewidth=2,
linecolor="black",
backgroundcolor="white",
title="",
showticklabels=False,
ticks="",
),
),
margin=dict(l=0, r=0, b=0, t=50), # Adjust margins to minimize spacing
title=title_dict,
)
fig.write_html(save_path)
[docs] def visualize_overlap_between_interacting_components_3D(
self, target: str, interaction: str, save_path: str, size: float = 2.0
):
"""Visualize the spatial distribution of signaling features (ligand, receptor, or L:R field) and target gene,
as well as the overlapping region. Intended for use with 3D spatial coordinates.
Args:
target: Target gene to visualize
interaction: Interaction to visualize (e.g. "Igf1:Igf1r" for L:R model, "Igf1" for ligand model)
save_path: Path to save the figure to (will save as HTML file)
size: Size of the points in the plot. Defaults to 2.
"""
from ...plotting.static.colorlabel import godsnot_102
# Rearrange slightly:
godsnot_102[1] = "#B200ED"
godsnot_102[2] = "#FFA500"
godsnot_102[3] = "#1CE6FF"
targets = pd.read_csv(
os.path.join(os.path.splitext(self.output_path)[0], "design_matrix", "targets.csv"), index_col=0
)
if target not in targets.columns:
raise ValueError(f"Target {target} not found in this model's directory. Please provide a valid target.")
if interaction not in self.X_df.columns:
raise ValueError(f"Interaction {interaction} not found in this model's directory.")
if hasattr(self, "remaining_cells"):
adata = self.adata[self.remaining_cells, :].copy()
else:
adata = self.adata.copy()
coords = adata.obsm[self.coords_key]
x, y, z = coords[:, 0], coords[:, 1], coords[:, 2]
# Label cells expressing target:
target_expressing = adata.obs_names[adata[:, target].X.toarray().reshape(-1) != 0]
# Label cells and with nonzero interaction feature value (for ligand model, cells that have expression of the
# ligand in their neighborhood (in addition to other caveats incorported in model setup), for L:R model,
# cells that have expression of the ligand in their neighborhood and expression of the receptor):
interaction_expressing = self.X_df[self.X_df[interaction] != 0].index
# Label cells expressing target and with nonzero interaction feature value:
overlap = target_expressing.intersection(interaction_expressing)
adata.obs[f"{interaction}_{target}"] = "Other"
adata.obs.loc[
target_expressing, f"{interaction}_{target}"
] = f"{target} only (no {interaction} in neighborhood and/or receptor)"
if self.mod_type == "lr":
ligand, receptor = interaction.split(":")
adata.obs.loc[
interaction_expressing, f"{interaction}_{target}"
] = f"{ligand.title()} in Neighborhood and {receptor}, no {target}"
adata.obs.loc[
overlap, f"{interaction}_{target}"
] = f"{ligand.title()} in Neighborhood, {receptor} and {target}"
elif self.mod_type == "ligand":
adata.obs.loc[
interaction_expressing, f"{interaction}_{target}"
] = f"{interaction.title()} in Neighborhood and Receptor, no {target}"
adata.obs.loc[
overlap, f"{interaction}_{target}"
] = f"{interaction.title()} in Neighborhood, Receptor and {target}"
color_mapping = dict(zip(adata.obs[f"{interaction}_{target}"].value_counts().index, godsnot_102))
color_mapping["Other"] = "#D3D3D3"
traces = []
for group, color in color_mapping.items():
marker_size = size * 0.75 if group == "Other" else size
opacity = 0.5 if group == "Other" else 1.0
mask = adata.obs[f"{interaction}_{target}"] == group
scatter = go.Scatter3d(
x=x[mask],
y=y[mask],
z=z[mask],
mode="markers",
marker=dict(size=marker_size, color=color, opacity=opacity),
showlegend=False,
)
traces.append(scatter)
# Invisible trace for the legend (so the colored point is larger than the plot points):
legend_target = go.Scatter3d(
x=[None],
y=[None],
z=[None],
mode="markers",
marker=dict(size=30, color=color), # Adjust size as needed
name=group,
showlegend=True,
)
traces.append(legend_target)
fig = go.Figure(data=traces)
if self.mod_type == "lr":
title = f"Distribution of interacting components: <br>{interaction} and {target}"
elif self.mod_type == "ligand":
title = (
f"Distribution of interacting components: <br>{interaction}, {interaction} receptor/downstream "
f"components and {target}"
)
title_dict = dict(
text=title,
y=0.9,
yanchor="top",
x=0.5,
xanchor="center",
font=dict(size=36),
)
fig.update_layout(
showlegend=True,
legend=dict(x=0.65, y=0.85, orientation="v", font=dict(size=18)),
scene=dict(
aspectmode="data",
xaxis=dict(
showgrid=False,
showline=False,
linewidth=2,
linecolor="black",
backgroundcolor="white",
title="",
showticklabels=False,
ticks="",
),
yaxis=dict(
showgrid=False,
showline=False,
linewidth=2,
linecolor="black",
backgroundcolor="white",
title="",
showticklabels=False,
ticks="",
),
zaxis=dict(
showgrid=False,
showline=False,
linewidth=2,
linecolor="black",
backgroundcolor="white",
title="",
showticklabels=False,
ticks="",
),
),
margin=dict(l=0, r=0, b=0, t=50), # Adjust margins to minimize spacing
title=title_dict,
)
fig.write_html(save_path)
[docs] def gene_expression_heatmap(
self,
use_ligands: bool = False,
use_receptors: bool = False,
use_target_genes: bool = False,
genes: Optional[List[str]] = None,
position_key: str = "spatial",
coord_column: Optional[Union[int, str]] = None,
reprocess: bool = False,
neatly_arrange_y: bool = True,
window_size: int = 3,
recompute: bool = False,
title: Optional[str] = None,
fontsize: Union[None, int] = None,
figsize: Union[None, Tuple[float, float]] = None,
cmap: str = "magma",
save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show",
save_kwargs: Optional[dict] = {},
):
"""Visualize the distribution of gene expression across cells in the spatial coordinates of cells; provides
an idea of the simultaneous relative positions/patternings of different genes.
Args:
use_ligands: Set True to use ligands as the genes to visualize. If True, will ignore "genes" argument.
"ligands_expr" file must be present in the model's directory.
use_receptors: Set True to use receptors as the genes to visualize. If True, will ignore "genes" argument.
"receptors_expr" file must be present in the model's directory.
use_target_genes: Set True to use target genes as the genes to visualize. If True, will ignore "genes"
argument. "targets" file must be present in the model's directory.
genes: Optional list of genes to visualize. If "use_ligands", "use_receptors", and "use_target_genes" are
all False, this must be given. This can also be used to visualize only a subset of the genes once
processing & saving has already completed using e.g. "use_ligands", "use_receptors", etc.
position_key: Key in adata.obs or adata.obsm that provides a relative indication of the position of
cells. i.e. spatial coordinates. Defaults to "spatial". For each value in the position array (each
coordinate, each category), multiple cells must have the same value.
coord_column: Optional, only used if "position_key" points to an entry in .obsm. In this case,
this is the index or name of the column to be used to provide the positional context. Can also
provide "xy", "yz", "xz", "-xy", "-yz", "-xz" to draw a line between the two coordinate axes. "xy"
will extend the new axis in the direction of increasing x and increasing y starting from x=0 and y=0 (or
min. x/min. y), "-xy" will extend the new axis in the direction of decreasing x and increasing y
starting from x=minimum x and y=maximum y, and so on.
reprocess: Set to True to reprocess the data and overwrite the existing files. Use if the genes to
visualize have changed compared to the saved file (if existing), e.g. if "use_ligands" is True when
the initial analysis used "use_target_genes".
neatly_arrange_y: Set True to order the y-axis in terms of how early along the position axis the max
z-scores for each row occur in. Used for a more uniform plot where similarly patterned
interaction-target pairs are grouped together. If False, will sort this axis by the identity of the
interaction (i.e. all "Fgf1" rows will be grouped together).
window_size: Size of window to use for smoothing. Must be an odd integer. If 1, no smoothing is applied.
recompute: Set to True to recompute the data and overwrite the existing files
title: Optional, can be used to provide title for plot
fontsize: Size of font for x and y labels.
figsize: Size of figure.
cmap: Colormap to use. Options: Any divergent matplotlib colormap.
save_show_or_return: Whether to save, show or return the figure.
If "both", it will save and plot the figure at the same time. If "all", the figure will be saved,
displayed and the associated axis and other object will be return.
save_kwargs: A dictionary that will passed to the save_fig function.
By default it is an empty dictionary and the save_fig function will use the
{"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True,
"verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those
keys according to your needs.
"""
logger = lm.get_main_logger()
if window_size % 2 == 0:
raise ValueError("Window size must be an odd integer.")
if not use_ligands and not use_receptors and not use_target_genes and genes is None:
raise ValueError(
"Please set either 'use_ligands', 'use_receptors', or 'use_target_genes' to True, or provide a list "
"of genes to visualize."
)
# Check if custom genes are given:
custom_genes = genes
if use_ligands:
if not os.path.exists(
os.path.join(os.path.splitext(self.output_path)[0], "design_matrix", "ligands_expr.csv")
):
raise FileNotFoundError("ligands_expr.csv not found in this model's directory.")
expr_df = pd.read_csv(
os.path.join(os.path.splitext(self.output_path)[0], "design_matrix", "ligands_expr.csv"), index_col=0
)
genes = expr_df.columns
genes = [
g
for g in genes
if g
not in [
"Lta4h",
"Fdx1",
"Tfrc",
"Trf",
"Lamc1",
"Aldh1a2",
"Dhcr24",
"Rnaset2a",
"Ptges3",
"Nampt",
"Trf",
"Fdx1",
"Kdr",
"Apoa2",
"Apoe",
"Dhcr7",
"Enho",
"Ptgr1",
"Agrp",
"Akr1b3",
"Daglb",
"Ubash3d",
]
]
file_id = "ligand_expression"
elif use_receptors:
if not os.path.exists(
os.path.join(os.path.splitext(self.output_path)[0], "design_matrix", "receptors_expr.csv")
):
raise FileNotFoundError("receptors_expr.csv not found in this model's directory.")
expr_df = pd.read_csv(
os.path.join(os.path.splitext(self.output_path)[0], "design_matrix", "receptors_expr.csv"), index_col=0
)
genes = expr_df.columns
file_id = "receptor_expression"
elif use_target_genes:
if not os.path.exists(os.path.join(os.path.splitext(self.output_path)[0], "design_matrix", "targets.csv")):
raise FileNotFoundError("targets.csv not found in this model's directory.")
expr_df = pd.read_csv(
os.path.join(os.path.splitext(self.output_path)[0], "design_matrix", "targets.csv"), index_col=0
)
genes = expr_df.columns
file_id = "target_gene_expression"
else:
expr_df = pd.DataFrame(self.adata[:, genes].X.toarray(), index=self.adata.obs_names, columns=genes)
file_id = "expression"
if hasattr(self, "remaining_cells"):
adata = self.adata[self.remaining_cells, :].copy()
else:
adata = self.adata.copy()
if position_key not in self.adata.obsm.keys() and position_key not in self.adata.obs.keys():
raise ValueError(
f"Position key {position_key} not found in adata.obsm or adata.obs. Please provide a valid key."
)
if position_key in self.adata.obsm.keys():
if coord_column in ["xy", "yz", "xz", "-xy", "-yz", "-xz"]:
self.adata = create_new_coordinate(self.adata, position_key, coord_column)
pos = self.adata.obs[f"{coord_column} Coordinate"]
x_label = f"Relative position along custom {coord_column} axis"
if title is None:
title = f"Signaling effect distribution along {coord_column} axis"
save_id = f"{coord_column}_axis"
else:
if coord_column is not None and isinstance(coord_column, str):
if not isinstance(self.adata.obsm[position_key], pd.DataFrame):
raise ValueError(
f"Array stored at position key {position_key} has no column names; provide the column "
f"index."
)
else:
pos = self.adata.obsm[position_key][coord_column]
elif coord_column is not None and isinstance(coord_column, int):
if isinstance(self.adata.obsm[position_key], pd.DataFrame):
pos = self.adata.obsm[position_key].iloc[:, coord_column]
x_label = f"Relative position along {coord_column}"
if title is None:
title = f"Signaling effect distribution along {coord_column}"
save_id = coord_column
else:
pos = pd.Series(self.adata.obsm[position_key][:, coord_column], index=self.adata.obs_names)
if coord_column == 0:
x_label = "Relative position along X"
if title is None:
title = "Signaling effect distribution along X"
save_id = "x_axis"
elif coord_column == 1:
x_label = "Relative position along Y"
if title is None:
title = "Signaling effect distribution along Y"
save_id = "y_axis"
elif coord_column == 2:
x_label = "Relative position along Z"
if title is None:
title = "Signaling effect distribution along Z"
save_id = "z_axis"
elif self.adata.obsm[position_key].shape[1] != 1:
raise ValueError(
f"Array stored at position key {position_key} has more than one column; provide the column "
f"index."
)
else:
pos = (
pd.Series(self.adata.obsm[position_key].flatten(), index=self.adata.obs_names)
if isinstance(self.adata.obsm[position_key], np.ndarray)
else self.adata.obsm[position_key]
)
x_label = "Relative position"
if title is None:
title = f"Signaling effect distribution along axis given by {position_key} key"
save_id = position_key
else:
pos = self.adata.obs[position_key]
x_label = "Relative position"
if title is None:
title = f"Signaling effect distribution along axis given by {position_key} key"
save_id = position_key
# If position array is numerical, there may not be an exact match- convert the data type to integer:
if pos.dtype == float:
pos = pos.astype(int)
if save_show_or_return in ["save", "both", "all"]:
if not os.path.exists(os.path.join(os.path.dirname(self.output_path), "figures")):
os.makedirs(os.path.join(os.path.dirname(self.output_path), "figures"))
figure_folder = os.path.join(os.path.dirname(self.output_path), "figures", "temp")
if not os.path.exists(figure_folder):
os.makedirs(figure_folder)
output_folder = os.path.join(os.path.dirname(self.output_path), "analyses")
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# Use the saved name for the AnnData object to define part of the name of the saved file:
base_name = os.path.basename(self.adata_path)
adata_id = os.path.splitext(base_name)[0]
# If divergent colormap is specified, center the colormap at 0:
divergent_cmaps = [
"seismic",
"coolwarm",
"bwr",
"RdBu",
"RdGy",
"PuOr",
"PiYG",
"PRGn",
"BrBG",
"RdYlBu",
"RdYlGn",
"Spectral",
]
# Check for existing dataframe:
if (
os.path.exists(os.path.join(output_folder, f"{adata_id}_distribution_{file_id}_along_{save_id}.csv"))
and not recompute
):
to_plot = pd.read_csv(
os.path.join(output_folder, f"{adata_id}_distribution_{file_id}_along_{save_id}.csv"),
index_col=0,
)
# Can plot a subset once this is already processed & saved:
if custom_genes is not None:
custom_genes = [g for g in custom_genes if g in to_plot.index]
to_plot = to_plot.loc[custom_genes]
else:
# For each gene, compute the mean expression:
mean_expr = pd.Series(index=genes)
for g in genes:
mean_expr[g] = expr_df[g].mean()
# For each cell, compute the fold change over the average for each combination:
all_fc = pd.DataFrame(index=self.adata.obs_names, columns=genes)
for g in tqdm(genes, desc="Computing fold changes for each gene..."):
g_expr = expr_df[g]
all_fc[g] = g_expr / mean_expr[g]
# Log fold change:
all_fc = np.log1p(all_fc)
# z-score the fold change values:
all_fc = all_fc.apply(scipy.stats.zscore, axis=0)
all_fc["pos"] = pos
all_fc_coord_sorted = all_fc.sort_values(by="pos")
# Mean z-score at each coordinate position:
all_fc_coord_sorted = all_fc_coord_sorted.groupby("pos").mean()
# Smooth in the case of dropouts:
all_fc_coord_sorted = all_fc_coord_sorted.rolling(window_size, center=True, min_periods=1).mean()
# For each unique value in 'pos', find the top genes with the highest mean z-score
top_genes = all_fc_coord_sorted.apply(lambda x: x.nlargest(30).index.tolist(), axis=1)
# Find interesting interaction effects by position- get features that are in the top features for at least
# five consecutive positions:
consecutive_counts = {g: 0 for g in genes}
genes_of_interest = set()
for pos in top_genes.index:
for g in top_genes[pos]:
consecutive_counts[g] += 1
if consecutive_counts[g] >= 5:
genes_of_interest.add(g)
for g in genes:
if g not in top_genes[pos]:
consecutive_counts[g] = 0
genes_of_interest = list(genes_of_interest)
to_plot = all_fc_coord_sorted[genes_of_interest]
if to_plot.index.is_numeric():
# Minmax scale to normalize positional context:
to_plot.index = (to_plot.index - to_plot.index.min()) / (to_plot.index.max() - to_plot.index.min())
to_plot = to_plot.T # so that the features are labeled along the y-axis
# Sort by "heat" if applicable (i.e. in order roughly determined by how early along the relative position
# the highest z-scores occur in for each interaction-target pair):
if neatly_arrange_y:
logger.info("Sorting by position of enrichment along axis...")
column_indices = np.tile(np.arange(len(to_plot.columns)), (len(to_plot), 1)) # Column indices array
# Look only at the indices corresponding to the highest changes:
percentile_95 = to_plot.apply(
lambda row: np.percentile(row[row > 0], 95) if row[row > 0].size > 0 else 0, axis=1
)
# Create a DataFrame that replicates the shape of to_plot
weights_matrix = to_plot.gt(percentile_95, axis=0) * to_plot
weighted_sum = np.sum(weights_matrix.values * column_indices, axis=1)
total_weight = np.sum(weights_matrix.values, axis=1)
weighted_avg = pd.Series(np.where(total_weight != 0, weighted_sum / total_weight, 0), index=to_plot.index)
top_cols_sorted = weighted_avg.sort_values().index
to_plot = to_plot.loc[top_cols_sorted]
flattened = to_plot.values.flatten()
flattened_series = pd.Series(flattened)
percentile_95 = flattened_series.quantile(0.95)
max_val = percentile_95
if figsize is None:
m = len(to_plot) * 40 / 200
n = 8
figsize = (n, m)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
if fontsize is None:
fontsize = rcParams.get("font.size")
# Format the numerical columns for the plot:
# First check whether the columns contain duplicates:
if all(isinstance(name, str) for name in to_plot.columns):
if any([name.count(".") > 1 for name in to_plot.columns]):
to_plot.columns = [".".join(name.split(".")[:2]) for name in to_plot.columns]
to_plot.columns = [float(col) for col in to_plot.columns]
col_series = pd.Series(to_plot.columns)
if set(col_series) != len(col_series):
unique_values, counts = np.unique(col_series, return_counts=True)
# Iterate through unique values
for value, count in zip(unique_values, counts):
if count > 1:
# Find indices of the repeated value
indices = col_series[col_series == value].index
# Calculate step size
if value == unique_values[-1]:
next_value = value + (value - unique_values[-2])
else:
next_index = np.where(unique_values == value)[0][0] + 1
next_value = unique_values[next_index]
step = (next_value - value) / count
# Update the values
for i in range(count):
col_series.iloc[indices[i]] = value + step * i
to_plot.columns = col_series.values
if all(isinstance(name, float) for name in to_plot.columns):
to_plot.columns = [f"{float(col):.3f}" for col in to_plot.columns]
to_plot.columns = [str(col) for col in to_plot.columns]
if genes is not None and not neatly_arrange_y:
to_plot = to_plot.reindex(genes)
m = sns.heatmap(to_plot, vmin=-max_val, vmax=max_val, ax=ax, cmap=cmap)
cbar = m.collections[0].colorbar
cbar.set_label("Z-score", fontsize=fontsize * 1.5, labelpad=10)
# Adjust colorbar tick font size
cbar.ax.tick_params(labelsize=fontsize * 1.25)
cbar.ax.set_aspect(np.min([len(to_plot), 70]))
ax.set_xlabel(x_label, fontsize=fontsize * 1.25)
ax.set_ylabel("Gene", fontsize=fontsize * 1.25)
ax.tick_params(axis="x", labelsize=fontsize)
ax.tick_params(axis="y", labelsize=fontsize)
ax.set_title(title, fontsize=fontsize * 1.5, pad=20)
if (
not os.path.exists(os.path.join(output_folder, f"{adata_id}_distribution_{file_id}_along_{save_id}.csv"))
and not recompute
):
to_plot.to_csv(
os.path.join(
output_folder,
f"{adata_id}_distribution_{file_id}_along_{save_id}.csv",
)
)
if save_show_or_return in ["save", "both", "all"]:
save_kwargs["ext"] = "png"
save_kwargs["dpi"] = 300
if "figure_folder" in locals():
save_kwargs["path"] = figure_folder
# Save figure:
save_return_show_fig_utils(
save_show_or_return=save_show_or_return,
show_legend=False,
background="white",
prefix=f"distribution_{file_id}_along_{save_id}",
save_kwargs=save_kwargs,
total_panels=1,
fig=fig,
axes=ax,
return_all=False,
return_all_list=None,
)
[docs] def effect_distribution_heatmap(
self,
target_subset: Optional[List[str]] = None,
interaction_subset: Optional[List[str]] = None,
position_key: str = "spatial",
coord_column: Optional[Union[int, str]] = None,
effect_threshold: Optional[float] = None,
check_downstream_ligand_effects: bool = False,
check_downstream_receptor_effects: bool = False,
check_downstream_target_effects: bool = False,
use_significant: bool = False,
sort_by_target: bool = False,
neatly_arrange_y: bool = True,
window_size: int = 3,
recompute: bool = False,
title: Optional[str] = None,
fontsize: Union[None, int] = None,
figsize: Union[None, Tuple[float, float]] = None,
cmap: str = "magma",
save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show",
save_kwargs: Optional[dict] = {},
):
"""Visualize the distribution of interaction effects across cells in the spatial coordinates of cells;
provides an idea of the simultaneous relative positions of different interaction effects.
Args:
target_subset: List of targets to consider. If None, will use all targets used in model fitting.
interaction_subset: List of interactions to consider. If None, will use all interactions used in model.
position_key: Key in adata.obs or adata.obsm that provides a relative indication of the position of
cells. i.e. spatial coordinates. Defaults to "spatial". For each value in the position array (each
coordinate, each category), multiple cells must have the same value.
coord_column: Optional, only used if "position_key" points to an entry in .obsm. In this case,
this is the index or name of the column to be used to provide the positional context. Can also
provide "xy", "yz", "xz", "-xy", "-yz", "-xz" to draw a line between the two coordinate axes. "xy"
will extend the new axis in the direction of increasing x and increasing y starting from x=0 and y=0 (or
min. x/min. y), "-xy" will extend the new axis in the direction of decreasing x and increasing y
starting from x=minimum x and y=maximum y, and so on.
effect_threshold: Optional threshold minimum effect size to consider an effect for further analysis,
as an absolute value. Use this to choose only the cells for which an interaction is predicted to
have a strong effect. If None, use the median interaction effect.
check_downstream_ligand_effects: Set True to check the coefficients of downstream ligand models instead of
coefficients of the upstream CCI model. Note that this may not necessarily look nice because
TF-target relationships are not spatially dependent like L:R effects are.
check_downstream_receptor_effects: Set True to check the coefficients of downstream receptor models instead
of coefficients of the upstream CCI model. Note that this may not necessarily look nice because
TF-target relationships are not spatially dependent like L:R effects are.
check_downstream_target_effects: Set True to check the coefficients of downstream target models instead of
coefficients of the upstream CCI model. Note that this may not necessarily look nice because
TF-target relationships are not spatially dependent like L:R effects are.
use_significant: Whether to use only significant effects in computing the specificity. If True,
will filter to cells + interactions where the interaction is significant for the target. Only valid
if :func `compute_coeff_significance()` has been run.
sort_by_target: Set True to order the y-axis in terms of the identity of the target gene. Incompatible
with "neatly_arrange_y". If both this and "neatly_arrange_y" are False, will sort this axis by the
identity of the interaction (i.e. all "Fgf1" rows will be grouped together).
neatly_arrange_y: Set True to order the y-axis in terms of how early along the position axis the max
z-scores for each row occur in. Used for a more uniform plot where similarly patterned
interaction-target pairs are grouped together. If False, will sort this axis by the identity of the
interaction (i.e. all "Fgf1" rows will be grouped together).
window_size: Size of window to use for smoothing. Must be an odd integer. If 1, no smoothing is applied.
recompute: Set to True to recompute the data and overwrite the existing files
title: Optional, can be used to provide title for plot
fontsize: Size of font for x and y labels.
figsize: Size of figure.
cmap: Colormap to use. Options: Any divergent matplotlib colormap.
save_show_or_return: Whether to save, show or return the figure.
If "both", it will save and plot the figure at the same time. If "all", the figure will be saved,
displayed and the associated axis and other object will be return.
save_kwargs: A dictionary that will passed to the save_fig function.
By default it is an empty dictionary and the save_fig function will use the
{"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True,
"verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those
keys according to your needs.
"""
logger = lm.get_main_logger()
if window_size % 2 == 0:
raise ValueError("Window size must be an odd integer.")
if position_key not in self.adata.obsm.keys() and position_key not in self.adata.obs.keys():
raise ValueError(
f"Position key {position_key} not found in adata.obsm or adata.obs. Please provide a valid key."
)
if position_key in self.adata.obsm.keys():
if coord_column in ["xy", "yz", "xz", "-xy", "-yz", "-xz"]:
self.adata = create_new_coordinate(self.adata, position_key, coord_column)
pos = self.adata.obs[f"{coord_column} Coordinate"]
x_label = f"Relative position along custom {coord_column} axis"
if title is None:
title = f"Signaling effect distribution along {coord_column} axis"
save_id = f"{coord_column}_axis"
else:
if coord_column is not None and isinstance(coord_column, str):
if not isinstance(self.adata.obsm[position_key], pd.DataFrame):
raise ValueError(
f"Array stored at position key {position_key} has no column names; provide the column "
f"index."
)
else:
pos = self.adata.obsm[position_key][coord_column]
elif coord_column is not None and isinstance(coord_column, int):
if isinstance(self.adata.obsm[position_key], pd.DataFrame):
pos = self.adata.obsm[position_key].iloc[:, coord_column]
x_label = f"Relative position along {coord_column}"
if title is None:
title = f"Signaling effect distribution along {coord_column}"
save_id = coord_column
else:
pos = pd.Series(self.adata.obsm[position_key][:, coord_column], index=self.adata.obs_names)
if coord_column == 0:
x_label = "Relative position along X"
if title is None:
title = "Signaling effect distribution along X"
save_id = "x_axis"
elif coord_column == 1:
x_label = "Relative position along Y"
if title is None:
title = "Signaling effect distribution along Y"
save_id = "y_axis"
elif coord_column == 2:
x_label = "Relative position along Z"
if title is None:
title = "Signaling effect distribution along Z"
save_id = "z_axis"
elif self.adata.obsm[position_key].shape[1] != 1:
raise ValueError(
f"Array stored at position key {position_key} has more than one column; provide the column "
f"index."
)
else:
pos = (
pd.Series(self.adata.obsm[position_key].flatten(), index=self.adata.obs_names)
if isinstance(self.adata.obsm[position_key], np.ndarray)
else self.adata.obsm[position_key]
)
x_label = "Relative position"
if title is None:
title = f"Signaling effect distribution along axis given by {position_key} key"
save_id = position_key
else:
pos = self.adata.obs[position_key]
x_label = "Relative position"
if title is None:
title = f"Signaling effect distribution along axis given by {position_key} key"
save_id = position_key
# To ensure each coordinate has enough samples, round to the nearest 10 if the max value is on the order of
# magnitude of 1000, and the nearest 1 if the max value is on the order of magnitude of 100.
max_value = pos.max()
if max_value < 1000:
rounding_base = 10
elif max_value >= 1000:
rounding_base = 100
else:
rounding_base = 1000
pos = pos.round(-int(np.log10(rounding_base)))
# If position array is numerical, there may not be an exact match- convert the data type to integer:
if pos.dtype == float:
pos = pos.astype(int)
if save_show_or_return in ["save", "both", "all"]:
if not os.path.exists(os.path.join(os.path.dirname(self.output_path), "figures")):
os.makedirs(os.path.join(os.path.dirname(self.output_path), "figures"))
figure_folder = os.path.join(os.path.dirname(self.output_path), "figures", "temp")
if not os.path.exists(figure_folder):
os.makedirs(figure_folder)
output_folder = os.path.join(os.path.dirname(self.output_path), "analyses")
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# Use the saved name for the AnnData object to define part of the name of the saved file:
base_name = os.path.basename(self.adata_path)
adata_id = os.path.splitext(base_name)[0]
# If divergent colormap is specified, center the colormap at 0:
divergent_cmaps = [
"seismic",
"coolwarm",
"bwr",
"RdBu",
"RdGy",
"PuOr",
"PiYG",
"PRGn",
"BrBG",
"RdYlBu",
"RdYlGn",
"Spectral",
]
# Check for existing dataframe:
if check_downstream_ligand_effects:
logger.info(
"Checking downstream TF-ligand effects...note that this may not look very nice because TF-target "
"relationships are not spatially dependent like L:R effects are."
)
df_path = os.path.join(
output_folder, f"{adata_id}_distribution_downstream_ligand_effects_along_{save_id}.csv"
)
elif check_downstream_receptor_effects:
logger.info(
"Checking downstream TF-receptor effects...note that this may not look very nice because TF-target "
"relationships are not spatially dependent like L:R effects are."
)
df_path = os.path.join(
output_folder, f"{adata_id}_distribution_downstream_receptor_effects_along_{save_id}.csv"
)
elif check_downstream_target_effects:
logger.info(
"Checking downstream TF-target effects...note that this may not look very nice because TF-target "
"relationships are not spatially dependent like L:R effects are."
)
df_path = os.path.join(
output_folder, f"{adata_id}_distribution_downstream_target_effects_along_{save_id}.csv"
)
else:
df_path = os.path.join(output_folder, f"{adata_id}_distribution_interaction_effects_along_{save_id}.csv")
if os.path.exists(df_path) and not recompute:
to_plot = pd.read_csv(df_path, index_col=0)
if interaction_subset is not None:
selected_interactions = [i for i in to_plot.index if i.split("-")[1] in interaction_subset]
to_plot = to_plot.loc[selected_interactions]
if target_subset is not None:
selected_targets = [t for t in to_plot.index if t.split("-")[0] if t in target_subset]
to_plot = to_plot.loc[selected_targets]
else:
# Determine where to look for coefficients:
if check_downstream_ligand_effects:
all_coeffs = self.downstream_model_ligand_coeffs.copy()
if len(all_coeffs) == 0:
raise ValueError("No downstream model results found for ligands.")
elif check_downstream_receptor_effects:
all_coeffs = self.downstream_model_receptor_coeffs.copy()
if len(all_coeffs) == 0:
raise ValueError("No downstream model results found for receptors.")
elif check_downstream_target_effects:
all_coeffs = self.downstream_model_target_coeffs.copy()
if len(all_coeffs) == 0:
raise ValueError("No downstream model results found for targets.")
else:
all_coeffs = self.coeffs.copy()
if target_subset is None:
target_subset = list(all_coeffs.keys())
else:
target_subset = [t for t in target_subset if t in all_coeffs.keys()]
removed = [t for t in target_subset if t not in all_coeffs.keys()]
if len(removed) > 0:
logger.warning(
f"Targets {removed} were not found in the model, and will be removed from the target subset."
)
if check_downstream_ligand_effects:
if self.downstream_model_ligand_design_matrix is None:
raise ValueError(
"No downstream model design matrix found for ligands. Run "
"`CCI_deg_detection_setup` with use_ligands=True first."
)
all_feature_names = [
feat.replace("regulator_", "") for feat in self.downstream_model_ligand_design_matrix.columns
]
elif check_downstream_receptor_effects:
if self.downstream_model_receptor_design_matrix is None:
raise ValueError(
"No downstream model design matrix found for receptors. Run "
"`CCI_deg_detection_setup` with use_receptors=True first."
)
all_feature_names = [
feat.replace("regulator_", "") for feat in self.downstream_model_receptor_design_matrix.columns
]
elif check_downstream_target_effects:
if self.downstream_model_target_design_matrix is None:
raise ValueError(
"No downstream model design matrix found for target genes. Run "
"`CCI_deg_detection_setup` with use_targets=True first."
)
all_feature_names = [
feat.replace("regulator_", "") for feat in self.downstream_model_target_design_matrix.columns
]
else:
all_feature_names = [feat for feat in self.feature_names if feat != "intercept"]
if interaction_subset is None:
feature_names = all_feature_names
else:
feature_names = [feat for feat in all_feature_names if feat in interaction_subset]
removed = [feat for feat in interaction_subset if feat not in all_feature_names]
if len(removed) > 0:
logger.warning(
f"Interactions {removed} were not found in the model, and will be removed from the interaction "
f"subset."
)
if use_significant:
for target in target_subset:
parent_dir = os.path.dirname(self.output_path)
sig = pd.read_csv(
os.path.join(parent_dir, "significance", f"{target}_is_significant.csv"), index_col=0
)
all_coeffs[target] *= sig
if effect_threshold is not None:
for target in target_subset:
all_coeffs[target] = all_coeffs[target].clip(lower=effect_threshold)
# For each feature-target combination, compute the mean effect across cells:
combinations = list(product(target_subset, feature_names))
combinations = [
(target, feature) for target, feature in combinations if f"b_{feature}" in all_coeffs[target].columns
]
# Remove combinations where the effect is hardly present (arbitrarily defined at 0.5% of cells):
combinations = [
f"{target};{feature}"
for target, feature in combinations
if (all_coeffs[target][f"b_{feature}"] != 0).mean() >= 0.005
]
mean_effect = pd.Series(index=combinations)
for combo in combinations:
target, feature = combo.split(";")
target_coefs = all_coeffs[target][f"b_{feature}"]
mean_effect[combo] = target_coefs.mean()
# For each cell, compute the fold change over the average for each combination:
all_fc = pd.DataFrame(index=self.adata.obs_names, columns=combinations)
for combo in tqdm(combinations, desc="Computing fold changes for interaction-target combinations..."):
target, feature = combo.split(";")
target_coefs = all_coeffs[target][f"b_{feature}"]
all_fc[combo] = target_coefs / mean_effect[combo]
# Log fold change:
all_fc = np.log1p(all_fc)
all_fc[np.isnan(all_fc)] = 0
# z-score the fold change values:
all_fc = all_fc.apply(scipy.stats.zscore, axis=0)
all_fc["pos"] = pos
all_fc_coord_sorted = all_fc.sort_values(by="pos")
# Mean z-score at each coordinate position:
all_fc_coord_sorted = all_fc_coord_sorted.groupby("pos").mean()
# Smooth in the case of dropouts:
all_fc_coord_sorted = all_fc_coord_sorted.rolling(window_size, center=True, min_periods=1).mean()
# For each unique value in 'pos', find the top features with the highest mean z-score
top_combinations = all_fc_coord_sorted.apply(lambda x: x.nlargest(30).index.tolist(), axis=1)
# Find interesting interaction effects by position- get features that are in the top features for at least
# window size positions:
consecutive_counts = {feature: 0 for feature in combinations}
feats_of_interest = set()
for pos in top_combinations.index:
for feature in top_combinations[pos]:
consecutive_counts[feature] += 1
if consecutive_counts[feature] >= int(window_size * 1.67):
feats_of_interest.add(feature)
# I am not sure what this is for. It will remove genes encountered before, and consecutive_counts is no longer used here after.
# for feature in combinations:
# if feature not in top_combinations[pos]:
# consecutive_counts[feature] = 0
feats_of_interest = list(
feats_of_interest
) # fix for set subscription, set subscription is no longer allowed
to_plot = all_fc_coord_sorted[feats_of_interest]
if to_plot.index.is_numeric():
# Minmax scale to normalize positional context:
to_plot.index = (to_plot.index - to_plot.index.min()) / (to_plot.index.max() - to_plot.index.min())
to_plot = to_plot.T # so that the features are labeled along the y-axis
if sort_by_target:
logger.info("Sorting by target gene...")
to_plot["temp"] = to_plot.index.to_series().apply(lambda x: x.split("-")[0])
to_plot = to_plot.sort_values(by="temp")
to_plot = to_plot.drop(columns="temp")
# Sort by "heat" if applicable (i.e. in order roughly determined by how early along the relative position
# the highest z-scores occur in for each interaction-target pair):
elif neatly_arrange_y:
logger.info("Sorting by position of enrichment along axis...")
column_indices = np.tile(np.arange(len(to_plot.columns)), (len(to_plot), 1)) # Column indices array
# Look only at the indices corresponding to the highest changes:
percentile_95 = to_plot.apply(
lambda row: np.percentile(row[row > 0], 95) if row[row > 0].size > 0 else 0, axis=1
)
# Create a DataFrame that replicates the shape of to_plot
weights_matrix = to_plot.gt(percentile_95, axis=0) * to_plot
weighted_sum = np.sum(weights_matrix.values * column_indices, axis=1)
total_weight = np.sum(weights_matrix.values, axis=1)
weighted_avg = pd.Series(np.where(total_weight != 0, weighted_sum / total_weight, 0), index=to_plot.index)
top_cols_sorted = weighted_avg.sort_values().index
to_plot = to_plot.loc[top_cols_sorted]
else:
logger.info("Sorting by interaction...")
to_plot["temp"] = to_plot.index.to_series().apply(lambda x: x.split("-")[-1])
to_plot = to_plot.sort_values(by="temp")
to_plot = to_plot.drop(columns="temp")
flattened = to_plot.values.flatten()
flattened_series = pd.Series(flattened)
percentile_95 = flattened_series.quantile(0.95)
max_val = percentile_95
if figsize is None:
m = len(to_plot) * 40 / 200
n = 8
figsize = (n, m)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
if fontsize is None:
fontsize = rcParams.get("font.size")
# Format the numerical columns for the plot:
# First check whether the columns contain duplicates:
if all(isinstance(name, str) for name in to_plot.columns):
if any([name.count(".") > 1 for name in to_plot.columns]):
to_plot.columns = [".".join(name.split(".")[:2]) for name in to_plot.columns]
to_plot.columns = [float(col) for col in to_plot.columns]
col_series = pd.Series(to_plot.columns)
if set(col_series) != len(col_series):
unique_values, counts = np.unique(col_series, return_counts=True)
# Iterate through unique values
for value, count in zip(unique_values, counts):
if count > 1:
# Find indices of the repeated value
indices = col_series[col_series == value].index
# Calculate step size
if value == unique_values[-1]:
next_value = value + (value - unique_values[-2])
else:
next_index = np.where(unique_values == value)[0][0] + 1
next_value = unique_values[next_index]
step = (next_value - value) / count
# Update the values
for i in range(count):
col_series.iloc[indices[i]] = value + step * i
to_plot.columns = col_series.values
if all(isinstance(name, float) for name in to_plot.columns):
to_plot.columns = [f"{float(col):.3f}" for col in to_plot.columns]
to_plot.columns = [str(col) for col in to_plot.columns]
to_plot.index = [i.replace(":", "-") for i in to_plot.index]
if not sort_by_target and not neatly_arrange_y:
to_plot["sort"] = to_plot.index.to_series().apply(lambda x: x.split("-")[1])
to_plot = to_plot.sort_values(by="sort")
to_plot = to_plot.drop("sort", axis=1)
m = sns.heatmap(to_plot, vmin=-max_val, vmax=max_val, ax=ax, cmap=cmap)
cbar = m.collections[0].colorbar
cbar.set_label("Z-score", fontsize=fontsize * 1.5, labelpad=10)
# Adjust colorbar tick font size
cbar.ax.tick_params(labelsize=fontsize * 1.25)
cbar.ax.set_aspect(np.min([len(to_plot), 70]))
ax.set_xlabel(x_label, fontsize=fontsize * 1.25)
ax.set_ylabel("Interaction Effect on Target (formatted target-interaction)", fontsize=fontsize * 1.25)
ax.tick_params(axis="x", labelsize=fontsize)
ax.tick_params(axis="y", labelsize=fontsize)
ax.set_title(title, fontsize=fontsize * 1.5, pad=20)
if not os.path.exists(df_path):
to_plot.to_csv(df_path)
if save_show_or_return in ["save", "both", "all"]:
save_kwargs["ext"] = "png"
save_kwargs["dpi"] = 300
if "figure_folder" in locals():
save_kwargs["path"] = figure_folder
# Save figure:
save_return_show_fig_utils(
save_show_or_return=save_show_or_return,
show_legend=False,
background="white",
prefix=f"distribution_interaction_effects_along_{save_id}",
save_kwargs=save_kwargs,
total_panels=1,
fig=fig,
axes=ax,
return_all=False,
return_all_list=None,
)
[docs] def effect_distribution_density(
self,
effect_names: List[str],
position_key: str = "spatial",
coord_column: Optional[Union[int, str]] = None,
max_coord_val: float = 1.0,
title: Optional[str] = None,
x_label: Optional[str] = None,
region_lower_bound: Optional[float] = None,
region_upper_bound: Optional[float] = None,
region_label: Optional[str] = None,
fontsize: Union[None, int] = None,
figsize: Union[None, Tuple[float, float]] = None,
save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show",
save_kwargs: Optional[dict] = {},
):
"""Visualize the spatial enrichment of cell-cell interaction effects using density plots over spatial
coordinates. Uses existing dataframe saved by :func:`effect_distribution_heatmap()`, which must be run first.
Args:
effect_names: List of interaction effects to include in plot, in format "Target-Ligand:Receptor"
(for L:R models) or "Target-Ligand" (for ligand models).
position_key: Key in adata.obs or adata.obsm that provides a relative indication of the position of
cells. i.e. spatial coordinates. Defaults to "spatial". For each value in the position array (each
coordinate, each category), multiple cells must have the same value.
coord_column: Optional, only used if "position_key" points to an entry in .obsm. In this case,
this is the index or name of the column to be used to provide the positional context. Can also
provide "xy", "yz", "xz", "-xy", "-yz", "-xz" to draw a line between the two coordinate axes. "xy"
will extend the new axis in the direction of increasing x and increasing y starting from x=0 and y=0 (or
min. x/min. y), "-xy" will extend the new axis in the direction of decreasing x and increasing y
starting from x=minimum x and y=maximum y, and so on.
max_coord_val: Optional, can be used to adjust the numbers displayed along the x-axis for the relative
position along the coordinate axis. Defaults to 1.0.
title: Optional, can be used to provide title for plot
x_label: Optional, can be used to provide x-axis label for plot
region_lower_bound: Optional, can be used to provide a lower bound for the region of interest to label on
the plot- this can correspond to a spatial domain, etc.
region_upper_bound: Optional, can be used to provide an upper bound for the region of interest to label on
the plot- this can correspond to a spatial domain, etc.
region_label: Optional, can be used to provide a label for the region of interest to label on the plot
fontsize: Size of font for x and y labels.
figsize: Size of figure.
cmap: Colormap to use. Options: Any divergent matplotlib colormap.
save_show_or_return: Whether to save, show or return the figure.
If "both", it will save and plot the figure at the same time. If "all", the figure will be saved,
displayed and the associated axis and other object will be return.
save_kwargs: A dictionary that will passed to the save_fig function.
By default it is an empty dictionary and the save_fig function will use the
{"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True,
"verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those
keys according to your needs.
"""
logger = lm.get_main_logger()
if position_key not in self.adata.obsm.keys() and position_key not in self.adata.obs.keys():
raise ValueError(
f"Position key {position_key} not found in adata.obsm or adata.obs. Please provide a valid key."
)
if position_key in self.adata.obsm.keys():
if coord_column in ["xy", "yz", "xz", "-xy", "-yz", "-xz"]:
if title is None:
title = f"Signaling effect density along {coord_column} axis"
if x_label is None:
x_label = f"Relative position along custom {coord_column} axis"
save_id = f"{coord_column}_axis"
else:
if coord_column is not None and not isinstance(coord_column, int):
if not isinstance(self.adata.obsm[position_key], pd.DataFrame):
raise ValueError(
f"Array stored at position key {position_key} has no column names; provide the column "
f"index."
)
elif coord_column is not None and isinstance(coord_column, int):
if isinstance(self.adata.obsm[position_key], pd.DataFrame):
if x_label is None:
x_label = f"Relative position along {coord_column}"
if title is None:
title = f"Signaling effect density along {coord_column}"
save_id = coord_column
else:
if coord_column == 0:
if x_label is None:
x_label = "Relative position along X"
if title is None:
title = "Signaling effect density along X"
save_id = "x_axis"
elif coord_column == 1:
if x_label is None:
x_label = "Relative position along Y"
if title is None:
title = "Signaling effect density along Y"
save_id = "y_axis"
elif coord_column == 2:
if x_label is None:
x_label = "Relative position along Z"
if title is None:
title = "Signaling effect density along Z"
save_id = "z_axis"
elif self.adata.obsm[position_key].shape[1] != 1:
raise ValueError(
f"Array stored at position key {position_key} has more than one column; provide the column "
f"index."
)
else:
if x_label is None:
x_label = "Relative position"
if title is None:
title = f"Signaling effect density along axis given by {position_key} key"
save_id = position_key
else:
if x_label is None:
x_label = "Relative position"
if title is None:
title = f"Signaling effect density along axis given by {position_key} key"
save_id = position_key
# Check for existing dataframe:
output_folder = os.path.join(os.path.dirname(self.output_path), "analyses")
# Use the saved name for the AnnData object to define part of the name of the saved file:
base_name = os.path.basename(self.adata_path)
adata_id = os.path.splitext(base_name)[0]
if not os.path.exists(
os.path.join(output_folder, f"{adata_id}_distribution_interaction_effects_along_{save_id}.csv")
):
raise ValueError(
f"Could not find dataframe saved by effect_distribution_heatmap() for position key {position_key}. "
f"Please run effect_distribution_heatmap() before running this function."
)
to_plot = pd.read_csv(
os.path.join(output_folder, f"{adata_id}_distribution_interaction_effects_along_{save_id}.csv"),
index_col=0,
)
# Format the numerical columns for the plot:
# First check whether the columns contain duplicates:
if all(isinstance(name, str) for name in to_plot.columns):
if any([name.count(".") > 1 for name in to_plot.columns]):
to_plot.columns = [".".join(name.split(".")[:2]) for name in to_plot.columns]
to_plot.columns = [float(col) for col in to_plot.columns]
col_series = pd.Series(to_plot.columns)
if set(col_series) != len(col_series):
unique_values, counts = np.unique(col_series, return_counts=True)
# Iterate through unique values
for value, count in zip(unique_values, counts):
if count > 1:
# Find indices of the repeated value
indices = col_series[col_series == value].index
# Calculate step size
if value == unique_values[-1]:
next_value = value + (value - unique_values[-2])
else:
next_index = np.where(unique_values == value)[0][0] + 1
next_value = unique_values[next_index]
step = (next_value - value) / count
# Update the values
for i in range(count):
col_series.iloc[indices[i]] = value + step * i
to_plot.columns = col_series.values
if all(isinstance(name, float) for name in to_plot.columns):
to_plot.columns = [f"{float(col):.3f}" for col in to_plot.columns]
# Normalize to custom max value if desired:
float_columns = [float(col) for col in to_plot.columns]
current_min = min(float_columns)
current_max = max(float_columns)
normalized_columns = [
(col - current_min) / (current_max - current_min) * max_coord_val for col in float_columns
]
to_plot.columns = [f"{col:.3f}" for col in normalized_columns]
# Rearrange dataframe such that each interaction is its own column:
to_plot = to_plot.T
if not pd.api.types.is_numeric_dtype(to_plot.index):
to_plot.index = pd.to_numeric(to_plot.index)
# For this function, weights cannot be negative, so set all negative values to 0:
to_plot[to_plot < 0] = 0
to_plot["Coord"] = to_plot.index
# Check if any inputs are not included in the dataframe:
missing = [name for name in effect_names if name not in to_plot.columns]
if len(missing) > 0:
logger.warning(
f"Interactions {missing} were not found in the dataframe. They will be removed from the plot."
)
effect_names = [name for name in effect_names if name in to_plot.columns]
if figsize is None:
figsize = (8, 6)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
if fontsize is None:
fontsize = rcParams.get("font.size")
sns.set_style("white")
for effect, color in zip(effect_names, godsnot_102):
sns.kdeplot(x="Coord", weights=effect, data=to_plot, color=color, label=effect, lw=2, ax=ax)
if region_lower_bound is not None and region_upper_bound is not None:
width = region_upper_bound - region_lower_bound
region_box = mpl.patches.Rectangle(
(region_lower_bound, ax.get_ylim()[0]),
width,
ax.get_ylim()[1] - ax.get_ylim()[0],
linewidth=1,
edgecolor="#1CE6FF",
facecolor="#1CE6FF",
alpha=0.2,
)
ax.add_patch(region_box)
region_box_legend = mpl.patches.Patch(color="#1CE6FF", alpha=0.2, label=region_label)
handles, labels = ax.get_legend_handles_labels()
handles.append(region_box_legend)
labels.append(region_label)
ax.legend(handles=handles, labels=labels, loc="upper left", bbox_to_anchor=(1, 1), fontsize=fontsize * 1.25)
else:
ax.legend(loc="upper left", bbox_to_anchor=(1, 1), fontsize=fontsize * 1.25)
ax.set_xlabel(x_label, fontsize=fontsize * 1.25)
ax.set_ylabel("Density", fontsize=fontsize * 1.25)
ax.tick_params(axis="x", labelsize=fontsize)
ax.tick_params(axis="y", labelsize=fontsize, labelleft=False, left=False)
ax.set_title(title, fontsize=fontsize * 1.5, pad=20)
if save_show_or_return in ["save", "both", "all"]:
save_kwargs["ext"] = "png"
save_kwargs["dpi"] = 300
if not os.path.exists(os.path.join(os.path.dirname(self.output_path), "figures")):
os.makedirs(os.path.join(os.path.dirname(self.output_path), "figures"))
figure_folder = os.path.join(os.path.dirname(self.output_path), "figures", "temp")
if not os.path.exists(figure_folder):
os.makedirs(figure_folder)
save_kwargs["path"] = figure_folder
# Save figure:
save_return_show_fig_utils(
save_show_or_return=save_show_or_return,
show_legend=True,
background="white",
prefix=f"density_interaction_effects_along_{save_id}",
save_kwargs=save_kwargs,
total_panels=1,
fig=fig,
axes=ax,
return_all=False,
return_all_list=None,
)
[docs] def visualize_effect_specificity(
self,
agg_method: Literal["mean", "percentage"] = "mean",
plot_type: Literal["heatmap", "volcano"] = "heatmap",
target_subset: Optional[List[str]] = None,
interaction_subset: Optional[List[str]] = None,
ct_subset: Optional[List[str]] = None,
group_key: Optional[str] = None,
n_anchors: Optional[int] = None,
effect_threshold: Optional[float] = None,
use_significant: bool = False,
target_cooccurrence_threshold: float = 0.1,
significance_cutoff: float = 1.3,
fold_change_cutoff: float = 1.5,
fold_change_cutoff_for_labels: float = 3.0,
fontsize: Union[None, int] = None,
figsize: Union[None, Tuple[float, float]] = None,
cmap: str = "seismic",
save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show",
save_kwargs: Optional[dict] = {},
save_df: bool = False,
):
"""Computes and visualizes the specificity of each interaction on each target. This is done by first
separating the target-expressing cells (and their neighbors) from the rest of the cells (conditioned on
predicted effect and also conditioned on receptor expression if L:R model is used). Then, computing the
fold change of the average expression of the ligand in the neighborhood of the first subset vs. the
neighborhoods of the second subset.
Args:
agg_method: Method to use for aggregating the specificity of each interaction on each target. Options:
"mean" for mean ligand expression, "percentage" for the percentage of cells expressing the ligand.
plot_type: Type of plot to use for visualization. Options: "heatmap" for heatmap, "volcano" for volcano
plot.
target_subset: List of targets to consider. If None, will use all targets used in model fitting.
interaction_subset: List of interactions to consider. If None, will use all interactions used in model.
ct_subset: Can be used to constrain the first group of cells (the query group) to the target-expressing
cells of a particular type (conditioned on any other relevant variables). If given, will search
for cell types in "group_key" attribute from model initialization. If not given, will use all cell
types.
group_key: Can be used to specify entry in adata.obs that contains cell type groupings. If None,
will use :attr `group_key` from model initialization.
n_anchors: Optional, number of target gene-expressing cells to use as anchors for analysis. Will be
selected randomly from the set of target gene-expressing cells (conditioned on any other relevant
values).
effect_threshold: Optional threshold minimum effect size to consider an effect for further analysis,
as an absolute value. Use this to choose only the cells for which an interaction is predicted to
have a strong effect. If None, use the median interaction effect.
use_significant: Whether to use only significant effects in computing the specificity. If True,
will filter to cells + interactions where the interaction is significant for the target. Only valid
if :func `compute_coeff_significance()` has been run.
significance_cutoff: Cutoff for negative log-10 q-value to consider an interaction/effect significant. Only
used if "plot_type" is "volcano". Defaults to 1.3 (corresponding to an approximate q-value of 0.05).
fold_change_cutoff: Cutoff for fold change to consider an interaction/effect significant. Only used if
"plot_type" is "volcano". Defaults to 1.5.
fold_change_cutoff_for_labels: Cutoff for fold change to include the label for an interaction/effect.
Only used if "plot_type" is "volcano". Defaults to 3.0.
fontsize: Size of font for x and y labels.
figsize: Size of figure.
cmap: Colormap to use. Options: Any divergent matplotlib colormap.
save_show_or_return: Whether to save, show or return the figure.
If "both", it will save and plot the figure at the same time. If "all", the figure will be saved,
displayed and the associated axis and other object will be return.
save_kwargs: A dictionary that will passed to the save_fig function.
By default it is an empty dictionary and the save_fig function will use the
{"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True,
"verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those
keys according to your needs.
save_df: Set True to save the metric dataframe in the end
"""
from ..find_neighbors import neighbors
logger = lm.get_main_logger()
config_spateo_rcParams()
# But set display DPI to 300:
plt.rcParams["figure.dpi"] = 300
if self.mod_type != "lr" and self.mod_type != "ligand":
raise ValueError("This function is only applicable for ligand-based models.")
if save_show_or_return in ["save", "both", "all"]:
if not os.path.exists(os.path.join(os.path.dirname(self.output_path), "figures")):
os.makedirs(os.path.join(os.path.dirname(self.output_path), "figures"))
figure_folder = os.path.join(os.path.dirname(self.output_path), "figures", "temp")
if not os.path.exists(figure_folder):
os.makedirs(figure_folder)
if save_df:
output_folder = os.path.join(os.path.dirname(self.output_path), "analyses")
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# Use the saved name for the AnnData object to define part of the name of the saved file:
base_name = os.path.basename(self.adata_path)
adata_id = os.path.splitext(base_name)[0]
# Colormap should be divergent:
divergent_cmaps = [
"seismic",
"coolwarm",
"bwr",
"RdBu",
"RdGy",
"PuOr",
"PiYG",
"PRGn",
"BrBG",
"RdYlBu",
"RdYlGn",
"Spectral",
]
if cmap not in divergent_cmaps:
logger.warning(
f"Colormap {cmap} is not divergent, which is recommended for this plot type. Using 'seismic' instead."
)
cmap = "seismic"
if target_subset is None:
target_subset = list(self.coeffs.keys())
else:
target_subset = [t for t in target_subset if t in self.coeffs.keys()]
removed = [t for t in target_subset if t not in self.coeffs.keys()]
if len(removed) > 0:
logger.warning(
f"Targets {removed} were not found in the model, and will be removed from the target subset."
)
all_feature_names = [feat for feat in self.feature_names if feat != "intercept"]
if interaction_subset is None:
feature_names = all_feature_names
else:
feature_names = [feat for feat in all_feature_names if feat in interaction_subset]
removed = [feat for feat in interaction_subset if feat not in all_feature_names]
if len(removed) > 0:
logger.warning(
f"Interactions {removed} were not found in the model, and will be removed from the interaction "
f"subset."
)
if ct_subset is not None and group_key is None:
group_key = self.group_key
if fontsize is None:
fontsize = rcParams.get("font.size")
if plot_type == "heatmap":
x_label = "Neighboring Ligand" if self.mod_type == "ligand" else "L:R Interaction"
y_label = "Target Gene"
title = "Fold Change Interaction Enrichment \n Target-Expressing Cells vs. Others"
cbar_label = "$\\log_2$(Fold change Interaction Enrichment \n Target-Expressing Cells vs. Others"
else:
x_label = "$\\log_2$(Fold change Interaction Enrichment \n Target-Expressing Cells vs. Others"
y_label = r"$-log_10$(qval)"
title = "Fold Change Interaction Enrichment \n Target-Expressing Cells vs. Others"
# Check for already-existing dataframe:
try:
output_folder = os.path.join(os.path.dirname(self.output_path), "analyses")
df = pd.read_csv(
os.path.join(
os.path.dirname(self.output_path),
"analyses",
f"{plot_type}_{adata_id}_interaction_enrichment_fold_change_target_expressing_v_nonexpressing.csv",
),
index_col=0,
)
if interaction_subset is not None:
df = df.loc[[i for i in df.index if i.split(":")[0] in interaction_subset]]
if target_subset is not None:
df = df.loc[[i for i in df.index if i.split(":")[1] in target_subset]]
except:
if plot_type == "heatmap":
df = pd.DataFrame(index=target_subset, columns=feature_names)
else:
combinations = product(feature_names, target_subset)
combinations = [f"{feature}-{target}" for feature, target in combinations]
df = pd.DataFrame(
index=combinations, columns=["log2FC", "p-value", "q-value", "Significance", "-log10(qval)"]
)
if (
"spatial_connectivities_secreted" in self.adata.obsp.keys()
and "spatial_connectivities_membrane_bound" in self.adata.obsp.keys()
):
conn_secreted = self.adata.obsp["spatial_connectivities_secreted"]
conn_membrane_bound = self.adata.obsp["spatial_connectivities_membrane_bound"]
else:
logger.info("Spatial graph not found, computing...")
adata = self.adata.copy()
_, adata = neighbors(
adata,
n_neighbors=self.n_neighbors_secreted,
basis="spatial",
spatial_key=self.coords_key,
n_neighbors_method="ball_tree",
)
conn_secreted = adata.obsp["spatial_connectivities"]
adata = self.adata.copy()
_, adata = neighbors(
adata,
n_neighbors=self.n_neighbors_membrane_bound,
basis="spatial",
spatial_key=self.coords_key,
n_neighbors_method="ball_tree",
)
conn_membrane_bound = adata.obsp["spatial_connectivities"]
self.adata.obsp["spatial_connectivities_secreted"] = conn_secreted
self.adata.obsp["spatial_connectivities_membrane_bound"] = conn_membrane_bound
# For each target, split cells into two groups: target-expressing and all neighbors of target-expressing
# cells, and the remainder.
for target in target_subset:
coef_target = self.coeffs[target].loc[adata.obs_names]
if effect_threshold is None:
nonzero_values = coef_target.values.flatten()
nonzero_values = nonzero_values[nonzero_values != 0]
effect_threshold = pd.Series(nonzero_values).quantile(0.75)
if use_significant:
parent_dir = os.path.dirname(self.output_path)
sig = pd.read_csv(
os.path.join(parent_dir, "significance", f"{target}_is_significant.csv"), index_col=0
)
coef_target *= sig
# Taking the first group (the query group)- first subset to cell types of interest, if given:
if ct_subset is not None:
query_adata = self.adata[self.adata.obs[group_key].isin(ct_subset)].copy()
else:
query_adata = self.adata.copy()
# Define masks for target expression:
target_expression = query_adata[:, target].X.toarray().reshape(-1)
target_expressing_mask = target_expression > 0
target_expressing_cells = query_adata.obs_names[target_expressing_mask]
# Define interaction-specific masks: (optionally, if L:R model) for cells expressing receptor,
# for cells predicted to be affected by an interaction and all of the neighbors of these cells:
for interaction in feature_names:
if f"b_{interaction}" not in coef_target.columns:
# Significance for this interaction-target combination:
if plot_type == "volcano":
df.loc[f"{interaction}-{target}", "p-value"] = 1.0
df.loc[f"{interaction}-{target}", "log2FC"] = 0.0
else:
df.loc[target, interaction] = 0.0
continue
if self.mod_type == "lr":
ligand, receptor = interaction.split(":")
receptor_expressing_mask = np.ones(query_adata.shape[0], dtype=bool)
for r in receptor.split("_"):
receptor_expression = query_adata[:, r].X.toarray().reshape(-1)
receptor_expressing_mask &= receptor_expression > 0
receptor_expressing_cells = query_adata.obs_names[receptor_expressing_mask]
coef_interaction_target = coef_target[f"b_{interaction}"]
coef_interaction_target_mask = coef_interaction_target > effect_threshold
coef_interaction_target_cells = query_adata.obs_names[coef_interaction_target_mask]
# Get mask for any neighbors of these cells:
# Check whether to use the neighbors found for membrane-bound interaction or those for secreted
# interactions:
to_check = interaction.split(":")[0] if ":" in interaction else interaction
if "/" in to_check:
interaction_components = to_check.split("/")
separator = "/"
elif "_" in to_check:
interaction_components = to_check.split("_")
separator = "_"
else:
interaction_components = [to_check]
separator = None
matching_rows = self.lr_db[self.lr_db["from"].isin(interaction_components)]
if (
matching_rows["type"].str.contains("Secreted Signaling").any()
or matching_rows["type"].str.contains("ECM-Receptor").any()
):
conn = conn_secreted
else:
conn = conn_membrane_bound
if self.mod_type != "lr":
# Get the intersection of cells expressing target and predicted to be affected by the interaction:
adata_mask = target_expressing_cells.intersection(coef_interaction_target_cells)
else:
# Get the intersection of cells expressing target and receptor and are predicted to be affected by
# interaction:
adata_mask = target_expressing_cells.intersection(receptor_expressing_cells)
adata_mask = adata_mask.intersection(coef_interaction_target_cells)
# This object contains samples that can constitute the query group:
query_adata_sub = query_adata[adata_mask].copy()
# This object contains the other samples, that can constitute the reference:
neg_mask = [
n
for n in self.adata.obs_names
if n not in target_expressing_cells and n not in coef_interaction_target_cells
]
reference_adata_sub = self.adata[neg_mask].copy()
if query_adata_sub.n_obs <= 30:
logger.info(
f"Insufficient query cells found for this interaction-target combination (likely based on "
f"absence of strong interaction effect)- {interaction}-{target}. Skipping."
)
del conn, query_adata_sub, reference_adata_sub
gc.collect()
if plot_type == "volcano":
df.loc[f"{interaction}-{target}", "p-value"] = 1
df.loc[f"{interaction}-{target}", "log2FC"] = 0.0
else:
df.loc[target, interaction] = 0.0
continue
# Query group:
# If applicable, select a subset of these cells to use as anchors:
if n_anchors is not None:
if query_adata_sub.n_obs < n_anchors:
logger.warning(
f"Number of anchors ({n_anchors}) is greater than number of target-expressing cells "
f"({query_adata_sub.n_obs}) for target {target} and interaction {interaction}. "
f"Skipping."
)
del conn, query_adata_sub, reference_adata_sub
gc.collect()
if plot_type == "volcano":
df.loc[f"{interaction}-{target}", "p-value"] = 1
df.loc[f"{interaction}-{target}", "log2FC"] = 0.0
else:
df.loc[target, interaction] = 0.0
continue
else:
if query_adata_sub.n_obs < 200:
logger.warning(
f"Number of target-expressing cells ({query_adata_sub.n_obs}) is less than 100 for "
f"target {target} and interaction {interaction}. Skipping."
)
del conn, query_adata_sub, reference_adata_sub
gc.collect()
if plot_type == "volcano":
df.loc[f"{interaction}-{target}", "p-value"] = 1
df.loc[f"{interaction}-{target}", "log2FC"] = 0.0
else:
df.loc[target, interaction] = 0.0
continue
anchors = np.random.choice(query_adata_sub.obs_names, size=n_anchors, replace=False)
selected_indices = [np.where(self.adata.obs_names == string)[0][0] for string in anchors]
# Get neighbors of these cells:
neighbors = conn[selected_indices].nonzero()[1]
neighbors = np.unique(neighbors)
# Remove the anchor cells from the neighbors:
neighbors = neighbors[~np.isin(neighbors, selected_indices)]
neighbors_selected = self.adata.obs_names[neighbors]
# The query group: anchors and their neighbors:
query_group = anchors.tolist() + neighbors_selected.tolist()
# Reference group:
# If applicable, select a subset of these cells to use as anchors:
anchors = np.random.choice(reference_adata_sub.obs_names, size=n_anchors, replace=False)
selected_indices = [np.where(self.adata.obs_names == string)[0][0] for string in anchors]
# Get neighbors of these cells:
neighbors = conn[selected_indices].nonzero()[1]
neighbors = np.unique(neighbors)
# Remove the anchor cells from the neighbors:
neighbors = neighbors[~np.isin(neighbors, selected_indices)]
neighbors_selected = self.adata.obs_names[neighbors]
# The reference group: anchors and their neighbors:
reference_group = anchors.tolist() + neighbors_selected.tolist()
# Ligand expression in the selected cells:
ligand = interaction.split(":")[0] if ":" in interaction else interaction
components = ligand.split(separator) if separator is not None else [ligand]
# Compute ligand values for query + reference together before separating them for fold change
# calculation:
ligand_values = self.adata[query_group + reference_group, components].X.toarray()
if separator == "/":
# Arithmetic mean of the genes
ligand_values = np.mean(ligand_values, axis=1)
elif separator == "_":
# Geometric mean of the genes
# Replace zeros with np.nan
ligand_values[ligand_values == 0] = np.nan
# Compute product along the rows
products = np.nanprod(ligand_values, axis=1)
# Count non-nan values in each row for nth root calculation
non_nan_counts = np.sum(~np.isnan(ligand_values), axis=1)
# Avoid division by zero
non_nan_counts[non_nan_counts == 0] = np.nan
ligand_values = np.power(products, 1 / non_nan_counts)
ligand_values = pd.DataFrame(ligand_values, index=query_group + reference_group, columns=[ligand])
ligand_query = ligand_values.loc[query_group, :]
ligand_reference = ligand_values.loc[reference_group, :]
# Significance for this interaction-target combination:
if plot_type == "volcano":
if (ligand_reference == 0).all().all():
df.loc[f"{interaction}-{target}", "p-value"] = 0
else:
df.loc[f"{interaction}-{target}", "p-value"] = mannwhitneyu(ligand_query, ligand_reference)[
1
]
if agg_method == "mean":
ligand_query = ligand_query.mean().values
ligand_reference = ligand_reference.mean().values
elif agg_method == "percentage":
ligand_query = (ligand_query > 0).mean().values
ligand_reference = (ligand_reference > 0).mean().values
if ligand_reference == 0:
# Prevent division by zero, this will get set to the max threshold anyways:
ligand_reference = 0.001
fold_change = np.log2(ligand_query / ligand_reference)
if plot_type == "volcano":
df.loc[f"{interaction}-{target}", "log2FC"] = fold_change
else:
df.loc[target, interaction] = fold_change
del conn, query_adata_sub, reference_adata_sub
gc.collect()
logger.info(f"Finished computing specificity for target {target}.")
# If relevant, compute adjusted p-values:
if plot_type == "volcano":
df["log2FC"] = df["log2FC"].apply(lambda x: x[0] if isinstance(x, np.ndarray) else x)
df["p-value"] = df["p-value"].apply(lambda x: x[0] if isinstance(x, np.ndarray) else x)
df["q-value"] = multitesting_correction(df["p-value"].values, method="fdr_bh")
df["Significance"] = df["q-value"] < 0.05
df["-log10(qval)"] = -np.log10(df["q-value"])
# And if relevant, perform hierarchical clustering- first to group interactions w/ similar fold changes
# across targets:
if plot_type == "heatmap":
col_linkage = sch.linkage(df.transpose(), method="ward")
col_dendro = sch.dendrogram(col_linkage, no_plot=True)
col_clustered_order = col_dendro["leaves"]
df = df.iloc[:, col_clustered_order]
# Then to group targets w/ similar fold changes across interactions:
row_linkage = sch.linkage(df, method="ward")
row_dendro = sch.dendrogram(row_linkage, no_plot=True)
row_clustered_order = row_dendro["leaves"]
df = df.iloc[row_clustered_order, :]
# Plot:
if figsize is None:
if plot_type == "heatmap":
# Set figure size based on the number of interaction features and targets:
m = len(target_subset) * 50 / 200
n = len(feature_names) * 50 / 200
else:
m = 6
n = 6
figsize = (n, m)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
cmap = mpl.colormaps[cmap]
# Center colormap at 0 for heatmap:
if plot_type == "heatmap":
max_distance = max(abs(df.max().max()), abs(df.min().min()))
norm = plt.Normalize(-max_distance, max_distance)
colors = cmap(norm(df))
custom_cmap = mpl.colors.LinearSegmentedColormap.from_list("custom", colors)
else:
max_distance = max(abs(df["log2FC"].max()), abs(df["log2FC"].min()))
if plot_type == "volcano":
if len(df) > 20:
size = 20
else:
size = 40
significant = df["-log10(qval)"] > significance_cutoff
significant_up = df["log2FC"] > fold_change_cutoff
significant_down = df["log2FC"] < -fold_change_cutoff
# Check if max -log10(qval) is greater than 8
if df["-log10(qval)"].max() > 8:
ax.set_yscale("log", base=2) # Set y-axis to log
y_label = r"$-log_10$(qval) ($log_2$ scale)"
sns.scatterplot(
x=df["log2FC"][significant & significant_up],
y=df["-log10(qval)"][significant & significant_up],
hue=df["log2FC"][significant & significant_up],
palette="Reds",
vmin=0,
edgecolor="black",
ax=ax,
s=size,
legend=False,
)
sns.scatterplot(
x=df["log2FC"][significant & significant_down],
y=df["-log10(qval)"][significant & significant_down],
hue=df["log2FC"][significant & significant_down],
palette="Blues_r",
vmax=0,
edgecolor="black",
ax=ax,
s=size,
legend=False,
)
sns.scatterplot(
x=df["log2FC"][~(significant & (significant_up | significant_down))],
y=df["-log10(qval)"][~(significant & (significant_up | significant_down))],
color="grey",
edgecolor="black",
ax=ax,
s=size,
)
# Add labels for significant interactions:
high_fold_change = df[abs(df["log2FC"]) > fold_change_cutoff_for_labels]
while high_fold_change.empty:
fold_change_cutoff_for_labels /= 2 # Halve the cutoff
high_fold_change = df[abs(df["log2FC"]) > fold_change_cutoff_for_labels]
text_labels = high_fold_change.index.tolist()
x_coord_text_labels = high_fold_change["log2FC"].tolist()
y_coord_text_labels = high_fold_change["-log10(qval)"].tolist()
text_objects = []
for i, label in enumerate(text_labels):
t = ax.text(
x_coord_text_labels[i],
y_coord_text_labels[i],
label,
fontsize=fontsize * 0.75,
color="black",
ha="center",
va="center",
)
text_objects.append(t)
adjust_text(text_objects, ax=ax, arrowprops=dict(arrowstyle="-", color="black", lw=0.5))
ax.axhline(y=significance_cutoff, color="grey", linestyle="--", linewidth=1.5)
ax.axvline(x=fold_change_cutoff, color="grey", linestyle="--", linewidth=1.5)
ax.axvline(x=-fold_change_cutoff, color="grey", linestyle="--", linewidth=1.5)
ax.set_xlim(df["log2FC"].min() - 0.2, max_distance + 0.2)
ax.set_xticklabels(["{:.2f}".format(x) for x in ax.get_xticks()], fontsize=fontsize)
ax.set_yticklabels(["{:.2f}".format(y) for y in ax.get_yticks()], fontsize=fontsize)
ax.set_xlabel(x_label, fontsize=fontsize * 1.25)
ax.set_ylabel(y_label, fontsize=fontsize * 1.25)
ax.set_title(title, fontsize=fontsize * 1.5)
prefix = "volcano"
elif plot_type == "heatmap":
vmin = -max_distance
vmax = max_distance
thickness = 0.3 * figsize[0] / 10
mask = np.abs(df) < 0.1
m = sns.heatmap(
df,
square=True,
linecolor="grey",
linewidths=thickness,
cbar_kws={"label": cbar_label, "location": "top", "pad": 0},
cmap=custom_cmap,
vmin=vmin,
vmax=vmax,
mask=mask,
ax=ax,
)
# Outer frame:
for _, spine in m.spines.items():
spine.set_visible(True)
spine.set_linewidth(thickness * 2.5)
# Adjust colorbar settings:
divider = make_axes_locatable(ax)
# Append axes to the top of the plot, where the colorbar will be placed
if df.shape[0] > df.shape[1]:
cax = divider.append_axes("top", size="30%", pad=0)
else:
cax = divider.append_axes("top", size="30%", pad="30%")
# Create the colorbar manually in the appended axes
cbar = plt.colorbar(m.collections[0], cax=cax, orientation="horizontal")
cbar.set_label(cbar_label.title(), fontsize=fontsize * 1.5, labelpad=10)
cbar.ax.xaxis.set_ticks_position("top") # Move ticks to the top
cbar.ax.xaxis.set_label_position("top") # Move the label (title) to the top
cbar.ax.tick_params(labelsize=fontsize * 1.5)
cbar.ax.set_aspect(0.02)
ax.set_xlabel(x_label, fontsize=fontsize * 1.25)
ax.set_ylabel("Cell Type-Specific Target", fontsize=fontsize * 1.25)
ax.tick_params(axis="x", labelsize=fontsize, rotation=90)
ax.tick_params(axis="y", labelsize=fontsize)
ax.set_title(title, fontsize=fontsize * 1.5, pad=20)
prefix = "heatmap"
if save_df:
df.to_csv(
os.path.join(
output_folder,
f"{prefix}_{adata_id}_interaction_enrichment_fold_change_target_expressing_v_nonexpressing.csv",
)
)
if save_show_or_return in ["save", "both", "all"]:
save_kwargs["ext"] = "png"
save_kwargs["dpi"] = 300
if "figure_folder" in locals():
save_kwargs["path"] = figure_folder
# Save figure:
save_return_show_fig_utils(
save_show_or_return=save_show_or_return,
show_legend=False,
background="white",
prefix=prefix,
save_kwargs=save_kwargs,
total_panels=1,
fig=fig,
axes=ax,
return_all=False,
return_all_list=None,
)
[docs] def visualize_neighborhood(
self,
target: str,
interaction: str,
interaction_type: Literal["secreted", "membrane-bound"],
select_examples_criterion: Literal["positive", "negative"] = "positive",
effect_threshold: Optional[float] = None,
cell_type: Optional[str] = None,
group_key: Optional[str] = None,
use_significant: bool = False,
n_anchors: int = 100,
n_neighbors_expressing: int = 20,
display_plot: bool = True,
) -> anndata.AnnData:
"""Sets up AnnData object for visualization of interaction effects- cells will be colored by expression of
the target gene, potentially conditioned on receptor expression, and neighboring cells will be colored by
ligand expression.
Args:
target: Target gene of interest
interaction: Interaction feature to visualize, given in the same form as in the design matrix (if model
is a ligand-based model or receptor-based model, this will be of form "Col4a1". If model is a
ligand-receptor based model, this will be of form "Col4a1:Itgb1", for example).
interaction_type: Specifies whether the chosen interaction is secreted or membrane-bound. Options:
"secreted" or "membrane-bound".
select_examples_criterion: Whether to select cells with positive or negative interaction effects for
visualization. Defaults to "positive", which searches for cells for which the predicted interaction
effect is above the given threshold. "Negative" will select cells for which the predicted interaction
has no effect on the target expression.
effect_threshold: Optional threshold for the effect size of an interaction/effect to be considered for
analysis; only used if "to_plot" is "percentage". If not given, will use the upper quartile value
among all interaction effect values to determine the threshold.
cell_type: Optional, can be used to select anchor cells from only a particular cell type. If None,
will select from all cells.
group_key: Can be used to specify entry in adata.obs that contains cell type groupings. If None,
will use :attr `group_key` from model initialization. Only used if "cell_type" is not None.
use_significant: Whether to use only significant effects in computing the specificity. If True,
will filter to cells + interactions where the interaction is significant for the target. Only valid
if :func `compute_coeff_significance()` has been run.
n_anchors: Number of target gene-expressing cells to use as anchors for visualization. Will be selected
randomly from the set of target gene-expressing cells.
n_neighbors_expressing: Filters the set of cells that can be selected as anchors based on the number of
their neighbors that express the chosen ligand. Only used for models that incorporate ligand expression.
display_plot: Whether to save a plot. If False, will return the AnnData object without doing
anything else- this can then be visualized e.g. using spateo-viewer.
Returns:
adata: Modified AnnData object containing the expression information for the target gene and neighboring
ligand expression.
"""
# Compute connectivity matrix if not already existing- only needed for ligand and L:R models:
from ..find_neighbors import neighbors
logger = lm.get_main_logger()
if display_plot:
figure_folder = os.path.join(os.path.dirname(self.output_path), "figures")
if not os.path.exists(figure_folder):
os.makedirs(figure_folder)
path = os.path.join(
figure_folder, f"{target}_{select_examples_criterion}_cells_example_neighborhoods_{interaction}.html"
)
logger.info(f"Saving plot to {path}")
if self.mod_type != "lr" and self.mod_type != "ligand":
raise ValueError("This function is only applicable for ligand-based models.")
if select_examples_criterion not in ["positive", "negative"]:
raise ValueError("Invalid criterion for selecting examples. Options: 'positive', 'negative'.")
try:
membrane_bound_path = os.path.join(
os.path.splitext(self.output_path)[0], "spatial_weights", "spatial_weights_membrane_bound.npz"
)
secreted_path = os.path.join(
os.path.splitext(self.output_path)[0], "spatial_weights", "spatial_weights_secreted.npz"
)
spatial_weights_membrane_bound = scipy.sparse.load_npz(membrane_bound_path)
conn_membrane_bound = spatial_weights_membrane_bound > 0
spatial_weights_secreted = scipy.sparse.load_npz(secreted_path)
conn_secreted = spatial_weights_secreted > 0
except:
if (
"spatial_connectivities_secreted" in self.adata.obsp.keys()
and "spatial_connectivities_membrane_bound" in self.adata.obsp.keys()
):
conn_secreted = self.adata.obsp["spatial_connectivities_secreted"]
conn_membrane_bound = self.adata.obsp["spatial_connectivities_membrane_bound"]
else:
logger.info("Spatial graph not found, computing...")
if interaction_type == "secreted":
adata = self.adata.copy()
_, adata_secreted = neighbors(
adata,
n_neighbors=self.n_neighbors_secreted,
basis="spatial",
spatial_key=self.coords_key,
n_neighbors_method="ball_tree",
)
conn_secreted = adata_secreted.obsp["spatial_connectivities"]
self.adata.obsp["spatial_connectivities_secreted"] = conn_secreted
conn = conn_secreted
elif interaction_type == "membrane-bound":
adata = self.adata.copy()
_, adata_membrane_bound = neighbors(
adata,
n_neighbors=self.n_neighbors_membrane_bound,
basis="spatial",
spatial_key=self.coords_key,
n_neighbors_method="ball_tree",
)
conn_membrane_bound = adata_membrane_bound.obsp["spatial_connectivities"]
self.adata.obsp["spatial_connectivities_membrane_bound"] = conn_membrane_bound
conn = conn_membrane_bound
else:
raise ValueError("Invalid interaction type. Options: 'secreted', 'membrane-bound'.")
if interaction_type == "secreted":
conn = conn_secreted
elif interaction_type == "membrane-bound":
conn = conn_membrane_bound
else:
raise ValueError("Invalid interaction type. Options: 'secreted', 'membrane-bound'.")
adata = self.adata.copy()
if cell_type is not None:
if group_key is None:
group_key = self.group_key
# Get the cells of the specified cell type:
cell_type_mask = adata.obs[group_key] == cell_type
adata_ct = adata[cell_type_mask, :].copy()
adata_ct_cells = adata_ct.obs_names
coef_target = self.coeffs[target].loc[adata.obs_names]
if effect_threshold is None:
nonzero_values = coef_target.values.flatten()
nonzero_values = nonzero_values[nonzero_values != 0]
effect_threshold = pd.Series(nonzero_values).quantile(0.75)
if use_significant:
parent_dir = os.path.dirname(self.output_path)
sig = pd.read_csv(os.path.join(parent_dir, "significance", f"{target}_is_significant.csv"), index_col=0)
coef_target *= sig
if hasattr(self, "remaining_cells"):
adata = adata[self.remaining_cells, :].copy()
conn = conn[self.remaining_indices, :][:, self.remaining_indices].copy()
# Compute the multiple possible masks that can be used to subset to the cells of interest:
# Get the target gene expression:
target_expression = adata[:, target].X.toarray().reshape(-1)
# Get the interaction effect:
interaction_effect = coef_target.loc[adata.obs_names, f"b_{interaction}"].values
# Get the cells expressing the target gene:
target_expressing_mask = target_expression > 0
target_expressing_cells = adata.obs_names[target_expressing_mask]
# Cells with significant interaction effect on target:
if select_examples_criterion == "positive":
interaction_mask = np.abs(interaction_effect) > effect_threshold
else:
interaction_mask = interaction_effect == 0
interaction_cells = adata.obs_names[interaction_mask]
# If applicable, split the interaction feature and get the ligand and receptor- for features w/ multiple
# ligands or multiple receptors, process accordingly:
to_check = interaction.split(":")[0] if ":" in interaction else interaction
if "/" in to_check:
genes = to_check.split("/")
separator = "/"
elif "_" in to_check:
genes = to_check.split("_")
separator = "_"
else:
genes = [to_check]
separator = None
if separator == "/":
# Cells expressing any of the genes
ligand_expr_mask = np.zeros(len(adata), dtype=bool)
for gene in genes:
ligand_expr_mask |= adata[:, gene].X.toarray().squeeze() > 0
elif separator == "_":
# Cells expressing all of the genes
ligand_expr_mask = np.ones(len(adata), dtype=bool)
for gene in genes:
ligand_expr_mask &= adata[:, gene].X.toarray().squeeze() > 0
else:
# Single gene
ligand_expr_mask = adata[:, to_check].X.toarray().squeeze() > 0
# Check how many cells have sufficient number of neighbors expressing the ligand:
neighbor_counts = np.zeros(len(adata))
for i in range(len(adata)):
# Get neighbors
neighbors = conn[i].nonzero()[1]
neighbor_counts[i] = np.sum(ligand_expr_mask[neighbors])
# Get the cells with sufficient number of neighbors expressing the ligand:
cells_meeting_neighbor_ligand_threshold = adata.obs_names[neighbor_counts > n_neighbors_expressing]
if self.mod_type == "lr":
to_check = interaction.split(":")[1] if ":" in interaction else interaction
if "_" in to_check:
genes = to_check.split("_")
separator = "_"
else:
genes = [to_check]
separator = None
if separator == "_":
# Cells expressing all of the genes
receptor_expr_mask = np.ones(len(adata), dtype=bool)
for gene in genes:
receptor_expr_mask &= adata[:, gene].X.toarray().squeeze() > 0
else:
# Single gene
receptor_expr_mask = adata[:, to_check].X.toarray().squeeze() > 0
# Get the cells expressing the receptor, to further subset the target-expressing cells to also :
receptor_expressing_cells = adata.obs_names[receptor_expr_mask]
elif self.mod_type == "ligand":
# True negative examples will express the target, but not be predicted to be affected by the interaction
# and either not have evidence of receptor/TF expression or not have ligand expression in the neighborhood:
X_df = pd.read_csv(
os.path.join(os.path.splitext(self.output_path)[0], "design_matrix", "design_matrix.csv"), index_col=0
)
if select_examples_criterion == "positive":
factor_expr_mask = X_df.loc[adata.obs_names, interaction] > 0
else:
factor_expr_mask = X_df.loc[adata.obs_names, interaction] == 0
factor_expr_cells = adata.obs_names[factor_expr_mask]
if select_examples_criterion == "positive":
if self.mod_type == "lr":
# Get the intersection of cells expressing target, predicted to be affected by interaction,
# with sufficient number of neighbors expressing the chosen ligand and expressing receptor:
adata_mask = (
target_expressing_cells
& interaction_cells
& cells_meeting_neighbor_ligand_threshold
& receptor_expressing_cells
)
else:
# Get the intersection of cells expressing target, predicted to be affected by interaction,
# with sufficient number of neighbors expressing the chosen ligand and expressing the receptor or the
# downstream factors of the receptor:
adata_mask = (
target_expressing_cells
& interaction_cells
& cells_meeting_neighbor_ligand_threshold
& factor_expr_cells
)
else:
# In this case, note that "interaction_cells" are actually those cells that are predicted not to be
# affected by the interaction (and "factor_expr_cells" are actually those that don't express any of the
# key downstream factors or the receptor):
adata_mask = target_expressing_cells & interaction_cells & factor_expr_cells
adata_sub = adata[adata_mask].copy()
if cell_type is not None:
adata_sub = adata_sub[adata_sub.obs[group_key] == cell_type].copy()
logger.info(
f"Randomly selecting {select_examples_criterion} example cells from a pool of {adata_sub.n_obs} for target"
f" {target} and interaction {interaction}."
)
if adata_sub.n_obs < n_anchors:
logger.info(
f"Given the constraints, not enough cells remain to choose {n_anchors} cells. Selecting all "
f"{adata_sub.n_obs} eligible cells instead."
)
n_anchors = min(n_anchors, adata_sub.n_obs)
# Randomly choose a subset of target cells to use as anchors:
if n_anchors == adata_sub.n_obs:
target_expressing_selected = adata_sub.obs_names
else:
target_expressing_selected = np.random.choice(adata_sub.obs_names, size=n_anchors, replace=False)
selected_indices = [np.where(adata.obs_names == string)[0][0] for string in target_expressing_selected]
# Find the neighbors of these anchor cells:
neighbors = conn[selected_indices].nonzero()[1]
neighbors = np.unique(neighbors)
# Remove the anchor cells from the neighbors:
neighbors = neighbors[~np.isin(neighbors, selected_indices)]
neighbors_selected = adata.obs_names[neighbors]
# Also make note of the nonselected cells & their neighbors if cell type parameter was given:
if cell_type is not None:
selected_and_neighbors = target_expressing_selected.tolist() + neighbors_selected.tolist()
ct_other_cells = [cell for cell in adata_ct_cells if cell not in selected_and_neighbors]
ct_other_indices = [np.where(adata.obs_names == string)[0][0] for string in ct_other_cells]
# Target expression in the selected cells:
target_expression = adata_sub[target_expressing_selected, target].X.toarray().squeeze()
# Ligand expression in the neighbors:
ligand = interaction.split(":")[0] if ":" in interaction else interaction
genes = ligand.split(separator) if separator is not None else [ligand]
gene_values = adata[neighbors_selected, genes].X.toarray()
if separator == "/":
# Arithmetic mean of the genes
ligand_expression = np.mean(gene_values, axis=1)
elif separator == "_":
# Geometric mean of the genes
# Replace zeros with np.nan
gene_values[gene_values == 0] = np.nan
# Compute product along the rows
products = np.nanprod(gene_values, axis=1)
# Count non-nan values in each row for nth root calculation
non_nan_counts = np.sum(~np.isnan(gene_values), axis=1)
# Avoid division by zero
non_nan_counts[non_nan_counts == 0] = np.nan
ligand_expression = np.power(products, 1 / non_nan_counts)
else:
ligand_expression = adata[neighbors_selected, ligand].X.toarray().squeeze()
adata.obs[f"{interaction}_{target}_{select_examples_criterion}_example_points"] = 0.0
adata.obs.loc[
target_expressing_selected, f"{interaction}_{target}_{select_examples_criterion}_example_points"
] = target_expression
adata.obs.loc[
neighbors_selected, f"{interaction}_{target}_{select_examples_criterion}_example_points"
] = ligand_expression
if display_plot:
# plotly to create 3D scatter plot:
spatial_coords = adata.obsm[self.coords_key]
if spatial_coords.shape[1] == 2:
x, y = spatial_coords[:, 0], spatial_coords[:, 1]
z = np.zeros(len(x))
else:
x, y, z = spatial_coords[:, 0], spatial_coords[:, 1], spatial_coords[:, 2]
# Color assignment:
default_color = "#D3D3D3"
if cell_type is not None:
ct_other_color = "#71797E"
target_color = "#39FF14"
# target_data = adata.obs.loc[target_expressing_selected, f"{interaction}_{target}_example_points"]
# p997 = np.percentile(target_data.values, 99.7)
# target_data[target_data > p997] = p997
# plot_vals = target_data.values
scatter_target = go.Scatter3d(
x=x[selected_indices],
y=y[selected_indices],
z=z[selected_indices],
mode="markers",
# Draw target cells larger
marker=dict(color=target_color, size=6.5),
showlegend=False,
# marker=dict(
# color=plot_vals, colorscale="Plotly3", size=6, colorbar=dict(title=f"{target} Expression", x=1.05)
# ),
)
nbr_data = adata.obs.loc[
neighbors_selected, f"{interaction}_{target}_{select_examples_criterion}_example_points"
]
# Lenient w/ the max value cutoff so that the colored dots are more distinct from black background
p95 = np.percentile(nbr_data.values, 95)
nbr_data[nbr_data > p95] = p95
plot_vals = nbr_data.values
scatter_ligand = go.Scatter3d(
x=x[neighbors],
y=y[neighbors],
z=z[neighbors],
mode="markers",
marker=dict(
color=plot_vals,
colorscale="Hot",
size=2.5,
colorbar=dict(title=f"{ligand} Expression", x=0.8, titlefont=dict(size=16), tickfont=dict(size=18)),
),
showlegend=False,
)
rest_indices = list(set(range(len(x))) - set(selected_indices) - set(neighbors))
scatter_rest = go.Scatter3d(
x=x[rest_indices],
y=y[rest_indices],
z=z[rest_indices],
mode="markers",
marker=dict(color=default_color, size=2),
name="Other Cells",
showlegend=False,
)
if cell_type is not None:
scatter_ct = go.Scatter3d(
x=x[ct_other_indices],
y=y[ct_other_indices],
z=z[ct_other_indices],
mode="markers",
marker=dict(color=ct_other_color, size=2),
name=f"Other Cells of Type {cell_type}",
showlegend=False,
)
# Invisible traces for the legend
legend_ct = go.Scatter3d(
x=[None],
y=[None],
z=[None],
mode="markers",
marker=dict(size=10, color=ct_other_color), # Adjust size as needed
name=f"Other Cells of Type <br>{cell_type}",
showlegend=True,
)
# Invisible traces for the legend
name = (
f"{target}-Expressing Cells <br>(w/ Receptor Expression)"
if select_examples_criterion == "positive"
else f"{target}-Expressing Cells <br>(w/o Receptor Expression)"
)
legend_target = go.Scatter3d(
x=[None],
y=[None],
z=[None],
mode="markers",
marker=dict(size=30, color=target_color), # Adjust size as needed
name=name,
showlegend=True,
)
legend_rest = go.Scatter3d(
x=[None],
y=[None],
z=[None],
mode="markers",
marker=dict(size=15, color=default_color), # Adjust size as needed
name="Other Cells",
showlegend=True,
)
# Create the figure and add the scatter plots
if cell_type is not None:
fig = go.Figure(
data=[
scatter_rest,
scatter_target,
scatter_ligand,
scatter_ct,
legend_target,
legend_rest,
legend_ct,
]
)
else:
fig = go.Figure(data=[scatter_rest, scatter_target, scatter_ligand, legend_target, legend_rest])
if cell_type is None:
title_dict = dict(
text=f"Target: {target}, Ligand: {ligand} "
f"<br>(Example {select_examples_criterion.title()} Predicted Effects)",
y=0.9,
yanchor="top",
x=0.5,
xanchor="center",
font=dict(size=28),
)
else:
title_dict = dict(
text=f"Target: {target}, Ligand: {ligand}, <br>Cell Type: {cell_type} "
f"(Example {select_examples_criterion.title()} Predicted Effects)",
y=0.9,
yanchor="top",
x=0.5,
xanchor="center",
font=dict(size=28),
)
# Turn off the grid
fig.update_layout(
showlegend=True,
legend=dict(x=0.65, y=0.85, orientation="v", font=dict(size=18)),
scene=dict(
aspectmode="data",
xaxis=dict(
showgrid=False,
showline=False,
linewidth=2,
linecolor="black",
backgroundcolor="white",
title="",
showticklabels=False,
ticks="",
),
yaxis=dict(
showgrid=False,
showline=False,
linewidth=2,
linecolor="black",
backgroundcolor="white",
title="",
showticklabels=False,
ticks="",
),
zaxis=dict(
showgrid=False,
showline=False,
linewidth=2,
linecolor="black",
backgroundcolor="white",
title="",
showticklabels=False,
ticks="",
),
),
margin=dict(l=0, r=0, b=0, t=50), # Adjust margins to minimize spacing
title=title_dict,
)
fig.write_html(path)
return adata
[docs] def cell_type_specific_interactions(
self,
to_plot: Literal["mean", "percentage"] = "mean",
plot_type: Literal["heatmap", "barplot"] = "heatmap",
group_key: Optional[str] = None,
ct_subset: Optional[List[str]] = None,
target_subset: Optional[List[str]] = None,
interaction_subset: Optional[List[str]] = None,
lower_threshold: float = 0.3,
upper_threshold: float = 1.0,
effect_threshold: Optional[float] = None,
use_significant: bool = False,
row_normalize: bool = False,
col_normalize: bool = False,
normalize_targets: bool = False,
hierarchical_cluster_ct: bool = False,
group_y_cell_type: bool = False,
fontsize: Union[None, int] = None,
figsize: Union[None, Tuple[float, float]] = None,
center: Optional[float] = None,
cmap: str = "Reds",
save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show",
save_kwargs: Optional[dict] = {},
save_df: bool = False,
):
"""Map interactions and interaction effects that are specific to particular cell type groupings. Returns a
heatmap representing the enrichment of the interaction/effect within cells of that grouping (if "to_plot" is
effect, this will be enrichment of the effect on cell type-specific expression). Enrichment determined by
mean effect size or expression.
Args:
to_plot: Whether to plot the mean effect size or the proportion of cells in a cell type w/ effect on
target. Options are "mean" or "percentage".
plot_type: Whether to plot the results as a heatmap or barplot. Options are "heatmap" or "barplot". If
"barplot", must provide a subset of up to four interactions to visualize.
group_key: Can be used to specify entry in adata.obs that contains cell type groupings. If None,
will use :attr `group_key` from model initialization.
ct_subset: Can be used to restrict the enrichment analysis to only cells of a particular type. If given,
will search for cell types in "group_key" attribute from model initialization. Recommended to use to
subset to cell types with sufficient numbers.
target_subset: List of targets to consider. If None, will use all targets used in model fitting.
interaction_subset: List of interactions to consider. If None, will use all interactions used in model.
Is necessary if "plot_type" is "barplot", since the barplot is only designed to accomodate up to three
interactions at once.
lower_threshold: Lower threshold for the proportion of cells in a cell type group that must express a
particular interaction/effect for it to be colored on the plot, as a proportion of the max value.
Threshold will be applied to the non-normalized values (if normalization is applicable). Defaults to
0.3.
upper_threshold: Upper threshold for the proportion of cells in a cell type group that must express a
particular interaction/effect for it to be colored on the plot, as a proportion of the max value.
Threshold will be applied to the non-normalized values (if normalization is applicable). Defaults to
1.0 (the max value).
effect_threshold: Optional threshold for the effect size of an interaction/effect to be considered for
analysis; only used if "to_plot" is "percentage". If not given, will use the upper quartile value
among all interaction effect values to determine the threshold.
use_significant: Whether to use only significant effects in computing the specificity. If True,
will filter to cells + interactions where the interaction is significant for the target. Only valid
if :func `compute_coeff_significance()` has been run.
row_normalize: Whether to minmax scale the metric values by row (i.e. for each interaction/effect). Helps
to alleviate visual differences that result from scale rather than differences in mean value across
cell types.
col_normalize: Whether to minmax scale the metric values by column (i.e. for each interaction/effect). Helps
to alleviate visual differences that result from scale rather than differences in mean value across
cell types.
normalize_targets: Whether to minmax scale the metric values by column for each target (i.e. for each
interaction/effect), to remove differences that occur as a result of scale of expression. Provides a
clearer picture of enrichment for each target.
hierarchical_cluster_ct: Whether to cluster the x-axis (target gene in cell type) using hierarchical
clustering. If False, will order the x-axis by the order of the target genes for organization purposes.
group_y_cell_type: Whether to group the y-axis (target gene in cell type) by cell type. If False,
will group by target gene instead. Defaults to False.
fontsize: Size of font for x and y labels.
figsize: Size of figure.
center: Optional, determines position of the colormap center. Between 0 and 1.
cmap: Colormap to use for heatmap. If metric is "number", "proportion", "specificity", the bottom end of
the range is 0. It is recommended to use a sequential colormap (e.g. "Reds", "Blues", "Viridis",
etc.). For metric = "fc", if a divergent colormap is not provided, "seismic" will automatically be
used.
save_show_or_return: Whether to save, show or return the figure.
If "both", it will save and plot the figure at the same time. If "all", the figure will be saved,
displayed and the associated axis and other object will be return.
save_kwargs: A dictionary that will passed to the save_fig function.
By default it is an empty dictionary and the save_fig function will use the
{"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True,
"verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those
keys according to your needs.
save_df: Set True to save the metric dataframe in the end
"""
logger = lm.get_main_logger()
config_spateo_rcParams()
# But set display DPI to 300:
plt.rcParams["figure.dpi"] = 300
if to_plot not in ["mean", "percentage"]:
raise ValueError("Unrecognized input for plotting. Options are 'mean' or 'percentage'.")
if plot_type == "barplot" and interaction_subset is None:
raise ValueError("Must provide a subset of interactions to visualize if 'plot_type' is 'barplot'.")
if plot_type == "barplot" and len(interaction_subset) > 4:
raise ValueError(
"Can only visualize up to four interactions at once with 'barplot' (for practical/plot "
"readability reasons)."
)
if self.mod_type not in ["lr", "ligand", "receptor"]:
raise ValueError("Model type must be one of 'lr', 'ligand', or 'receptor'.")
if save_show_or_return in ["save", "both", "all"]:
if not os.path.exists(os.path.join(os.path.dirname(self.output_path), "figures")):
os.makedirs(os.path.join(os.path.dirname(self.output_path), "figures"))
figure_folder = os.path.join(os.path.dirname(self.output_path), "figures", "temp")
if not os.path.exists(figure_folder):
os.makedirs(figure_folder)
if save_df:
output_folder = os.path.join(os.path.dirname(self.output_path), "analyses")
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# Colormap should be sequential:
sequential_colormaps = [
"Blues",
"BuGn",
"BuPu",
"GnBu",
"Greens",
"Greys",
"Oranges",
"OrRd",
"PuBu",
"PuBuGn",
"PuRd",
"Purples",
"RdPu",
"Reds",
"YlGn",
"YlGnBu",
"YlOrBr",
"YlOrRd",
"afmhot",
"autumn",
"bone",
"cool",
"copper",
"gist_heat",
"gray",
"hot",
"pink",
"spring",
"summer",
"winter",
"viridis",
"plasma",
"inferno",
"magma",
"cividis",
]
if cmap not in sequential_colormaps:
logger.info(f"For option {to_plot}, colormap should be sequential: using 'viridis'.")
cmap = "viridis"
if group_key is None:
group_key = self.group_key
# Get appropriate adata:
if isinstance(ct_subset, str):
ct_subset = [ct_subset]
if ct_subset is None:
adata = self.adata.copy()
else:
adata = self.adata[self.adata.obs[group_key].isin(ct_subset)].copy()
cell_types = adata.obs[group_key].unique()
all_targets = list(self.coeffs.keys())
all_feature_names = [feat for feat in self.feature_names if feat != "intercept"]
if isinstance(interaction_subset, str):
interaction_subset = [interaction_subset]
feature_names = all_feature_names if interaction_subset is None else interaction_subset
if fontsize is None:
fontsize = rcParams.get("font.size")
if isinstance(target_subset, str):
target_subset = [target_subset]
targets = all_targets if target_subset is None else target_subset
combinations = product(cell_types, targets)
combinations = [f"{ct}-{target}" for ct, target in combinations]
if figsize is None:
if plot_type == "heatmap":
# Set figure size based on the number of interaction features and cell type-target combos:
m = len(combinations) * 50 / 200
n = len(feature_names) * 50 / 200
else:
# Set figure size based on the number of cell type-target combos:
n = len(combinations) * 50 / 200
m = 3 * len(feature_names)
figsize = (n, m)
df = pd.DataFrame(0, index=combinations, columns=feature_names)
for ct in cell_types:
cell_type_mask = adata.obs[group_key] == ct
cell_in_ct = adata[cell_type_mask].copy()
# Get appropriate coefficient arrays:
for target in targets:
expressing_target = pd.DataFrame(
adata[:, target].X.toarray().reshape(-1) > 0, index=adata.obs_names, columns=[target]
)
total_mask = cell_type_mask & expressing_target[target]
total_mask = total_mask[total_mask].index
if to_plot == "mean":
mean_effects = []
# coef_target = self.coeffs[target].loc[adata.obs_names]
coef_target = self.coeffs[target].loc[
cell_in_ct.obs_names
] # This should be cell_in_ct, since it's cell type specific effect, and not the entire adata object
coef_target = coef_target[[c for c in coef_target.columns if "intercept" not in c]]
if effect_threshold is None:
# Cell type-specific threshold:
nonzero_values = coef_target.loc[cell_type_mask].values.flatten()
nonzero_values = nonzero_values[nonzero_values != 0]
target_effect_threshold = pd.Series(nonzero_values).quantile(0.75)
else:
target_effect_threshold = effect_threshold
coef_target[coef_target < target_effect_threshold] = 0
if use_significant:
parent_dir = os.path.dirname(self.output_path)
sig = pd.read_csv(
os.path.join(parent_dir, "significance", f"{target}_is_significant.csv"), index_col=0
)
coef_target *= sig
for feat in feature_names:
if f"b_{feat}" in coef_target.columns:
# If a given cell type does not have much expression of the target gene, mask out the
# mean effect (use an arbitrary cutoff of 2% of cells):
if len(total_mask) < 0.02 * cell_in_ct.n_obs:
mean_effects.append(0)
else:
# Get mean effect size for each interaction feature in each cell type, from among the
# cells that express the target gene:
mean_effects.append(coef_target.loc[total_mask, f"b_{feat}"].values.mean())
else:
mean_effects.append(0)
df.loc[f"{ct}-{target}", :] = mean_effects
elif to_plot == "percentage":
percentages = []
coef_target = self.coeffs[target].loc[adata.obs_names]
coef_target = coef_target[[c for c in coef_target.columns if "intercept" not in c]]
if effect_threshold is None:
# Cell type-specific threshold:
nonzero_values = coef_target.loc[cell_type_mask].values.flatten()
nonzero_values = nonzero_values[nonzero_values != 0]
target_effect_threshold = pd.Series(nonzero_values).quantile(0.75)
else:
target_effect_threshold = effect_threshold
coef_target[coef_target < target_effect_threshold] = 0
for feat in feature_names:
if f"b_{feat}" in coef_target.columns:
# If a given cell type does not have much expression of the target gene, mask out the
# mean effect (use an arbitrary cutoff of 2% of cells):
if len(total_mask) < 0.02 * cell_in_ct.n_obs:
percentages.append(0)
else:
# Get percentage of cells in each cell type that express each interaction feature:
percentages.append(
(coef_target.loc[total_mask, f"b_{feat}"].values > target_effect_threshold).mean()
)
else:
percentages.append(0)
df.loc[f"{ct}-{target}", :] = percentages
# Apply metric threshold (do this in a grouped manner, for each target):
# Split the index to get the targets portion of the tuples
grouping_element = df.index.map(lambda x: x.split("-")[1])
# Compute the maximum (and optionally used minimum) for each group
group_max = df.groupby(grouping_element).max()
# Apply the threshold in a grouped fashion
for group in group_max.index:
# Select the rows belonging to the current group
group_rows = df.index[df.index.str.contains(f"-{group}$")]
# Apply the lower threshold specific to this group
df.loc[group_rows] = df.loc[group_rows].where(
df.loc[group_rows].ge(lower_threshold * group_max.loc[group]), 0
)
if normalize_targets:
# Take 0 to be the min. value in all cases:
df.loc[group_rows] = df.loc[group_rows] / group_max.loc[group]
if upper_threshold != 1.0:
df[df >= upper_threshold * df.max().max()] = df.max().max()
# Optionally, normalize each row by minmax scaling (to get an idea of the top effects for each target),
# or each column by minmax scaling:
if row_normalize or col_normalize or normalize_targets:
normalize = True
else:
normalize = False
if row_normalize:
# Calculate row-wise min and max
row_min = df.min(axis=1).values.reshape(-1, 1)
row_max = df.max(axis=1).values.reshape(-1, 1)
df = (df - row_min) / (row_max - row_min)
elif col_normalize:
df = (df - df.min()) / (df.max() - df.min())
df.fillna(0, inplace=True)
if plot_type == "heatmap":
# Hierarchical clustering- first to group interactions w/ similar patterns across cell types:
col_linkage = sch.linkage(df.transpose(), method="ward")
col_dendro = sch.dendrogram(col_linkage, no_plot=True)
col_clustered_order = col_dendro["leaves"]
df = df.iloc[:, col_clustered_order]
# Then to group cell types w/ similar interaction patterns, if specified:
if hierarchical_cluster_ct:
row_linkage = sch.linkage(df, method="ward")
row_dendro = sch.dendrogram(row_linkage, no_plot=True)
row_clustered_order = row_dendro["leaves"]
df = df.iloc[row_clustered_order, :]
else:
# Sort by target:
# Create a temporary MultiIndex
df.index = pd.MultiIndex.from_tuples(df.index.str.split("-").map(tuple), names=["first", "second"])
if group_y_cell_type:
# Sort by the first element, then the second
df.sort_index(level=["first", "second"], inplace=True)
else:
# Sort by the second element, then the first
df.sort_index(level=["second", "first"], inplace=True)
# Revert to the original index format
df.index = df.index.map("-".join)
else:
# Sort by target:
# Create a temporary MultiIndex
df.index = pd.MultiIndex.from_tuples(df.index.str.split("-").map(tuple), names=["first", "second"])
if group_y_cell_type:
# Sort by the first element, then the second
df.sort_index(level=["first", "second"], inplace=True)
else:
# Sort by the second element, then the first
df.sort_index(level=["second", "first"], inplace=True)
# Revert to the original index format
df.index = df.index.map("-".join)
# Delete all-zero rows and all-zero columns:
df = df.loc[:, ~(df == 0).all()]
logger.info(f"Final dataframe for {ct} shape: {df.shape}")
if normalize and to_plot == "mean":
if plot_type == "heatmap":
label = (
"Normalized avg. effect per cell type for cells expressing target"
if not normalize_targets
else "Normalized avg. effect per cell type \nfor cells expressing target (normalized within target)"
)
else:
label = (
"Normalized avg. effect\n per cell type \nfor cells expressing target"
if not normalize_targets
else "Normalized avg. effect\n per cell type \nfor cells expressing target \n(normalized within "
"target)"
)
elif normalize and to_plot == "percentage":
if plot_type == "heatmap":
label = (
"Normalized enrichment of effect per cell type \nfor cells expressing target"
if not normalize_targets
else "Normalized enrichment of effect per cell type \nfor cells expressing target (normalized "
"within target)"
)
else:
label = (
"Normalized enrichment of\n effect per cell type \nfor cells expressing target"
if not normalize_targets
else "Normalized enrichment \nof effect per cell type\n for cells expressing target\n(normalized "
"within target)"
)
elif not normalize and to_plot == "mean":
label = (
"Avg. effect per cell type \nfor cells expressing target"
if plot_type == "heatmap"
else "Avg. effect\n per cell type \nfor cells expressing target"
)
else:
label = (
"Enrichment of effect per cell type \nfor cells expressing target"
if plot_type == "heatmap"
else "Enrichment of effect\n per cell type \nfor cells expressing target"
)
if self.mod_type == "lr":
x_label = "Interaction"
title = "Enrichment of L:R interaction in each cell type"
elif self.mod_type == "ligand":
x_label = "Neighboring ligand expression"
title = "Enrichment of neighboring ligand expression in each cell type for each target"
elif self.mod_type == "receptor":
x_label = "Receptor expression"
title = "Enrichment of receptor expression in each cell type"
# Formatting color legend:
if group_y_cell_type:
group_labels = [idx.split("-")[0] for idx in df.index]
else:
group_labels = [idx.split("-")[1] for idx in df.index]
target_colors = mpl.colormaps["tab20"].colors
if group_y_cell_type:
color_mapping = {
annotation: target_colors[i % len(target_colors)] for i, annotation in enumerate(set(cell_types))
}
else:
color_mapping = {
annotation: target_colors[i % len(target_colors)] for i, annotation in enumerate(set(targets))
}
max_annotation_length = max([len(annotation) for annotation in color_mapping.keys()])
if max_annotation_length > 30:
ax2_size = "30%"
elif max_annotation_length > 20:
ax2_size = "20%"
else:
ax2_size = "10%"
if plot_type == "heatmap":
# Plot heatmap:
vmin = 0
vmax = 1 if normalize else df.max().max()
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
divider = make_axes_locatable(ax)
ax2 = divider.append_axes("right", size=ax2_size, pad=0)
# Keep track of groups:
current_group = None
group_start = None
# Color legend:
for i, annotation in enumerate(group_labels):
if annotation != current_group:
if current_group is not None:
group_center = len(df) - ((group_start + i - 1) / 2) - 1
ax2.text(0.22, group_center, current_group, va="center", ha="left", fontsize=fontsize)
current_group = annotation
group_start = i
color = color_mapping[annotation]
ax2.add_patch(plt.Rectangle((0, i), 0.2, 1, color=color))
# Add label for the last group:
group_center = len(df) - ((group_start + len(df) - 1) / 2) - 1
ax2.text(0.22, group_center, current_group, va="center", ha="left", fontsize=fontsize)
ax2.set_ylim(0, len(df.index))
ax2.axis("off")
thickness = 0.3 * figsize[0] / 10
mask = df == 0
m = sns.heatmap(
df,
square=True,
linecolor="grey",
linewidths=thickness,
cbar_kws={"label": label, "location": "top", "pad": 0},
cmap=cmap,
center=center,
vmin=vmin,
vmax=vmax,
mask=mask,
ax=ax,
)
# Outer frame:
for _, spine in m.spines.items():
spine.set_visible(True)
spine.set_linewidth(thickness * 2.5)
# Adjust colorbar settings:
divider = make_axes_locatable(ax)
# Append axes to the top of the plot, where the colorbar will be placed
if df.shape[0] > df.shape[1]:
cax = divider.append_axes("top", size="30%", pad=0)
else:
cax = divider.append_axes("top", size="30%", pad="30%")
# Create the colorbar manually in the appended axes
cbar = plt.colorbar(m.collections[0], cax=cax, orientation="horizontal")
cbar.set_label(to_plot.title(), fontsize=fontsize * 1.5, labelpad=10)
cbar.ax.xaxis.set_ticks_position("top") # Move ticks to the top
cbar.ax.xaxis.set_label_position("top") # Move the label (title) to the top
cbar.ax.tick_params(labelsize=fontsize * 1.5)
cbar.ax.set_aspect(0.02)
ax.set_xlabel(x_label, fontsize=fontsize * 1.25)
ax.set_ylabel("Cell Type-Specific Target", fontsize=fontsize * 1.25)
ax.tick_params(axis="x", labelsize=fontsize, rotation=90)
ax.tick_params(axis="y", labelsize=fontsize)
ax.set_title(title, fontsize=fontsize * 1.5, pad=20)
# Use the saved name for the AnnData object to define part of the name of the saved file:
base_name = os.path.basename(self.adata_path)
adata_id = os.path.splitext(base_name)[0]
prefix = f"{adata_id}_{to_plot}_enrichment_cell_type"
else:
rem_interactions = [i for i in interaction_subset if i in df.columns]
fig, axes = plt.subplots(nrows=len(rem_interactions), ncols=1, figsize=figsize)
fig.subplots_adjust(hspace=0.4)
colormap = mpl.colormaps[cmap]
# Determine the order of the plot based on averaging over the chosen interactions (if there is more than
# one):
df_sub = df[rem_interactions]
df_sub["Group"] = group_labels
# Ranks within each group:
grouped_ranked_df = df_sub.groupby("Group").rank(ascending=False)
# Average rank across groups:
avg_ranked_df = grouped_ranked_df.mean()
# Sort by average rank:
sorted_features = avg_ranked_df.sort_values().index.tolist()
df = df[sorted_features]
# Color legend:
if not isinstance(axes, (list, np.ndarray)):
divider = make_axes_locatable(axes)
else:
# If 'axes' is an array, and we want to apply to the first one
if len(axes) > 0:
divider = make_axes_locatable(axes[0])
else:
raise ValueError("No axes found in the 'axes' array")
ax2 = divider.append_axes("top", size=ax2_size, pad=0)
current_group = None
group_start = None
for i, annotation in enumerate(group_labels):
if annotation != current_group:
if current_group is not None:
group_center = (group_start + i - 1) / 2
ax2.text(group_center, 0.42, current_group, va="bottom", ha="center", fontsize=fontsize)
current_group = annotation
group_start = i
color = color_mapping[annotation]
ax2.add_patch(plt.Rectangle((i, 0), 1, 0.4, color=color))
# Add label for the last group:
group_center = (group_start + len(df) - 1) / 2
ax2.text(group_center, 0.42, current_group, va="bottom", ha="center", fontsize=fontsize)
ax2.set_xlim(0, len(df.index))
ax2.axis("off")
if not isinstance(axes, (list, np.ndarray)):
vmin = 0
vmax = 1 if normalize else df[interaction_subset].max().values
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
colors = [colormap(norm(val)) for val in df[interaction_subset].values]
sns.barplot(
x=df[interaction_subset].index,
y=df[interaction_subset].values.flatten(),
edgecolor="black",
linewidth=1,
palette=colors,
ax=axes,
)
axes.set_title(interaction_subset[0], fontsize=fontsize * 1.5, pad=35)
axes.set_xlabel("Cell Type-Specific Target", fontsize=fontsize)
axes.set_ylabel(label, fontsize=fontsize)
axes.tick_params(axis="y", labelsize=fontsize * 1.1)
axes.tick_params(axis="x", labelsize=fontsize * 0.9, rotation=90)
else:
for i, ax in enumerate(axes):
# From the larger dataframe, get the column for the chosen interaction as a series:
interaction = interaction_subset[i]
interaction_series = df[interaction]
vmin = 0
vmax = 1 if normalize else interaction_series.max()
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
colors = [colormap(norm(val)) for val in interaction_series]
sns.barplot(
x=interaction_series.index,
y=interaction_series.values,
edgecolor="black",
linewidth=1,
palette=colors,
ax=ax,
)
if ax is axes[0]:
ax.set_title(interaction, fontsize=fontsize * 1.5, pad=35)
else:
ax.set_title(interaction, fontsize=fontsize * 1.5, pad=10)
ax.set_xlabel("Cell Type-Specific Target", fontsize=fontsize)
ax.set_ylabel(label, fontsize=fontsize)
ax.tick_params(axis="y", labelsize=fontsize * 1.1)
if ax is axes[-1]:
ax.tick_params(axis="x", labelsize=fontsize * 0.9, rotation=90)
else:
ax.tick_params(axis="x", labelbottom=False)
# Use the saved name for the AnnData object to define part of the name of the saved file:
base_name = os.path.basename(self.adata_path)
adata_id = os.path.splitext(base_name)[0]
prefix = f"{adata_id}_{to_plot}_enrichment_cell_type"
# Save figure:
save_kwargs["ext"] = "png"
save_kwargs["dpi"] = 300
if "figure_folder" in locals():
save_kwargs["path"] = figure_folder
save_return_show_fig_utils(
save_show_or_return=save_show_or_return,
show_legend=False,
background="white",
prefix=prefix,
save_kwargs=save_kwargs,
total_panels=1,
fig=fig,
axes=axes,
return_all=False,
return_all_list=None,
)
if save_df:
df.to_csv(os.path.join(output_folder, f"{prefix}.csv"))
[docs] def cell_type_interaction_fold_change(
self,
ref_ct: str,
query_ct: str,
group_key: Optional[str] = None,
target_subset: Optional[List[str]] = None,
interaction_subset: Optional[List[str]] = None,
to_plot: Literal["mean", "percentage"] = "mean",
plot_type: Literal["volcano", "barplot"] = "barplot",
source_data: Literal["interaction", "effect", "target"] = "effect",
top_n_to_plot: Optional[int] = None,
significance_cutoff: float = 1.3,
fold_change_cutoff: float = 1.5,
fold_change_cutoff_for_labels: float = 3.0,
plot_query_over_ref: bool = False,
plot_ref_over_query: bool = False,
plot_only_significant: bool = False,
fontsize: Union[None, int] = None,
figsize: Union[None, Tuple[float, float]] = None,
cmap: str = "seismic",
save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show",
save_kwargs: Optional[dict] = {},
save_df: bool = False,
):
"""Computes fold change in predicted interaction effects between two cell types, and visualizes result.
Args:
ref_ct: Label of the first cell type to consider. Fold change will be computed with respect to the level
in this cell type.
query_ct: Label of the second cell type to consider
group_key: Name of the key in .obs containing cell type information. If not given, will use
:attr `group_key` from model initialization.
target_subset: List of targets to consider. If None, will use all targets used in model fitting.
interaction_subset: List of interactions to consider. If None, will use all interactions used in model.
to_plot: Whether to plot the mean effect size or the proportion of cells in a cell type w/ effect on
target. Options are "mean" or "percentage".
plot_type: Whether to plot the results as a volcano plot or barplot. Options are "volcano" or "barplot".
source_data: Selects what to use in computing fold changes. Options:
- "interaction": will use the design matrix (e.g. neighboring ligand expression or L:R mapping)
- "effect": will use the coefficient arrays for each target
- "target": will use the target gene expression
top_n_to_plot: If given, will only include the top n features in the visualization. Recommended if
"source_data" is "effect", as all combinations of interaction and target will be considered in this
case.
significance_cutoff: Cutoff for negative log-10 q-value to consider an interaction/effect significant. Only
used if "plot_type" is "volcano". Defaults to 1.3 (corresponding to an approximate q-value of 0.05).
fold_change_cutoff: Cutoff for fold change to consider an interaction/effect significant. Only used if
"plot_type" is "volcano". Defaults to 1.5.
fold_change_cutoff_for_labels: Cutoff for fold change to include the label for an interaction/effect.
Only used if "plot_type" is "volcano". Defaults to 3.0.
plot_query_over_ref: Whether to plot/visualize only the portion that corresponds to the fold change of
the query cell type over the reference cell type (and the portion that is significant). If False (and
"plot_ref_over_query" is False), will plot the entire volcano plot. Only used if "plot_type" is
"volcano".
plot_ref_over_query: Whether to plot/visualize only the portion that corresponds to the fold change of
the reference cell type over the query cell type (and the portion that is significant). If False (and
"plot_query_over_ref" is False), will plot the entire volcano plot. Only used if "plot_type" is
"volcano".
plot_only_significant: Whether to plot/visualize only the portion that passes the "significance_cutoff"
p-value threshold. Only used if "plot_type" is "volcano".
fontsize: Size of font for x and y labels.
figsize: Size of figure.
cmap: Colormap to use for heatmap. If metric is "number", "proportion", "specificity", the bottom end of
the range is 0. It is recommended to use a sequential colormap (e.g. "Reds", "Blues", "Viridis",
etc.). For metric = "fc", if a divergent colormap is not provided, "seismic" will automatically be
used.
save_show_or_return: Whether to save, show or return the figure.
If "both", it will save and plot the figure at the same time. If "all", the figure will be saved,
displayed and the associated axis and other object will be return.
save_kwargs: A dictionary that will passed to the save_fig function.
By default it is an empty dictionary and the save_fig function will use the
{"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True,
"verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those
keys according to your needs.
save_df: Set True to save the metric dataframe in the end
"""
config_spateo_rcParams()
# But set display DPI to 300:
plt.rcParams["figure.dpi"] = 300
parent_dir = os.path.dirname(self.output_path)
if save_show_or_return in ["save", "both", "all"]:
if not os.path.exists(os.path.join(os.path.dirname(self.output_path), "figures")):
os.makedirs(os.path.join(os.path.dirname(self.output_path), "figures"))
figure_folder = os.path.join(os.path.dirname(self.output_path), "figures", "temp")
if not os.path.exists(figure_folder):
os.makedirs(figure_folder)
# Use the saved name for the AnnData object to define part of the name of the saved figure & file (if
# applicable):
base_name = os.path.basename(self.adata_path)
adata_id = os.path.splitext(base_name)[0]
prefix = f"{adata_id}_fold_changes_{source_data}_{ref_ct}_{query_ct}"
output_folder = os.path.join(os.path.dirname(self.output_path), "analyses")
if not os.path.exists(output_folder):
os.makedirs(output_folder)
if fontsize is None:
fontsize = rcParams.get("font.size")
if group_key is None:
group_key = self.group_key
if target_subset is None:
target_subset = self.targets_expr.columns
if interaction_subset is None:
interaction_subset = self.feature_names
# Formatting:
if source_data == "effect":
x_label = f"$\\log_2$(Fold change effect on target- \n{ref_ct} and {query_ct})"
title = f"Fold change effect on target \n{ref_ct} and {query_ct}"
if self.mod_type == "lr":
y_label = "L:R effect on target"
elif self.mod_type == "ligand":
y_label = "Ligand effect on target"
elif source_data == "interaction":
x_label = f"$\\log_2$(Fold change interaction enrichment \n {ref_ct} and {query_ct})"
title = f"Fold change interaction enrichment \n{ref_ct} and {query_ct}"
if self.mod_type == "lr":
y_label = "L:R interaction"
elif self.mod_type == "ligand":
y_label = "Ligand"
elif source_data == "target":
x_label = f"$\\log_2$(Fold change target expression \n {ref_ct} and {query_ct})"
title = f"Fold change target expression \n {ref_ct} and {query_ct}"
y_label = "Target"
# Check for already-existing dataframe:
if os.path.exists(os.path.join(parent_dir, output_folder, f"{prefix}.csv")):
results = pd.read_csv(os.path.join(parent_dir, output_folder, f"{prefix}.csv"), index_col=0)
else:
ref_names = self.adata[self.adata.obs[group_key] == ref_ct].obs_names
query_names = self.adata[self.adata.obs[group_key] == query_ct].obs_names
# Series/dataframes for each group:
if source_data == "interaction":
ref_data = self.X_df.loc[ref_names, interaction_subset]
query_data = self.X_df.loc[query_names, interaction_subset]
elif source_data == "effect":
# Coefficients for all targets in subset:
for target in target_subset:
if target not in self.coeffs.keys():
raise ValueError(f"Target {target} not found in model.")
else:
coef_target = self.coeffs[target].loc[self.adata.obs_names]
coef_target.columns = coef_target.columns.str.replace("b_", "")
coef_target = coef_target[[col for col in coef_target.columns if col != "intercept"]]
coef_target.columns = [replace_col_with_collagens(col) for col in coef_target.columns]
coef_target.columns = [f"{col}-> target {target}" for col in coef_target.columns]
duplicates = coef_target.columns[coef_target.columns.duplicated(keep=False)]
for item in duplicates.unique():
# Calculate mean for collagens:
mean_series = coef_target.filter(like=item).mean(axis=1)
coef_target.drop(columns=coef_target.filter(like=item).columns, inplace=True)
coef_target[item] = mean_series
target_interaction_subset = [replace_col_with_collagens(i) for i in interaction_subset]
target_interaction_subset = list(
set([f"{i}-> target {target}" for i in target_interaction_subset])
)
target_interaction_subset = [i for i in target_interaction_subset if i in coef_target.columns]
if "effect_df" not in locals():
effect_df = coef_target.loc[:, target_interaction_subset]
else:
effect_df = pd.concat([effect_df, coef_target.loc[:, target_interaction_subset]], axis=1)
ref_data = effect_df.loc[ref_names, :]
query_data = effect_df.loc[query_names, :]
elif source_data == "target":
ref_data = self.targets_expr.loc[ref_names, target_subset]
query_data = self.targets_expr.loc[query_names, target_subset]
else:
raise ValueError(
f"Unrecognized input for source_data: {source_data}. Options are 'interaction', 'effect', or "
f"'target'."
)
# Compute significance for each column:
pvals = []
for col in tqdm(ref_data.columns, desc="Computing significance..."):
if source_data == "effect" or source_data == "interaction":
pvals.append(ttest_ind(ref_data[col], query_data[col])[1])
elif source_data == "target":
pvals.append(mannwhitneyu(ref_data[col], query_data[col])[1])
# Correct for multiple hypothesis testing:
qvals = multitesting_correction(pvals, method="fdr_bh")
results = pd.DataFrame(qvals, index=ref_data.columns, columns=["qval"])
results["qval"] = results["qval"].apply(lambda x: x[0] if isinstance(x, np.ndarray) else x)
results["Significance"] = results.apply(assign_significance, axis=1)
# Negative log q-value (in the case of volcano plot):
results["-log10(qval)"] = -np.log10(qvals)
# Threshold at the highest non-infinity q-value:
max_non_inf = results[results["-log10(qval)"] != np.inf]["-log10(qval)"].max()
results["-log10(qval)"] = results["-log10(qval)"].apply(lambda x: x if x != np.inf else max_non_inf)
if to_plot == "mean":
ref_data = ref_data.mean(axis=0)
query_data = query_data.mean(axis=0)
elif to_plot == "percentage":
ref_data = (ref_data > 0).mean(axis=0)
query_data = (query_data > 0).mean(axis=0)
# Add small offset to ensure reference value is not 0:
ref_data += 1e-3
query_data += 1e-3
# Compute fold change:
fold_change = query_data / ref_data
results["Fold Change"] = fold_change
results["Fold Change"] = results["Fold Change"].apply(lambda x: x[0] if isinstance(x, np.ndarray) else x)
# Take the log of the fold change:
results["Fold Change"] = np.log2(results["Fold Change"])
# Remove NaNs:
results = results[~results["Fold Change"].isna()]
results = results.sort_values("Fold Change")
if top_n_to_plot is not None:
results = results.iloc[:top_n_to_plot, :]
# Plot:
if figsize is None:
# Set figure size based on the number of interaction features and targets:
if plot_type == "barplot":
m = len(results) / 2
n = m / 2
elif plot_only_significant or plot_query_over_ref or plot_ref_over_query:
m = 7
n = m * 2
else:
m = 10
n = m
figsize = (n, m)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
cmap = mpl.colormaps[cmap]
# Center colormap at 0:
max_distance = max(abs(results["Fold Change"]).max(), abs(results["Fold Change"]).min())
max_pos = results["Fold Change"].max()
max_neg = results["Fold Change"].min()
norm = plt.Normalize(-max_distance, max_distance)
colors = cmap(norm(results["Fold Change"]))
if plot_type == "barplot":
barplot = sns.barplot(
x="Fold Change",
y=results.index,
data=results,
orient="h",
palette=colors,
edgecolor="black",
linewidth=1,
ax=ax,
)
for index, row in results.iterrows():
ax.text(row["Fold Change"], index, f"{row['Significance']}", color="black", ha="right")
ax.axvline(x=0, color="grey", linestyle="--", linewidth=2)
ax.set_xlim(max_neg * 1.1, max_pos * 1.1)
ax.set_xticklabels(results.index, fontsize=fontsize)
ax.set_yticklabels(["{:.2f}".format(y) for y in ax.get_yticks()], fontsize=fontsize)
ax.set_xlabel(x_label, fontsize=fontsize * 1.25)
ax.set_ylabel(y_label, fontsize=fontsize * 1.25)
ax.set_title(title, fontsize=fontsize * 1.5)
elif plot_type == "volcano":
if len(results) > 20:
size = 20
else:
size = 40
# Check if max -log10(qval) is greater than 8
if results["-log10(qval)"].max() > 8:
ax.set_yscale("log", base=2) # Set y-axis to log
y_label = r"$-log_{10}$(qval) ($log_2$ scale)"
else:
y_label = r"$-log_{10}$(qval)"
significant = results["-log10(qval)"] > significance_cutoff
significant_up = results["Fold Change"] > fold_change_cutoff
significant_down = results["Fold Change"] < -fold_change_cutoff
if plot_only_significant:
results = results[significant]
size *= 1.5
positive_fold_change = results["Fold Change"] > 0
negative_fold_change = results["Fold Change"] < 0
# Check if only plotting query over ref or ref over query:
if plot_query_over_ref:
size *= 1.5
fc_up = ax.scatter(
x=results["Fold Change"][significant & significant_up & positive_fold_change],
y=results["-log10(qval)"][significant & significant_up & positive_fold_change],
c=results["Fold Change"][significant & significant_up & positive_fold_change],
cmap="Reds",
edgecolor="black",
s=size,
)
elif plot_ref_over_query:
size *= 1.5
fc_down = ax.scatter(
x=results["Fold Change"][significant & significant_down & negative_fold_change],
y=results["-log10(qval)"][significant & significant_down & negative_fold_change],
c=results["Fold Change"][significant & significant_down & negative_fold_change],
cmap="Blues_r",
edgecolor="black",
s=size,
)
else:
fc_up = ax.scatter(
x=results["Fold Change"][significant & significant_up],
y=results["-log10(qval)"][significant & significant_up],
c=results["Fold Change"][significant & significant_up],
cmap="Reds",
edgecolor="black",
s=size,
)
fc_down = ax.scatter(
x=results["Fold Change"][significant & significant_down],
y=results["-log10(qval)"][significant & significant_down],
c=results["Fold Change"][significant & significant_down],
cmap="Blues_r",
edgecolor="black",
s=size,
)
ax.scatter(
x=results["Fold Change"][~(significant & (significant_up | significant_down))],
y=results["-log10(qval)"][~(significant & (significant_up | significant_down))],
color="grey",
edgecolor="black",
s=size,
)
# Add color bars
if "fc_up" in locals():
cbar_red = fig.colorbar(fc_up, ax=ax, orientation="vertical", pad=0.0, aspect=40)
cbar_red.ax.set_ylabel(
f"Fold Changes- {query_ct} over {ref_ct}", rotation=90, labelpad=15, fontsize=fontsize
)
cbar_red.ax.yaxis.set_label_position("left")
cbar_red.ax.yaxis.label.set_horizontalalignment("right")
cbar_red.ax.yaxis.label.set_position((0, 1.0))
for label in cbar_red.ax.get_yticklabels():
label.set_fontsize(fontsize)
if "fc_down" in locals():
cbar_blue = fig.colorbar(fc_down, ax=ax, orientation="vertical", pad=0.1, aspect=40)
cbar_blue.ax.set_ylabel(
f"Fold Changes- {ref_ct} over {query_ct}", rotation=90, labelpad=15, fontsize=fontsize
)
cbar_blue.ax.yaxis.set_label_position("left")
cbar_blue.ax.yaxis.label.set_horizontalalignment("right")
cbar_blue.ax.yaxis.label.set_position((0, 1.0))
for label in cbar_blue.ax.get_yticklabels():
label.set_fontsize(fontsize)
# Add text for most significant interactions:
# Get the highest fold changes:
high_fold_change = results[abs(results["Fold Change"]) > fold_change_cutoff_for_labels]
while high_fold_change.empty:
fold_change_cutoff_for_labels /= 2
high_fold_change = results[abs(results["Fold Change"]) > fold_change_cutoff_for_labels]
# Take only the top few (it is impossible to view all at once clearly):
if len(high_fold_change) > 3:
high_fold_change = high_fold_change.sort_values(by="Fold Change", ascending=False)
high_fold_change_selected = high_fold_change.iloc[:3, :]
else:
high_fold_change_selected = high_fold_change
# And a few more from not as high but still significant q-values:
max_log10_qval = high_fold_change["-log10(qval)"].max()
log10_qval_steps = []
i = 0
current_value = max_log10_qval
while current_value >= 10:
log10_qval_steps.append(current_value)
i += 1
# Add to labels in descending half steps, with smaller steps taken if only visualizing the positive
# or the negative fold changes:
if plot_query_over_ref or plot_ref_over_query or plot_only_significant:
step_size = 1.25
else:
step_size = 1.5
current_value = max_log10_qval / (step_size**i)
selected_rows = []
for value in log10_qval_steps:
# Find the row closest to the current value without duplicates
closest_index = abs(high_fold_change["-log10(qval)"] - value).idxmin()
if closest_index not in high_fold_change_selected.index:
selected_rows.append(high_fold_change.loc[closest_index])
high_fold_change_log10_qval = pd.DataFrame(selected_rows)
# Combine with high_fold_change and remove duplicates
high_fold_change_selected = pd.concat(
[high_fold_change_selected, high_fold_change_log10_qval]
).drop_duplicates()
text_labels = high_fold_change_selected.index.tolist()
x_coord_text_labels = high_fold_change_selected["Fold Change"].tolist()
y_coord_text_labels = high_fold_change_selected["-log10(qval)"].tolist()
text_objects = []
for i, label in enumerate(text_labels):
t = ax.text(
x_coord_text_labels[i],
y_coord_text_labels[i],
label,
fontsize=fontsize * 0.75,
color="black",
ha="center",
va="center",
)
text_objects.append(t)
adjust_text(text_objects, ax=ax, arrowprops=dict(arrowstyle="<|-", color="black", lw=1.0))
y_label = r"$-log_{10}$(qval)" if "y_label" not in locals() else y_label
ax.axhline(y=significance_cutoff, color="grey", linestyle="--", linewidth=1.5)
ax.axvline(x=fold_change_cutoff, color="grey", linestyle="--", linewidth=1.5)
ax.axvline(x=-fold_change_cutoff, color="grey", linestyle="--", linewidth=1.5)
ax.set_xlim(max_neg * 1.1, max_pos * 1.1)
ax.set_xticklabels(["{:.2f}".format(x) for x in ax.get_xticks()], fontsize=fontsize)
ax.set_yticklabels(["{:.2f}".format(y) for y in ax.get_yticks()], fontsize=fontsize)
ax.set_xlabel(x_label, fontsize=fontsize * 1.25)
ax.set_ylabel(y_label, fontsize=fontsize * 1.25)
ax.set_title(title, fontsize=fontsize * 1.5)
save_kwargs["ext"] = "png"
save_kwargs["dpi"] = 300
if "figure_folder" in locals():
save_kwargs["path"] = figure_folder
# Save figure:
save_return_show_fig_utils(
save_show_or_return=save_show_or_return,
show_legend=False,
background="white",
prefix=prefix,
save_kwargs=save_kwargs,
total_panels=1,
fig=fig,
axes=ax,
return_all=False,
return_all_list=None,
)
if save_df:
results.to_csv(os.path.join(output_folder, f"{prefix}.csv"))
[docs] def enriched_interactions_barplot(
self,
interactions: Optional[Union[str, List[str]]] = None,
targets: Optional[Union[str, List[str]]] = None,
plot_type: Literal["average", "proportion"] = "average",
effect_size_threshold: float = 0.0,
fontsize: Union[None, int] = None,
figsize: Union[None, Tuple[float, float]] = None,
cmap: str = "Reds",
top_n: Optional[int] = None,
save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show",
save_kwargs: Optional[dict] = {},
):
"""Visualize the top predicted effect sizes for each interaction on particular target gene(s).
Args:
interactions: Optional subset of interactions to focus on, given in the form ligand(s):receptor(s),
following the formatting in the design matrix. If not given, will consider all interactions that were
specified in model fitting.
targets: Can optionally specify a subset of the targets to compute this on. If not given, will use all
targets that were specified in model fitting. If multiple targets are given, "save_show_or_return"
should be "save" (and provide appropriate keyword arguments for saving using "save_kwargs"), otherwise
only the last target will be shown.
plot_type: Options: "average" or "proportion". Whether to plot the average effect size or the proportion of
cells expressing the target predicted to be affected by the interaction.
effect_size_threshold: Lower bound for average effect size to include a particular interaction in the
barplot
fontsize: Size of font for x and y labels
figsize: Size of figure
cmap: Colormap to use for barplot. It is recommended to use a sequential colormap (e.g. "Reds", "Blues",
"Viridis", etc.).
top_n: If given, will only include the top n features in the visualization. If not given, will include all
features that pass the "effect_size_threshold".
save_show_or_return: Whether to save, show or return the figure
If "both", it will save and plot the figure at the same time. If "all", the figure will be saved,
displayed and the associated axis and other object will be return.
save_kwargs: A dictionary that will passed to the save_fig function
By default it is an empty dictionary and the save_fig function will use the
{"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True,
"verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those
keys according to your needs.
"""
config_spateo_rcParams()
# But set display DPI to 300:
plt.rcParams["figure.dpi"] = 300
if fontsize is None:
fontsize = rcParams.get("font.size")
if interactions is None:
interactions = self.X_df.columns.tolist()
elif isinstance(interactions, str):
interactions = [interactions]
elif not isinstance(interactions, list):
raise TypeError(f"Interactions must be a list or string, not {type(interactions)}.")
# Predictions:
parent_dir = os.path.dirname(self.output_path)
pred_path = os.path.join(parent_dir, "predictions.csv")
predictions = pd.read_csv(pred_path, index_col=0)
if save_show_or_return in ["save", "both", "all"]:
if not os.path.exists(os.path.join(os.path.dirname(self.output_path), "figures")):
os.makedirs(os.path.join(os.path.dirname(self.output_path), "figures"))
figure_folder = os.path.join(os.path.dirname(self.output_path), "figures", "temp")
if not os.path.exists(figure_folder):
os.makedirs(figure_folder)
if targets is None:
targets = self.targets_expr.columns
elif isinstance(targets, str):
targets = [targets]
elif not isinstance(targets, list):
raise TypeError(f"targets must be a list or string, not {type(targets)}.")
for target in targets:
# Get coefficients for this key
coef = self.coeffs[target]
effects = coef[[col for col in coef.columns if col.startswith("b_") and "intercept" not in col]]
effects.columns = [col.split("_")[1] for col in effects.columns]
# Subset to only the interactions of interest:
interactions = [interaction for interaction in interactions if interaction in effects.columns]
effects = effects[interactions]
target_expr = self.adata[:, target].X.toarray().squeeze() > 0
target_pred = predictions[target]
target_pred_np = target_pred.values.astype(bool)
if plot_type == "average":
# Subset to cells expressing the target that are predicted to be expressing the target:
target_true_pos_indices = np.where(target_expr & target_pred_np)[0]
target_expr_sub = self.adata[target_true_pos_indices, :].copy()
# Subset effects dataframe to same subset:
effects_sub = effects.loc[target_expr_sub.obs_names, :]
# Compute average for each column:
to_plot = effects_sub.mean(axis=0)
elif plot_type == "proportion":
# Get proportion of cells expressing the target that are predicted to be affected by particular
# molecule:
effects_sub = effects.loc[target_expr == 1, :]
to_plot = (effects_sub > 0).mean(axis=0)
else:
raise ValueError(f"Unrecognized input for to_plot: {plot_type}. Options are 'average' or 'proportion'.")
# Filter based on the threshold
to_plot = to_plot[to_plot > effect_size_threshold]
# Sort the average_expression in descending order
to_plot = to_plot.sort_values(ascending=False)
if self.mod_type == "ligand":
to_plot.index = [replace_col_with_collagens(idx) for idx in to_plot.index]
to_plot.index = [replace_hla_with_hlas(idx) for idx in to_plot.index]
if top_n is not None:
to_plot = to_plot.iloc[:top_n]
# Plot:
if figsize is None:
# Set figure size based on the number of interaction features and targets:
m = len(to_plot) / 2
figsize = (m, 5)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
palette = sns.color_palette(cmap, n_colors=len(to_plot))
sns.barplot(
x=to_plot.index,
y=to_plot.values,
edgecolor="black",
linewidth=1,
palette=palette,
ax=ax,
)
ax.set_xticklabels(to_plot.index, rotation=90, fontsize=fontsize)
ax.set_yticklabels(["{:.2f}".format(y) for y in ax.get_yticks()], fontsize=fontsize)
ax.set_xlabel("Interaction (ligand(s):receptor(s))", fontsize=fontsize)
if plot_type == "average":
title = f"Average Predicted Interaction Effects on {target}"
ylabel = "Mean Coefficient \nMagnitude"
elif plot_type == "proportion":
if top_n is not None and top_n < 20:
title = f"Proportion of {target}-Expressing Cells \nPredicted to be Affected by Interaction"
else:
title = f"Proportion of {target}-Expressing Cells Predicted to be Affected by Interaction"
ylabel = "Proportion of Cells"
ax.set_ylabel("Mean Coefficient \nMagnitude", fontsize=fontsize)
ax.set_title(title, fontsize=fontsize)
# Use the saved name for the AnnData object to define part of the name of the saved file:
base_name = os.path.basename(self.adata_path)
adata_id = os.path.splitext(base_name)[0]
prefix = f"{adata_id}_interaction_barplot_{target}"
save_kwargs["ext"] = "png"
save_kwargs["dpi"] = 300
if "figure_folder" in locals():
save_kwargs["path"] = figure_folder
# Save figure:
save_return_show_fig_utils(
save_show_or_return=save_show_or_return,
show_legend=False,
background="white",
prefix=prefix,
save_kwargs=save_kwargs,
total_panels=1,
fig=fig,
axes=ax,
return_all=False,
return_all_list=None,
)
[docs] def summarize_interaction_effects(
self,
interactions: Optional[Union[str, List[str]]] = None,
targets: Optional[Union[str, List[str]]] = None,
effect_size_threshold: float = 0.0,
):
"""Summarize the interaction effects for each target gene in dataframe format. Each element will be the
average effect size for a particular interaction on a particular target gene.
Args:
interactions: Optional subset of interactions to focus on. If not given, will consider all interactions.
targets: Can optionally specify a subset of the targets. If not given, will use all targets.
effect_size_threshold: Lower bound for average effect size to include a particular interaction.
Returns:
effects_df: Dataframe with the average effect size for each interaction (rows) on each target gene (
columns).
"""
if interactions is None:
interactions = self.X_df.columns.tolist()
elif isinstance(interactions, str):
interactions = [interactions]
elif not isinstance(interactions, list):
raise TypeError(f"Interactions must be a list or string, not {type(interactions)}.")
if targets is None:
targets = self.targets_expr.columns
elif isinstance(targets, str):
targets = [targets]
elif not isinstance(targets, list):
raise TypeError(f"Targets must be a list or string, not {type(targets)}.")
# Predictions:
parent_dir = os.path.dirname(self.output_path)
pred_path = os.path.join(parent_dir, "predictions.csv")
predictions = pd.read_csv(pred_path, index_col=0)
# Initialize the DataFrame to store the results
effects_df = pd.DataFrame(0.0, index=interactions, columns=targets)
for target in targets:
coef = self.coeffs[target]
effects = coef[[col for col in coef.columns if col.startswith("b_") and "intercept" not in col]]
effects.columns = [col.split("_")[1] for col in effects.columns]
# Subset to only the interactions of interest:
interactions = [interaction for interaction in interactions if interaction in effects.columns]
effects = effects[interactions]
target_expr = self.adata[:, target].X.toarray().squeeze() > 0
target_pred = predictions[target]
target_pred_np = target_pred.values.astype(bool)
# Subset to cells expressing the target that are predicted to be expressing the target:
target_true_pos_indices = np.where(target_expr & target_pred_np)[0]
target_expr_sub = self.adata[target_true_pos_indices, :].copy()
# Subset effects dataframe to same subset:
effects_sub = effects.loc[target_expr_sub.obs_names, :]
average_effect_size = effects_sub.mean(axis=0)
# Filter interactions based on threshold and add to results DataFrame
filtered_effect_sizes = average_effect_size[average_effect_size > effect_size_threshold]
effects_df[target] = filtered_effect_sizes
effects_df = effects_df.replace(np.nan, 0.0)
return effects_df
[docs] def enriched_tfs_barplot(
self,
tfs: Optional[Union[str, List[str]]] = None,
targets: Optional[Union[str, List[str]]] = None,
target_type: Literal["ligand", "receptor", "target_gene"] = "target_gene",
plot_type: Literal["average", "proportion"] = "average",
effect_size_threshold: float = 0.0,
fontsize: Union[None, int] = None,
figsize: Union[None, Tuple[float, float]] = None,
cmap: str = "Reds",
top_n: Optional[int] = None,
save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show",
save_kwargs: Optional[dict] = {},
):
"""Visualize the top predicted effect sizes for each transcription factor on particular target gene(s).
Args:
tfs: Optional subset of transcription factors to focus on. If not given, will consider all transcription
factors that were specified in model fitting.
targets: Can optionally specify a subset of the targets to compute this on. If not given, will use all
targets that were specified in model fitting. If multiple targets are given, "save_show_or_return"
should be "save" (and provide appropriate keyword arguments for saving using "save_kwargs"), otherwise
only the last target will be shown.
target_type: Set whether the given targets are ligands, receptors or target genes. Used to determine
which folder to check for outputs.
plot_type: Options: "average" or "proportion". Whether to plot the average effect size or the proportion of
cells expressing the target predicted to be affected by the interaction.
effect_size_threshold: Lower bound for average effect size to include a particular interaction in the
barplot
fontsize: Size of font for x and y labels
figsize: Size of figure
cmap: Colormap to use for barplot. It is recommended to use a sequential colormap (e.g. "Reds", "Blues",
"Viridis", etc.).
top_n: If given, will only include the top n features in the visualization. If not given, will include
all features that pass the "effect_size_threshold".
save_show_or_return: Whether to save, show or return the figure
If "both", it will save and plot the figure at the same time. If "all", the figure will be saved,
displayed and the associated axis and other object will be return.
save_kwargs: A dictionary that will passed to the save_fig function
By default it is an empty dictionary and the save_fig function will use the
{"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True,
"verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those
keys according to your needs.
"""
config_spateo_rcParams()
# But set display DPI to 300:
plt.rcParams["figure.dpi"] = 300
if fontsize is None:
fontsize = rcParams.get("font.size")
if target_type == "ligand":
coeffs = self.downstream_model_ligand_coeffs
if tfs is None:
tfs = self.downstream_model_ligand_design_matrix.columns.tolist()
elif target_type == "receptor":
coeffs = self.downstream_model_receptor_coeffs
if tfs is None:
tfs = self.downstream_model_receptor_design_matrix.columns.tolist()
elif target_type == "target_gene":
coeffs = self.downstream_model_target_coeffs
if tfs is None:
tfs = self.downstream_model_target_design_matrix.columns.tolist()
else:
raise ValueError(
f"Unrecognized input for target_type: {target_type}. Options are 'ligand', 'receptor', "
f"or 'target_gene'."
)
tfs = [tf.replace("regulator_", "") for tf in tfs]
if isinstance(tfs, str):
interactions = [tfs]
elif not isinstance(tfs, list):
raise TypeError(f"TFs must be a list or string, not {type(tfs)}.")
# Predictions:
downstream_parent_dir = os.path.dirname(os.path.splitext(self.output_path)[0])
id = os.path.splitext(os.path.basename(self.output_path))[0]
if target_type == "ligand":
folder = "ligand_analysis"
elif target_type == "receptor":
folder = "receptor_analysis"
elif target_type == "target_gene":
folder = "target_gene_analysis"
pred_path = os.path.join(downstream_parent_dir, "cci_deg_detection", folder, "downstream/predictions.csv")
predictions = pd.read_csv(pred_path, index_col=0)
if save_show_or_return in ["save", "both", "all"]:
if not os.path.exists(os.path.join(os.path.dirname(self.output_path), "figures")):
os.makedirs(os.path.join(os.path.dirname(self.output_path), "figures"))
figure_folder = os.path.join(os.path.dirname(self.output_path), "figures", "temp")
if not os.path.exists(figure_folder):
os.makedirs(figure_folder)
if targets is None:
targets = coeffs.keys()
elif isinstance(targets, str):
targets = [targets]
elif not isinstance(targets, list):
raise TypeError(f"targets must be a list or string, not {type(targets)}.")
for target in targets:
# Get coefficients for this key
coef = coeffs[target]
effects = coef[[col for col in coef.columns if col.startswith("b_") and "intercept" not in col]]
effects.columns = [col.split("_")[1] for col in effects.columns]
# Subset to only the TFs of interest:
tfs = [tf for tf in tfs if tf in effects.columns]
effects = effects[tfs]
target_expr = self.adata[:, target].X.toarray().squeeze() > 0
target_pred = predictions[target]
target_pred_np = target_pred.values.astype(bool)
if plot_type == "average":
# Subset to cells expressing the target that are predicted to be expressing the target:
target_true_pos_indices = np.where(target_expr & target_pred_np)[0]
target_expr_sub = self.adata[target_true_pos_indices, :].copy()
# Subset effects dataframe to same subset:
effects_sub = effects.loc[target_expr_sub.obs_names, :]
# Compute average for each column:
to_plot = effects_sub.mean(axis=0)
elif plot_type == "proportion":
# Get proportion of cells expressing the target that are predicted to be affected by particular
# molecule:
effects_sub = effects.loc[target_expr == 1, :]
to_plot = (effects_sub > 0).mean(axis=0)
else:
raise ValueError(f"Unrecognized input for to_plot: {plot_type}. Options are 'average' or 'proportion'.")
# Filter based on the threshold
to_plot = to_plot[to_plot > effect_size_threshold]
# Sort the average_expression in descending order
to_plot = to_plot.sort_values(ascendin