Source code for spateo.tdr.interpolations.interpolation_gaussianprocess.gp_train
import gpytorch
import torch
from tqdm import tqdm
[docs]def gp_train(model, likelihood, train_loader, train_epochs, method, N, device):
if torch.cuda.is_available() and device != "cpu":
model = model.cuda()
likelihood = likelihood.cuda()
model.train()
likelihood.train()
# define the mll (loss)
if method == "SVGP":
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=N)
else:
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
optimizer = torch.optim.Adam(
[
{"params": model.parameters()},
{"params": likelihood.parameters()},
],
lr=0.01,
)
epochs_iter = tqdm(range(train_epochs), desc="Epoch")
for i in epochs_iter:
if method == "SVGP":
# Within each iteration, we will go over each minibatch of data
minibatch_iter = tqdm(train_loader, desc="Minibatch", leave=True)
for x_batch, y_batch in minibatch_iter:
optimizer.zero_grad()
output = model(x_batch)
loss = -mll(output, y_batch)
minibatch_iter.set_postfix(loss=loss.item())
loss.backward()
optimizer.step()
else:
# Zero gradients from previous iteration
optimizer.zero_grad()
# Output from model
output = model(train_loader["train_x"])
# Calc loss and backprop gradients
loss = -mll(output, train_loader["train_y"])
loss.backward()
optimizer.step()