from __future__ import annotations
import logging
import warnings
from typing import TYPE_CHECKING, Dict, Optional
import numpy as np
import torch
from scipy.sparse import coo_matrix
# Import libraries needed for spatial graph processing
from scipy.spatial import Delaunay
from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager
from scvi.data._constants import ADATA_MINIFY_TYPE
from scvi.data._utils import _get_adata_minify_type
from scvi.data.fields import (
CategoricalJointObsField,
CategoricalObsField,
LayerField,
NumericalJointObsField,
NumericalObsField,
)
from scvi.model._utils import _init_library_size
from scvi.model.base import (
ArchesMixin,
BaseMinifiedModeModelClass,
EmbeddingMixin,
RNASeqMixin,
UnsupervisedTrainingMixin,
VAEMixin,
)
from scvi.module import VAE
from scvi.utils import setup_anndata_dsp
from sklearn.neighbors import kneighbors_graph
from torch_geometric.utils import from_scipy_sparse_matrix
if TYPE_CHECKING:
from typing import Literal
from anndata import AnnData
[docs]logger = logging.getLogger(__name__)
[docs]class SpatialVI(
EmbeddingMixin,
RNASeqMixin,
VAEMixin,
ArchesMixin,
UnsupervisedTrainingMixin,
BaseMinifiedModeModelClass,
):
"""Single-cell Variational Inference model.
This model uses a variational autoencoder (VAE) to learn low-dimensional representations of single-cell data.
It can be used for batch effect correction, differential expression analysis, data imputation, and various other downstream tasks.
Parameters
----------
adata
AnnData object containing single-cell data, which must be registered through the setup_anndata method.
n_hidden
Number of nodes in hidden layers.
n_latent
Dimension of the latent space (dimension of low-dimensional representation).
n_spatial
Dimension of spatial features. Effective when use_spatial=True.
n_layers
Number of hidden layers in encoder and decoder networks.
dropout_rate
Ratio of neurons randomly dropped during training.
dispersion
How to model variance parameters:
* 'gene' - one parameter per gene
* 'gene-batch' - different parameters for each gene in each batch
* 'gene-label' - different parameters for each gene in each label group
* 'gene-cell' - different parameters for each gene in each cell
gene_likelihood
Distribution used to model gene expression:
* 'nb' - negative binomial distribution (for count data with overdispersion)
* 'zinb' - zero-inflated negative binomial distribution (for sparse count data)
* 'poisson' - Poisson distribution (for count data)
* 'normal' - normal distribution (experimental)
use_observed_lib_size
Whether to use observed library size as scaling factor.
latent_distribution
Type of latent distribution:
* 'normal' - standard normal distribution
* 'ln' - logistic normal distribution
use_spatial
Whether to use spatial information. If True, spatial coordinates will be read from adata.obsm['spatial'].
spatial_graph_type
Method for spatial graph construction:
* 'knn' - K-nearest neighbors graph
* 'delaunay' - Delaunay triangulation
n_neighbors
Number of neighbors for KNN graph when spatial_graph_type='knn'.
attention_heads
Number of attention heads in graph attention network.
spatial_kl_weight
Weight for KL divergence of spatial latent variables.
**kwargs
Additional parameters passed to the VAE module.
Examples
--------
>>> import anndata
>>> import scvi
>>> adata = anndata.read_h5ad("my_data.h5ad")
>>> scvi.model.SCVI.setup_anndata(adata, batch_key="batch")
>>> model = scvi.model.SCVI(adata, use_spatial=True)
>>> model.train()
>>> adata.obsm["X_scVI"] = model.get_latent_representation()
"""
# Key names for latent variable mean and variance
[docs] latent_mean_key = "scvi_latent_qzm"
[docs] latent_var_key = "scvi_latent_qzv"
def __init__(
self,
adata: AnnData | None = None,
n_hidden: int = 128,
n_latent: int = 10,
n_spatial: int = 10,
n_layers: int = 1,
dropout_rate: float = 0.1,
dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene",
gene_likelihood: Literal["zinb", "nb", "poisson", "normal"] = "zinb",
use_observed_lib_size: bool = True,
latent_distribution: Literal["normal", "ln"] = "normal",
use_spatial: bool = False,
spatial_graph_type: Literal["knn", "delaunay"] = "knn",
n_neighbors: int = 10,
attention_heads: int = 1,
spatial_kl_weight: float = 0.01,
**kwargs,
):
# Initialize parent class
super().__init__(adata)
# Store model configuration parameters
[docs] self._module_kwargs = {
"n_hidden": n_hidden,
"n_latent": n_latent,
"n_spatial": n_spatial,
"n_layers": n_layers,
"dropout_rate": dropout_rate,
"dispersion": dispersion,
"gene_likelihood": gene_likelihood,
"latent_distribution": latent_distribution,
**kwargs,
}
# Create model summary string
[docs] self._model_summary_string = (
"SpatialVI model parameter summary: \n"
f"Hidden layer nodes: {n_hidden}, Latent dimensions: {n_latent}, Spatial dimensions: {n_spatial}, "
f"Layers: {n_layers}, Dropout rate: {dropout_rate}, "
f"Dispersion type: {dispersion}, Gene likelihood: {gene_likelihood}, "
f"Use spatial info: {use_spatial}, Spatial KL weight: {spatial_kl_weight}"
)
# Spatial related parameters
[docs] self.use_spatial = use_spatial
[docs] self.spatial_graph_type = spatial_graph_type
[docs] self.n_neighbors = n_neighbors
[docs] self.attention_heads = attention_heads
[docs] self.spatial_kl_weight = spatial_kl_weight
[docs] self.edge_index = None # Edge indices for spatial graph
# If using spatial information, set up spatial graph
if use_spatial and adata is not None:
if "spatial" not in adata.obsm:
warnings.warn(
"Using spatial information but no spatial coordinates found in adata.obsm['spatial']. "
"Please provide coordinates in adata.obsm['spatial'].",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)
else:
self.setup_spatial_graph(adata)
# Handle case where adata is not provided during initialization
if self._module_init_on_train:
self.module = None
warnings.warn(
"No adata provided during model initialization. Module will be initialized when train() is called. "
"This is an experimental feature and may change in the future.",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)
else:
# Get categorical covariate information (if available)
cat_covs_per_group = (
self.adata_manager.get_state_registry("categorical_covariates").n_cats_per_key
if "categorical_covariates" in self.adata_manager.data_registry
else None
)
# Get number of batches from data
num_batches = self.summary_stats.n_batch
# Check if size factor is provided in data
use_size_factor = "size_factor" in self.adata_manager.data_registry
# Initialize library size parameters if needed
lib_mean, lib_var = None, None
if (
not use_size_factor
and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR
and not use_observed_lib_size
):
lib_mean, lib_var = _init_library_size(self.adata_manager, num_batches)
# CreateVAE module - add spatial related parameters
from .scvi_spatial_module import SpatialVAE # Import new spatial VAE module
if use_spatial:
# Use VAE module with spatial information
self.module = SpatialVAE(
n_input=self.summary_stats.n_vars,
n_batch=num_batches,
n_labels=self.summary_stats.n_labels,
n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0),
n_cats_per_cov=cat_covs_per_group,
n_hidden=n_hidden,
n_latent=n_latent,
n_spatial=n_spatial,
n_layers=n_layers,
dropout_rate=dropout_rate,
dispersion=dispersion,
gene_likelihood=gene_likelihood,
use_observed_lib_size=use_observed_lib_size,
latent_distribution=latent_distribution,
use_size_factor_key=use_size_factor,
library_log_means=lib_mean,
library_log_vars=lib_var,
edge_index=self.edge_index,
attention_heads=attention_heads,
spatial_kl_weight=spatial_kl_weight,
**kwargs,
)
else:
# Use standard VAE module
self.module = self._module_cls(
n_input=self.summary_stats.n_vars,
n_batch=num_batches,
n_labels=self.summary_stats.n_labels,
n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0),
n_cats_per_cov=cat_covs_per_group,
n_hidden=n_hidden,
n_latent=n_latent,
n_layers=n_layers,
dropout_rate=dropout_rate,
dispersion=dispersion,
gene_likelihood=gene_likelihood,
use_observed_lib_size=use_observed_lib_size,
latent_distribution=latent_distribution,
use_size_factor_key=use_size_factor,
library_log_means=lib_mean,
library_log_vars=lib_var,
**kwargs,
)
self.module.minified_data_type = self.minified_data_type
# Store initialization parameters
[docs] self.init_params_ = self._get_init_params(locals())
@classmethod
@setup_anndata_dsp.dedent
[docs] def setup_anndata(
cls,
adata: AnnData,
layer: str | None = None,
batch_key: str | None = None,
labels_key: str | None = None,
size_factor_key: str | None = None,
categorical_covariate_keys: list[str] | None = None,
continuous_covariate_keys: list[str] | None = None,
**kwargs,
):
"""Set up AnnData object for SpatialVI model training.
This method prepares for training by registering necessary data fields.
Parameters
----------
adata
AnnData object containing single-cell data
layer
If provided, use this layer's expression values instead of adata.X
batch_key
Key in adata.obs representing batch information. If None, assumes all cells are from same batch.
labels_key
Key in adata.obs representing cell type or other label information
size_factor_key
Key in adata.obs representing size factor information. Used for normalization.
categorical_covariate_keys
List of keys in adata.obs representing categorical covariates (like experimental conditions)
continuous_covariate_keys
List of keys in adata.obs representing continuous covariates (like quality control metrics)
"""
# Store setup parameters
setup_args = cls._get_setup_method_args(**locals())
# Define all data fields needed by the model
data_fields = [
# Main expression data (X)
LayerField("X", layer, is_count_data=True),
# Batch information, used for handling batch effects
CategoricalObsField("batch", batch_key),
# Cell type or other label information
CategoricalObsField("labels", labels_key),
# Size factor for normalization (optional)
NumericalObsField("size_factor", size_factor_key, required=False),
# Additional categorical and continuous covariates
CategoricalJointObsField("categorical_covariates", categorical_covariate_keys),
NumericalJointObsField("continuous_covariates", continuous_covariate_keys),
]
# Check if data is in minified format and add appropriate fields
data_type = _get_adata_minify_type(adata)
if data_type is not None:
# Add required fields for minified data format
data_fields += cls._get_fields_for_adata_minification(data_type)
# Create manager for data fields
adata_manager = AnnDataManager(fields=data_fields, setup_method_args=setup_args)
# Register all fields to AnnData object
adata_manager.register_fields(adata, **kwargs)
# Register manager to model class
cls.register_manager(adata_manager)
[docs] def setup_spatial_graph(self, adata: AnnData):
"""Set up spatial graph for spatial information processing.
Build spatial graph based on spatial coordinates in adata.obsm['spatial'],
choosing between K-nearest neighbors graph or Delaunay triangulation.
Parameters
----------
adata
AnnData object containing spatial coordinates in adata.obsm['spatial']
"""
try:
# Get spatial coordinates
if "spatial" not in adata.obsm:
raise ValueError("Spatial coordinates not found in adata.obsm['spatial']")
coordinates = adata.obsm["spatial"]
logger.info(
f"Retrieved spatial coordinates for {len(coordinates)} cells, dimension: {coordinates.shape[1]}"
)
# Check validity of coordinate data
if np.isnan(coordinates).any():
# Handle NaN values
logger.warning("Spatial coordinates contain NaN values, will use mean filling")
coordinates = np.nan_to_num(coordinates, nan=np.nanmean(coordinates, axis=0))
# Build spatial graph based on chosen method
if self.spatial_graph_type == "knn":
# Build K-nearest neighbors graph
logger.info(f"Building spatial graph using KNN method, neighbors: {self.n_neighbors}")
adj_matrix = kneighbors_graph(
coordinates,
n_neighbors=min(self.n_neighbors, len(coordinates) - 1),
mode="connectivity",
include_self=False,
n_jobs=-1, # Use all available CPU cores to speed up computation
)
# Ensure graph is symmetric
adj_matrix = adj_matrix + adj_matrix.T
adj_matrix.data = np.ones(adj_matrix.data.shape)
# Remove duplicate edges
adj_matrix = adj_matrix.tocsr()
adj_matrix.data = np.ones_like(adj_matrix.data)
elif self.spatial_graph_type == "delaunay":
# Build Delaunay triangulation
logger.info("Building spatial graph using Delaunay triangulation")
# Ensure coordinates are 2D or 3D, Delaunay needs at least 2D
if coordinates.shape[1] == 1:
# If only 1D coordinates, add a fake second dimension
logger.warning(
"Coordinates are only 1D, adding random small noise as second dimension for Delaunay triangulation"
)
coordinates_2d = np.column_stack(
[coordinates, np.random.normal(0, 0.01, size=(len(coordinates), 1))]
)
else:
coordinates_2d = coordinates
# Perform triangulation
tri = Delaunay(coordinates_2d)
# Get adjacency from triangulation
edges = set()
for simplex in tri.simplices:
for i in range(len(simplex)):
for j in range(i + 1, len(simplex)):
# Add undirected edge (both directions)
edges.add((simplex[i], simplex[j]))
edges.add((simplex[j], simplex[i]))
# Create sparse adjacency matrix
if edges:
rows, cols = zip(*edges)
data = np.ones(len(rows))
adj_matrix = coo_matrix((data, (rows, cols)), shape=(adata.n_obs, adata.n_obs)).tocsr()
else:
logger.warning("Delaunay triangulation did not generate any edges, will use fully connected graph")
# Create a sparse graph with a few random connections
adj_matrix = kneighbors_graph(
coordinates, n_neighbors=min(10, len(coordinates) - 1), mode="connectivity"
)
else:
raise ValueError(f"Unknown spatial graph type: {self.spatial_graph_type}")
# Check graph connectivity
n_components = 1 # Assume connected
if adj_matrix.nnz == 0:
logger.warning("Generated graph has no edges, will use fully connected graph")
# Create fully connected graph
rows, cols = np.where(np.ones((adata.n_obs, adata.n_obs)) - np.eye(adata.n_obs))
data = np.ones(len(rows))
adj_matrix = coo_matrix((data, (rows, cols)), shape=(adata.n_obs, adata.n_obs))
# Convert to PyTorch Geometric edge index format
edge_index, _ = from_scipy_sparse_matrix(adj_matrix)
# Move to appropriate device
if hasattr(self, "device"):
device = self.device
elif hasattr(self, "module") and hasattr(self.module, "device"):
device = self.module.device
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.edge_index = edge_index.to(device)
logger.info(
f"Successfully built spatial graph, type: {self.spatial_graph_type}, "
f"nodes: {adata.n_obs}, edges: {edge_index.shape[1] // 2}, "
f"device: {device}"
)
except Exception as e:
# Catch any errors, provide detailed error information
import traceback
logger.error(f"Error building spatial graph: {str(e)}\n{traceback.format_exc()}")
warnings.warn(
f"Error building spatial graph: {str(e)}. Will not use spatial information.",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)
self.edge_index = None
self.use_spatial = False
[docs]class MERFISHVI(SpatialVI):
"""MERFISH spatial multimodal variational inference model.
This model extends SCVI to simultaneously process MERFISH data with spatial information and
other modality data (such as scRNA-seq) without spatial information. It uses a shared
latent space to jointly model both modalities.
Parameters
----------
adata_spatial
AnnData object containing spatial modality data
adata_nonspatial
AnnData object containing non-spatial modality data, optional
n_hidden
Number of nodes in hidden layers
n_latent
Dimension of latent space
n_spatial
Dimension of spatial features
n_layers
Number of hidden layers in encoder and decoder networks
dropout_rate
Dropout rate during training
dispersion
How to model variance parameters
gene_likelihood
Distribution used to model gene expression
latent_distribution
Type of latent distribution
spatial_graph_type
Spatial graph construction method, 'knn' or 'delaunay'
n_neighbors
Number of neighbors when spatial_graph_type='knn'
attention_heads
Number of attention heads in graph attention network
spatial_kl_weight
Weight for spatial feature KL divergence
modality_weights
Modality weights to balance contribution of each modality in the loss function
**kwargs
Additional parameters passed to VAE module
"""
def __init__(
self,
adata_spatial: AnnData,
adata_nonspatial: Optional[AnnData] = None,
n_hidden: int = 128,
n_latent: int = 10,
n_spatial: int = 10,
n_layers: int = 1,
dropout_rate: float = 0.1,
dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene",
gene_likelihood: Literal["zinb", "nb", "poisson", "normal"] = "zinb",
use_observed_lib_size: bool = True,
latent_distribution: Literal["normal", "ln"] = "normal",
spatial_graph_type: Literal["knn", "delaunay"] = "knn",
n_neighbors: int = 10,
attention_heads: int = 1,
spatial_kl_weight: float = 0.01,
modality_weights: Dict[str, float] = {"spatial": 1.0, "nonspatial": 1.0},
**kwargs,
):
# Store non-spatial modality data
[docs] self.adata_nonspatial = adata_nonspatial
[docs] self.modality_weights = modality_weights
# Initialize parent class using spatial modality data
super().__init__(adata_spatial)
# Store model configuration parameters
[docs] self._module_kwargs = {
"n_hidden": n_hidden,
"n_latent": n_latent,
"n_spatial": n_spatial,
"n_layers": n_layers,
"dropout_rate": dropout_rate,
"dispersion": dispersion,
"gene_likelihood": gene_likelihood,
"latent_distribution": latent_distribution,
**kwargs,
}
# Create model summary string
[docs] self._model_summary_string = (
"MERFISHVI model parameter summary: \n"
f"Hidden layer nodes: {n_hidden}, Latent dimensions: {n_latent}, Spatial dimensions: {n_spatial}, "
f"Layers: {n_layers}, Dropout rate: {dropout_rate}, "
f"Dispersion type: {dispersion}, Gene likelihood: {gene_likelihood}, "
f"Spatial graph type: {spatial_graph_type}, Spatial KL weight: {spatial_kl_weight}, "
f"Modality weights: {modality_weights}"
)
# Spatial related parameters
[docs] self.spatial_graph_type = spatial_graph_type
[docs] self.n_neighbors = n_neighbors
[docs] self.attention_heads = attention_heads
[docs] self.spatial_kl_weight = spatial_kl_weight
# Set up spatial graph for spatial modality
self.setup_spatial_graph(adata_spatial)
# Set up non-spatial modality data manager
if adata_nonspatial is not None:
self.setup_nonspatial_anndata(adata_nonspatial)
# Initialize model
if self._module_init_on_train:
self.module = None
warnings.warn(
"No adata provided during model initialization. Module will be initialized when train() is called. "
"This is an experimental feature and may change in the future.",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)
else:
# Create model
self._create_module()
self.module.minified_data_type = self.minified_data_type
# Store initialization parameters
[docs] self.init_params_ = self._get_init_params(locals())
@classmethod
@setup_anndata_dsp.dedent
[docs] def setup_nonspatial_anndata(cls, adata: AnnData, **kwargs):
"""Set up non-spatial modality AnnData object.
Parameters
----------
adata
Non-spatial modality AnnData object
"""
# Create separate data manager for non-spatial modality
# Store setup parameters
setup_args = cls._get_setup_method_args(**locals())
# Define all data fields needed by the model
data_fields = [
LayerField("X", None, is_count_data=True),
CategoricalObsField("batch", None),
CategoricalObsField("labels", None),
NumericalObsField("size_factor", None, required=False),
CategoricalJointObsField("categorical_covariates", None),
NumericalJointObsField("continuous_covariates", None),
]
# Check if data is in minified format and add appropriate fields
data_type = _get_adata_minify_type(adata)
if data_type is not None:
# Add required fields for minified data format
data_fields += cls._get_fields_for_adata_minification(data_type)
# Create manager for data fields
adata_manager = AnnDataManager(fields=data_fields, setup_method_args=setup_args)
# Register all fields to AnnData object
adata_manager.register_fields(adata, **kwargs)
# Register manager to model class
cls.register_manager(adata_manager)
logger.info(f"Successfully set up non-spatial modality data, shape: {adata.shape}")
[docs] def _create_module(self):
"""Create model module"""
# Get spatial modality data information
n_vars_spatial = self.summary_stats.n_vars
n_batch_spatial = self.summary_stats.n_batch
n_labels_spatial = self.summary_stats.n_labels
# Get spatial modality categorical covariate information
cat_covs_per_group_spatial = (
self.adata_manager.get_state_registry("categorical_covariates").n_cats_per_key
if "categorical_covariates" in self.adata_manager.data_registry
else None
)
# Use size factor
use_size_factor_spatial = "size_factor" in self.adata_manager.data_registry
# Initialize library size parameters
lib_mean_spatial, lib_var_spatial = None, None
if (
not use_size_factor_spatial
and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR
and not self._module_kwargs.get("use_observed_lib_size", True)
):
lib_mean_spatial, lib_var_spatial = _init_library_size(self.adata_manager, n_batch_spatial)
# Check if non-spatial modality is available
has_nonspatial = hasattr(self, "adata_manager_nonspatial")
if has_nonspatial:
# Get non-spatial modality data information
n_vars_nonspatial = self.adata_manager_nonspatial.summary_stats.n_vars
n_batch_nonspatial = self.adata_manager_nonspatial.summary_stats.n_batch
n_labels_nonspatial = self.adata_manager_nonspatial.summary_stats.n_labels
# Get non-spatial modality categorical covariate information
cat_covs_per_group_nonspatial = (
self.adata_manager_nonspatial.get_state_registry("categorical_covariates").n_cats_per_key
if "categorical_covariates" in self.adata_manager_nonspatial.data_registry
else None
)
# Use size factor
use_size_factor_nonspatial = "size_factor" in self.adata_manager_nonspatial.data_registry
# Initialize library size parameters
lib_mean_nonspatial, lib_var_nonspatial = None, None
if (
not use_size_factor_nonspatial
and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR
and not self._module_kwargs.get("use_observed_lib_size", True)
):
lib_mean_nonspatial, lib_var_nonspatial = _init_library_size(
self.adata_manager_nonspatial, n_batch_nonspatial
)
# Import multimodal spatial VAE module
from .multimodal_spatial_vae import MultiModalSpatialVAE
# Create model module
if has_nonspatial:
self.module = MultiModalSpatialVAE(
n_input_spatial=n_vars_spatial,
n_input_nonspatial=n_vars_nonspatial,
n_batch_spatial=n_batch_spatial,
n_batch_nonspatial=n_batch_nonspatial,
n_labels_spatial=n_labels_spatial,
n_labels_nonspatial=n_labels_nonspatial,
n_hidden=self._module_kwargs["n_hidden"],
n_latent=self._module_kwargs["n_latent"],
n_spatial=self._module_kwargs["n_spatial"],
n_layers=self._module_kwargs["n_layers"],
dropout_rate=self._module_kwargs["dropout_rate"],
dispersion=self._module_kwargs["dispersion"],
gene_likelihood=self._module_kwargs["gene_likelihood"],
latent_distribution=self._module_kwargs["latent_distribution"],
use_observed_lib_size=self._module_kwargs.get("use_observed_lib_size", True),
edge_index=self.edge_index,
attention_heads=self.attention_heads,
spatial_kl_weight=self.spatial_kl_weight,
modality_weights=self.modality_weights,
cats_per_cov_spatial=cat_covs_per_group_spatial,
cats_per_cov_nonspatial=cat_covs_per_group_nonspatial,
use_size_factor_spatial=use_size_factor_spatial,
use_size_factor_nonspatial=use_size_factor_nonspatial,
library_log_means_spatial=lib_mean_spatial,
library_log_vars_spatial=lib_var_spatial,
library_log_means_nonspatial=lib_mean_nonspatial,
library_log_vars_nonspatial=lib_var_nonspatial,
**{
k: v
for k, v in self._module_kwargs.items()
if k
not in [
"n_hidden",
"n_latent",
"n_spatial",
"n_layers",
"dropout_rate",
"dispersion",
"gene_likelihood",
"latent_distribution",
"use_observed_lib_size",
]
},
)
else:
# Only use spatial modality
from .scvi_spatial_module import SpatialVAE
self.module = SpatialVAE(
n_input=n_vars_spatial,
n_batch=n_batch_spatial,
n_labels=n_labels_spatial,
n_hidden=self._module_kwargs["n_hidden"],
n_latent=self._module_kwargs["n_latent"],
n_spatial=self._module_kwargs["n_spatial"],
n_layers=self._module_kwargs["n_layers"],
dropout_rate=self._module_kwargs["dropout_rate"],
dispersion=self._module_kwargs["dispersion"],
gene_likelihood=self._module_kwargs["gene_likelihood"],
latent_distribution=self._module_kwargs["latent_distribution"],
use_observed_lib_size=self._module_kwargs.get("use_observed_lib_size", True),
edge_index=self.edge_index,
attention_heads=self.attention_heads,
spatial_kl_weight=self.spatial_kl_weight,
n_cats_per_cov=cat_covs_per_group_spatial,
use_size_factor_key=use_size_factor_spatial,
library_log_means=lib_mean_spatial,
library_log_vars=lib_var_spatial,
**{
k: v
for k, v in self._module_kwargs.items()
if k
not in [
"n_hidden",
"n_latent",
"n_spatial",
"n_layers",
"dropout_rate",
"dispersion",
"gene_likelihood",
"latent_distribution",
"use_observed_lib_size",
]
},
)
# Add methods for getting multimodal representations
@torch.inference_mode()
[docs] def get_spatial_representation(self, adata=None, indices=None, batch_size=None):
"""Get spatial feature representation"""
if hasattr(self.module, "get_spatial_representation"):
return self.module.get_spatial_representation(adata, indices, batch_size)
else:
raise AttributeError("Current model does not support getting spatial representation")
@torch.inference_mode()
[docs] def get_latent_representation(self, adata=None, indices=None, batch_size=None, modality="spatial"):
"""Get latent representation, can choose from spatial modality, non-spatial modality, or fused representation"""
if hasattr(self.module, "get_latent_representation_by_modality"):
return self.module.get_latent_representation_by_modality(adata, indices, batch_size, modality)
else:
return super().get_latent_representation(adata, indices, batch_size)
@torch.inference_mode()
[docs] def get_fused_representation(self, adata=None, indices=None, batch_size=None):
"""Get fused latent representation"""
if hasattr(self.module, "get_fused_representation"):
return self.module.get_fused_representation(adata, indices, batch_size)
else:
return self.get_latent_representation(adata, indices, batch_size)