Bayes By Backpropagation Trainer (BBBTrainer)

Class to train a neural network using the Bayes by Backpropagation [45] method and a Pytorch optimization algorithm.

The BBBTrainer class is imported using the following command:

>>> from UQpy.scientific_machine_learning.trainers.BBBTrainer import BBBTrainer

Methods

class BBBTrainer(model, optimizer, scheduler=None, loss_function=MSELoss(), divergence=GaussianKullbackLeiblerDivergence())[source]

Prepare to train a Bayesian neural network using Bayes by back propagation

Parameters:
  • model (Module) – Bayesian Neural Network model to be trained

  • optimizer (Optimizer) – Optimization algorithm used to update model parameters

  • scheduler (Union[LRScheduler, list, None]) – Scheduler used to adjust the learning rate of the optimizer. Schedulers may be chained together by creating a list of schedulers

  • loss_function (Module) – Function used to compute negative log likelihood of the data during training

  • divergence (Module) – Divergence measured between prior and posterior distribution of Bayesian layers Default: sml.GaussianKullbackLeiblerLoss()

run(train_data=None, test_data=None, epochs=100, num_samples=1, tolerance=0.0, beta=1.0)[source]

Run the optimizer algorithm to learn the parameters of the model that fit train_data

Parameters:
  • train_data (Optional[DataLoader]) – Data used to compute model loss

  • test_data (Optional[DataLoader]) – Data used to validate the performance of the model

  • epochs (int) – Maximum number of epochs to run the optimizer for

  • num_samples (int) – Number of Monte Carlo samples to approximate the loss

  • tolerance (float) – Optimization terminates early if average training loss is below tolerance. Default: 0.0

  • beta (float) – Weighting for the divergence term in ELBO loss. Default: 1.0

Raises:

RuntimeError – If neither train_data nor test_data is provided, a RuntimeError occurs.

Attributes

BBBTrainer.history: dict

Record of the loss during training and validation. Note if training ends early there may be NaN values, as the histories are initialized with NaN.

  • history["train_loss"] contains the training history as a torch.Tensor.

  • history["train_divergence"] contains the training divergence as a torch.Tensor.

  • history["train_nll"] contains the training negative log likelihood loss as a torch.Tensor.

  • history["test_nll"] contains the testing negative log likelihood loss as a torch.Tensor.

Examples