Trainer
Class to train a neural network using a Pytorch optimization algorithm.
The Trainer class is imported using the following command:
>>> from UQpy.scientific_machine_learning import Trainer
Methods
- class Trainer(model, optimizer, loss_function=MSELoss(), scheduler=None)[source]
Prepare to train a neural network
- Parameters:
model (
Module) – Neural Network model to be trainedoptimizer (
Optimizer) – Optimization algorithm used to updatemodelparametersloss_function (
Module) – Function used to compute loss during trainingscheduler (
Union[LRScheduler,list,None]) – Scheduler used to adjust the learning rate of theoptimizer. Schedulers may be chained together by creating a list of schedulers
- run(train_data=None, test_data=None, epochs=100, tolerance=0.0)[source]
Run the
optimizeralgorithm to learn the parameters of themodelthat fittrain_data- Parameters:
train_data (
Optional[DataLoader]) – Data used to computemodellosstest_data (
Optional[DataLoader]) – Data used to validate the performance of the modelepochs (
int) – Maximum number of epochs to run theoptimizerfortolerance (
float) – Optimization terminates early if average training loss is below tolerance. Default: \(0.0\)
- Raises:
RuntimeError – If neither
train_datanortest_datais provided, a RuntimeError occurs.
Attributes
-
Trainer.history:
dict Record of the loss during training and validation. Note if training ends early there may be
NaNvalues, as the histories are initialized withNaN.history["train_loss"]contains training history as atorch.Tensor.history["test_loss"]contains testing history as atorch.Tensor.