Source code for spateo.external.CAST.model.model_GCNII

# import torch.nn.functional as F
from dataclasses import dataclass, field

import torch
import torch.nn as nn

# from dgl.nn import GCN2Conv, GraphConv


@dataclass
[docs]class Args:
[docs] dataname: str
[docs] gpu: int = 0
[docs] epochs: int = 1000
[docs] lr1: float = 1e-3
[docs] wd1: float = 0.0
[docs] lambd: float = 1e-3
[docs] n_layers: int = 9
[docs] der: float = 0.2
[docs] dfr: float = 0.2
[docs] device: str = field(init=False)
[docs] encoder_dim: int = 256
[docs] use_encoder: bool = False
[docs] def __post_init__(self): if self.gpu != -1 and torch.cuda.is_available(): self.device = "cuda:{}".format(self.gpu) else: self.device = "cpu"
# fix the div zero standard deviation bug, Shuchen Luo (20220217)
[docs]def standardize(x, eps=1e-12): return (x - x.mean(0)) / x.std(0).clamp(eps)
[docs]class Encoder(nn.Module): def __init__(self, in_dim: int, encoder_dim: int): super().__init__()
[docs] self.layer = nn.Linear(in_dim, encoder_dim, bias=True)
[docs] self.relu = nn.ReLU()
[docs] def forward(self, x): return self.relu(self.layer(x))
# GCN2Conv(in_feats, layer, alpha=0.1, lambda_=1, project_initial_features=True, allow_zero_in_degree=False, bias=True, activation=None)
[docs]class GCNII(nn.Module): def __init__(self, in_dim: int, encoder_dim: int, n_layers: int, alpha=None, lambda_=None, use_encoder=False): super().__init__() from dgl.nn import GCN2Conv, GraphConv
[docs] self.n_layers = n_layers
[docs] self.use_encoder = use_encoder
if alpha is None: self.alpha = [0.1] * self.n_layers else: self.alpha = alpha if lambda_ is None: self.lambda_ = [1.0] * self.n_layers else: self.lambda_ = lambda_ if self.use_encoder: self.encoder = Encoder(in_dim, encoder_dim) self.hid_dim = encoder_dim else: self.hid_dim = in_dim
[docs] self.relu = nn.ReLU()
[docs] self.convs = nn.ModuleList()
for i in range(n_layers): self.convs.append( GCN2Conv(self.hid_dim, i + 1, alpha=self.alpha[i], lambda_=self.lambda_[i], activation=None) )
[docs] def forward(self, graph, x): if self.use_encoder: x = self.encoder(x) # print('GCNII forward: after encoder', torch.any(torch.isnan(x))) feat0 = x for i in range(self.n_layers): x = self.relu(self.convs[i](graph, x, feat0)) # print('GCNII layer', i + 1, 'is_nan', torch.any(torch.isnan(x))) return x
[docs]class GCN(nn.Module): def __init__(self, in_dim: int, encoder_dim: int, n_layers: int, use_encoder=False): from dgl.nn import GCN2Conv, GraphConv super().__init__()
[docs] self.n_layers = n_layers
[docs] self.use_encoder = use_encoder
if self.use_encoder: self.encoder = Encoder(in_dim, encoder_dim) self.hid_dim = encoder_dim else: self.hid_dim = in_dim
[docs] self.relu = nn.ReLU()
[docs] self.convs = nn.ModuleList()
for i in range(n_layers): self.convs.append(GraphConv(self.hid_dim, self.hid_dim, activation=None))
[docs] def forward(self, graph, x): if self.use_encoder: x = self.encoder(x) # print('GCN forward: after encoder', torch.any(torch.isnan(x))) for i in range(self.n_layers): x = self.relu(self.convs[i](graph, x)) # print('GCN layer', i + 1, 'is_nan', torch.any(torch.isnan(x))) return x
[docs]class CCA_SSG(nn.Module): def __init__(self, in_dim, encoder_dim, n_layers, backbone="GCNII", alpha=None, lambda_=None, use_encoder=False): super().__init__() if backbone == "GCNII": self.backbone = GCNII(in_dim, encoder_dim, n_layers, alpha, lambda_, use_encoder) elif backbone == "GCN": self.backbone = GCN(in_dim, encoder_dim, n_layers, use_encoder)
[docs] def get_embedding(self, graph, feat): out = self.backbone(graph, feat) return out.detach()
[docs] def forward(self, graph1, feat1, graph2, feat2): h1 = self.backbone(graph1, feat1) h2 = self.backbone(graph2, feat2) # print('CCASSG forward: h1 is', torch.any(torch.isnan(h1))) # print('CCASSG forward: h2 is', torch.any(torch.isnan(h2))) z1 = standardize(h1) z2 = standardize(h2) # print('h1.std', h1.std(0)) # print('h1-h1.mean(0)', h1 - h1.mean(0)) # print('CCASSG forward: z1 is', torch.any(torch.isnan(z1))) # print('CCASSG forward: z2 is', torch.any(torch.isnan(z2))) return z1, z2