Source code for UQpy.scientific_machine_learning.neural_networks.FeedForwardNeuralNetwork

import torch.nn as nn
import logging
from UQpy.scientific_machine_learning.baseclass import NeuralNetwork


[docs]class FeedForwardNeuralNetwork(NeuralNetwork): def __init__(self, network: nn.Module): """Initialize a typical feed-forward neural network using the architecture provided by ``network`` :param network: Network defining the function from :math:`f(x)=y` """ super().__init__() self.network: nn.Module = network """Neural network architecture defined as a :py:class:`torch.nn.Module`""" # set all layers to the same training, dropping, and sampling mode training = False dropping = False sampling = False for m in self.network.modules(): training = training or m.training if hasattr(m, "dropping"): dropping = dropping or m.dropping if hasattr(m, "sampling"): sampling = sampling or m.sampling self.train(training) self.drop(dropping) self.sample(sampling) self.logger = logging.getLogger(__name__)
[docs] def forward(self, *args, **kwargs): """Forward call of the neural network :param args: Input arguments pass to ``network`` :param kwargs: Keyword arguments passed to ``network`` :return: Output arguments """ return self.network(*args, **kwargs)