spateo.external.MERFISHVI.scvi_spatial_module

Attributes

Classes

AnnTorchDataset

SpatialEncoder

Spatial encoder that uses graph attention networks to process spatial information.

SpatialVAE

Variational autoencoder with spatial information support.

Functions

Module Contents

spateo.external.MERFISHVI.scvi_spatial_module.unsupported_if_adata_minified(fn)[source]
class spateo.external.MERFISHVI.scvi_spatial_module.AnnTorchDataset(*args, **kwargs)[source]
spateo.external.MERFISHVI.scvi_spatial_module.logger = None[source]
class spateo.external.MERFISHVI.scvi_spatial_module.SpatialEncoder(n_latent: int, n_spatial: int, attention_heads: int = 1, dropout_rate: float = 0.1, var_eps: float = 0.0001)[source]

Bases: torch.nn.Module

Spatial encoder that uses graph attention networks to process spatial information.

Applies graph attention network to latent representations to obtain spatial features.

Parameters:
n_latent

Dimension of the latent space

n_spatial

Dimension of the spatial features

attention_heads

Number of attention heads

dropout_rate

Dropout ratio

var_eps

Minimum value for variance to ensure numerical stability

gat[source]
mean_encoder[source]
var_encoder[source]
n_spatial[source]
var_eps = 0.0001[source]
forward(z: torch.Tensor, edge_index: torch.Tensor) tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]

Forward pass, calculate spatial feature distribution.

Parameters:
z

Latent representation, shape [batch_size, n_latent]

edge_index

Graph edge indices, shape [2, num_edges]

Returns:

Mean, variance and sampled value of the spatial feature distribution

Return type:

tuple

class spateo.external.MERFISHVI.scvi_spatial_module.SpatialVAE(n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent: int = 10, n_spatial: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', gene_likelihood: Literal['zinb', 'nb', 'poisson', 'normal'] = 'zinb', latent_distribution: Literal['normal', 'ln'] = 'normal', edge_index: torch.Tensor | None = None, attention_heads: int = 1, spatial_kl_weight: float = 0.01, var_eps: float = 0.0001, **kwargs)[source]

Bases: scvi.module.VAE

Variational autoencoder with spatial information support.

Extends standard VAE to include spatial information processing. Uses graph attention networks to capture spatial relationships between cells.

Parameters:
n_input

Number of input features

n_batch

Number of batches

n_labels

Number of labels

n_hidden

Number of nodes in hidden layers

n_latent

Dimension of latent space

n_spatial

Dimension of spatial features

n_layers

Number of hidden layers

dropout_rate

Dropout rate

dispersion

Dispersion parameter type

gene_likelihood

Gene likelihood distribution type

latent_distribution

Latent distribution type

**kwargs

Additional parameters

n_spatial = 10[source]
spatial_kl_weight = 0.01[source]
spatial_encoder[source]
edge_index = None[source]
inference(x: torch.Tensor, batch_index: torch.Tensor, cont_covs: torch.Tensor | None = None, cat_covs: torch.Tensor | None = None, cont_covariates: torch.Tensor | None = None, cat_covariates: torch.Tensor | None = None, **kwargs) dict[str, torch.Tensor][source]

Inference process, computes latent representation and spatial features.

Parameters:
x

Input data

batch_index

Batch indices

cont_covs

Continuous covariates (VAE parameter naming)

cat_covs

Categorical covariates (VAE parameter naming)

cont_covariates

Continuous covariates (compatible format)

cat_covariates

Categorical covariates (compatible format)

Returns:

Dictionary containing latent representation and spatial features

Return type:

dict

forward(tensors, inference_kwargs=None, compute_loss=True, **kwargs)[source]

Forward pass process.

Parameters:
tensors

Input tensor dictionary

inference_kwargs

Parameters passed to inference function

compute_loss

Whether to compute loss

Returns:

Inference outputs, generative outputs and loss

Return type:

tuple

loss(tensors: dict[str, torch.Tensor], inference_outputs: dict[str, torch.Tensor | torch.distributions.Distribution | None], generative_outputs: dict[str, torch.distributions.Distribution | None], kl_weight: torch.tensor | float = 1.0) scvi.module.base.LossOutput[source]

Calculate loss function, including KL divergence of spatial features.

Parameters:
tensors

Input tensors

inference_outputs

Inference process outputs

generative_outputs

Generative process outputs

kl_weight

KL divergence weight

Returns:

Loss output object

Return type:

LossOutput

get_latent_representation(adata, indices, batch_size)[source]
get_spatial_representation(adata=None, indices=None, batch_size=None) numpy.ndarray[source]

Get spatial feature representation.

Parameters:
adata

AnnData object, optional

indices

Indices to get representation for, optional

batch_size

Batch size, optional

Returns:

Spatial feature representation

Return type:

np.ndarray

_get_inference_input(tensors: dict[str, torch.Tensor | None], full_forward_pass: bool = False) dict[str, torch.Tensor | None][source]

Get tensors needed for inference process, overrides parent method to fix parameter name mismatch issues.

Parameters:
tensors

Input data tensors

full_forward_pass

Whether to perform full forward pass

Returns:

Dictionary of inputs for inference process

_get_generative_input(tensors: dict[str, torch.Tensor], inference_outputs: dict[str, torch.Tensor | torch.distributions.Distribution | None]) dict[str, torch.Tensor | None][source]

Get tensors for generative process, overrides parent method to fix parameter name mismatch issues.

Parameters:
tensors

Original data tensors

inference_outputs

Outputs from inference process

Returns:

Dictionary of inputs needed for generative process

setup_spatial_graph(adata: anndata.AnnData, spatial_key: str = 'spatial', batch_key: str | None = None, method: str = 'knn', n_neighbors: int = 10)[source]

Set up spatial graph for spatial information processing.

Constructs a spatial graph based on spatial coordinates in adata.obsm[spatial_key], using either K-nearest neighbors or Delaunay triangulation.

Parameters:
adata

AnnData object containing spatial coordinates in adata.obsm[spatial_key]

spatial_key

obsm key storing spatial coordinates, default is ‘spatial’

batch_key

obs key for batch information, if provided, graph will be constructed per batch

method

Method for constructing the graph, can be ‘knn’ or ‘delaunay’

n_neighbors

Number of neighbors for KNN method

process_in_batches(edge_index: torch.Tensor, max_edges_per_batch: int = 100000, combine_results: bool = True, adata: anndata.AnnData | None = None)[source]

Process edge indices in batches, suitable for processing large graph structures.

Divides edge_index into smaller batches for processing to avoid memory overflow errors.

Parameters:
edge_index

Edge index tensor, shape [2, num_edges]

max_edges_per_batch

Maximum number of edges per batch

combine_results

Whether to combine results from all batches

adata

AnnData object, if None uses the AnnData object from training

Returns:

If combine_results is True, returns combined result dictionary; otherwise returns a list of results for each batch

Return type:

dict or list

process_edges(edge_index: torch.Tensor, adata: anndata.AnnData | None = None)[source]

Process a single batch of edge indices.

This is a utility method for processing a single batch of edge indices in batch processing.

Parameters:
edge_index

Edge index tensor, shape [2, num_edges]

adata

AnnData object, if None uses the AnnData object from training

Returns:

Processing results

Return type:

dict