# import torch.nn.functional as F
from dataclasses import dataclass, field
import torch
import torch.nn as nn
# from dgl.nn import GCN2Conv, GraphConv
@dataclass
[docs]class Args:
[docs] device: str = field(init=False)
[docs] use_encoder: bool = False
[docs] def __post_init__(self):
if self.gpu != -1 and torch.cuda.is_available():
self.device = "cuda:{}".format(self.gpu)
else:
self.device = "cpu"
# fix the div zero standard deviation bug, Shuchen Luo (20220217)
[docs]def standardize(x, eps=1e-12):
return (x - x.mean(0)) / x.std(0).clamp(eps)
[docs]class Encoder(nn.Module):
def __init__(self, in_dim: int, encoder_dim: int):
super().__init__()
[docs] self.layer = nn.Linear(in_dim, encoder_dim, bias=True)
[docs] def forward(self, x):
return self.relu(self.layer(x))
# GCN2Conv(in_feats, layer, alpha=0.1, lambda_=1, project_initial_features=True, allow_zero_in_degree=False, bias=True, activation=None)
[docs]class GCNII(nn.Module):
def __init__(self, in_dim: int, encoder_dim: int, n_layers: int, alpha=None, lambda_=None, use_encoder=False):
super().__init__()
from dgl.nn import GCN2Conv, GraphConv
[docs] self.n_layers = n_layers
[docs] self.use_encoder = use_encoder
if alpha is None:
self.alpha = [0.1] * self.n_layers
else:
self.alpha = alpha
if lambda_ is None:
self.lambda_ = [1.0] * self.n_layers
else:
self.lambda_ = lambda_
if self.use_encoder:
self.encoder = Encoder(in_dim, encoder_dim)
self.hid_dim = encoder_dim
else:
self.hid_dim = in_dim
[docs] self.convs = nn.ModuleList()
for i in range(n_layers):
self.convs.append(
GCN2Conv(self.hid_dim, i + 1, alpha=self.alpha[i], lambda_=self.lambda_[i], activation=None)
)
[docs] def forward(self, graph, x):
if self.use_encoder:
x = self.encoder(x)
# print('GCNII forward: after encoder', torch.any(torch.isnan(x)))
feat0 = x
for i in range(self.n_layers):
x = self.relu(self.convs[i](graph, x, feat0))
# print('GCNII layer', i + 1, 'is_nan', torch.any(torch.isnan(x)))
return x
[docs]class GCN(nn.Module):
def __init__(self, in_dim: int, encoder_dim: int, n_layers: int, use_encoder=False):
from dgl.nn import GCN2Conv, GraphConv
super().__init__()
[docs] self.n_layers = n_layers
[docs] self.use_encoder = use_encoder
if self.use_encoder:
self.encoder = Encoder(in_dim, encoder_dim)
self.hid_dim = encoder_dim
else:
self.hid_dim = in_dim
[docs] self.convs = nn.ModuleList()
for i in range(n_layers):
self.convs.append(GraphConv(self.hid_dim, self.hid_dim, activation=None))
[docs] def forward(self, graph, x):
if self.use_encoder:
x = self.encoder(x)
# print('GCN forward: after encoder', torch.any(torch.isnan(x)))
for i in range(self.n_layers):
x = self.relu(self.convs[i](graph, x))
# print('GCN layer', i + 1, 'is_nan', torch.any(torch.isnan(x)))
return x
[docs]class CCA_SSG(nn.Module):
def __init__(self, in_dim, encoder_dim, n_layers, backbone="GCNII", alpha=None, lambda_=None, use_encoder=False):
super().__init__()
if backbone == "GCNII":
self.backbone = GCNII(in_dim, encoder_dim, n_layers, alpha, lambda_, use_encoder)
elif backbone == "GCN":
self.backbone = GCN(in_dim, encoder_dim, n_layers, use_encoder)
[docs] def get_embedding(self, graph, feat):
out = self.backbone(graph, feat)
return out.detach()
[docs] def forward(self, graph1, feat1, graph2, feat2):
h1 = self.backbone(graph1, feat1)
h2 = self.backbone(graph2, feat2)
# print('CCASSG forward: h1 is', torch.any(torch.isnan(h1)))
# print('CCASSG forward: h2 is', torch.any(torch.isnan(h2)))
z1 = standardize(h1)
z2 = standardize(h2)
# print('h1.std', h1.std(0))
# print('h1-h1.mean(0)', h1 - h1.mean(0))
# print('CCASSG forward: z1 is', torch.any(torch.isnan(z1)))
# print('CCASSG forward: z2 is', torch.any(torch.isnan(z2)))
return z1, z2