spateo.segmentation.vi¶
Variational inference implementation of a negative binomial mixture model using Pyro.
Classes¶
Subclass of |
Functions¶
|
Compute the conditional probabilities, for each pixel, of observing the |
|
Run negative binomial mixture variational inference. |
Module Contents¶
- class spateo.segmentation.vi.NegativeBinomialMixture(x: numpy.ndarray, n: int = 2, n_init: int = 5, w: numpy.ndarray | None = None, mu: numpy.ndarray | None = None, var: numpy.ndarray | None = None, zero_inflated: bool = False, seed: int | None = None)[source]¶
Bases:
pyro.nn.PyroModuleSubclass of
torch.nn.Modulewhose attributes can be modified by Pyro effects. Attributes can be set using helpersPyroParamandPyroSample, and methods can be decorated bypyro_method().Parameters
To create a Pyro-managed parameter attribute, set that attribute using either
torch.nn.Parameter(for unconstrained parameters) orPyroParam(for constrained parameters). Reading that attribute will then trigger apyro.paramstatement. For example:# Create Pyro-managed parameter attributes. my_module = PyroModule() my_module.loc = nn.Parameter(torch.tensor(0.)) my_module.scale = PyroParam(torch.tensor(1.), constraint=constraints.positive) # Read the attributes. loc = my_module.loc # Triggers a pyro.param statement. scale = my_module.scale # Triggers another pyro.param statement.
Note that, unlike normal
torch.nn.Modules,PyroModules should not be registered withpyro.modulestatements.PyroModules can contain otherPyroModules and normaltorch.nn.Modules. Accessing a normaltorch.nn.Moduleattribute of aPyroModuletriggers apyro.modulestatement. If multiplePyroModules appear in a single Pyro model or guide, they should be included in a single rootPyroModulefor that model.PyroModules synchronize data with the param store at eachsetattr,getattr, anddelattrevent, based on the nested name of an attribute:Setting
mod.x = x_inittries to readxfrom the param store. If a value is found in the param store, that value is copied intomodandx_initis ignored; otherwisex_initis copied into bothmodand the param store.Reading
mod.xtries to readxfrom the param store. If a value is found in the param store, that value is copied intomod; otherwisemod’s value is copied into the param store. Finallymodand the param store agree on a single value to return.Deleting
del mod.xremoves a value from bothmodand the param store.
Note two
PyroModuleof the same name will both synchronize with the global param store and thus contain the same data. When creating aPyroModule, then deleting it, then creating another with the same name, the latter will be populated with the former’s data from the param store. To avoid this persistence, eitherpyro.clear_param_store()or callclear()before deleting aPyroModule.PyroModules can be saved and loaded either directly usingtorch.save()/torch.load()or indirectly using the param store’ssave()/load(). Note thattorch.load()will be overridden by any values in the param store, so it is safest topyro.clear_param_store()before loading.Samples
To create a Pyro-managed random attribute, set that attribute using the
PyroSamplehelper, specifying a prior distribution. Reading that attribute will then trigger apyro.samplestatement. For example:# Create Pyro-managed random attributes. my_module.x = PyroSample(dist.Normal(0, 1)) my_module.y = PyroSample(lambda self: dist.Normal(self.loc, self.scale)) # Sample the attributes. x = my_module.x # Triggers a pyro.sample statement. y = my_module.y # Triggers one pyro.sample + two pyro.param statements.
Sampling is cached within each invocation of
.__call__()or method decorated bypyro_method(). Because sample statements can appear only once in a Pyro trace, you should ensure that traced access to sample attributes is wrapped in a single invocation of.__call__()or method decorated bypyro_method().To make an existing module probabilistic, you can create a subclass and overwrite some parameters with
PyroSamples:class RandomLinear(nn.Linear, PyroModule): # used as a mixin def __init__(self, in_features, out_features): super().__init__(in_features, out_features) self.weight = PyroSample( lambda self: dist.Normal(0, 1) .expand([self.out_features, self.in_features]) .to_event(2))
Mixin classes
PyroModulecan be used as a mixin class, and supports simple syntax for dynamically creating mixins, for example the following are equivalent:# Version 1. create a named mixin class class PyroLinear(nn.Linear, PyroModule): pass m.linear = PyroLinear(m, n) # Version 2. create a dynamic mixin class m.linear = PyroModule[nn.Linear](m, n)
This notation can be used recursively to create Bayesian modules, e.g.:
model = PyroModule[nn.Sequential]( PyroModule[nn.Linear](28 * 28, 100), PyroModule[nn.Sigmoid](), PyroModule[nn.Linear](100, 100), PyroModule[nn.Sigmoid](), PyroModule[nn.Linear](100, 10), ) assert isinstance(model, nn.Sequential) assert isinstance(model, PyroModule) # Now we can be Bayesian about weights in the first layer. model[0].weight = PyroSample( prior=dist.Normal(0, 1).expand([28 * 28, 100]).to_event(2)) guide = AutoDiagonalNormal(model)
Note that
PyroModule[...]does not recursively mix inPyroModuleto submodules of the inputModule; hence we needed to wrap each submodule of thenn.Sequentialabove.- Parameters:
- name str
Optional name for a root PyroModule. This is ignored in sub-PyroModules of another PyroModule.
- train(n_epochs: int = 500)[source]¶
Set the module in training mode.
This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g.
Dropout,BatchNorm, etc.- Parameters:
- mode bool
whether to set training mode (
True) or evaluation mode (False). Default:True.
- Returns:
self
- Return type:
Module
- spateo.segmentation.vi.conditionals(X: numpy.ndarray, vi_results: Dict[int, Dict[str, float]] | Dict[str, float], bins: numpy.ndarray | None = None) Tuple[numpy.ndarray, numpy.ndarray][source]¶
Compute the conditional probabilities, for each pixel, of observing the observed number of UMIs given that the pixel is background/foreground.
- Parameters:
- X
UMI counts per pixel
- em_results
Return value of
run_em().- bins
Pixel bins, as was passed to
run_em().
- Returns:
Two Numpy arrays, the first corresponding to the background conditional probabilities, and the second to the foreground conditional probabilities
- Raises:
SegmentationError – If em_results is a dictionary but bins was not provided.
- spateo.segmentation.vi.run_vi(X: numpy.ndarray, downsample: int | float = 0.01, n_epochs: int = 500, bins: numpy.ndarray | None = None, params: Dict[str, Tuple[float, float]] | Dict[int, Dict[str, Tuple[float, float]]] = dict(w=(0.5, 0.5), mu=(10.0, 300.0), var=(20.0, 400.0)), zero_inflated: bool = False, seed: int | None = None) Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float]] | Dict[int, Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float]]][source]¶
Run negative binomial mixture variational inference.
- Parameters:
- X
- downsample
- n_epochs
- bins
- params
- zero_inflated
- seed
Returns: