spateo.external.MERFISHVI.multimodal_spatial_vae

Attributes

Classes

MultiModalSpatialVAE

Multi-modal spatial variational autoencoder.

Functions

unsupported_if_adata_minified(fn)

log_zinb_positive(x, mu, theta, pi[, eps])

Log likelihood of zero-inflated negative binomial distribution.

log_nb_positive(x, mu, theta[, eps])

Log likelihood of negative binomial distribution.

log_poisson(x, mu[, eps])

Log likelihood of Poisson distribution.

log_normal(x, mu, var[, eps])

Log likelihood of normal distribution.

Module Contents

spateo.external.MERFISHVI.multimodal_spatial_vae.unsupported_if_adata_minified(fn)[source]
spateo.external.MERFISHVI.multimodal_spatial_vae.logger = None[source]
class spateo.external.MERFISHVI.multimodal_spatial_vae.MultiModalSpatialVAE(n_input_spatial: int, n_input_nonspatial: int, n_batch_spatial: int = 0, n_batch_nonspatial: int = 0, n_labels_spatial: int = 0, n_labels_nonspatial: int = 0, n_hidden: int = 128, n_latent: int = 10, n_spatial: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', gene_likelihood: Literal['zinb', 'nb', 'poisson', 'normal'] = 'zinb', latent_distribution: Literal['normal', 'ln'] = 'normal', use_observed_lib_size: bool = True, edge_index: torch.Tensor | None = None, attention_heads: int = 1, spatial_kl_weight: float = 0.01, modality_weights: Dict[str, float] = {'spatial': 1.0, 'nonspatial': 1.0}, cats_per_cov_spatial: List[int] | None = None, cats_per_cov_nonspatial: List[int] | None = None, use_size_factor_spatial: bool = False, use_size_factor_nonspatial: bool = False, library_log_means_spatial: torch.Tensor | None = None, library_log_vars_spatial: torch.Tensor | None = None, library_log_means_nonspatial: torch.Tensor | None = None, library_log_vars_nonspatial: torch.Tensor | None = None, var_eps: float = 0.0001, **kwargs)[source]

Bases: spateo.external.MERFISHVI.scvi_spatial_module.SpatialVAE

Multi-modal spatial variational autoencoder.

Processes both modalities with spatial information and those without spatial information. Uses a shared latent space for joint modeling.

Parameters:
n_input_spatial

Number of input features for spatial modality

n_input_nonspatial

Number of input features for non-spatial modality

n_batch_spatial

Number of batches for spatial modality

n_batch_nonspatial

Number of batches for non-spatial modality

n_labels_spatial

Number of labels for spatial modality

n_labels_nonspatial

Number of labels for non-spatial modality

n_hidden

Number of nodes in hidden layers

n_latent

Dimension of latent space

n_spatial

Dimension of spatial features

n_layers

Number of hidden layers

dropout_rate

Dropout rate

dispersion

Dispersion parameter type

gene_likelihood

Gene likelihood distribution type

latent_distribution

Latent distribution type

**kwargs

Additional parameters

n_spatial = 10[source]
spatial_kl_weight = 0.01[source]
modality_weights[source]
var_eps = 0.0001[source]
spatial_encoder[source]
nonspatial_encoder[source]
edge_index = None[source]
inference_spatial(x: torch.Tensor, batch_index: torch.Tensor | None = None, **kwargs) Dict[str, torch.Tensor][source]

Spatial modality inference process.

Parameters:
x

Input data of spatial modality

batch_index

Batch index

Returns:

Dictionary containing inference results

Return type:

dict

inference_nonspatial(x: torch.Tensor, batch_index: torch.Tensor | None = None, **kwargs) Dict[str, torch.Tensor][source]

Non-spatial modality inference process.

Parameters:
x

Input data of non-spatial modality

batch_index

Batch index

Returns:

Dictionary containing inference results

Return type:

dict

inference(x_spatial: torch.Tensor, x_nonspatial: torch.Tensor | None = None, batch_index_spatial: torch.Tensor | None = None, batch_index_nonspatial: torch.Tensor | None = None, **kwargs) Dict[str, torch.Tensor][source]

Joint inference process, handles both spatial and non-spatial modalities.

Parameters:
x_spatial

Input data of spatial modality

x_nonspatial

Input data of non-spatial modality, optional

batch_index_spatial

Batch index of spatial modality

batch_index_nonspatial

Batch index of non-spatial modality

Returns:

Dictionary containing joint inference results

Return type:

dict

generative_spatial(z: torch.Tensor, batch_index: torch.Tensor | None = None, **kwargs) Dict[str, torch.Tensor][source]

Spatial modality generative process.

Parameters:
z

Latent representation

batch_index

Batch index

Returns:

Generative output

Return type:

dict

generative_nonspatial(z: torch.Tensor, batch_index: torch.Tensor | None = None, library_size: torch.Tensor | None = None, **kwargs) Dict[str, torch.Tensor][source]

Non-spatial modality generative process.

Parameters:
z

Latent representation

batch_index

Batch index

library_size

Library size

Returns:

Generative output

Return type:

dict

generative(z: torch.Tensor, batch_index_spatial: torch.Tensor | None = None, batch_index_nonspatial: torch.Tensor | None = None, library_size_spatial: torch.Tensor | None = None, library_size_nonspatial: torch.Tensor | None = None, **kwargs) Dict[str, Dict[str, torch.Tensor]][source]

Joint generative process.

Parameters:
z

Latent representation

batch_index_spatial

Batch index of spatial modality

batch_index_nonspatial

Batch index of non-spatial modality

library_size_spatial

Library size of spatial modality

library_size_nonspatial

Library size of non-spatial modality

Returns:

Dictionary containing outputs of two modalities generative process

Return type:

dict

forward(tensors: Dict[str, torch.Tensor], inference_kwargs: Dict = None, compute_loss: bool = True, **kwargs) Tuple[source]

Forward propagation process.

Parameters:
tensors

Input tensor dictionary, containing data of two modalities

inference_kwargs

Parameters to pass to inference function

compute_loss

Whether to compute loss

Returns:

Inference output, generative output and loss

Return type:

tuple

loss(tensors: Dict[str, torch.Tensor], inference_outputs: Dict[str, torch.Tensor | torch.distributions.Distribution | None], generative_outputs: Dict[str, Dict[str, torch.distributions.Distribution | None]], kl_weight: torch.tensor | float = 1.0) scvi.module.base.LossOutput[source]

Calculate joint loss function.

Parameters:
tensors

Input tensors

inference_outputs

Inference process outputs

generative_outputs

Generative process outputs

kl_weight

KL divergence weight

Returns:

Loss output object

Return type:

LossOutput

_get_reconstruction_loss_nonspatial(x: torch.Tensor, generative_outputs: Dict[str, torch.Tensor]) torch.Tensor[source]

Calculate non-spatial modality reconstruction loss.

Parameters:
x

Input data

generative_outputs

Generative process outputs

Returns:

Reconstruction loss

Return type:

torch.Tensor

get_latent_representation_by_modality(adata=None, indices=None, batch_size=None, modality='spatial') numpy.ndarray[source]

Get latent representation of specific modality.

Parameters:
adata

AnnData object, optional

indices

Index to get representation, optional

batch_size

Batch processing size, optional

modality

Modality to get, can be “spatial”, “nonspatial” or “fused”

Returns:

Latent representation

Return type:

np.ndarray

get_fused_representation(adata=None, indices=None, batch_size=None) numpy.ndarray[source]

Get fused latent representation.

Parameters:
adata

AnnData object, optional

indices

Index to get representation, optional

batch_size

Batch processing size, optional

Returns:

Fused latent representation

Return type:

np.ndarray

get_nonspatial_specific_features(adata=None, indices=None, batch_size=None) numpy.ndarray[source]

Get non-spatial modality specific features.

Parameters:
adata

AnnData object, optional

indices

Index to get representation, optional

batch_size

Batch processing size, optional

Returns:

Non-spatial modality specific features

Return type:

np.ndarray

_get_inference_input(tensors: Dict[str, torch.Tensor | None], full_forward_pass: bool = False) Dict[str, torch.Tensor | None][source]

Get input tensors required for inference process, override parent method to handle multi-modal data.

Parameters:
tensors

Input data tensors

full_forward_pass

Whether to execute full forward propagation

Returns:

Input dictionary for inference process

Return type:

Dict

_get_generative_input(tensors: Dict[str, torch.Tensor], inference_outputs: Dict[str, torch.Tensor | torch.distributions.Distribution | None]) Dict[str, torch.Tensor | None][source]

Get input tensors required for generative process, override parent method to handle multi-modal data.

Parameters:
tensors

Input data tensors

inference_outputs

Outputs from inference process

Returns:

Input dictionary for generative process

Return type:

Dict

spateo.external.MERFISHVI.multimodal_spatial_vae.log_zinb_positive(x, mu, theta, pi, eps=1e-08)[source]

Log likelihood of zero-inflated negative binomial distribution.

spateo.external.MERFISHVI.multimodal_spatial_vae.log_nb_positive(x, mu, theta, eps=1e-08)[source]

Log likelihood of negative binomial distribution.

spateo.external.MERFISHVI.multimodal_spatial_vae.log_poisson(x, mu, eps=1e-08)[source]

Log likelihood of Poisson distribution.

spateo.external.MERFISHVI.multimodal_spatial_vae.log_normal(x, mu, var, eps=1e-08)[source]

Log likelihood of normal distribution.