spateo.external.MERFISHVI.scvi_spatial_module ============================================= .. py:module:: spateo.external.MERFISHVI.scvi_spatial_module Attributes ---------- .. autoapisummary:: spateo.external.MERFISHVI.scvi_spatial_module.logger Classes ------- .. autoapisummary:: spateo.external.MERFISHVI.scvi_spatial_module.AnnTorchDataset spateo.external.MERFISHVI.scvi_spatial_module.SpatialEncoder spateo.external.MERFISHVI.scvi_spatial_module.SpatialVAE Functions --------- .. autoapisummary:: spateo.external.MERFISHVI.scvi_spatial_module.unsupported_if_adata_minified Module Contents --------------- .. py:function:: unsupported_if_adata_minified(fn) .. py:class:: AnnTorchDataset(*args, **kwargs) .. py:data:: logger :value: None .. py:class:: SpatialEncoder(n_latent: int, n_spatial: int, attention_heads: int = 1, dropout_rate: float = 0.1, var_eps: float = 0.0001) Bases: :py:obj:`torch.nn.Module` Spatial encoder that uses graph attention networks to process spatial information. Applies graph attention network to latent representations to obtain spatial features. :param n_latent: Dimension of the latent space :param n_spatial: Dimension of the spatial features :param attention_heads: Number of attention heads :param dropout_rate: Dropout ratio :param var_eps: Minimum value for variance to ensure numerical stability .. py:attribute:: gat .. py:attribute:: mean_encoder .. py:attribute:: var_encoder .. py:attribute:: n_spatial .. py:attribute:: var_eps :value: 0.0001 .. py:method:: forward(z: torch.Tensor, edge_index: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] Forward pass, calculate spatial feature distribution. :param z: Latent representation, shape [batch_size, n_latent] :param edge_index: Graph edge indices, shape [2, num_edges] :returns: Mean, variance and sampled value of the spatial feature distribution :rtype: tuple .. py:class:: SpatialVAE(n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent: int = 10, n_spatial: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', gene_likelihood: Literal['zinb', 'nb', 'poisson', 'normal'] = 'zinb', latent_distribution: Literal['normal', 'ln'] = 'normal', edge_index: Optional[torch.Tensor] = None, attention_heads: int = 1, spatial_kl_weight: float = 0.01, var_eps: float = 0.0001, **kwargs) Bases: :py:obj:`scvi.module.VAE` Variational autoencoder with spatial information support. Extends standard VAE to include spatial information processing. Uses graph attention networks to capture spatial relationships between cells. :param n_input: Number of input features :param n_batch: Number of batches :param n_labels: Number of labels :param n_hidden: Number of nodes in hidden layers :param n_latent: Dimension of latent space :param n_spatial: Dimension of spatial features :param n_layers: Number of hidden layers :param dropout_rate: Dropout rate :param dispersion: Dispersion parameter type :param gene_likelihood: Gene likelihood distribution type :param latent_distribution: Latent distribution type :param \*\*kwargs: Additional parameters .. py:attribute:: n_spatial :value: 10 .. py:attribute:: spatial_kl_weight :value: 0.01 .. py:attribute:: spatial_encoder .. py:attribute:: edge_index :value: None .. py:method:: inference(x: torch.Tensor, batch_index: torch.Tensor, cont_covs: torch.Tensor | None = None, cat_covs: torch.Tensor | None = None, cont_covariates: torch.Tensor | None = None, cat_covariates: torch.Tensor | None = None, **kwargs) -> dict[str, torch.Tensor] Inference process, computes latent representation and spatial features. :param x: Input data :param batch_index: Batch indices :param cont_covs: Continuous covariates (VAE parameter naming) :param cat_covs: Categorical covariates (VAE parameter naming) :param cont_covariates: Continuous covariates (compatible format) :param cat_covariates: Categorical covariates (compatible format) :returns: Dictionary containing latent representation and spatial features :rtype: dict .. py:method:: forward(tensors, inference_kwargs=None, compute_loss=True, **kwargs) Forward pass process. :param tensors: Input tensor dictionary :param inference_kwargs: Parameters passed to inference function :param compute_loss: Whether to compute loss :returns: Inference outputs, generative outputs and loss :rtype: tuple .. py:method:: loss(tensors: dict[str, torch.Tensor], inference_outputs: dict[str, torch.Tensor | torch.distributions.Distribution | None], generative_outputs: dict[str, torch.distributions.Distribution | None], kl_weight: torch.tensor | float = 1.0) -> scvi.module.base.LossOutput Calculate loss function, including KL divergence of spatial features. :param tensors: Input tensors :param inference_outputs: Inference process outputs :param generative_outputs: Generative process outputs :param kl_weight: KL divergence weight :returns: Loss output object :rtype: LossOutput .. py:method:: get_latent_representation(adata, indices, batch_size) .. py:method:: get_spatial_representation(adata=None, indices=None, batch_size=None) -> numpy.ndarray Get spatial feature representation. :param adata: AnnData object, optional :param indices: Indices to get representation for, optional :param batch_size: Batch size, optional :returns: Spatial feature representation :rtype: np.ndarray .. py:method:: _get_inference_input(tensors: dict[str, torch.Tensor | None], full_forward_pass: bool = False) -> dict[str, torch.Tensor | None] Get tensors needed for inference process, overrides parent method to fix parameter name mismatch issues. :param tensors: Input data tensors :param full_forward_pass: Whether to perform full forward pass :returns: Dictionary of inputs for inference process .. py:method:: _get_generative_input(tensors: dict[str, torch.Tensor], inference_outputs: dict[str, torch.Tensor | torch.distributions.Distribution | None]) -> dict[str, torch.Tensor | None] Get tensors for generative process, overrides parent method to fix parameter name mismatch issues. :param tensors: Original data tensors :param inference_outputs: Outputs from inference process :returns: Dictionary of inputs needed for generative process .. py:method:: setup_spatial_graph(adata: anndata.AnnData, spatial_key: str = 'spatial', batch_key: Optional[str] = None, method: str = 'knn', n_neighbors: int = 10) Set up spatial graph for spatial information processing. Constructs a spatial graph based on spatial coordinates in adata.obsm[spatial_key], using either K-nearest neighbors or Delaunay triangulation. :param adata: AnnData object containing spatial coordinates in adata.obsm[spatial_key] :param spatial_key: obsm key storing spatial coordinates, default is 'spatial' :param batch_key: obs key for batch information, if provided, graph will be constructed per batch :param method: Method for constructing the graph, can be 'knn' or 'delaunay' :param n_neighbors: Number of neighbors for KNN method .. py:method:: process_in_batches(edge_index: torch.Tensor, max_edges_per_batch: int = 100000, combine_results: bool = True, adata: Optional[anndata.AnnData] = None) Process edge indices in batches, suitable for processing large graph structures. Divides edge_index into smaller batches for processing to avoid memory overflow errors. :param edge_index: Edge index tensor, shape [2, num_edges] :param max_edges_per_batch: Maximum number of edges per batch :param combine_results: Whether to combine results from all batches :param adata: AnnData object, if None uses the AnnData object from training :returns: If combine_results is True, returns combined result dictionary; otherwise returns a list of results for each batch :rtype: dict or list .. py:method:: process_edges(edge_index: torch.Tensor, adata: Optional[anndata.AnnData] = None) Process a single batch of edge indices. This is a utility method for processing a single batch of edge indices in batch processing. :param edge_index: Edge index tensor, shape [2, num_edges] :param adata: AnnData object, if None uses the AnnData object from training :returns: Processing results :rtype: dict