from __future__ import annotations
import logging
import warnings
from typing import Dict, List, Literal, Optional, Sequence, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Distribution, Normal, kl_divergence
try:
from torch_geometric.nn import GATv2Conv
except ImportError:
try:
from torch_geometric.nn.conv import GATv2Conv
except ImportError:
raise ImportError("Failed to import GATv2Conv, please install PyTorch Geometric")
from scvi import REGISTRY_KEYS
try:
from scvi.module.base import LossOutput, auto_move_data
except ImportError:
try:
from scvi.module.base import LossOutput
from scvi.nn.base import auto_move_data
except ImportError:
try:
from scvi.model.base import LossOutput
from scvi.nn import auto_move_data
except ImportError:
raise ImportError("Failed to import auto_move_data and LossOutput, please check scvi-tools version")
try:
from scvi.utils import unsupported_if_adata_minified
except ImportError:
try:
from scvi.model.base import unsupported_if_adata_minified
except ImportError:
# Create a dummy decorator if not available
[docs] def unsupported_if_adata_minified(fn):
return fn
from anndata import AnnData
from scvi.module import VAE
from scvi.nn import Decoder, Encoder, FCLayers
# Import SpatialEncoder and SpatialVAE
from .scvi_spatial_module import SpatialEncoder, SpatialVAE
[docs]logger = logging.getLogger(__name__)
[docs]class MultiModalSpatialVAE(SpatialVAE):
"""Multi-modal spatial variational autoencoder.
Processes both modalities with spatial information and those without spatial information.
Uses a shared latent space for joint modeling.
Parameters
----------
n_input_spatial
Number of input features for spatial modality
n_input_nonspatial
Number of input features for non-spatial modality
n_batch_spatial
Number of batches for spatial modality
n_batch_nonspatial
Number of batches for non-spatial modality
n_labels_spatial
Number of labels for spatial modality
n_labels_nonspatial
Number of labels for non-spatial modality
n_hidden
Number of nodes in hidden layers
n_latent
Dimension of latent space
n_spatial
Dimension of spatial features
n_layers
Number of hidden layers
dropout_rate
Dropout rate
dispersion
Dispersion parameter type
gene_likelihood
Gene likelihood distribution type
latent_distribution
Latent distribution type
**kwargs
Additional parameters
"""
def __init__(
self,
n_input_spatial: int,
n_input_nonspatial: int,
n_batch_spatial: int = 0,
n_batch_nonspatial: int = 0,
n_labels_spatial: int = 0,
n_labels_nonspatial: int = 0,
n_hidden: int = 128,
n_latent: int = 10,
n_spatial: int = 10,
n_layers: int = 1,
dropout_rate: float = 0.1,
dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene",
gene_likelihood: Literal["zinb", "nb", "poisson", "normal"] = "zinb",
latent_distribution: Literal["normal", "ln"] = "normal",
use_observed_lib_size: bool = True,
edge_index: Optional[torch.Tensor] = None,
attention_heads: int = 1,
spatial_kl_weight: float = 0.01,
modality_weights: Dict[str, float] = {"spatial": 1.0, "nonspatial": 1.0},
cats_per_cov_spatial: Optional[List[int]] = None,
cats_per_cov_nonspatial: Optional[List[int]] = None,
use_size_factor_spatial: bool = False,
use_size_factor_nonspatial: bool = False,
library_log_means_spatial: Optional[torch.Tensor] = None,
library_log_vars_spatial: Optional[torch.Tensor] = None,
library_log_means_nonspatial: Optional[torch.Tensor] = None,
library_log_vars_nonspatial: Optional[torch.Tensor] = None,
var_eps: float = 1e-4,
**kwargs,
):
# Initialize base VAE (note here we do not call SpatialVAE's initialization, but directly call VAE's initialization)
VAE.__init__(
self,
n_input=n_input_spatial,
n_batch=n_batch_spatial,
n_labels=n_labels_spatial,
n_hidden=n_hidden,
n_latent=n_latent,
n_layers=n_layers,
dropout_rate=dropout_rate,
dispersion=dispersion,
gene_likelihood=gene_likelihood,
latent_distribution=latent_distribution,
use_observed_lib_size=use_observed_lib_size,
n_cats_per_cov=cats_per_cov_spatial,
use_size_factor_key=use_size_factor_spatial,
library_log_means=library_log_means_spatial,
library_log_vars=library_log_vars_spatial,
**kwargs,
)
# Store configuration parameters
[docs] self.n_spatial = n_spatial
[docs] self.spatial_kl_weight = spatial_kl_weight
[docs] self.modality_weights = modality_weights
# Initialize spatial encoder
[docs] self.spatial_encoder = SpatialEncoder(
n_latent=n_latent,
n_spatial=n_spatial,
attention_heads=attention_heads,
dropout_rate=dropout_rate,
var_eps=var_eps,
)
# Create encoder for non-spatial modality
[docs] self.nonspatial_encoder = Encoder(
n_input=n_input_nonspatial,
n_output=n_latent,
n_cat_list=[n_batch_nonspatial] if n_batch_nonspatial > 0 else None,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
distribution=latent_distribution,
)
# Create decoder for non-spatial modality
if gene_likelihood in ["zinb", "nb"]:
from scvi.nn import DecoderSCVI
self.nonspatial_decoder = DecoderSCVI(
n_input=n_latent,
n_output=n_input_nonspatial,
n_cat_list=[n_batch_nonspatial] if n_batch_nonspatial > 0 else None,
n_layers=n_layers,
n_hidden=n_hidden,
)
else:
self.nonspatial_decoder = Decoder(
n_input=n_latent,
n_output=n_input_nonspatial,
n_cat_list=[n_batch_nonspatial] if n_batch_nonspatial > 0 else None,
n_layers=n_layers,
n_hidden=n_hidden,
)
# Create library size parameter for non-spatial modality (if needed)
if not use_observed_lib_size:
self.nonspatial_l_mean = torch.nn.Parameter(
library_log_means_nonspatial
if library_log_means_nonspatial is not None
else torch.zeros(n_batch_nonspatial if n_batch_nonspatial > 0 else 1)
)
self.nonspatial_l_var = torch.nn.Parameter(
library_log_vars_nonspatial
if library_log_vars_nonspatial is not None
else torch.zeros(n_batch_nonspatial if n_batch_nonspatial > 0 else 1)
)
# Process edge_index (spatial graph)
if edge_index is not None:
if not isinstance(edge_index, torch.Tensor):
try:
edge_index = torch.tensor(edge_index, dtype=torch.long)
except Exception as e:
warnings.warn(f"Failed to convert edge_index to tensor: {str(e)}, will set to None", UserWarning)
edge_index = None
elif edge_index.dtype != torch.long:
try:
edge_index = edge_index.long()
except Exception as e:
warnings.warn(f"Failed to convert edge_index to long type: {str(e)}, will set to None", UserWarning)
edge_index = None
# Check edge_index shape
if edge_index is not None and (len(edge_index.shape) != 2 or edge_index.shape[0] != 2):
warnings.warn(
f"edge_index shape error: {edge_index.shape}, should be [2, num_edges], will set to None",
UserWarning,
)
edge_index = None
[docs] self.edge_index = edge_index
self.register_buffer("_edge_index", edge_index)
@auto_move_data
[docs] def inference_spatial(
self, x: torch.Tensor, batch_index: Optional[torch.Tensor] = None, **kwargs
) -> Dict[str, torch.Tensor]:
"""Spatial modality inference process.
Parameters
----------
x
Input data of spatial modality
batch_index
Batch index
Returns
-------
dict
Dictionary containing inference results
"""
# Call base VAE inference
outputs = VAE.inference(self, x, batch_index, **kwargs)
# Get latent representation
z = outputs["z"]
# Ensure edge_index is on the correct device
if self.edge_index is not None and z.device != self.edge_index.device:
self.edge_index = self.edge_index.to(z.device)
# Calculate spatial feature
try:
spatial_mean, spatial_var, spatial_sample = self.spatial_encoder(z, self.edge_index)
# Add spatial feature to outputs
outputs.update(
{
"spatial_mean": spatial_mean,
"spatial_var": spatial_var,
"spatial_sample": spatial_sample,
}
)
except Exception as e:
# If spatial encoder fails, add warning log and return zero tensor
warnings.warn(
f"Spatial encoder processing failed: {str(e)}. Will return zero tensor as spatial feature.", UserWarning
)
batch_size = z.size(0)
device = z.device
# Create zero tensor as spatial feature
spatial_mean = torch.zeros(batch_size, self.n_spatial, device=device)
spatial_var = torch.ones(batch_size, self.n_spatial, device=device) * self.var_eps
spatial_sample = torch.zeros(batch_size, self.n_spatial, device=device)
# Add spatial feature to outputs
outputs.update(
{
"spatial_mean": spatial_mean,
"spatial_var": spatial_var,
"spatial_sample": spatial_sample,
}
)
return outputs
@auto_move_data
[docs] def inference_nonspatial(
self, x: torch.Tensor, batch_index: Optional[torch.Tensor] = None, **kwargs
) -> Dict[str, torch.Tensor]:
"""Non-spatial modality inference process.
Parameters
----------
x
Input data of non-spatial modality
batch_index
Batch index
Returns
-------
dict
Dictionary containing inference results
"""
# Use non-spatial modality encoder
return self.nonspatial_encoder(x, batch_index)
@auto_move_data
[docs] def inference(
self,
x_spatial: torch.Tensor,
x_nonspatial: Optional[torch.Tensor] = None,
batch_index_spatial: Optional[torch.Tensor] = None,
batch_index_nonspatial: Optional[torch.Tensor] = None,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""Joint inference process, handles both spatial and non-spatial modalities.
Parameters
----------
x_spatial
Input data of spatial modality
x_nonspatial
Input data of non-spatial modality, optional
batch_index_spatial
Batch index of spatial modality
batch_index_nonspatial
Batch index of non-spatial modality
Returns
-------
dict
Dictionary containing joint inference results
"""
# First perform spatial modality inference
outputs = self.inference_spatial(x_spatial, batch_index_spatial, **kwargs)
# If no non-spatial modality input, return spatial modality results directly
if x_nonspatial is None:
return outputs
# Perform non-spatial modality inference
nonspatial_outputs = self.inference_nonspatial(x_nonspatial, batch_index_nonspatial)
# Fuse latent representations of two modalities
w1 = self.modality_weights.get("spatial", 1.0)
w2 = self.modality_weights.get("nonspatial", 1.0)
total_weight = w1 + w2
# Weighted fusion latent representations
fused_z = (w1 * outputs["z"] + w2 * nonspatial_outputs["z"]) / total_weight
# Update output dictionary
outputs.update(
{
# Add non-spatial modality outputs
"nonspatial_qz_m": nonspatial_outputs["qz_m"],
"nonspatial_qz_v": nonspatial_outputs["qz_v"],
"nonspatial_z": nonspatial_outputs["z"],
# Use fused z
"fused_z": fused_z,
# Default use fused z as main latent representation
"z": fused_z,
}
)
return outputs
@auto_move_data
[docs] def generative_spatial(
self, z: torch.Tensor, batch_index: Optional[torch.Tensor] = None, **kwargs
) -> Dict[str, torch.Tensor]:
"""Spatial modality generative process.
Parameters
----------
z
Latent representation
batch_index
Batch index
Returns
-------
dict
Generative output
"""
return VAE.generative(self, z, batch_index, **kwargs)
@auto_move_data
[docs] def generative_nonspatial(
self,
z: torch.Tensor,
batch_index: Optional[torch.Tensor] = None,
library_size: Optional[torch.Tensor] = None,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""Non-spatial modality generative process.
Parameters
----------
z
Latent representation
batch_index
Batch index
library_size
Library size
Returns
-------
dict
Generative output
"""
# Generate reconstruction of non-spatial modality
px_rate = self.nonspatial_decoder(z, batch_index)
# If library size is not provided and needs estimation
if library_size is None and not self.use_observed_lib_size:
batch_index = batch_index.view(-1, 1) if batch_index is not None else None
if batch_index is None and self.nonspatial_l_mean.shape[0] > 1:
raise ValueError("No batch_index provided, but model has multiple batches")
# Get library size parameter of current batch
if batch_index is not None and self.nonspatial_l_mean.shape[0] > 1:
library_loc = F.linear(torch.ones_like(batch_index, dtype=torch.float), self.nonspatial_l_mean)
library_scale = F.linear(
torch.ones_like(batch_index, dtype=torch.float), torch.exp(self.nonspatial_l_var) + 1e-4
)
else:
library_loc = self.nonspatial_l_mean
library_scale = torch.exp(self.nonspatial_l_var) + 1e-4
# Sample library size
library = torch.distributions.LogNormal(library_loc, library_scale.sqrt()).rsample()
elif library_size is not None:
library = library_size
else:
# Use observed library size (normalized)
library = torch.log(torch.sum(torch.exp(x_nonspatial), dim=1, keepdim=True))
# Build output dictionary
outputs = {"px_rate": px_rate}
if self.gene_likelihood == "zinb":
px_r = self.px_r
px_dropout = self.px_dropout
outputs.update({"px_r": px_r, "px_dropout": px_dropout})
return outputs
@auto_move_data
[docs] def generative(
self,
z: torch.Tensor,
batch_index_spatial: Optional[torch.Tensor] = None,
batch_index_nonspatial: Optional[torch.Tensor] = None,
library_size_spatial: Optional[torch.Tensor] = None,
library_size_nonspatial: Optional[torch.Tensor] = None,
**kwargs,
) -> Dict[str, Dict[str, torch.Tensor]]:
"""Joint generative process.
Parameters
----------
z
Latent representation
batch_index_spatial
Batch index of spatial modality
batch_index_nonspatial
Batch index of non-spatial modality
library_size_spatial
Library size of spatial modality
library_size_nonspatial
Library size of non-spatial modality
Returns
-------
dict
Dictionary containing outputs of two modalities generative process
"""
# Generate reconstruction of spatial modality
spatial_outputs = self.generative_spatial(z, batch_index_spatial, library_size=library_size_spatial, **kwargs)
# If batch_index_nonspatial is not None, generate reconstruction of non-spatial modality
if batch_index_nonspatial is not None:
nonspatial_outputs = self.generative_nonspatial(
z, batch_index_nonspatial, library_size=library_size_nonspatial, **kwargs
)
# Return dictionary containing outputs of two modalities
return {"spatial": spatial_outputs, "nonspatial": nonspatial_outputs}
else:
# Return output of spatial modality only
return {"spatial": spatial_outputs}
@auto_move_data
[docs] def forward(
self, tensors: Dict[str, torch.Tensor], inference_kwargs: Dict = None, compute_loss: bool = True, **kwargs
) -> Tuple:
"""Forward propagation process.
Parameters
----------
tensors
Input tensor dictionary, containing data of two modalities
inference_kwargs
Parameters to pass to inference function
compute_loss
Whether to compute loss
Returns
-------
tuple
Inference output, generative output and loss
"""
# Ensure no conflicting parameters are passed
if inference_kwargs is None:
inference_kwargs = {}
# Fix parameter names
if "cont_covariates" in tensors and "cont_covs" not in tensors:
tensors["cont_covs"] = tensors.pop("cont_covariates")
if "cat_covariates" in tensors and "cat_covs" not in tensors:
tensors["cat_covs"] = tensors.pop("cat_covariates")
# Extract data of two modalities from tensors
x_spatial = tensors.get("X", None)
x_nonspatial = tensors.get("X_nonspatial", None)
batch_index_spatial = tensors.get("batch_indices", None)
batch_index_nonspatial = tensors.get("batch_indices_nonspatial", None)
# If no non-spatial modality data is provided, only process spatial modality
has_nonspatial = x_nonspatial is not None
# Perform joint inference
inference_inputs = {
"x_spatial": x_spatial,
"batch_index_spatial": batch_index_spatial,
}
if has_nonspatial:
inference_inputs.update(
{
"x_nonspatial": x_nonspatial,
"batch_index_nonspatial": batch_index_nonspatial,
}
)
inference_inputs.update(inference_kwargs)
inference_outputs = self.inference(**inference_inputs)
# Perform joint generative
generative_inputs = {
"z": inference_outputs["z"], # Use fused latent representation
"batch_index_spatial": batch_index_spatial,
}
if has_nonspatial:
generative_inputs.update(
{
"batch_index_nonspatial": batch_index_nonspatial,
}
)
generative_outputs = self.generative(**generative_inputs)
if compute_loss:
# Calculate joint loss
loss_inputs = {
"tensors": tensors,
"inference_outputs": inference_outputs,
"generative_outputs": generative_outputs,
}
losses = self.loss(**loss_inputs)
return inference_outputs, generative_outputs, losses
else:
return inference_outputs, generative_outputs
@unsupported_if_adata_minified
[docs] def loss(
self,
tensors: Dict[str, torch.Tensor],
inference_outputs: Dict[str, torch.Tensor | Distribution | None],
generative_outputs: Dict[str, Dict[str, Distribution | None]],
kl_weight: torch.tensor | float = 1.0,
) -> LossOutput:
"""Calculate joint loss function.
Parameters
----------
tensors
Input tensors
inference_outputs
Inference process outputs
generative_outputs
Generative process outputs
kl_weight
KL divergence weight
Returns
-------
LossOutput
Loss output object
"""
# Extract weights of two modalities
w_spatial = self.modality_weights.get("spatial", 1.0)
w_nonspatial = self.modality_weights.get("nonspatial", 1.0)
total_weight = w_spatial + w_nonspatial
# Extract input data
x_spatial = tensors.get("X", None)
x_nonspatial = tensors.get("X_nonspatial", None)
# Calculate spatial modality loss
# Create tensors dictionary containing only spatial modality data
spatial_tensors = {"X": x_spatial}
for k, v in tensors.items():
if k != "X_nonspatial" and not k.endswith("_nonspatial"):
spatial_tensors[k] = v
# If there is spatial feature, calculate spatial modality KL divergence
if "spatial_mean" in inference_outputs:
spatial_mean = inference_outputs["spatial_mean"]
spatial_var = inference_outputs["spatial_var"]
# Create distribution objects
q_s = Normal(spatial_mean, spatial_var.sqrt())
p_s = Normal(torch.zeros_like(spatial_mean), torch.ones_like(spatial_var.sqrt()))
# Calculate KL divergence
kl_divergence_s = kl_divergence(q_s, p_s).sum(dim=-1)
spatial_kl = kl_divergence_s
else:
spatial_kl = torch.tensor(0.0, device=self.device)
# Calculate spatial modality reconstruction loss
spatial_reconst_loss, spatial_kl_local = self._get_reconstruction_loss(x_spatial, generative_outputs["spatial"])
# Calculate total KL divergence of spatial modality
spatial_kl_local.update({"kl_divergence_s": spatial_kl})
spatial_kl_weighted = torch.mean(
torch.sum(spatial_kl_local["kl_divergence_z"], dim=-1)
) + self.spatial_kl_weight * torch.mean(spatial_kl)
# Calculate total loss of spatial modality
spatial_loss = spatial_reconst_loss + kl_weight * spatial_kl_weighted
# If there is non-spatial modality, calculate non-spatial modality loss
if x_nonspatial is not None and "nonspatial" in generative_outputs:
# Calculate non-spatial modality KL divergence
nonspatial_qz_m = inference_outputs["nonspatial_qz_m"]
nonspatial_qz_v = inference_outputs["nonspatial_qz_v"]
# Create distribution objects
qz = Normal(nonspatial_qz_m, nonspatial_qz_v.sqrt())
pz = Normal(torch.zeros_like(nonspatial_qz_m), torch.ones_like(nonspatial_qz_v.sqrt()))
# Calculate KL divergence
kl_divergence_z_nonspatial = kl_divergence(qz, pz).sum(dim=-1)
# Calculate non-spatial modality reconstruction loss
nonspatial_reconst_loss = self._get_reconstruction_loss_nonspatial(
x_nonspatial, generative_outputs["nonspatial"]
)
# Calculate total loss of non-spatial modality
nonspatial_kl_weighted = torch.mean(kl_divergence_z_nonspatial)
nonspatial_loss = nonspatial_reconst_loss + kl_weight * nonspatial_kl_weighted
# Merge two modality losses
total_loss = (w_spatial * spatial_loss + w_nonspatial * nonspatial_loss) / total_weight
# Update KL divergence dictionary
spatial_kl_local.update({"kl_divergence_z_nonspatial": kl_divergence_z_nonspatial})
else:
# Use spatial modality loss only
total_loss = spatial_loss
nonspatial_reconst_loss = torch.tensor(0.0, device=self.device)
# Return loss output object
extra_metrics = {
"spatial_reconstruction_loss": spatial_reconst_loss,
"nonspatial_reconstruction_loss": nonspatial_reconst_loss if x_nonspatial is not None else None,
"spatial_kl": spatial_kl_weighted,
"nonspatial_kl": nonspatial_kl_weighted if x_nonspatial is not None else None,
}
return LossOutput(
loss=total_loss,
reconstruction_loss=spatial_reconst_loss, # Use spatial modality reconstruction loss as main reconstruction loss
kl_local=spatial_kl_local,
extra_metrics=extra_metrics,
)
[docs] def _get_reconstruction_loss_nonspatial(
self, x: torch.Tensor, generative_outputs: Dict[str, torch.Tensor]
) -> torch.Tensor:
"""Calculate non-spatial modality reconstruction loss.
Parameters
----------
x
Input data
generative_outputs
Generative process outputs
Returns
-------
torch.Tensor
Reconstruction loss
"""
# Calculate reconstruction loss based on different likelihood functions
px_rate = generative_outputs["px_rate"]
if self.gene_likelihood == "zinb":
px_r = generative_outputs["px_r"]
px_dropout = generative_outputs["px_dropout"]
# Use zero-inflated negative binomial distribution
reconst_loss = -log_zinb_positive(x, px_rate, px_r, px_dropout)
elif self.gene_likelihood == "nb":
px_r = generative_outputs["px_r"]
# Use negative binomial distribution
reconst_loss = -log_nb_positive(x, px_rate, px_r)
elif self.gene_likelihood == "poisson":
# Use Poisson distribution
reconst_loss = -log_poisson(x, px_rate)
else: # normal
# Use normal distribution
reconst_loss = -log_normal(x, px_rate, torch.ones_like(px_rate))
return torch.mean(reconst_loss.sum(dim=-1))
@torch.inference_mode()
[docs] def get_latent_representation_by_modality(
self, adata=None, indices=None, batch_size=None, modality="spatial"
) -> np.ndarray:
"""Get latent representation of specific modality.
Parameters
----------
adata
AnnData object, optional
indices
Index to get representation, optional
batch_size
Batch processing size, optional
modality
Modality to get, can be "spatial", "nonspatial" or "fused"
Returns
-------
np.ndarray
Latent representation
"""
if modality == "spatial":
# Get standard latent representation
return self.get_latent_representation(adata, indices, batch_size)
elif modality == "nonspatial":
# If not joint training, return latent representation of spatial modality
if not hasattr(self, "nonspatial_encoder"):
logger.warning(
"Model does not have non-spatial encoder, will return latent representation of spatial modality"
)
return self.get_latent_representation(adata, indices, batch_size)
# Get latent representation of non-spatial modality
# This requires running inference process first
# Implement similar to get_latent_representation but using nonspatial_encoder
# ...
raise NotImplementedError("Non-spatial latent representation retrieval not implemented yet")
elif modality == "fused":
return self.get_fused_representation(adata, indices, batch_size)
else:
raise ValueError(f"Unsupported modality: {modality}, valid values are 'spatial', 'nonspatial' or 'fused'")
@torch.inference_mode()
[docs] def get_fused_representation(self, adata=None, indices=None, batch_size=None) -> np.ndarray:
"""Get fused latent representation.
Parameters
----------
adata
AnnData object, optional
indices
Index to get representation, optional
batch_size
Batch processing size, optional
Returns
-------
np.ndarray
Fused latent representation
"""
# If not joint training, return latent representation of spatial modality
if not hasattr(self, "nonspatial_encoder"):
logger.warning(
"Model does not have non-spatial encoder, will return latent representation of spatial modality"
)
return self.get_latent_representation(adata, indices, batch_size)
# Need to get latent representations of both modalities and fuse them
# In actual application, this may require more complex implementation
return self.get_latent_representation(adata, indices, batch_size)
@torch.inference_mode()
[docs] def get_nonspatial_specific_features(self, adata=None, indices=None, batch_size=None) -> np.ndarray:
"""Get non-spatial modality specific features.
Parameters
----------
adata
AnnData object, optional
indices
Index to get representation, optional
batch_size
Batch processing size, optional
Returns
-------
np.ndarray
Non-spatial modality specific features
"""
# If not joint training, return None
if not hasattr(self, "nonspatial_encoder"):
logger.warning("Model does not have non-spatial encoder, cannot get non-spatial features")
return None
# Implement logic to get non-spatial modality specific features
# This may require additional network layers
return None
# 引入一些可能需要的辅助函数
[docs]def log_zinb_positive(x, mu, theta, pi, eps=1e-8):
"""Log likelihood of zero-inflated negative binomial distribution."""
case_zero = torch.log(pi + ((1 - pi) * torch.pow(theta / (theta + mu), theta)))
case_non_zero = (
torch.log(1 - pi)
+ torch.lgamma(theta + x)
- torch.lgamma(theta)
- torch.lgamma(x + 1)
+ theta * torch.log(theta)
+ x * torch.log(mu)
- (x + theta) * torch.log(theta + mu)
)
return torch.where(x < eps, case_zero, case_non_zero)
[docs]def log_nb_positive(x, mu, theta, eps=1e-8):
"""Log likelihood of negative binomial distribution."""
return (
torch.lgamma(theta + x)
- torch.lgamma(theta)
- torch.lgamma(x + 1)
+ theta * torch.log(theta)
+ x * torch.log(mu)
- (x + theta) * torch.log(theta + mu)
)
[docs]def log_poisson(x, mu, eps=1e-8):
"""Log likelihood of Poisson distribution."""
return x * torch.log(mu) - mu - torch.lgamma(x + 1)
[docs]def log_normal(x, mu, var, eps=1e-8):
"""Log likelihood of normal distribution."""
return -0.5 * torch.log(2 * np.pi * var) - 0.5 * torch.pow(x - mu, 2) / var