# -*- coding: utf-8 -*-
"""
@File : cci_two_cluster.py
@Time : 2022/07/03 11:50:40
@Author : LuluZuo
@Version : 1.0
@Desc : spatial cell cell communication
"""
import itertools
import random
from typing import Tuple
import numpy as np
import pandas as pd
from anndata import AnnData
from scipy.sparse import issparse
from scipy.stats import ttest_ind
from sklearn.datasets import make_blobs
from tqdm import tqdm as tqdm
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
from ..configuration import SKM
from ..logging import logger_manager as lm
from .cci_fdr import fdr_correct
@SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE, "adata")
[docs]def find_cci_two_group(
adata: AnnData,
path: str,
species: Literal["human", "mouse", "drosophila", "zebrafish", "axolotl"] = "human",
layer: Tuple[None, str] = None,
group: str = None,
lr_pair: list = None,
sender_group: str = None,
receiver_group: str = None,
mode: Literal["mode1", "mode2"] = "mode2",
filter_lr: Literal["outer", "inner"] = "outer",
top: int = 20,
spatial_neighbors: str = "spatial_neighbors",
spatial_distances: str = "spatial_distances",
min_cells_by_counts: int = 0,
min_pairs: int = 5,
min_pairs_ratio: float = 0.01,
num: int = 1000,
pvalue: float = 0.05,
fdr: bool = False,
) -> dict:
"""Performing cell-cell transformation on an anndata object, while also
limiting the nearest neighbor per cell to n_neighbors. This function returns
a dictionary, where the key is 'cell_pair' and 'lr_pair'.
Args:
adata: An Annodata object.
path: Path to ligand_receptor network of NicheNet (prior lr_network).
species: Which species is your adata generated from. Will be used to determine the proper ligand-receptor
database.
layer: the key to the layer. If it is None, adata.X will be used by default.
group: The group name in adata.obs
lr_pair: given a lr_pair list.
sender_group: the cell group name of send ligands.
receiver_group: the cell group name of receive receptors.
spatial_neighbors: spatial neighbor key {spatial_neighbors} in adata.uns.keys(),
spatial_distances: spatial neighbor distance key {spatial_distances} in adata.obsp.keys().
min_cells_by_counts: threshold for minimum number of cells expressing ligand/receptor to avoid being filtered
out. Only used if 'lr_pair' is None.
min_pairs: minimum number of cell pairs between cells from two groups.
min_pairs_ratio: minimum ratio of cell pairs to theoretical cell pairs (n x M / 2) between cells
from two groups.
num: number of permutations. It is recommended that this number be at least 1000.
pvalue: the p-value threshold that will be used to filter for significant ligand-receptor pairs.
filter_lr: filter ligand and receptor based on specific expressed in sender groups
and receiver groups. 'inner': specific both in sender groups and receiver groups;
'outer': specific in sender groups or receiver groups.
top: the number of top expressed fraction in given sender groups(receiver groups)
for each gene(ligand or receptor).
Returns:
result_dict: a dictionary where the key is 'cell_pair' and 'lr_pair'.
"""
logger = lm.get_main_logger()
# prior lr_network
if species == "human":
lr_network = pd.read_csv(path + "lr_db_human.csv", index_col=0)
lr_network["lr_pair"] = lr_network["from"].str.cat(lr_network["to"], sep="-")
elif species == "mouse":
lr_network = pd.read_csv(path + "lr_db_mouse.csv", index_col=0)
lr_network["lr_pair"] = lr_network["from"].str.cat(lr_network["to"], sep="-")
elif species == "drosophila":
lr_network = pd.read_csv(path + "lr_network_drosophila.csv", index_col=0)
lr_network["lr_pair"] = lr_network["from"].str.cat(lr_network["to"], sep="-")
elif species == "zebrafish":
lr_network = pd.read_csv(path + "lr_network_zebrafish.csv", index_col=0)
lr_network["lr_pair"] = lr_network["from"].str.cat(lr_network["to"], sep="-")
elif species == "axolotl":
lr_network = pd.read_csv(path + "lr_network_axolotl.csv", index_col=0)
lr_network["lr_pair"] = lr_network["human_ligand"].str.cat(lr_network["human_receptor"], sep="-")
# layer
if layer is None:
adata.X = adata.X
else:
adata.X = adata.layers[layer]
x_sparse = issparse(adata.X)
### find cell_pair
# cell_pair_all
sender_id = adata[adata.obs[group].isin([sender_group])].obs.index
receiver_id = adata[adata.obs[group].isin([receiver_group])].obs.index
cell_pair_all = len(sender_id) * len(receiver_id) / 2
# spatial constrain cell pair
nw = {"neighbors": adata.uns[spatial_neighbors]["indices"], "weights": adata.obsp[spatial_distances]}
k = adata.uns[spatial_neighbors]["params"]["n_neighbors"]
# cell_pair:all cluster spatial constrain cell pair
cell_pair = []
for i, cell_id in enumerate(nw["neighbors"]):
# - sometimes will be used in adata.obs_names, use >-<in stead
cell_pair.append(str(adata.obs.index[i]) + ">-<" + adata.obs.index[cell_id])
cell_pair = [i for j in cell_pair for i in j]
cell_pair = pd.DataFrame({"cell_pair_name": cell_pair})
cell_pair[["cell_sender", "cell_receiver"]] = cell_pair["cell_pair_name"].str.split(">-<", n=2, expand=True)
# cell_pair:sender_group
cell_pair = cell_pair.loc[cell_pair["cell_sender"].isin(sender_id.tolist())]
# cell_pair:receiver_group
cell_pair = cell_pair.loc[cell_pair["cell_receiver"].isin(receiver_id.tolist())]
# filter cell pairs
if cell_pair.shape[0] < min_pairs:
raise ValueError(f"cell pairs found between", sender_group, "and", receiver_group, "less than min_pairs")
if cell_pair.shape[0] / cell_pair_all < min_pairs_ratio:
raise ValueError(f"cell pairs found between", sender_group, "and", receiver_group, "less than min_pairs_ratio")
# spatial-distal subcluster
sender_dist = list(set(sender_id.tolist()) - set(cell_pair["cell_sender"]))
receiver_dist = list(set(receiver_id.tolist()) - set(cell_pair["cell_receiver"]))
# spatial-proximal subcluster vs. spatial-distal subcluster
group_sp = group + "sp"
adata.obs[group_sp] = adata.obs[group]
adata.obs[group_sp] = adata.obs[group_sp].cat.add_categories(
[sender_group + "_prox", sender_group + "_dist", receiver_group + "_prox", receiver_group + "_dist"],
)
adata.obs.loc[adata.obs.index.isin(cell_pair["cell_sender"].tolist()), group_sp] = sender_group + "_prox"
adata.obs.loc[adata.obs.index.isin(cell_pair["cell_receiver"].tolist()), group_sp] = receiver_group + "_prox"
# If count sender_group_prox is 0, return None:
count_sender_prox = (adata.obs[group_sp] == sender_group + "_prox").sum()
if count_sender_prox == 0:
logger.info(f"No cells found in the category {sender_group + '_prox'}. Returning None.")
return None
adata.obs.loc[adata.obs.index.isin(sender_dist), group_sp] = sender_group + "_dist"
adata.obs.loc[adata.obs.index.isin(receiver_dist), group_sp] = receiver_group + "_dist"
### filter lr
if lr_pair is None:
# expressed lr_network in our data
ligand = lr_network["from"].unique()
expressed_ligand = list(set(ligand) & set(adata.var_names))
if len(expressed_ligand) == 0:
raise ValueError(f"No intersected ligand between your adata object and lr_network dataset.")
lr_network = lr_network[lr_network["from"].isin(expressed_ligand)]
receptor = lr_network["to"].unique()
expressed_receptor = list(set(receptor) & set(adata.var_names))
if len(expressed_receptor) == 0:
raise ValueError(f"No intersected receptor between your adata object and lr_network dataset.")
lr_network = lr_network[lr_network["to"].isin(expressed_receptor)]
ligands = list(set(lr_network["from"]))
ligands_test = [l.upper() if species == "human" else l[0].upper() + l[1:].lower() for l in ligands]
ligands = [
l
for l in ligands_test
if l.title()
not in [
"Lta4h",
"Fdx1",
"Tfrc",
"Trf",
"Lamc1",
"Aldh1a1",
"Aldh1a2",
"Dhcr24",
"Rnaset2a",
"Ptges3",
"Nampt",
"Trf",
"Fdx1",
"Kdr",
"Apoa1",
"Apoa2",
"Apoe",
"Dhcr7",
"Enho",
"Ptgr1",
"Agrp",
"Pnmt",
"Akr1b3",
"Daglb",
"Ubash3d",
"Psap",
"Lck",
"Lipa",
"Alox5",
"Alox5ap",
"Alox12",
"Cbr1",
"Srd5a3",
"Ddc",
"Ggt1",
"Ggt5",
"Srd5a1",
"Tyr",
"Mmp2",
"Ttr",
"Alb",
"Sult2a1",
"Hsd17b6",
"Cyp11a1",
"Cyp11b1",
"Cyp11b2",
"Cyp17a1",
"Cyp19a1",
"Cyp21a1",
"Cyp27b1",
"Sult1e1",
"Dio3",
]
]
# ligand_sender_spec
print(ligands)
adata_l = adata[:, ligands]
for g in adata.obs[group_sp].unique():
# Of all cells expressing particular ligand, what proportion are group g:
frac = (adata_l[adata_l.obs[group_sp] == g].X > 0).sum(axis=0) / (adata_l.X > 0).sum(axis=0)
adata_l.var[g + "_frac"] = np.asarray(frac.A1) if x_sparse else np.asarray(frac)
# Check if preprocessing has already been done:
if "n_cells_by_counts" not in adata_l.var_keys():
if issparse(adata_l.X):
adata_l.var["n_cells_by_counts"] = adata_l.X.getnnz(axis=0)
else:
adata_l.var["n_cells_by_counts"] = np.count_nonzero(adata_l.X, axis=0)
dfl = adata_l.var[adata_l.var[sender_group + "_prox" + "_frac"] > 0]
dfl = dfl[dfl["n_cells_by_counts"] > min_cells_by_counts]
ligand_sender_spec = dfl.sort_values(by=sender_group + "_prox" + "_frac", ascending=False)[:top].index
logger.info(
f"{top} ligands for cell type {sender_group+'_prox'} with highest fraction of prevalence: "
f"{list(ligand_sender_spec)}. Testing interactions involving these genes."
)
lr_network_l = lr_network.loc[lr_network["from"].isin(ligand_sender_spec.tolist())]
# receptor_receiver_spec
adata_r = adata[:, list(set(lr_network["to"]))]
for g in adata.obs[group_sp].unique():
# Of all cells expressing particular receptor, what proportion are group g:
frac = (adata_r[adata_r.obs[group_sp] == g].X > 0).sum(axis=0) / (adata_r.X > 0).sum(axis=0)
adata_r.var[g + "_frac"] = np.asarray(frac.A1) if x_sparse else np.asarray(frac)
# Check if preprocessing has already been done:
if "n_cells_by_counts" not in adata_r.var_keys():
if issparse(adata_r.X):
adata_r.var["n_cells_by_counts"] = adata_r.X.getnnz(axis=0)
else:
adata_r.var["n_cells_by_counts"] = np.count_nonzero(adata_r.X, axis=0)
dfr = adata_r.var[adata_r.var[receiver_group + "_prox" + "_frac"] > 0]
dfr = dfr[dfr["n_cells_by_counts"] > min_cells_by_counts]
receptor_receiver_spec = dfr.sort_values(by=receiver_group + "_prox" + "_frac", ascending=False)[:top].index
logger.info(
f"{top} receptors for cell type {receiver_group+'_prox'} with highest fraction of prevalence: "
f"{list(set(receptor_receiver_spec))}. Testing interactions involving these genes."
)
lr_network_r = lr_network.loc[lr_network["to"].isin(receptor_receiver_spec.tolist())]
if filter_lr == "inner":
# inner merge
lr_network_inner = lr_network_l.merge(lr_network_r, how="inner", on=["from", "to"])
lr_network = lr_network.loc[
lr_network["from"].isin(lr_network_inner["from"].tolist())
& lr_network["to"].isin(lr_network_inner["to"].tolist())
]
elif filter_lr == "outer":
# outer merge
lr_network = pd.concat([lr_network_l, lr_network_r], axis=0, join="outer")
lr_network.drop_duplicates(keep="first", inplace=True)
else:
lr_network = lr_network.loc[lr_network["lr_pair"].isin(lr_pair)]
# mode1
# permutation annotation label
if mode == "mode1":
cols = adata.obs[group_sp].unique()
group_pairs = list(itertools.combinations(cols, 2))
# real mean result, each lr_pair expression in each group_pair.
mean_res = calculate_group_pair_lr_pair(adata, group_sp, group_pairs, cols, lr_network)
# permutation spot label.
df_list = []
group_list = adata.obs[group_sp].tolist()
for i in tqdm(range(num)):
np.random.shuffle(group_list)
adata.obs[group_sp] = group_list
df = calculate_group_pair_lr_pair(adata, group_sp, group_pairs, cols, lr_network)
df_list.append(df)
del df
# calculate p_value
combined_df = pd.DataFrame()
for i in range(num):
mean_i = (df_list[i].loc[:, :] > mean_res.loc[:, :]) * 1
if i == 0:
combined_df = mean_i
else:
combined_df = combined_df.add(mean_i)
pvalue = combined_df / num
significant = (combined_df / num) < pvalue
sig_num = significant.sum(axis=1)
sort_sig_num = sig_num.sort_values(axis=0, ascending=False, inplace=False, kind="quicksort", na_position="last")
sort_index = sort_sig_num.index.tolist()
# use this to plot heatmap.
res = pvalue.loc[sort_index]
return res
else:
# mode2
# calculate score
# real lr_cp_exp_score
ligand_data = adata[cell_pair["cell_sender"], lr_network["from"]]
receptor_data = adata[cell_pair["cell_receiver"], lr_network["to"]]
lr_data = ligand_data.X.toarray() * receptor_data.X.toarray() if x_sparse else ligand_data.X * receptor_data.X
lr_data = np.array(lr_data)
if cell_pair.shape[0] == 0:
lr_prod = np.zeros(lr_network.shape[0])
lr_co_exp_ratio = np.zeros(lr_network.shape[0])
lr_co_exp_num = np.zeros(lr_network.shape[0])
else:
lr_prod = np.apply_along_axis(lambda x: np.mean(x), 0, lr_data)
lr_co_exp_ratio = np.apply_along_axis(lambda x: np.sum(x > 0) / x.size, 0, lr_data)
lr_co_exp_num = np.apply_along_axis(lambda x: np.sum(x > 0), 0, lr_data)
lr_network["lr_product"] = lr_prod
lr_network["lr_co_exp_num"] = lr_co_exp_num
lr_network["lr_co_exp_ratio"] = lr_co_exp_ratio
# permutation test
per_data = np.zeros((lr_network.shape[0], num))
for i in tqdm(range(num)):
random.seed(i)
try:
cell_id = random.sample(adata.obs.index.tolist(), k=cell_pair.shape[0] * 2)
per_sender_id = cell_id[0 : cell_pair.shape[0]]
per_receiver_id = cell_id[cell_pair.shape[0] : cell_pair.shape[0] * 2]
except:
# If cell_pair * 2 is too large a number:
comb = itertools.permutations(adata.obs.index.tolist(), r=2)
pairs = random.sample(list(comb), k=cell_pair.shape[0])
per_sender_id = [pair[0] for pair in pairs]
per_receiver_id = [pair[1] for pair in pairs]
per_ligand_data = adata[per_sender_id, lr_network["from"]]
per_receptor_data = adata[per_receiver_id, lr_network["to"]]
per_lr_data = (
per_ligand_data.X.toarray() * per_receptor_data.X.toarray()
if x_sparse
else per_ligand_data.X * per_receptor_data.X
)
per_lr_co_exp_ratio = np.apply_along_axis(lambda x: np.sum(x > 0) / x.size, 0, per_lr_data)
if np.isnan(per_lr_co_exp_ratio).all():
per_data[:, i] = np.zeros(lr_network.shape[0])
else:
per_data[:, i] = per_lr_co_exp_ratio
per_data = pd.DataFrame(per_data)
per_data.index = lr_network["from"]
per_data["real"] = lr_network["lr_co_exp_ratio"].tolist()
lr_network["lr_co_exp_ratio_pvalue"] = per_data.apply(
lambda x: sum(x[:num] >= x["real"]) / num, axis=1
).tolist()
lr_network["is_significant"] = lr_network["lr_co_exp_ratio_pvalue"] < pvalue
if fdr:
# Multiple hypothesis testing correction:
qvalues = fdr_correct(pd.DataFrame(lr_network["lr_co_exp_ratio_pvalue"]), corr_method="fdr_bh")
lr_network["lr_co_exp_ratio_qvalues"] = qvalues
# After multiple testing correction:
lr_network["is_significant_fdr"] = qvalues < pvalue
lr_network["sr_pair"] = sender_group + "-" + receiver_group
res = {"cell_pair": cell_pair, "lr_pair": lr_network}
return res
# utils for mode1 significant test
@SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE, "adata")
[docs]def calculate_group_pair_lr_pair(
adata,
group,
group_pairs,
cols,
lr_network,
):
## ligand-20 groups; receptor-20 groups
adata_l = adata[:, lr_network["from"].tolist()]
adata_r = adata[:, lr_network["to"].tolist()]
dfl = pd.DataFrame(index=lr_network["lr_pair"], columns=cols)
dfr = pd.DataFrame(index=lr_network["lr_pair"], columns=cols)
df = pd.DataFrame(index=lr_network["lr_pair"], columns=group_pairs)
## ligand-20 groups; receptor-20 groups
for g in cols:
meanl = np.mean(adata_l[adata_l.obs[group] == g].X.toarray(), axis=0)
dfl[g] = meanl
meanr = np.mean(adata_r[adata_r.obs[group] == g].X.toarray(), axis=0)
dfr[g] = meanr
## group_pairs
for i, group_pair in enumerate(group_pairs):
df[group_pair] = (dfl[group_pair[0]] + dfr[group_pair[1]]) / 2
return df
# Wrapper for preprocessing for plotting:
[docs]def prepare_cci_df(cci_df: pd.DataFrame, means_col: str, pval_col: str, lr_pair_col: str, sr_pair_col: str):
"""
Given a dataframe generated from the output of :func `cci_two_cluster`, prepare for visualization by heatmap by
splitting into two dataframes, corresponding to the mean cell type-cell type L:R product and probability values
from the permutation test.
Args:
cci_df: CCI dataframe with columns for: ligand name, receptor name, L:R product, p value, and sender-receiver
cell types
means_col: Label for the column corresponding to the mean product of L:R expression between two cell types
pval_col: Label for the column corresponding to the p-value of the interaction
lr_pair_col: Label for the column corresponding to the ligand-receptor pair in format "{ligand}-{receptor}"
sr_pair_col: Label for the column corresponding to the sending-receiving cell type pair in format "{
sender}-{receiver}"
Returns:
dict: If 'adata' is None. Keys: 'means', 'pvalues', values: mean cell type-cell type L:R product, probability
values, respectively
Example:
res = find_cci_two_group(adata, ...)
# The df to save can be found under "lr_pair":
res["lr_pair"].to_csv(...)
adata, dict = prepare_cci_df(res["lr_pair"])
"""
logger = lm.get_main_logger()
# Dictionary to store mean and p-value dataframes:
dict = {}
# Split sender and receiver into separate columns:
cci_df[["sender", "receiver"]] = cci_df[sr_pair_col].str.split("-", expand=True)
all_lr_products, all_lr_pvals = {}, {}
# Split dataframe based on ligand-receptor pair, set "sender" and "receiver" as multiindex, keep only the means
# or p-values to get series for each LR interaction:
cci_grouped = cci_df.groupby(lr_pair_col)
for group in cci_grouped.groups.keys():
lig, rec = group.split("-")
df_group = cci_grouped.get_group(group)
df_group.set_index(["sender", "receiver"])
df_group = df_group.transpose()
# Series to row dataframe for means and p-values:
prod_df_group = df_group.loc[means_col].to_frame().transpose()
prod_df_group[["source", "target"]] = [lig, rec]
prod_df_group.set_index(["source", "target"])
pval_df_group = df_group.loc[pval_col].to_frame().transpose()
pval_df_group[["source", "target"]] = [lig, rec]
pval_df_group.set_index(["source", "target"])
all_lr_products[group] = prod_df_group
all_lr_pvals[group] = pval_df_group
means = pd.concat(all_lr_products.values())
pvals = pd.concat(all_lr_pvals.values())
dict["means"] = means
dict["pvalues"] = pvals
return dict
@SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE, "adata")
[docs]def prepare_cci_cellpair_adata(
adata: AnnData,
sender_group: str = None,
receiver_group: str = None,
group: str = None,
cci_dict: dict = None,
all_cell_pair: bool = False,
) -> AnnData:
"""prepare for visualization cellpairs by func `st.tl.space`, plot all_cell_pair,
or cell pairs which constrain by spatial distance(output of :func `cci_two_cluster`).
Args:
adata:An Annodata object.
sender_group: the cell group name of send ligands.
receiver_group: the cell group name of receive receptors.
group:The group name in adata.obs, Unused unless 'all_cell_pair' is True.
cci_dict: a dictionary result from :func `cci_two_cluster`, where the key is 'cell_pair' and 'lr_pair'.
Unused unless 'all_cell_pair' is False.
all_cell_pair: show all cells of the sender and receiver cell group, spatial_key: Key in .obsm containing coordinates for each bucket. Defult `False`.
Returns:
adata: Updated AnnData object containing 'spec' in .obs.
"""
logger = lm.get_main_logger()
adata.obs["spec"] = "other"
if all_cell_pair:
adata.obs.loc[adata.obs[group] == sender_group, "spec"] = sender_group
adata.obs.loc[adata.obs[group] == receiver_group, "spec"] = receiver_group
else:
adata.obs.loc[adata.obs.index.isin(cci_dict["cell_pair"]["cell_sender"].tolist()), "spec"] = sender_group
adata.obs.loc[adata.obs.index.isin(cci_dict["cell_pair"]["cell_receiver"].tolist()), "spec"] = receiver_group
return adata