Bayes By Backpropagation Trainer (BBBTrainer)
Class to train a neural network using the Bayes by Backpropagation [45] method and a Pytorch optimization algorithm.
The BBBTrainer class is imported using the following command:
>>> from UQpy.scientific_machine_learning.trainers.BBBTrainer import BBBTrainer
Methods
- class BBBTrainer(model, optimizer, scheduler=None, loss_function=MSELoss(), divergence=GaussianKullbackLeiblerDivergence())[source]
Prepare to train a Bayesian neural network using Bayes by back propagation
- Parameters:
model (
Module) – Bayesian Neural Network model to be trainedoptimizer (
Optimizer) – Optimization algorithm used to updatemodelparametersscheduler (
Union[LRScheduler,list,None]) – Scheduler used to adjust the learning rate of theoptimizer. Schedulers may be chained together by creating a list of schedulersloss_function (
Module) – Function used to compute negative log likelihood of the data during trainingdivergence (
Module) – Divergence measured between prior and posterior distribution of Bayesian layers Default:sml.GaussianKullbackLeiblerLoss()
- run(train_data=None, test_data=None, epochs=100, num_samples=1, tolerance=0.0, beta=1.0)[source]
Run the
optimizeralgorithm to learn the parameters of themodelthat fittrain_data- Parameters:
train_data (
Optional[DataLoader]) – Data used to computemodellosstest_data (
Optional[DataLoader]) – Data used to validate the performance of the modelepochs (
int) – Maximum number of epochs to run theoptimizerfornum_samples (
int) – Number of Monte Carlo samples to approximate the losstolerance (
float) – Optimization terminates early if average training loss is below tolerance. Default: 0.0beta (
float) – Weighting for the divergence term in ELBO loss. Default: 1.0
- Raises:
RuntimeError – If neither
train_datanortest_datais provided, a RuntimeError occurs.
Attributes
-
BBBTrainer.history:
dict Record of the loss during training and validation. Note if training ends early there may be
NaNvalues, as the histories are initialized withNaN.history["train_loss"]contains the training history as atorch.Tensor.history["train_divergence"]contains the training divergence as atorch.Tensor.history["train_nll"]contains the training negative log likelihood loss as atorch.Tensor.history["test_nll"]contains the testing negative log likelihood loss as atorch.Tensor.