Source code for spateo.tdr.interpolations.interpolation_deeplearn.interpolation_nn

import numpy as np
import torch
import torch.nn as nn

###########################################
# THE HIGH-TO-LOW-DIMENSIONAL TRANSFORMER #
###########################################


[docs]class A(nn.Module): def __init__( self, network_dim, data_dim, hidden_features=256, hidden_layers=1, activation_function=torch.nn.functional.leaky_relu, ): super(A, self).__init__() # Call to the super-class is necessary self.f = activation_function self.name = "model/A" self.layer1 = nn.Linear(data_dim, hidden_features) self.net = [] for i in range(hidden_layers): self.net.append(nn.Linear(hidden_features, hidden_features)) self.hidden_layers = nn.Sequential(*self.net) self.outlayer = nn.Linear(256, network_dim) # torch.nn.init.normal_(self.layer1.weight, std=.02) # torch.nn.init.normal_(self.layer2.weight, std=.02) # torch.nn.init.normal_(self.layer3.weight, std=.02)
[docs] def forward(self, inp): out = self.f(self.layer1(inp), negative_slope=0.2) out = self.f(self.hidden_layers(out), negative_slope=0.2) out = self.outlayer(out) return out
########################################### # THE LOW-TO-HIGH-DIMENSIONAL TRANSFORMER # ###########################################
[docs]class B(nn.Module): def __init__( self, network_dim, data_dim, hidden_features=256, hidden_layers=3, activation_function=torch.nn.functional.leaky_relu, ): super(B, self).__init__() # Call to the super-class is necessary self.f = activation_function self.name = "model/B" self.layer1 = nn.Linear(network_dim, hidden_features) self.net = [] for i in range(hidden_layers): self.net.append(nn.Linear(hidden_features, hidden_features)) self.hidden_layers = nn.Sequential(*self.net) self.outlayer = nn.Linear(hidden_features, data_dim) # torch.nn.init.normal_(self.layer1.weight, std=.02) # torch.nn.init.normal_(self.layer2.weight, std=.02) # torch.nn.init.normal_(self.layer3.weight, std=.02)
[docs] def forward(self, inp): out = self.f(self.layer1(inp), negative_slope=0.2) out = self.f(self.hidden_layers(out), negative_slope=0.2) out = self.outlayer(out) return out
[docs]class SineLayer(nn.Module): # from https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb#scrollTo=uTQfrFvah3Zc # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0. # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the # nonlinearity. Different signals may require different omega_0 in the first layer - this is a # hyperparameter. # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5) """ As discussed above, we aim to provide each sine nonlinearity with activations that are standard normal distributed, except in the case of the first layer, where we introduced a factor ω0 that increased the spatial frequency of the first layer to better match the frequency spectrum of the signal. However, we found that the training of SIREN can be accelerated by leveraging a factor ω0 in all layers of the SIREN, by factorizing the weight matrix W as W = Wˆ ∗ ω0, choosing. This keeps the distribution of activations constant, but boosts gradients to the weight matrix Wˆ by the factor ω0 while leaving gradients w.r.t. the input of the sine neuron unchanged """ def __init__(self, in_features, out_features, bias=True, is_first=False, omega_0=30.0): super().__init__() self.omega_0 = omega_0 self.is_first = is_first self.in_features = in_features self.linear = nn.Linear(in_features, out_features, bias=bias) self.init_weights()
[docs] def init_weights(self): with torch.no_grad(): if self.is_first: self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features) else: self.linear.weight.uniform_( -np.sqrt(6 / self.in_features) / self.omega_0, np.sqrt(6 / self.in_features) / self.omega_0, )
[docs] def forward(self, input): return torch.sin(self.omega_0 * self.linear(input))
[docs] def forward_with_intermediate(self, input): # For visualization of activation distributions intermediate = self.omega_0 * self.linear(input) return torch.sin(intermediate), intermediate
[docs]class h(nn.Module): def __init__( self, input_network_dim, output_network_dim, hidden_features=256, hidden_layers=3, sirens=False, first_omega_0=30.0, hidden_omega_0=30.0, ): self.sirens, self.first_omega_0, self.hidden_omega_0 = ( sirens, first_omega_0, hidden_omega_0, ) super(h, self).__init__() # Call to the super-class is necessary self.f = torch.sin if self.sirens else torch.nn.functional.leaky_relu self.name = "model/h" self.layer1 = nn.Linear(input_network_dim, hidden_features) if sirens: torch.nn.init.uniform_(self.layer1.weight, -1 / input_network_dim, 1 / input_network_dim) self.net = [] for i in range(hidden_layers): if sirens: self.net.append( SineLayer( hidden_features, hidden_features, is_first=False, omega_0=self.hidden_omega_0, ) ) else: self.net.append(nn.Linear(hidden_features, hidden_features)) self.hidden_layers = nn.Sequential(*self.net) self.outlayer = nn.Linear(hidden_features, output_network_dim) if sirens: torch.nn.init.uniform_( self.outlayer.weight, -np.sqrt(6 / hidden_features) / self.hidden_omega_0, np.sqrt(6 / hidden_features) / self.hidden_omega_0, )
[docs] def forward(self, inp): out = ( self.f(self.first_omega_0 * self.layer1(inp)) if self.sirens else self.f(self.layer1(inp), negative_slope=0.2) ) # , negative_slope=0.2 out = self.hidden_layers(out) if self.sirens else self.f(self.hidden_layers(out), negative_slope=0.2) # out = self.outlayer(out) return out
[docs]class MainFlow(torch.nn.Module): def __init__(self, h, A=None, B=None, enforce_positivity=False): super(MainFlow, self).__init__() self.A = A self.B = B self.h = h self.enforce_positivity = enforce_positivity
[docs] def forward(self, t, x, freeze=None): x_low = self.A(x) if self.A is not None else x e_low = self.h.forward(x_low) e_hat = self.B(e_low) if self.B is not None else e_low if freeze is not None: for i in freeze: if len(e_hat.shape) == 1: e_hat[i] = 0 elif len(e_hat.shape) == 2: e_hat[:, i] = 0 else: raise ValueError("Invalid output data shape. Please debug.") # forcing the x to remain positive: set velocity to 0 if x<=0 and v<0 if self.enforce_positivity: e_hat *= ~(e_hat < 0) # ~((x <= 0) * (v_hat < 0)) return e_hat