Source code for UQpy.scientific_machine_learning.neural_networks.Unet

import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
from UQpy.scientific_machine_learning.baseclass.NeuralNetwork import NeuralNetwork
from UQpy.utilities.ValidationTypes import PositiveInteger


[docs]class Unet(NeuralNetwork): def __init__( self, n_filters: list[PositiveInteger], kernel_size: PositiveInteger, out_channels: PositiveInteger, layer_type: nn.Module = nn.Conv2d, ): r"""Construct U-net convolutional neural network for mean-field prediction :param n_filters: A list of positive integers specifying the number of filters at for each convolutional layer in the encoding and decoding paths. The length of the list determines the depth of the U-Net. :param kernel_size: The size of the convolutional kernels. This value is used for all convolutional layers. Standard kernel size options are 3, 6, or 9. :param out_channels: The number of output channels in the final convolutional layer. :param layer_type: The type of convolutional layer to use. The default is the ``nn.Conv2d``. It can be replaced with Bayesian layers for performing uncertainty quantification. .. note:: A default value ``stride=2`` is used for the max pooling layers. A default padding value of ``kernel_size // 2`` is used for all convolutional layers. Shape: - Input: Tensor of shape :math:`(N, C_\text{in}, H, W)` - Output: Tensor of shape :math:`(N, C_\text{out}, H, W)` Attributes: Encoder Layers: The encoding blocks are created during initialization from the ``n_filters`` list for indices :math:`i=1, \dots, \text{len}(\texttt{n_filters})- 1`. - **encoder_maxpool_i** (:py:class:`torch.nn.MaxPool2d`): Max pooling layer for downsampling at encoder layer ``i`` (for ``i > 1``). - **encoder_conv_1_i** (:py:class:`torch.nn.Conv2d`): First convolutional layer at encoder layer ``i``. - **encoder_bn_1_i** (:py:class:`torch.nn.BatchNorm2d`): Batch normalization layer after ``encoder_conv_1_i``. - **encoder_conv_2_i** (:py:class:`torch.nn.Conv2d`): Second convolutional layer at encoder layer ``i``. - **encoder_bn_2_i** (:py:class:`torch.nn.BatchNorm2d`): Batch normalization layer after ``encoder_conv_2_i``. Decoder Layers: - **decoder_upsample_i** (:py:class:`torch.nn.Upsample`): Upsampling layer at decoder layer ``i``. - **decoder_conv_1_i** (:py:class:`torch.nn.Conv2d`): First convolutional at decoder layer ``i``. - **decoder_bn_1_i** (:py:class:`torch.nn.BatchNorm2d`): Batch normalization layer after ``decoder_conv_1_i``. - **decoder_conv_2_i** (:py:class:`torch.nn.Conv2d`): Second convolutional layer at decoder layer ``i``. - **decoder_bn_2_i** (:py:class:`torch.nn.BatchNorm2d`): Batch normalization layer after ``decoder_conv_2_i``. Final Convolution Layer: - **final_conv** (:py:class:'torch.nn.Conv2d'): Convolutional layer applied after the last decoder block. It maps the output to the desired number of channels. """ super(Unet, self).__init__() self.n_filters = n_filters self.kernel_size = kernel_size self.out_channels = out_channels self.layer_type = layer_type self.logger = logging.getLogger(__name__) for i in range(1, len(self.n_filters)): # initialize encoding blocks in_channels = self.n_filters[i - 1] out_channels = self.n_filters[i] if i != 1: setattr( self, f"encoder_maxpool_{i}", nn.MaxPool2d(kernel_size=2, stride=2) ) setattr( self, f"encoder_conv_1_{i}", self.layer_type( in_channels, out_channels, kernel_size=self.kernel_size, padding=self.kernel_size // 2, ), ) setattr(self, f"encoder_bn_1_{i}", nn.BatchNorm2d(out_channels)) setattr( self, f"encoder_conv_2_{i}", self.layer_type( out_channels, out_channels, kernel_size=self.kernel_size, padding=self.kernel_size // 2, ), ) setattr(self, f"encoder_bn_2_{i}", nn.BatchNorm2d(out_channels)) for i in range(len(self.n_filters) - 1, 1, -1): # initialize decoding blocks combined_channels = self.n_filters[i] + self.n_filters[i - 1] out_channels = self.n_filters[i - 1] setattr( self, f"decoder_upsample_{i}", nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), ) setattr( self, f"decoder_conv_1_{i}", self.layer_type( combined_channels, out_channels, kernel_size=self.kernel_size, padding=self.kernel_size // 2, ), ) setattr(self, f"decoder_bn_1_{i}", nn.BatchNorm2d(out_channels)) setattr( self, f"decoder_conv_2_{i}", self.layer_type( out_channels, out_channels, kernel_size=self.kernel_size, padding=self.kernel_size // 2, ), ) setattr(self, f"decoder_bn_2_{i}", nn.BatchNorm2d(out_channels)) self.final_conv = nn.Conv2d( self.n_filters[1], self.out_channels, kernel_size=1, padding=0 )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the U-Net model. The output is computed by passing the input through each encoding and decoding block together with the skip connections. :param x: Tensor of shape :math:`(N, C_\text{in}, H, W)` :return: Tensor of shape :math:`(N, C_\text{out}, H, W)` """ encoder_outputs = [] # Pass through encoder layers for i in range(1, len(self.n_filters)): x = self.optional_step_en(x, i) if i != 1: x = getattr(self, f"encoder_maxpool_{i}")(x) x = getattr(self, f"encoder_conv_1_{i}")(x) x = getattr(self, f"encoder_bn_1_{i}")(x) x = F.relu(x) x = getattr(self, f"encoder_conv_2_{i}")(x) x = getattr(self, f"encoder_bn_2_{i}")(x) x = F.relu(x) encoder_outputs.append(x) # Pass through decoder layers for idx, i in enumerate(range(len(self.n_filters) - 1, 1, -1)): x = getattr(self, f"decoder_upsample_{i}")(x) # get skip connection and concatenate to output skip_input = encoder_outputs[-(idx + 2)] x = torch.cat([x, skip_input], dim=1) # option dropout x = self.optional_step_dec(x, i) x = getattr(self, f"decoder_conv_1_{i}")(x) x = getattr(self, f"decoder_bn_1_{i}")(x) x = F.relu(x) x = getattr(self, f"decoder_conv_2_{i}")(x) x = getattr(self, f"decoder_bn_2_{i}")(x) x = F.relu(x) x = self.final_conv(x) return x
[docs] def optional_step_en(self, x: torch.Tensor, i: int): """Optional method for additional operators during encoding Intended to be overridden by subclasses to apply operations like Monte Carlo Dropout based on the layer index i. :param x: Input tensor :param i: Index of the encoding block :return: Output tensor """ return x
[docs] def optional_step_dec(self, x: torch.Tensor, i: int): """Optional method for additional operations during decoding Intended to be overridden by subclasses to apply operations like Monte Carlo Dropout based on the layer index i. :param x: Input tensor :param i: Index of the decoding block :return: Output tensor """ return x