Source code for spateo.tdr.interpolations.interpolation_gaussianprocess.gp_train
importgpytorchimporttorchfromtqdmimporttqdm
[docs]defgp_train(model,likelihood,train_loader,train_epochs,method,N,device):iftorch.cuda.is_available()anddevice!="cpu":model=model.cuda()likelihood=likelihood.cuda()model.train()likelihood.train()# define the mll (loss)ifmethod=="SVGP":mll=gpytorch.mlls.VariationalELBO(likelihood,model,num_data=N)else:mll=gpytorch.mlls.ExactMarginalLogLikelihood(likelihood,model)optimizer=torch.optim.Adam([{"params":model.parameters()},{"params":likelihood.parameters()},],lr=0.01,)epochs_iter=tqdm(range(train_epochs),desc="Epoch")foriinepochs_iter:ifmethod=="SVGP":# Within each iteration, we will go over each minibatch of dataminibatch_iter=tqdm(train_loader,desc="Minibatch",leave=True)forx_batch,y_batchinminibatch_iter:optimizer.zero_grad()output=model(x_batch)loss=-mll(output,y_batch)minibatch_iter.set_postfix(loss=loss.item())loss.backward()optimizer.step()else:# Zero gradients from previous iterationoptimizer.zero_grad()# Output from modeloutput=model(train_loader["train_x"])# Calc loss and backprop gradientsloss=-mll(output,train_loader["train_y"])loss.backward()optimizer.step()