spateo.tools.spatial_impute.impute_model#

Module Contents#

Classes#

Discriminator

Module that learns associations between graph embeddings and their positively-labeled augmentations

AvgReadout

Aggregates graph embedding information over graph neighborhoods to obtain global representation of the graph

Encoder

Representation learning for spatial transcriptomics data

class spateo.tools.spatial_impute.impute_model.Discriminator(nf: int)[source]#

Bases: torch.nn.Module

Module that learns associations between graph embeddings and their positively-labeled augmentations

Parameters
nf

Dimensionality (along the feature axis) of the input array

weights_init(m)[source]#
forward(g_repr: torch.FloatTensor, g_pos: torch.FloatTensor, g_neg: torch.FloatTensor)[source]#

Feeds data forward through network and computes graph representations

Parameters
g_repr

Representation of source graph, with aggregated neighborhood representations

g_pos

Representation of augmentation of the source graph that can be considered a positive pairing, with aggregated neighborhood representations

g_neg

Representation of augmentation of the source graph that can be considered a negative pairing, with aggregated neighborhood representations

Returns

Similarity score for the positive and negative paired graphs

Return type

logits

class spateo.tools.spatial_impute.impute_model.AvgReadout[source]#

Bases: torch.nn.Module

Aggregates graph embedding information over graph neighborhoods to obtain global representation of the graph

forward(emb: torch.FloatTensor, mask: torch.FloatTensor)[source]#
Parameters
emb

float tensor Graph embedding

mask

float tensor Selects elements to aggregate for each row

class spateo.tools.spatial_impute.impute_model.Encoder(in_features: int, out_features: int, graph_neigh: torch.FloatTensor, dropout: float = 0.0, act=F.relu, clip: Union[None, float] = None)[source]#

Bases: torch.nn.modules.module.Module

Representation learning for spatial transcriptomics data

Parameters
in_features

Number of features in the dataset

out_features

Size of the desired encoding

graph_neigh

Pairwise adjacency matrix indicating which spots are neighbors of which other spots

dropout

Proportion of weights in each layer to set to 0

act

object of class torch.nn.functional, default F.relu. Activation function for each encoder layer

clip

Threshold below which imputed feature values will be set to 0, as a percentile of the max value

reset_parameters()[source]#
forward(feat: torch.FloatTensor, feat_a: torch.FloatTensor, adj: torch.FloatTensor)[source]#
Parameters
feat

Counts matrix

feat_a

Counts matrix following permutation and augmentation

adj

Pairwise distance matrix