import math
from typing import Optional, Union
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from anndata import AnnData
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
from .utils import save_return_show_fig_utils
[docs]def glm_fit(
adata: AnnData,
genes: Optional[Union[str, list]] = None,
feature_x: str = None,
feature_y: str = "expression",
glm_key: str = "glm_degs",
remove_zero: bool = False,
color_key: Optional[str] = None,
color_key_cmap: Optional[str] = "vlag",
point_size: float = 14,
point_color: Union[str, np.ndarray, list] = "skyblue",
line_size: float = 2,
line_color: str = "black",
ax_size: Union[tuple, list] = (6, 4),
background_color: str = "white",
ncols: int = 4,
show_point: bool = True,
show_line: bool = True,
show_legend: bool = True,
save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show",
save_kwargs: Optional[dict] = None,
**kwargs,
):
"""
Plot the glm_degs result in a scatterplot.
Args:
adata: An Anndata object contain glm_degs result in ``.uns[glm_key]``.
genes: A gene name or a list of genes that will be used to plot.
feature_x: The key in ``.uns[glm_key]['correlation'][gene]`` that corresponds to the independent variables, such as ``'torsion'``, etc.
feature_y: The key in ``.uns[glm_key]['correlation'][gene]`` that corresponds to the dependent variables, such as ``'expression'``, etc.
glm_key: The key in ``.uns`` that corresponds to the glm_degs result.
remove_zero: Whether to remove the data equal to 0 saved in ``.uns[glm_key]['correlation'][gene][feature_y]``.
color_key: This can either be an explicit dict mapping labels to colors (as strings of form ‘#RRGGBB’), or an array like object providing one color for each distinct category being provided in labels.
color_key_cmap: The name of a matplotlib colormap to use for categorical coloring.
point_size: The scale of the feature_y point size.
point_color: The color of the feature_y point.
line_size: The scale of the fitted line width.
line_color: The color of the fitted line.
ax_size: The width and height of each ax.
background_color: The background color of the figure.
ncols: Number of columns for the figure.
show_point: Whether to show the scatter plot.
show_line: Whether to show the line plot.
show_legend: Whether to show the legend.
save_show_or_return: If ``'both'``, it will save and plot the figure at the same time.
If ``'all'``, the figure will be saved, displayed and the associated axis and other object will be return.
save_kwargs: A dictionary that will be passed to the save_fig function.
By default, it is an empty dictionary and the save_fig function will use the ``{"path": None, "prefix": 'scatter',
"dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True}`` as its parameters.
Otherwise, you can provide a dictionary that properly modify those keys according to your needs.
**kwargs: Additional parameters that will be passed into the ``seaborn.scatterplot`` function.
"""
assert not (feature_x is None), "``feature_x`` cannot be None."
assert not (feature_y is None), "``feature_y`` cannot be None."
assert (
glm_key in adata.uns
), f"``glm_key`` does not exist in adata.uns, please replace ``glm_key`` or run st.tl.glm_degs(key_added={glm_key})."
genes = list(adata.uns[glm_key]["glm_result"].index) if genes is None else genes
genes = list(genes) if isinstance(genes, list) else [genes]
genes_data = [adata.uns[glm_key]["correlation"][g].copy() for g in genes]
ncols = len(genes) if len(genes) < ncols else ncols
nrows = math.ceil(len(genes) / ncols)
fig = plt.figure(figsize=(ax_size[0] * ncols, ax_size[1] * nrows))
axes_list = []
for i, data in enumerate(genes_data):
data.sort_values(by=feature_x, ascending=True, axis=0, inplace=True)
if remove_zero:
data = data[np.asarray(data[feature_y]).flatten() != 0]
ax = plt.subplot(nrows, ncols, i + 1)
ax.set_title(f"Gene: {genes[i]}")
if show_point:
sns.scatterplot(
data=data,
x=feature_x,
y=feature_y,
hue=color_key,
palette=color_key_cmap,
color=point_color,
s=point_size,
legend=show_legend,
ax=ax,
**kwargs,
)
ax.set_ylabel(feature_y)
if show_line:
ax = ax.twinx() if show_point is True else ax
sns.lineplot(
data=data,
x=feature_x,
y="mu",
color=line_color,
lw=line_size,
legend=False,
ax=ax,
)
ax.set_ylabel("mu")
axes_list.append(ax)
added_pad = nrows * 0.1 if ncols * 2 < nrows else ncols * 0.2
plt.tight_layout(pad=1 + added_pad)
return save_return_show_fig_utils(
save_show_or_return=save_show_or_return,
show_legend=show_legend,
background=background_color,
prefix="glm_degs",
save_kwargs=save_kwargs,
total_panels=len(genes),
fig=fig,
axes=axes_list,
return_all=False,
return_all_list=None,
)
[docs]def glm_heatmap(
adata: AnnData,
genes: Optional[Union[str, list]] = None,
feature_x: str = None,
feature_y: str = "expression",
glm_key: str = "glm_degs",
lowess_smooth: bool = True,
frac: float = 0.2,
robust: bool = True,
colormap: str = "vlag",
figsize: tuple = (6, 6),
background_color: str = "white",
show_legend: bool = True,
save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show",
save_kwargs: Optional[dict] = None,
**kwargs,
):
"""
Plot the glm_degs result in a heatmap.
Args:
adata: An Anndata object contain glm_degs result in ``.uns[glm_key]``.
genes: A gene name or a list of genes that will be used to plot.
feature_x: The key in ``.uns[glm_key]['correlation'][gene]`` that corresponds to the independent variables, such as ``'torsion'``, etc.
feature_y: The key in ``.uns[glm_key]['correlation'][gene]`` that corresponds to the dependent variables, such as ``'expression'``, etc.
glm_key: The key in ``.uns`` that corresponds to the glm_degs result.
lowess_smooth: If True, use statsmodels to estimate a nonparametric lowess model (locally weighted linear regression).
frac: Between 0 and 1. The fraction of the data used when estimating each y-value.
robust: If True and vmin or vmax are absent, the colormap range is computed with robust quantiles instead of the extreme values.
colormap: The name of a matplotlib colormap.
figsize: The width and height of figure.
background_color: The background color of the figure.
show_legend: Whether to show the legend.
save_show_or_return: If ``'both'``, it will save and plot the figure at the same time.
If ``'all'``, the figure will be saved, displayed and the associated axis and other object will be return.
save_kwargs: A dictionary that will be passed to the save_fig function.
By default, it is an empty dictionary and the save_fig function will use the ``{"path": None, "prefix": 'scatter',
"dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True}`` as its parameters.
Otherwise, you can provide a dictionary that properly modify those keys according to your needs.
**kwargs: Additional parameters that will be passed into the ``seaborn.heatmap`` function.
"""
assert not (feature_x is None), "``feature_x`` cannot be None."
assert not (feature_y is None), "``feature_y`` cannot be None."
assert (
glm_key in adata.uns
), f"``glm_key`` does not exist in adata.uns, please replace ``glm_key`` or run st.tl.glm_degs(key_added={glm_key})."
genes = list(adata.uns[glm_key]["glm_result"].index) if genes is None else genes
genes = list(genes) if isinstance(genes, list) else [genes]
genes_data = []
for g in genes:
gene_data = adata.uns[glm_key]["correlation"][g].copy()
gene_data.sort_values(by=feature_x, ascending=True, axis=0, inplace=True)
gene_data = gene_data.loc[:, [feature_x, feature_y]]
data = pd.DataFrame(gene_data.groupby(by=feature_x)[feature_y].mean())
if lowess_smooth:
import statsmodels.api as sm
data = pd.DataFrame(sm.nonparametric.lowess(exog=data.index, endog=data[feature_y], frac=frac))[1]
genes_data.append(data)
genes_data = pd.concat(genes_data, axis=1)
genes_data.fillna(value=0, inplace=True)
genes_data.columns = genes
genes_data = genes_data.T
max_sort = np.argsort(np.argmax(genes_data.values, axis=1))
genes_data = genes_data.iloc[max_sort]
fig, ax = plt.subplots(figsize=figsize)
sns.heatmap(genes_data, cmap=colormap, robust=robust, ax=ax, **kwargs)
ax.set_xlabel(feature_x)
ax.set_ylabel(feature_y)
plt.tight_layout(pad=1)
return save_return_show_fig_utils(
save_show_or_return=save_show_or_return,
show_legend=show_legend,
background=background_color,
prefix="glm_degs",
save_kwargs=save_kwargs,
total_panels=len(genes),
fig=fig,
axes=ax,
return_all=False,
return_all_list=None,
)