U-net Convolutional Neural Network
The Unet class provides the implementation of the U-net neural network (U-Net) originally introduced by Ronneberger et al. [44]
The network is originally designed for image segmentation tasks but can also be generalized to perform image-to-image regression.
The architecture comprises a series of encoding blocks and decoding blocks, as shown below.
The architecture of the U-Net neural network.
Each encoding block consists of a repeated set of a convolutional layers of kernel, a batch normalization layer followed by a nonlinear activation function and a downsampling layer together with the rectified linear unit (ReLU) activation function. The maximum pooling (max-pooling) operation is used in the downsampling layer. The decoding blocks have the same structure as their encoding counterparts with the exception that the downsampling layers are replaced by upsampling layers. The last convolutional layer of kernel size combines the features of the last multi-channel output to a single prediction. The network also includes a number of skip connections between the contracting and expanding paths, aimed at combining high resolution features with abstract feature representations from the encoding path.
The Unet class is imported using the following command:
>>> from UQpy.scientific_machine_learning import Unet
Methods
- class Unet(n_filters, kernel_size, out_channels, layer_type=<class 'torch.nn.modules.conv.Conv2d'>)[source]
Construct U-net convolutional neural network for mean-field prediction
- Parameters:
n_filters (
list[int]) – A list of positive integers specifying the number of filters at for each convolutional layer in the encoding and decoding paths. The length of the list determines the depth of the U-Net.kernel_size (
int) – The size of the convolutional kernels. This value is used for all convolutional layers. Standard kernel size options are 3, 6, or 9.out_channels (
int) – The number of output channels in the final convolutional layer.layer_type (
Module) – The type of convolutional layer to use. The default is thenn.Conv2d. It can be replaced with Bayesian layers for performing uncertainty quantification.
Note
A default value
stride=2is used for the max pooling layers. A default padding value ofkernel_size // 2is used for all convolutional layers.Shape:
Input: Tensor of shape \((N, C_\text{in}, H, W)\)
Output: Tensor of shape \((N, C_\text{out}, H, W)\)
Attributes:
Encoder Layers: The encoding blocks are created during initialization from the
n_filterslist for indices \(i=1, \dots, \text{len}(\texttt{n_filters})- 1\).encoder_maxpool_i (
torch.nn.MaxPool2d): Max pooling layer for downsampling at encoder layeri(fori > 1).encoder_conv_1_i (
torch.nn.Conv2d): First convolutional layer at encoder layeri.encoder_bn_1_i (
torch.nn.BatchNorm2d): Batch normalization layer afterencoder_conv_1_i.encoder_conv_2_i (
torch.nn.Conv2d): Second convolutional layer at encoder layeri.encoder_bn_2_i (
torch.nn.BatchNorm2d): Batch normalization layer afterencoder_conv_2_i.
Decoder Layers:
decoder_upsample_i (
torch.nn.Upsample): Upsampling layer at decoder layeri.decoder_conv_1_i (
torch.nn.Conv2d): First convolutional at decoder layeri.decoder_bn_1_i (
torch.nn.BatchNorm2d): Batch normalization layer afterdecoder_conv_1_i.decoder_conv_2_i (
torch.nn.Conv2d): Second convolutional layer at decoder layeri.decoder_bn_2_i (
torch.nn.BatchNorm2d): Batch normalization layer afterdecoder_conv_2_i.
Final Convolution Layer:
final_conv (:py:class:’torch.nn.Conv2d’): Convolutional layer applied after the last decoder block. It maps the output to the desired number of channels.
- forward(x)[source]
Forward pass through the U-Net model.
The output is computed by passing the input through each encoding and decoding block together with the skip connections.
- Parameters:
x (
Tensor) – Tensor of shape \((N, C_ ext{in}, H, W)\)- Return type:
Tensor- Returns:
Tensor of shape \((N, C_ ext{out}, H, W)\)
- optional_step_en(x, i)[source]
Optional method for additional operators during encoding
Intended to be overridden by subclasses to apply operations like Monte Carlo Dropout based on the layer index i.
- Parameters:
x (
Tensor) – Input tensori (
int) – Index of the encoding block
- Returns:
Output tensor
- optional_step_dec(x, i)[source]
Optional method for additional operations during decoding
Intended to be overridden by subclasses to apply operations like Monte Carlo Dropout based on the layer index i.
- Parameters:
x (
Tensor) – Input tensori (
int) – Index of the decoding block
- Returns:
Output tensor
- summary(**kwargs)
Call
torchinfo.summary()onself. See torchinfo documentation for details.- Parameters:
kwargs – Keyword arguments passed to
torchinfo.summary.- Returns:
Model statistics
- count_parameters()
Get the total number of parameters that require a gradient computation in the model
Note on Convolutional Layer Parameters:
Each convolutional layer accepts an input volume with width \(W_1\), height \(H_1\), and depth \(D_1\).
- Input Dimensions: \(W_1 \times H_1 \times D_1\)
Number of Filters (K): Determines the number of filters (or kernels) used, affecting the depth \(D_2\) of the output volume.
Spatial Extent (F): The size of each filter, typically a square (e.g., 3x3).
Stride (S): The step size with which the filters are moved across the input volume.
Zero Padding (P): Number of pixels added to the border of the input volume, enabling control over the spatial dimensions of the output volume.
- Output Dimensions: \(W_2 \times H_2 \times D_2\)
Output Width: \(W_2 = \left( \frac{W_1 - F + 2P}{S} \right) + 1\)
Output Height: \(H_2 = \left( \frac{H_1 - F + 2P}{S} \right) + 1\)
Output Depth: \(D_2 = K\), corresponding to the number of filters.
Examples
1import torch
2import torch.nn as nn
3import UQpy.scientific_machine_learning as sml
4
5n_filters = [1, 64, 128]
6kernel_size = 3
7out_channels = 3
8unet = sml.Unet(n_filters, kernel_size, out_channels)
9x = torch.rand(1, 1, 512, 512)
10y = unet(x)
11
12print(f"Input shape: {x.shape}") # (N, in_channels, H, W)
13print(f"Prediction shape: {y.shape}") # (N, out_channels, H, W)