spateo.external.MERFISHVI.multimodal_spatial_vae ================================================ .. py:module:: spateo.external.MERFISHVI.multimodal_spatial_vae Attributes ---------- .. autoapisummary:: spateo.external.MERFISHVI.multimodal_spatial_vae.logger Classes ------- .. autoapisummary:: spateo.external.MERFISHVI.multimodal_spatial_vae.MultiModalSpatialVAE Functions --------- .. autoapisummary:: spateo.external.MERFISHVI.multimodal_spatial_vae.unsupported_if_adata_minified spateo.external.MERFISHVI.multimodal_spatial_vae.log_zinb_positive spateo.external.MERFISHVI.multimodal_spatial_vae.log_nb_positive spateo.external.MERFISHVI.multimodal_spatial_vae.log_poisson spateo.external.MERFISHVI.multimodal_spatial_vae.log_normal Module Contents --------------- .. py:function:: unsupported_if_adata_minified(fn) .. py:data:: logger :value: None .. py:class:: MultiModalSpatialVAE(n_input_spatial: int, n_input_nonspatial: int, n_batch_spatial: int = 0, n_batch_nonspatial: int = 0, n_labels_spatial: int = 0, n_labels_nonspatial: int = 0, n_hidden: int = 128, n_latent: int = 10, n_spatial: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', gene_likelihood: Literal['zinb', 'nb', 'poisson', 'normal'] = 'zinb', latent_distribution: Literal['normal', 'ln'] = 'normal', use_observed_lib_size: bool = True, edge_index: Optional[torch.Tensor] = None, attention_heads: int = 1, spatial_kl_weight: float = 0.01, modality_weights: Dict[str, float] = {'spatial': 1.0, 'nonspatial': 1.0}, cats_per_cov_spatial: Optional[List[int]] = None, cats_per_cov_nonspatial: Optional[List[int]] = None, use_size_factor_spatial: bool = False, use_size_factor_nonspatial: bool = False, library_log_means_spatial: Optional[torch.Tensor] = None, library_log_vars_spatial: Optional[torch.Tensor] = None, library_log_means_nonspatial: Optional[torch.Tensor] = None, library_log_vars_nonspatial: Optional[torch.Tensor] = None, var_eps: float = 0.0001, **kwargs) Bases: :py:obj:`spateo.external.MERFISHVI.scvi_spatial_module.SpatialVAE` Multi-modal spatial variational autoencoder. Processes both modalities with spatial information and those without spatial information. Uses a shared latent space for joint modeling. :param n_input_spatial: Number of input features for spatial modality :param n_input_nonspatial: Number of input features for non-spatial modality :param n_batch_spatial: Number of batches for spatial modality :param n_batch_nonspatial: Number of batches for non-spatial modality :param n_labels_spatial: Number of labels for spatial modality :param n_labels_nonspatial: Number of labels for non-spatial modality :param n_hidden: Number of nodes in hidden layers :param n_latent: Dimension of latent space :param n_spatial: Dimension of spatial features :param n_layers: Number of hidden layers :param dropout_rate: Dropout rate :param dispersion: Dispersion parameter type :param gene_likelihood: Gene likelihood distribution type :param latent_distribution: Latent distribution type :param \*\*kwargs: Additional parameters .. py:attribute:: n_spatial :value: 10 .. py:attribute:: spatial_kl_weight :value: 0.01 .. py:attribute:: modality_weights .. py:attribute:: var_eps :value: 0.0001 .. py:attribute:: spatial_encoder .. py:attribute:: nonspatial_encoder .. py:attribute:: edge_index :value: None .. py:method:: inference_spatial(x: torch.Tensor, batch_index: Optional[torch.Tensor] = None, **kwargs) -> Dict[str, torch.Tensor] Spatial modality inference process. :param x: Input data of spatial modality :param batch_index: Batch index :returns: Dictionary containing inference results :rtype: dict .. py:method:: inference_nonspatial(x: torch.Tensor, batch_index: Optional[torch.Tensor] = None, **kwargs) -> Dict[str, torch.Tensor] Non-spatial modality inference process. :param x: Input data of non-spatial modality :param batch_index: Batch index :returns: Dictionary containing inference results :rtype: dict .. py:method:: inference(x_spatial: torch.Tensor, x_nonspatial: Optional[torch.Tensor] = None, batch_index_spatial: Optional[torch.Tensor] = None, batch_index_nonspatial: Optional[torch.Tensor] = None, **kwargs) -> Dict[str, torch.Tensor] Joint inference process, handles both spatial and non-spatial modalities. :param x_spatial: Input data of spatial modality :param x_nonspatial: Input data of non-spatial modality, optional :param batch_index_spatial: Batch index of spatial modality :param batch_index_nonspatial: Batch index of non-spatial modality :returns: Dictionary containing joint inference results :rtype: dict .. py:method:: generative_spatial(z: torch.Tensor, batch_index: Optional[torch.Tensor] = None, **kwargs) -> Dict[str, torch.Tensor] Spatial modality generative process. :param z: Latent representation :param batch_index: Batch index :returns: Generative output :rtype: dict .. py:method:: generative_nonspatial(z: torch.Tensor, batch_index: Optional[torch.Tensor] = None, library_size: Optional[torch.Tensor] = None, **kwargs) -> Dict[str, torch.Tensor] Non-spatial modality generative process. :param z: Latent representation :param batch_index: Batch index :param library_size: Library size :returns: Generative output :rtype: dict .. py:method:: generative(z: torch.Tensor, batch_index_spatial: Optional[torch.Tensor] = None, batch_index_nonspatial: Optional[torch.Tensor] = None, library_size_spatial: Optional[torch.Tensor] = None, library_size_nonspatial: Optional[torch.Tensor] = None, **kwargs) -> Dict[str, Dict[str, torch.Tensor]] Joint generative process. :param z: Latent representation :param batch_index_spatial: Batch index of spatial modality :param batch_index_nonspatial: Batch index of non-spatial modality :param library_size_spatial: Library size of spatial modality :param library_size_nonspatial: Library size of non-spatial modality :returns: Dictionary containing outputs of two modalities generative process :rtype: dict .. py:method:: forward(tensors: Dict[str, torch.Tensor], inference_kwargs: Dict = None, compute_loss: bool = True, **kwargs) -> Tuple Forward propagation process. :param tensors: Input tensor dictionary, containing data of two modalities :param inference_kwargs: Parameters to pass to inference function :param compute_loss: Whether to compute loss :returns: Inference output, generative output and loss :rtype: tuple .. py:method:: loss(tensors: Dict[str, torch.Tensor], inference_outputs: Dict[str, torch.Tensor | torch.distributions.Distribution | None], generative_outputs: Dict[str, Dict[str, torch.distributions.Distribution | None]], kl_weight: torch.tensor | float = 1.0) -> scvi.module.base.LossOutput Calculate joint loss function. :param tensors: Input tensors :param inference_outputs: Inference process outputs :param generative_outputs: Generative process outputs :param kl_weight: KL divergence weight :returns: Loss output object :rtype: LossOutput .. py:method:: _get_reconstruction_loss_nonspatial(x: torch.Tensor, generative_outputs: Dict[str, torch.Tensor]) -> torch.Tensor Calculate non-spatial modality reconstruction loss. :param x: Input data :param generative_outputs: Generative process outputs :returns: Reconstruction loss :rtype: torch.Tensor .. py:method:: get_latent_representation_by_modality(adata=None, indices=None, batch_size=None, modality='spatial') -> numpy.ndarray Get latent representation of specific modality. :param adata: AnnData object, optional :param indices: Index to get representation, optional :param batch_size: Batch processing size, optional :param modality: Modality to get, can be "spatial", "nonspatial" or "fused" :returns: Latent representation :rtype: np.ndarray .. py:method:: get_fused_representation(adata=None, indices=None, batch_size=None) -> numpy.ndarray Get fused latent representation. :param adata: AnnData object, optional :param indices: Index to get representation, optional :param batch_size: Batch processing size, optional :returns: Fused latent representation :rtype: np.ndarray .. py:method:: get_nonspatial_specific_features(adata=None, indices=None, batch_size=None) -> numpy.ndarray Get non-spatial modality specific features. :param adata: AnnData object, optional :param indices: Index to get representation, optional :param batch_size: Batch processing size, optional :returns: Non-spatial modality specific features :rtype: np.ndarray .. py:method:: _get_inference_input(tensors: Dict[str, torch.Tensor | None], full_forward_pass: bool = False) -> Dict[str, torch.Tensor | None] Get input tensors required for inference process, override parent method to handle multi-modal data. :param tensors: Input data tensors :param full_forward_pass: Whether to execute full forward propagation :returns: Input dictionary for inference process :rtype: Dict .. py:method:: _get_generative_input(tensors: Dict[str, torch.Tensor], inference_outputs: Dict[str, torch.Tensor | torch.distributions.Distribution | None]) -> Dict[str, torch.Tensor | None] Get input tensors required for generative process, override parent method to handle multi-modal data. :param tensors: Input data tensors :param inference_outputs: Outputs from inference process :returns: Input dictionary for generative process :rtype: Dict .. py:function:: log_zinb_positive(x, mu, theta, pi, eps=1e-08) Log likelihood of zero-inflated negative binomial distribution. .. py:function:: log_nb_positive(x, mu, theta, eps=1e-08) Log likelihood of negative binomial distribution. .. py:function:: log_poisson(x, mu, eps=1e-08) Log likelihood of Poisson distribution. .. py:function:: log_normal(x, mu, var, eps=1e-08) Log likelihood of normal distribution.