spateo.segmentation.vi¶
Variational inference implementation of a negative binomial mixture model using Pyro.
Classes¶
Subclass of |
Functions¶
|
Compute the conditional probabilities, for each pixel, of observing the |
|
Run negative binomial mixture variational inference. |
Module Contents¶
- class spateo.segmentation.vi.NegativeBinomialMixture(x: numpy.ndarray, n: int = 2, n_init: int = 5, w: numpy.ndarray | None = None, mu: numpy.ndarray | None = None, var: numpy.ndarray | None = None, zero_inflated: bool = False, seed: int | None = None)[source]¶
Bases:
pyro.nn.PyroModule
Subclass of
torch.nn.Module
whose attributes can be modified by Pyro effects. Attributes can be set using helpersPyroParam
andPyroSample
, and methods can be decorated bypyro_method()
.Parameters
To create a Pyro-managed parameter attribute, set that attribute using either
torch.nn.Parameter
(for unconstrained parameters) orPyroParam
(for constrained parameters). Reading that attribute will then trigger apyro.param
statement. For example:# Create Pyro-managed parameter attributes. my_module = PyroModule() my_module.loc = nn.Parameter(torch.tensor(0.)) my_module.scale = PyroParam(torch.tensor(1.), constraint=constraints.positive) # Read the attributes. loc = my_module.loc # Triggers a pyro.param statement. scale = my_module.scale # Triggers another pyro.param statement.
Note that, unlike normal
torch.nn.Module
s,PyroModule
s should not be registered withpyro.module
statements.PyroModule
s can contain otherPyroModule
s and normaltorch.nn.Module
s. Accessing a normaltorch.nn.Module
attribute of aPyroModule
triggers apyro.module
statement. If multiplePyroModule
s appear in a single Pyro model or guide, they should be included in a single rootPyroModule
for that model.PyroModule
s synchronize data with the param store at eachsetattr
,getattr
, anddelattr
event, based on the nested name of an attribute:Setting
mod.x = x_init
tries to readx
from the param store. If a value is found in the param store, that value is copied intomod
andx_init
is ignored; otherwisex_init
is copied into bothmod
and the param store.Reading
mod.x
tries to readx
from the param store. If a value is found in the param store, that value is copied intomod
; otherwisemod
’s value is copied into the param store. Finallymod
and the param store agree on a single value to return.Deleting
del mod.x
removes a value from bothmod
and the param store.
Note two
PyroModule
of the same name will both synchronize with the global param store and thus contain the same data. When creating aPyroModule
, then deleting it, then creating another with the same name, the latter will be populated with the former’s data from the param store. To avoid this persistence, eitherpyro.clear_param_store()
or callclear()
before deleting aPyroModule
.PyroModule
s can be saved and loaded either directly usingtorch.save()
/torch.load()
or indirectly using the param store’ssave()
/load()
. Note thattorch.load()
will be overridden by any values in the param store, so it is safest topyro.clear_param_store()
before loading.Samples
To create a Pyro-managed random attribute, set that attribute using the
PyroSample
helper, specifying a prior distribution. Reading that attribute will then trigger apyro.sample
statement. For example:# Create Pyro-managed random attributes. my_module.x = PyroSample(dist.Normal(0, 1)) my_module.y = PyroSample(lambda self: dist.Normal(self.loc, self.scale)) # Sample the attributes. x = my_module.x # Triggers a pyro.sample statement. y = my_module.y # Triggers one pyro.sample + two pyro.param statements.
Sampling is cached within each invocation of
.__call__()
or method decorated bypyro_method()
. Because sample statements can appear only once in a Pyro trace, you should ensure that traced access to sample attributes is wrapped in a single invocation of.__call__()
or method decorated bypyro_method()
.To make an existing module probabilistic, you can create a subclass and overwrite some parameters with
PyroSample
s:class RandomLinear(nn.Linear, PyroModule): # used as a mixin def __init__(self, in_features, out_features): super().__init__(in_features, out_features) self.weight = PyroSample( lambda self: dist.Normal(0, 1) .expand([self.out_features, self.in_features]) .to_event(2))
Mixin classes
PyroModule
can be used as a mixin class, and supports simple syntax for dynamically creating mixins, for example the following are equivalent:# Version 1. create a named mixin class class PyroLinear(nn.Linear, PyroModule): pass m.linear = PyroLinear(m, n) # Version 2. create a dynamic mixin class m.linear = PyroModule[nn.Linear](m, n)
This notation can be used recursively to create Bayesian modules, e.g.:
model = PyroModule[nn.Sequential]( PyroModule[nn.Linear](28 * 28, 100), PyroModule[nn.Sigmoid](), PyroModule[nn.Linear](100, 100), PyroModule[nn.Sigmoid](), PyroModule[nn.Linear](100, 10), ) assert isinstance(model, nn.Sequential) assert isinstance(model, PyroModule) # Now we can be Bayesian about weights in the first layer. model[0].weight = PyroSample( prior=dist.Normal(0, 1).expand([28 * 28, 100]).to_event(2)) guide = AutoDiagonalNormal(model)
Note that
PyroModule[...]
does not recursively mix inPyroModule
to submodules of the inputModule
; hence we needed to wrap each submodule of thenn.Sequential
above.- Parameters:
- name str
Optional name for a root PyroModule. This is ignored in sub-PyroModules of another PyroModule.
- train(n_epochs: int = 500)[source]¶
Set the module in training mode.
This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g.
Dropout
,BatchNorm
, etc.- Parameters:
- mode bool
whether to set training mode (
True
) or evaluation mode (False
). Default:True
.
- Returns:
self
- Return type:
Module
- spateo.segmentation.vi.conditionals(X: numpy.ndarray, vi_results: Dict[int, Dict[str, float]] | Dict[str, float], bins: numpy.ndarray | None = None) Tuple[numpy.ndarray, numpy.ndarray] [source]¶
Compute the conditional probabilities, for each pixel, of observing the observed number of UMIs given that the pixel is background/foreground.
- Parameters:
- X
UMI counts per pixel
- em_results
Return value of
run_em()
.- bins
Pixel bins, as was passed to
run_em()
.
- Returns:
Two Numpy arrays, the first corresponding to the background conditional probabilities, and the second to the foreground conditional probabilities
- Raises:
SegmentationError – If em_results is a dictionary but bins was not provided.
- spateo.segmentation.vi.run_vi(X: numpy.ndarray, downsample: int | float = 0.01, n_epochs: int = 500, bins: numpy.ndarray | None = None, params: Dict[str, Tuple[float, float]] | Dict[int, Dict[str, Tuple[float, float]]] = dict(w=(0.5, 0.5), mu=(10.0, 300.0), var=(20.0, 400.0)), zero_inflated: bool = False, seed: int | None = None) Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float]] | Dict[int, Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float]]] [source]¶
Run negative binomial mixture variational inference.
- Parameters:
- X
- downsample
- n_epochs
- bins
- params
- zero_inflated
- seed
Returns: