Source code for UQpy.scientific_machine_learning.trainers.BBBTrainer

import torch
import torch.nn as nn
import UQpy.scientific_machine_learning as sml
import logging
from beartype import beartype
from UQpy.utilities.ValidationTypes import PositiveInteger
from typing import Union


[docs]@beartype class BBBTrainer: def __init__( self, model: nn.Module, optimizer: torch.optim.Optimizer, scheduler: Union[torch.optim.lr_scheduler.LRScheduler, list] = None, loss_function: nn.Module = nn.MSELoss(), divergence: nn.Module = sml.GaussianKullbackLeiblerDivergence(), ): """Prepare to train a Bayesian neural network using Bayes by back propagation :param model: Bayesian Neural Network model to be trained :param optimizer: Optimization algorithm used to update ``model`` parameters :param scheduler: Scheduler used to adjust the learning rate of the ``optimizer``. Schedulers may be chained together by creating a list of schedulers :param loss_function: Function used to compute negative log likelihood of the data during training :param divergence: Divergence measured between prior and posterior distribution of Bayesian layers Default: ``sml.GaussianKullbackLeiblerLoss()`` """ self.model = model self.optimizer = optimizer self.scheduler = ( [scheduler] if isinstance(scheduler, torch.optim.lr_scheduler.LRScheduler) else scheduler ) self.loss_function = loss_function self.divergence = divergence self.history: dict = { "train_loss": torch.tensor(torch.nan), "train_divergence": torch.tensor(torch.nan), "train_nll": torch.tensor(torch.nan), "test_nll": torch.tensor(torch.nan), } """Record of the loss during training and validation. Note if training ends early there may be ``NaN`` values, as the histories are initialized with ``NaN``. - ``history["train_loss"]`` contains the training history as a ``torch.Tensor``. - ``history["train_divergence"]`` contains the training divergence as a ``torch.Tensor``. - ``history["train_nll"]`` contains the training negative log likelihood loss as a ``torch.Tensor``. - ``history["test_nll"]`` contains the testing negative log likelihood loss as a ``torch.Tensor``. """ self.logger = logging.getLogger(__name__)
[docs] def run( self, train_data: torch.utils.data.DataLoader = None, test_data: torch.utils.data.DataLoader = None, epochs: PositiveInteger = 100, num_samples: PositiveInteger = 1, tolerance: float = 0.0, beta: float = 1.0, ): """Run the ``optimizer`` algorithm to learn the parameters of the ``model`` that fit ``train_data`` :param train_data: Data used to compute ``model`` loss :param test_data: Data used to validate the performance of the model :param epochs: Maximum number of epochs to run the ``optimizer`` for :param num_samples: Number of Monte Carlo samples to approximate the loss :param tolerance: Optimization terminates early if *average* training loss is below tolerance. Default: 0.0 :param beta: Weighting for the divergence term in ELBO loss. Default: 1.0 :raises RuntimeError: If neither ``train_data`` nor ``test_data`` is provided, a RuntimeError occurs. """ if train_data and not test_data: log_note = f"training {self.model.__class__.__name__}" elif not train_data and test_data: log_note = f"testing {self.model.__class__.__name__}" elif train_data and test_data: log_note = f"training and testing {self.model.__class__.__name__}" else: raise RuntimeError( "UQpy: At least one of `train_data` or `test_data` must be provided." ) if train_data: self.history["train_loss"] = torch.full( [epochs], torch.nan, requires_grad=False ) self.history["train_divergence"] = torch.full( [epochs], torch.nan, requires_grad=False ) self.history["train_nll"] = torch.full( [epochs], torch.nan, requires_grad=False ) if test_data: self.history["test_nll"] = torch.full( [epochs], torch.nan, requires_grad=False ) self.logger.info("UQpy: Scientific Machine Learning: Beginning " + log_note) i = 0 average_train_loss = torch.inf while i < epochs and average_train_loss >= tolerance: if train_data: self.model.train(True) self.model.sample(True) total_train_loss = 0 total_nll_loss = 0 total_divergence_loss = 0 for batch_number, (*x, y) in enumerate(train_data): nll_loss = torch.zeros(num_samples) for sample in range(num_samples): prediction = self.model(*x) nll_loss[sample] = self.loss_function(prediction, y) divergence_loss = self.divergence(self.model) mean_nll = torch.mean(nll_loss) train_loss = mean_nll + beta * divergence_loss train_loss.backward() self.optimizer.step() self.optimizer.zero_grad() total_train_loss += train_loss.item() total_nll_loss += mean_nll.item() total_divergence_loss += divergence_loss if self.scheduler: for s in self.scheduler: if isinstance(s, torch.optim.lr_scheduler.ReduceLROnPlateau): s.step(train_loss) else: s.step() average_train_loss = total_train_loss / len(train_data) average_train_nll = total_nll_loss / len(train_data) average_divergence_loss = total_divergence_loss / len(train_data) self.history["train_loss"][i] = average_train_loss self.history["train_nll"][i] = average_train_nll self.history["train_divergence"][i] = average_divergence_loss self.model.train(False) log_message = ( f"UQpy: Scientific Machine Learning: " f"Epoch {i + 1:,} / {epochs:,} " f"Train Loss {average_train_loss:.6e} " f"Train NLL {average_train_nll:.6e} " f"Train Divergence {average_divergence_loss:.6e}" ) if test_data: total_test_nll = 0 with torch.no_grad(): for batch_number, (*x, y) in enumerate(test_data): test_prediction = self.model(*x) test_nll = self.loss_function(test_prediction, y) total_test_nll += test_nll.item() average_test_nll = total_test_nll / len(test_data) self.history["test_nll"][i] = average_test_nll log_message += f" Test NLL {average_test_nll:.6e}" self.logger.info(log_message) i += 1 self.logger.info(f"UQpy: Scientific Machine Learning: Completed " + log_note)