"""Implementation of EM algorithm to identify parameter estimates for a
Negative Binomial mixture model.
https://iopscience.iop.org/article/10.1088/1742-6596/1324/1/012093/meta
Written by @HailinPan, optimized by @Lioscro.
"""
from functools import partial
from typing import Dict, Optional, Tuple, Union
import numpy as np
from joblib import Parallel, delayed
from scipy import special, stats
from tqdm import tqdm
from ..configuration import config
from ..errors import SegmentationError
try:
from ngs_tools.utils import ParallelWithProgress
except ModuleNotFoundError:
[docs] progress = partial(tqdm, ascii=True, smoothing=0.1)
class ParallelWithProgress(Parallel):
"""Wrapper around joblib.Parallel that uses tqdm to print execution progress.
Taken from https://github.com/Lioscro/ngs-tools/blob/main/ngs_tools/utils.py
"""
def __init__(
self,
pbar: Optional[tqdm] = None,
total: Optional[int] = None,
desc: Optional[str] = None,
disable: bool = False,
*args,
**kwargs
):
self._pbar = pbar or progress(total=total, desc=desc, disable=disable)
super(ParallelWithProgress, self).__init__(*args, **kwargs)
def __call__(self, *args, **kwargs):
try:
return Parallel.__call__(self, *args, **kwargs)
finally:
self._pbar.close()
def print_progress(self):
self._pbar.n = self.n_completed_tasks
self._pbar.refresh()
[docs]def lamtheta_to_r(lam: float, theta: float) -> float:
"""Convert lambda and theta to r."""
return -lam / np.log(theta)
[docs]def muvar_to_lamtheta(mu: float, var: float) -> Tuple[float, float]:
"""Convert the mean and variance to lambda and theta."""
r = mu**2 / (var - mu)
theta = mu / var
lam = -r * np.log(theta)
return lam, theta
[docs]def lamtheta_to_muvar(lam: float, theta: float) -> Tuple[float, float]:
"""Convert the lambda and theta to mean and variance."""
r = lamtheta_to_r(lam, theta)
mu = r / theta - r
var = mu + mu**2 / r
return mu, var
[docs]def nbn_pmf(n, p, X):
"""Helper function to compute PMF of negative binomial distribution.
This function is used instead of calling :func:`stats.nbinom` directly because there
is some weird behavior when float32 is used. This function essentially casts the `n` and
`p` parameters as floats.
"""
return stats.nbinom(n=float(n), p=float(p)).pmf(X)
[docs]def nbn_em(
X: np.ndarray,
w: Tuple[float, float] = (0.99, 0.01),
mu: Tuple[float, float] = (10.0, 300.0),
var: Tuple[float, float] = (20.0, 400.0),
max_iter: int = 2000,
precision: float = 1e-3,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Run the EM algorithm to estimate the parameters for background and cell
UMIs.
Args:
X: Numpy array containing mixture counts
w: Initial proportions of cell and background as a tuple.
mu: Initial means of cell and background negative binomial distributions.
var: Initial variances of cell and background negative binomial
distributions.
max_iter: Maximum number of iterations.
precision: Desired precision. Algorithm will stop once this is reached.
Returns:
Estimated `w`, `r`, `p`.
"""
w = np.array(w)
mu = np.array(mu)
var = np.array(var)
lam, theta = muvar_to_lamtheta(mu, var)
tau = np.zeros((2,) + X.shape)
prev_w = w.copy()
prev_lam = lam.copy()
prev_theta = theta.copy()
r = lamtheta_to_r(lam, theta)
for i in range(max_iter):
# E step
bp = nbn_pmf(r[0], theta[0], X)
cp = nbn_pmf(r[1], theta[1], X)
tau[0] = w[0] * bp
tau[1] = w[1] * cp
# mu = lamtheta_to_muvar(lam, theta)[0]
# NOTE: tau changes with each line
# tau[0][(tau.sum(axis=0) <= 1e-9) & (X < mu[0])] = 1
# tau[1][(tau.sum(axis=0) <= 1e-9) & (X > mu[1])] = 1
tau = np.clip(tau, 1e-10, 1e10)
tau /= tau.sum(axis=0)
beta = 1 - 1 / (1 - theta) - 1 / np.log(theta)
r = r.reshape(-1, 1)
delta = r * (special.digamma(r + X) - special.digamma(r))
tau_sum = tau.sum(axis=1)
w = tau_sum / tau_sum.sum()
lam = (tau * delta).sum(axis=1) / tau_sum
theta = beta * (tau * delta).sum(axis=1) / (tau * (X - (1 - beta).reshape(-1, 1) * delta)).sum(axis=1)
r = lamtheta_to_r(lam, theta)
isnan = np.any(np.isnan(r) | np.isnan(w) | np.isnan(theta))
isinf = np.any(np.isinf(r) | np.isinf(w) | np.isinf(theta))
isinvalid = np.any((r <= 0) | (theta > 1) | (theta < 0) | (w < 0) | (w > 1))
use_prev = isnan or isinf or isinvalid
if (
max(
np.abs(w - prev_w).max(),
np.abs(lam - prev_lam).max(),
np.abs(theta - prev_theta).max(),
)
< precision
) or use_prev:
break
prev_w = w.copy()
prev_lam = lam.copy()
prev_theta = theta.copy()
return (prev_w, lamtheta_to_r(prev_lam, prev_theta), prev_theta) if use_prev else (w, r, theta)
[docs]def conditionals(
X: np.ndarray,
em_results: Union[
Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float]],
Dict[int, Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float]]],
],
bins: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""Compute the conditional probabilities, for each pixel, of observing the
observed number of UMIs given that the pixel is background/foreground.
Args:
X: UMI counts per pixel
em_results: Return value of :func:`run_em`.
bins: Pixel bins, as was passed to :func:`run_em`.
Returns:
Two Numpy arrays, the first corresponding to the background conditional
probabilities, and the second to the foreground conditional probabilities
Raises:
SegmentationError: If `em_results` is a dictionary but `bins` was not
provided.
"""
if isinstance(em_results, dict):
if bins is None:
raise SegmentationError("`em_results` indicate binning was used, but `bins` was not provided")
background_cond = np.ones(X.shape)
cell_cond = np.zeros(X.shape)
for label, (_, r, p) in em_results.items():
mask = bins == label
samples = X[mask]
background_cond[mask] = nbn_pmf(r[0], p[0], samples)
cell_cond[mask] = nbn_pmf(r[1], p[1], samples)
else:
_, r, p = em_results
background_cond = nbn_pmf(r[0], p[0], X)
cell_cond = nbn_pmf(r[1], p[1], X)
return background_cond, cell_cond
[docs]def confidence(
X: np.ndarray,
em_results: Union[
Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float]],
Dict[int, Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float]]],
],
bins: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Compute confidence of each pixel being a cell, using the parameters
estimated by the EM algorithm.
Args:
X: Numpy array containing mixture counts.
em_results: Return value of :func:`run_em`.
bins: Pixel bins, as was passed to :func:`run_em`.
Returns:
Numpy array of confidence scores within the range [0, 1].
"""
bp, cp = conditionals(X, em_results, bins)
tau0 = np.zeros(X.shape)
tau1 = np.zeros(X.shape)
if isinstance(em_results, dict):
for label, (w, _, _) in em_results.items():
mask = bins == label
tau0[mask] = w[0] * bp[mask]
tau1[mask] = w[1] * cp[mask]
else:
w, _, _ = em_results
tau0 = w[0] * bp
tau1 = w[1] * cp
return tau1 / (tau0 + tau1)
[docs]def run_em(
X: np.ndarray,
downsample: Union[int, float] = 0.001,
params: Union[Dict[str, Tuple[float, float]], Dict[int, Dict[str, Tuple[float, float]]]] = dict(
w=(0.5, 0.5), mu=(10.0, 300.0), var=(20.0, 400.0)
),
max_iter: int = 2000,
precision: float = 1e-6,
bins: Optional[np.ndarray] = None,
seed: Optional[int] = None,
) -> Union[
Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float]],
Dict[int, Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float]]],
]:
"""EM
Args:
X: UMI counts per pixel.
use_peaks: Whether to use peaks of convolved image as samples for the
EM algorithm.
min_distance: Minimum distance between peaks when `use_peaks=True`
downsample: Use at most this many samples. If `use_peaks` is False,
samples are chosen randomly with probability proportional to the
log UMI counts. When `bins` is provided, the size of each bin is
used as a scaling factor. If this is a float, then samples are
downsampled by this fraction.
params: Initial parameters. This is a dictionary that contains `w`, `mu`,
`var` as its keys, each corresponding to initial proportions, means and
variances of background and foreground pixels. The values must be a
2-element tuple containing the values for background and foreground.
This may also be a nested dictionary, where the outermost key maps bin
labels provided in the `bins` argument. In this case, each of the inner
dictionaries will be used as the initial paramters corresponding to
each bin.
max_iter: Maximum number of EM iterations.
precision: Stop EM algorithm once desired precision has been reached.
bins: Bins of pixels to estimate separately, such as those obtained by
density segmentation. Zeros are ignored.
seed: Random seed.
Returns:
Tuple of parameters estimated by the EM algorithm if `bins` is not provided.
Otherwise, a dictionary of tuple of parameters, with bin labels as keys.
"""
samples = {} # key 0 when bins = None
if bins is not None:
for label in np.unique(bins):
if label > 0:
samples[label] = X[bins == label]
_params = params.get(label, params)
if set(_params.keys()) != {"w", "mu", "var"}:
raise SegmentationError("`params` must contain exactly the keys `w`, `mu`, `var`.")
else:
samples[0] = X.flatten()
if set(params.keys()) != {"w", "mu", "var"}:
raise SegmentationError("`params` must contain exactly the keys `w`, `mu`, `var`.")
downsample_scale = True
if downsample > 1:
downsample_scale = False
rng = np.random.default_rng(seed)
final_samples = {}
total = sum(len(_samples) for _samples in samples.values())
for label, _samples in samples.items():
_downsample = int(len(_samples) * downsample) if downsample_scale else int(downsample * (len(_samples) / total))
if len(_samples) > _downsample:
weights = np.log1p(_samples + 1)
_samples = rng.choice(_samples, _downsample, replace=False, p=weights / weights.sum())
final_samples[label] = np.array(_samples)
# Run in parallel
results = {}
for label, (res_w, res_r, res_p) in zip(
final_samples.keys(),
ParallelWithProgress(n_jobs=config.n_threads, total=len(final_samples), desc="Running EM")(
delayed(nbn_em)(final_samples[label], max_iter=max_iter, precision=precision, **params.get(label, params))
for label in final_samples
),
):
results[label] = (tuple(res_w), tuple(res_r), tuple(res_p))
return results if bins is not None else results[0]