spateo.segmentation.vi#

Variational inference implementation of a negative binomial mixture model using Pyro.

Module Contents#

Classes#

NegativeBinomialMixture

Subclass of torch.nn.Module whose attributes can be modified by

Functions#

conditionals(→ Tuple[numpy.ndarray, numpy.ndarray])

Compute the conditional probabilities, for each pixel, of observing the

run_vi([mu, var])

Run negative binomial mixture variational inference.

class spateo.segmentation.vi.NegativeBinomialMixture(x: numpy.ndarray, n: int = 2, n_init: int = 5, w: numpy.ndarray | None = None, mu: numpy.ndarray | None = None, var: numpy.ndarray | None = None, zero_inflated: bool = False, seed: int | None = None)[source]#

Bases: pyro.nn.PyroModule

Subclass of torch.nn.Module whose attributes can be modified by Pyro effects. Attributes can be set using helpers PyroParam and PyroSample , and methods can be decorated by pyro_method() .

Parameters

To create a Pyro-managed parameter attribute, set that attribute using either torch.nn.Parameter (for unconstrained parameters) or PyroParam (for constrained parameters). Reading that attribute will then trigger a pyro.param statement. For example:

# Create Pyro-managed parameter attributes.
my_module = PyroModule()
my_module.loc = nn.Parameter(torch.tensor(0.))
my_module.scale = PyroParam(torch.tensor(1.),
                            constraint=constraints.positive)
# Read the attributes.
loc = my_module.loc  # Triggers a pyro.param statement.
scale = my_module.scale  # Triggers another pyro.param statement.

Note that, unlike normal torch.nn.Module s, PyroModule s should not be registered with pyro.module statements. PyroModule s can contain other PyroModule s and normal torch.nn.Module s. Accessing a normal torch.nn.Module attribute of a PyroModule triggers a pyro.module statement. If multiple PyroModule s appear in a single Pyro model or guide, they should be included in a single root PyroModule for that model.

PyroModule s synchronize data with the param store at each setattr, getattr, and delattr event, based on the nested name of an attribute:

  • Setting mod.x = x_init tries to read x from the param store. If a value is found in the param store, that value is copied into mod and x_init is ignored; otherwise x_init is copied into both mod and the param store.

  • Reading mod.x tries to read x from the param store. If a value is found in the param store, that value is copied into mod; otherwise mod’s value is copied into the param store. Finally mod and the param store agree on a single value to return.

  • Deleting del mod.x removes a value from both mod and the param store.

Note two PyroModule of the same name will both synchronize with the global param store and thus contain the same data. When creating a PyroModule, then deleting it, then creating another with the same name, the latter will be populated with the former’s data from the param store. To avoid this persistence, either pyro.clear_param_store() or call clear() before deleting a PyroModule .

PyroModule s can be saved and loaded either directly using torch.save() / torch.load() or indirectly using the param store’s save() / load() . Note that torch.load() will be overridden by any values in the param store, so it is safest to pyro.clear_param_store() before loading.

Samples

To create a Pyro-managed random attribute, set that attribute using the PyroSample helper, specifying a prior distribution. Reading that attribute will then trigger a pyro.sample statement. For example:

# Create Pyro-managed random attributes.
my_module.x = PyroSample(dist.Normal(0, 1))
my_module.y = PyroSample(lambda self: dist.Normal(self.loc, self.scale))

# Sample the attributes.
x = my_module.x  # Triggers a pyro.sample statement.
y = my_module.y  # Triggers one pyro.sample + two pyro.param statements.

Sampling is cached within each invocation of .__call__() or method decorated by pyro_method() . Because sample statements can appear only once in a Pyro trace, you should ensure that traced access to sample attributes is wrapped in a single invocation of .__call__() or method decorated by pyro_method() .

To make an existing module probabilistic, you can create a subclass and overwrite some parameters with PyroSample s:

class RandomLinear(nn.Linear, PyroModule):  # used as a mixin
    def __init__(self, in_features, out_features):
        super().__init__(in_features, out_features)
        self.weight = PyroSample(
            lambda self: dist.Normal(0, 1)
                             .expand([self.out_features,
                                      self.in_features])
                             .to_event(2))

Mixin classes

PyroModule can be used as a mixin class, and supports simple syntax for dynamically creating mixins, for example the following are equivalent:

# Version 1. create a named mixin class
class PyroLinear(nn.Linear, PyroModule):
    pass

m.linear = PyroLinear(m, n)

# Version 2. create a dynamic mixin class
m.linear = PyroModule[nn.Linear](m, n)

This notation can be used recursively to create Bayesian modules, e.g.:

model = PyroModule[nn.Sequential](
    PyroModule[nn.Linear](28 * 28, 100),
    PyroModule[nn.Sigmoid](),
    PyroModule[nn.Linear](100, 100),
    PyroModule[nn.Sigmoid](),
    PyroModule[nn.Linear](100, 10),
)
assert isinstance(model, nn.Sequential)
assert isinstance(model, PyroModule)

# Now we can be Bayesian about weights in the first layer.
model[0].weight = PyroSample(
    prior=dist.Normal(0, 1).expand([28 * 28, 100]).to_event(2))
guide = AutoDiagonalNormal(model)

Note that PyroModule[...] does not recursively mix in PyroModule to submodules of the input Module; hence we needed to wrap each submodule of the nn.Sequential above.

Parameters:
name str

Optional name for a root PyroModule. This is ignored in sub-PyroModules of another PyroModule.

assignment(train=False)[source]#
dist(assignment, train=False)[source]#
init_best_params(n_init)[source]#
init_mean_variance(w, mu, var)[source]#
optimizer()[source]#
get_params(train=False, transform=True)[source]#
forward(x)[source]#
train(n_epochs: int = 500)[source]#

Set the module in training mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. Dropout, BatchNorm, etc.

Parameters:
mode bool

whether to set training mode (True) or evaluation mode (False). Default: True.

Returns:

self

Return type:

Module

static conditionals(params, x, use_weights=False)[source]#
spateo.segmentation.vi.conditionals(X: numpy.ndarray, vi_results: Dict[int, Dict[str, float]] | Dict[str, float], bins: numpy.ndarray | None = None) Tuple[numpy.ndarray, numpy.ndarray][source]#

Compute the conditional probabilities, for each pixel, of observing the observed number of UMIs given that the pixel is background/foreground.

Parameters:
X

UMI counts per pixel

em_results

Return value of run_em().

bins

Pixel bins, as was passed to run_em().

Returns:

Two Numpy arrays, the first corresponding to the background conditional probabilities, and the second to the foreground conditional probabilities

Raises:

SegmentationError – If em_results is a dictionary but bins was not provided.

spateo.segmentation.vi.run_vi(X: numpy.ndarray, downsample: int | float = 0.01, n_epochs: int = 500, bins: numpy.ndarray | None = None, params: Dict[str, Tuple[float, float]] | Dict[int, Dict[str, Tuple[float, float]]] = dict(w=(0.5, 0.5), mu=(10.0, 300.0), var=(20.0, 400.0)), zero_inflated: bool = False, seed: int | None = None) Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float]] | Dict[int, Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float]]][source]#

Run negative binomial mixture variational inference.

Parameters:
X

downsample

n_epochs

bins

params

zero_inflated

seed

Returns: