Source code for UQpy.scientific_machine_learning.neural_networks.DeepOperatorNetwork

import torch
import torch.nn as nn
from UQpy.scientific_machine_learning.baseclass.NeuralNetwork import NeuralNetwork
from UQpy.utilities.ValidationTypes import PositiveInteger


[docs]class DeepOperatorNetwork(NeuralNetwork): def __init__( self, branch_network: nn.Module, trunk_network: nn.Module, out_channels: PositiveInteger = 1, ): r"""Construct a Deep Operator Network via its branch and trunk networks :param branch_network: Encodes mapping of the function :math:`f(x)` :param trunk_network: Encodes mapping of the domain :math:`y` for :math:`\mathcal{L}f(y)` :param out_channels: Number of channels produced by the Deep Operator Network .. note:: The last layer of the branch and trunk network must have the same number of neurons so the last dimension of their outputs match, i.e. both outputs have shape :math:`(*, \text{width})`. Additionally, :math:`\text{width}` must be divisible by :math:`C_\text{out}`. Shape: - Input: Two tensors representing :math:`x` and :math:`f(x)` - Branch Network (:math:`f(x)`): Any shape (can be different from trunk) - Trunk Network (:math:`x`): Any shape (can be different from branch) - Intermediary: The output from the branch and trunk network must be of shapes :math:`(*, \text{width})` and :math:`(*, \text{width})`. Where :math:`*` refers to any broadcastable dimensions. Both the tensors are viewed reshaped as :math:`(*, C_\text{out}. \frac{\text{width}}{C_\text{out}}` before the dot product is computed. - Output: Tensor of shape :math:`(*, C_\text{out})` """ super().__init__() self.branch_network: nn.Module = branch_network """Architecture of the branch neural network defined by a :py:class:`torch.nn.Module`""" self.trunk_network: nn.Module = trunk_network """Architecture of the trunk neural network defined by a :py:class:`torch.nn.Module`""" self.out_channels = out_channels
[docs] def forward( self, x: torch.Tensor, f_x: torch.Tensor, ) -> torch.Tensor: """Compute the dot product of branch and trunk outputs :param x: Input to the :code:`trunk_network` :param f_x: Input to the :code:`branch_network` :return: Dot product of the branch and trunk outputs :raises RuntimeError: If incompatible trunk and branch outputs are encountered. See Shape for details. """ branch_output = self.branch_network(f_x) trunk_output = self.trunk_network(x) if branch_output.size(-1) != trunk_output.size(-1): raise RuntimeError( f"UQpy: Incompatible trunk {trunk_output.shape} and branch {branch_output.shape} output shapes." f"\nTrunk and branch output must have shape (*, width). " ) width = branch_output.size(-1) if width % self.out_channels != 0: raise RuntimeError( f"UQpy: Branch and trunk width {width} must be divisible by out_channels {self.out_channels}" ) return torch.einsum( "...i,...i", branch_output.view( *branch_output.shape[:-1], self.out_channels, width // self.out_channels ), trunk_output.view( *trunk_output.shape[:-1], self.out_channels, width // self.out_channels ), )
def extra_repr(self) -> str: if self.out_channels != 1: return f"out_channels={self.out_channels}" return ""