spateo.external.MERFISHVI.scvi_spatial_module¶
Attributes¶
Classes¶
Spatial encoder that uses graph attention networks to process spatial information. |
|
Variational autoencoder with spatial information support. |
Functions¶
Module Contents¶
- class spateo.external.MERFISHVI.scvi_spatial_module.SpatialEncoder(n_latent: int, n_spatial: int, attention_heads: int = 1, dropout_rate: float = 0.1, var_eps: float = 0.0001)[source]¶
Bases:
torch.nn.ModuleSpatial encoder that uses graph attention networks to process spatial information.
Applies graph attention network to latent representations to obtain spatial features.
- Parameters:
- n_latent
Dimension of the latent space
- n_spatial
Dimension of the spatial features
- attention_heads
Number of attention heads
- dropout_rate
Dropout ratio
- var_eps
Minimum value for variance to ensure numerical stability
- forward(z: torch.Tensor, edge_index: torch.Tensor) tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]¶
Forward pass, calculate spatial feature distribution.
- Parameters:
- z
Latent representation, shape [batch_size, n_latent]
- edge_index
Graph edge indices, shape [2, num_edges]
- Returns:
Mean, variance and sampled value of the spatial feature distribution
- Return type:
- class spateo.external.MERFISHVI.scvi_spatial_module.SpatialVAE(n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent: int = 10, n_spatial: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', gene_likelihood: Literal['zinb', 'nb', 'poisson', 'normal'] = 'zinb', latent_distribution: Literal['normal', 'ln'] = 'normal', edge_index: torch.Tensor | None = None, attention_heads: int = 1, spatial_kl_weight: float = 0.01, var_eps: float = 0.0001, **kwargs)[source]¶
Bases:
scvi.module.VAEVariational autoencoder with spatial information support.
Extends standard VAE to include spatial information processing. Uses graph attention networks to capture spatial relationships between cells.
- Parameters:
- n_input
Number of input features
- n_batch
Number of batches
- n_labels
Number of labels
- n_hidden
Number of nodes in hidden layers
- n_latent
Dimension of latent space
- n_spatial
Dimension of spatial features
- n_layers
Number of hidden layers
- dropout_rate
Dropout rate
- dispersion
Dispersion parameter type
- gene_likelihood
Gene likelihood distribution type
- latent_distribution
Latent distribution type
- **kwargs
Additional parameters
- inference(x: torch.Tensor, batch_index: torch.Tensor, cont_covs: torch.Tensor | None = None, cat_covs: torch.Tensor | None = None, cont_covariates: torch.Tensor | None = None, cat_covariates: torch.Tensor | None = None, **kwargs) dict[str, torch.Tensor][source]¶
Inference process, computes latent representation and spatial features.
- Parameters:
- x
Input data
- batch_index
Batch indices
- cont_covs
Continuous covariates (VAE parameter naming)
- cat_covs
Categorical covariates (VAE parameter naming)
- cont_covariates
Continuous covariates (compatible format)
- cat_covariates
Categorical covariates (compatible format)
- Returns:
Dictionary containing latent representation and spatial features
- Return type:
- forward(tensors, inference_kwargs=None, compute_loss=True, **kwargs)[source]¶
Forward pass process.
- Parameters:
- tensors
Input tensor dictionary
- inference_kwargs
Parameters passed to inference function
- compute_loss
Whether to compute loss
- Returns:
Inference outputs, generative outputs and loss
- Return type:
- loss(tensors: dict[str, torch.Tensor], inference_outputs: dict[str, torch.Tensor | torch.distributions.Distribution | None], generative_outputs: dict[str, torch.distributions.Distribution | None], kl_weight: torch.tensor | float = 1.0) scvi.module.base.LossOutput[source]¶
Calculate loss function, including KL divergence of spatial features.
- Parameters:
- tensors
Input tensors
- inference_outputs
Inference process outputs
- generative_outputs
Generative process outputs
- kl_weight
KL divergence weight
- Returns:
Loss output object
- Return type:
LossOutput
- get_spatial_representation(adata=None, indices=None, batch_size=None) numpy.ndarray[source]¶
Get spatial feature representation.
- Parameters:
- adata
AnnData object, optional
- indices
Indices to get representation for, optional
- batch_size
Batch size, optional
- Returns:
Spatial feature representation
- Return type:
np.ndarray
- _get_inference_input(tensors: dict[str, torch.Tensor | None], full_forward_pass: bool = False) dict[str, torch.Tensor | None][source]¶
Get tensors needed for inference process, overrides parent method to fix parameter name mismatch issues.
- Parameters:
- tensors
Input data tensors
- full_forward_pass
Whether to perform full forward pass
- Returns:
Dictionary of inputs for inference process
- _get_generative_input(tensors: dict[str, torch.Tensor], inference_outputs: dict[str, torch.Tensor | torch.distributions.Distribution | None]) dict[str, torch.Tensor | None][source]¶
Get tensors for generative process, overrides parent method to fix parameter name mismatch issues.
- Parameters:
- tensors
Original data tensors
- inference_outputs
Outputs from inference process
- Returns:
Dictionary of inputs needed for generative process
- setup_spatial_graph(adata: anndata.AnnData, spatial_key: str = 'spatial', batch_key: str | None = None, method: str = 'knn', n_neighbors: int = 10)[source]¶
Set up spatial graph for spatial information processing.
Constructs a spatial graph based on spatial coordinates in adata.obsm[spatial_key], using either K-nearest neighbors or Delaunay triangulation.
- Parameters:
- adata
AnnData object containing spatial coordinates in adata.obsm[spatial_key]
- spatial_key
obsm key storing spatial coordinates, default is ‘spatial’
- batch_key
obs key for batch information, if provided, graph will be constructed per batch
- method
Method for constructing the graph, can be ‘knn’ or ‘delaunay’
- n_neighbors
Number of neighbors for KNN method
- process_in_batches(edge_index: torch.Tensor, max_edges_per_batch: int = 100000, combine_results: bool = True, adata: anndata.AnnData | None = None)[source]¶
Process edge indices in batches, suitable for processing large graph structures.
Divides edge_index into smaller batches for processing to avoid memory overflow errors.
- Parameters:
- edge_index
Edge index tensor, shape [2, num_edges]
- max_edges_per_batch
Maximum number of edges per batch
- combine_results
Whether to combine results from all batches
- adata
AnnData object, if None uses the AnnData object from training
- Returns:
If combine_results is True, returns combined result dictionary; otherwise returns a list of results for each batch
- Return type:
- process_edges(edge_index: torch.Tensor, adata: anndata.AnnData | None = None)[source]¶
Process a single batch of edge indices.
This is a utility method for processing a single batch of edge indices in batch processing.
- Parameters:
- edge_index
Edge index tensor, shape [2, num_edges]
- adata
AnnData object, if None uses the AnnData object from training
- Returns:
Processing results
- Return type: