spateo.external.MERFISHVI.multimodal_spatial_vae¶
Attributes¶
Classes¶
Multi-modal spatial variational autoencoder. |
Functions¶
|
Log likelihood of zero-inflated negative binomial distribution. |
|
Log likelihood of negative binomial distribution. |
|
Log likelihood of Poisson distribution. |
|
Log likelihood of normal distribution. |
Module Contents¶
- class spateo.external.MERFISHVI.multimodal_spatial_vae.MultiModalSpatialVAE(n_input_spatial: int, n_input_nonspatial: int, n_batch_spatial: int = 0, n_batch_nonspatial: int = 0, n_labels_spatial: int = 0, n_labels_nonspatial: int = 0, n_hidden: int = 128, n_latent: int = 10, n_spatial: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', gene_likelihood: Literal['zinb', 'nb', 'poisson', 'normal'] = 'zinb', latent_distribution: Literal['normal', 'ln'] = 'normal', use_observed_lib_size: bool = True, edge_index: torch.Tensor | None = None, attention_heads: int = 1, spatial_kl_weight: float = 0.01, modality_weights: Dict[str, float] = {'spatial': 1.0, 'nonspatial': 1.0}, cats_per_cov_spatial: List[int] | None = None, cats_per_cov_nonspatial: List[int] | None = None, use_size_factor_spatial: bool = False, use_size_factor_nonspatial: bool = False, library_log_means_spatial: torch.Tensor | None = None, library_log_vars_spatial: torch.Tensor | None = None, library_log_means_nonspatial: torch.Tensor | None = None, library_log_vars_nonspatial: torch.Tensor | None = None, var_eps: float = 0.0001, **kwargs)[source]¶
Bases:
spateo.external.MERFISHVI.scvi_spatial_module.SpatialVAEMulti-modal spatial variational autoencoder.
Processes both modalities with spatial information and those without spatial information. Uses a shared latent space for joint modeling.
- Parameters:
- n_input_spatial
Number of input features for spatial modality
- n_input_nonspatial
Number of input features for non-spatial modality
- n_batch_spatial
Number of batches for spatial modality
- n_batch_nonspatial
Number of batches for non-spatial modality
- n_labels_spatial
Number of labels for spatial modality
- n_labels_nonspatial
Number of labels for non-spatial modality
- n_hidden
Number of nodes in hidden layers
- n_latent
Dimension of latent space
- n_spatial
Dimension of spatial features
- n_layers
Number of hidden layers
- dropout_rate
Dropout rate
- dispersion
Dispersion parameter type
- gene_likelihood
Gene likelihood distribution type
- latent_distribution
Latent distribution type
- **kwargs
Additional parameters
- inference_spatial(x: torch.Tensor, batch_index: torch.Tensor | None = None, **kwargs) Dict[str, torch.Tensor][source]¶
Spatial modality inference process.
- Parameters:
- x
Input data of spatial modality
- batch_index
Batch index
- Returns:
Dictionary containing inference results
- Return type:
- inference_nonspatial(x: torch.Tensor, batch_index: torch.Tensor | None = None, **kwargs) Dict[str, torch.Tensor][source]¶
Non-spatial modality inference process.
- Parameters:
- x
Input data of non-spatial modality
- batch_index
Batch index
- Returns:
Dictionary containing inference results
- Return type:
- inference(x_spatial: torch.Tensor, x_nonspatial: torch.Tensor | None = None, batch_index_spatial: torch.Tensor | None = None, batch_index_nonspatial: torch.Tensor | None = None, **kwargs) Dict[str, torch.Tensor][source]¶
Joint inference process, handles both spatial and non-spatial modalities.
- Parameters:
- x_spatial
Input data of spatial modality
- x_nonspatial
Input data of non-spatial modality, optional
- batch_index_spatial
Batch index of spatial modality
- batch_index_nonspatial
Batch index of non-spatial modality
- Returns:
Dictionary containing joint inference results
- Return type:
- generative_spatial(z: torch.Tensor, batch_index: torch.Tensor | None = None, **kwargs) Dict[str, torch.Tensor][source]¶
Spatial modality generative process.
- Parameters:
- z
Latent representation
- batch_index
Batch index
- Returns:
Generative output
- Return type:
- generative_nonspatial(z: torch.Tensor, batch_index: torch.Tensor | None = None, library_size: torch.Tensor | None = None, **kwargs) Dict[str, torch.Tensor][source]¶
Non-spatial modality generative process.
- Parameters:
- z
Latent representation
- batch_index
Batch index
- library_size
Library size
- Returns:
Generative output
- Return type:
- generative(z: torch.Tensor, batch_index_spatial: torch.Tensor | None = None, batch_index_nonspatial: torch.Tensor | None = None, library_size_spatial: torch.Tensor | None = None, library_size_nonspatial: torch.Tensor | None = None, **kwargs) Dict[str, Dict[str, torch.Tensor]][source]¶
Joint generative process.
- Parameters:
- z
Latent representation
- batch_index_spatial
Batch index of spatial modality
- batch_index_nonspatial
Batch index of non-spatial modality
- library_size_spatial
Library size of spatial modality
- library_size_nonspatial
Library size of non-spatial modality
- Returns:
Dictionary containing outputs of two modalities generative process
- Return type:
- forward(tensors: Dict[str, torch.Tensor], inference_kwargs: Dict = None, compute_loss: bool = True, **kwargs) Tuple[source]¶
Forward propagation process.
- Parameters:
- tensors
Input tensor dictionary, containing data of two modalities
- inference_kwargs
Parameters to pass to inference function
- compute_loss
Whether to compute loss
- Returns:
Inference output, generative output and loss
- Return type:
- loss(tensors: Dict[str, torch.Tensor], inference_outputs: Dict[str, torch.Tensor | torch.distributions.Distribution | None], generative_outputs: Dict[str, Dict[str, torch.distributions.Distribution | None]], kl_weight: torch.tensor | float = 1.0) scvi.module.base.LossOutput[source]¶
Calculate joint loss function.
- Parameters:
- tensors
Input tensors
- inference_outputs
Inference process outputs
- generative_outputs
Generative process outputs
- kl_weight
KL divergence weight
- Returns:
Loss output object
- Return type:
LossOutput
- _get_reconstruction_loss_nonspatial(x: torch.Tensor, generative_outputs: Dict[str, torch.Tensor]) torch.Tensor[source]¶
Calculate non-spatial modality reconstruction loss.
- Parameters:
- x
Input data
- generative_outputs
Generative process outputs
- Returns:
Reconstruction loss
- Return type:
torch.Tensor
- get_latent_representation_by_modality(adata=None, indices=None, batch_size=None, modality='spatial') numpy.ndarray[source]¶
Get latent representation of specific modality.
- Parameters:
- adata
AnnData object, optional
- indices
Index to get representation, optional
- batch_size
Batch processing size, optional
- modality
Modality to get, can be “spatial”, “nonspatial” or “fused”
- Returns:
Latent representation
- Return type:
np.ndarray
- get_fused_representation(adata=None, indices=None, batch_size=None) numpy.ndarray[source]¶
Get fused latent representation.
- Parameters:
- adata
AnnData object, optional
- indices
Index to get representation, optional
- batch_size
Batch processing size, optional
- Returns:
Fused latent representation
- Return type:
np.ndarray
- get_nonspatial_specific_features(adata=None, indices=None, batch_size=None) numpy.ndarray[source]¶
Get non-spatial modality specific features.
- Parameters:
- adata
AnnData object, optional
- indices
Index to get representation, optional
- batch_size
Batch processing size, optional
- Returns:
Non-spatial modality specific features
- Return type:
np.ndarray
- _get_inference_input(tensors: Dict[str, torch.Tensor | None], full_forward_pass: bool = False) Dict[str, torch.Tensor | None][source]¶
Get input tensors required for inference process, override parent method to handle multi-modal data.
- Parameters:
- tensors
Input data tensors
- full_forward_pass
Whether to execute full forward propagation
- Returns:
Input dictionary for inference process
- Return type:
Dict
- _get_generative_input(tensors: Dict[str, torch.Tensor], inference_outputs: Dict[str, torch.Tensor | torch.distributions.Distribution | None]) Dict[str, torch.Tensor | None][source]¶
Get input tensors required for generative process, override parent method to handle multi-modal data.
- Parameters:
- tensors
Input data tensors
- inference_outputs
Outputs from inference process
- Returns:
Input dictionary for generative process
- Return type:
Dict
- spateo.external.MERFISHVI.multimodal_spatial_vae.log_zinb_positive(x, mu, theta, pi, eps=1e-08)[source]¶
Log likelihood of zero-inflated negative binomial distribution.
- spateo.external.MERFISHVI.multimodal_spatial_vae.log_nb_positive(x, mu, theta, eps=1e-08)[source]¶
Log likelihood of negative binomial distribution.