Source code for spateo.plotting.static.polarity

"""Written by @Jinerhal, adapted by @Xiaojieqiu.
"""

import numpy as np
import pandas as pd
import seaborn as sns
from anndata import AnnData


[docs]def polarity( adata: AnnData, gene_dict: dict, region_key: str, mode: str = "density", ): """Simple function to visualize expression level varies along regions. Args: adata (AnnData): _description_ gene_dict (dict): _description_ region_key (str): _description_ mode (str, optional): _description_. Defaults to "density". Returns: _type_: _description_ """ digi_region = np.array([]) gene_list = np.array([]) gene_mean = np.array([]) if mode == "exp": for i in np.unique(adata.obs[region_key]): adata_tmp = adata[adata.obs[region_key] == i, :] for anno in list(gene_dict.keys()): for gene in gene_dict[anno]: gene_mean_tmp = adata_tmp[:, gene].X.toarray().T[0] digi_region = np.append(digi_region, np.repeat(i, len(adata_tmp))) gene_list = np.append(gene_list, np.repeat(gene + " " + anno, len(adata_tmp))) gene_mean = np.append(gene_mean, gene_mean_tmp) df_plt = pd.DataFrame({region_key: digi_region, "Gene": gene_list, "Mean expression": gene_mean}) ax = sns.lineplot(data=df_plt, x=region_key, y="Mean expression", hue="Gene") elif mode == "density": for i in np.unique(adata.obs[region_key]): adata_tmp = adata[adata.obs[region_key] == i, :] for anno in list(gene_dict.keys()): for gene in gene_dict[anno]: digi_region.append(i) gene_list.append(gene + " " + anno) gene_mean.append(np.mean(adata_tmp[:, gene].X)) df_plt = pd.DataFrame({region_key: digi_region, "Gene": gene_list, "Mean expression": gene_mean}) p = sns.kdeplot(data=df_plt, x=region_key, weights="Mean expression", hue="Gene") p.set_xlim(0, max(adata.obs[region_key])) return ax