Source code for UQpy.scientific_machine_learning.layers.RangeNormalizer

import torch
from UQpy.scientific_machine_learning.baseclass import Layer
from UQpy.utilities.ValidationTypes import NonNegativeInteger
from beartype import beartype
from beartype.vale import Is
from typing import Union, Annotated


[docs]@beartype class RangeNormalizer(Layer): def __init__( self, x: torch.Tensor, encoding: bool = True, low: Union[int, float] = 0.0, high: Union[int, float] = 1.0, dim: Union[ NonNegativeInteger, Annotated[tuple, Is[lambda x: all([isinstance(d, int) for d in x])]], None, ] = None, ): r"""Normalize a tensor to fall within the range :math:`[\text{low}, \text{high}]` .. note:: Due to machine precision, normalized values may be outside of range by errors on the order of :math:`10^{-8}`. :param x: Tensor of any shape :param encoding: If ``True``, scale and shift a tensor to be within :math:`[\text{low}, \text{high}]`. If ``False``, scale and shift from :math:`[\text{low}, \text{high}]` to the original range of ``x``. Default: ``True`` :param low: Lower bound of the normalized range :param high: Upper bound of the normalized range :param dim: Dimensions to be reduced in :math:`\min(x), \max(x)`. If :code:`None`, reduce all dimensions for scalar min and max. Default: None :raises ValueError: If ``low`` greater or equal to ``high`` :raises RuntimeError: If :math:`\min(x)` equals :math:`\max(x)` over any dimension to be reduced. This is to prevent a :code:`ZeroDivisionError` from occuring in the computation of :code:`scale` Shape: - Input: Any shape - Output: Any shape (same shape as input) Example: >>> # use one instance and change mode >>> torch.manual_seed(0) # for reproducibility >>> x = torch.normal(0, 1, (100, 100)) >>> normalizer = sml.RangeNormalizer(x) >>> y = normalizer(x) >>> normalizer.decode() # equivalent to normalizer.encode(False) >>> x_reconstruction = normalizer(y) >>> print(x.min(), y.min(), x_reconstruction.min()) >>> print(x.max(), y.max(), x_reconstruction.max()) tensor(-4.3433) tensor(0.) tensor(-4.3433) tensor(4.1015) tensor(1.) tensor(4.1015) >>> # use two instances with different modes >>> torch.manual_seed(0) # for reproducibility >>> x = torch.normal(0, 1, (100, 100)) >>> encoder = sml.RangeNormalizer(x) >>> decoder = sml.RangeNormalizer(x, encoding=False) >>> y = encoder(x) >>> x_reconstruction = decoder(y) >>> print(x.min(), y.min(), x_reconstruction.min()) >>> print(x.max(), y.max(), x_reconstruction.max()) tensor(-4.3433) tensor(0.) tensor(-4.3433) tensor(4.1015) tensor(1.) tensor(4.1015) """ super().__init__() if low >= high: raise ValueError( f"UQpy: Invalid combination low={low} and high={high}. " f"`high` must be strictly greater than `low`" ) self.x = x self.low = low self.high = high self.encoding = encoding self.dim = dim # handle cases if dim is None, an integer, or a tuple if self.dim is None: x_min = torch.min(x) x_max = torch.max(x) elif isinstance(self.dim, int): x_min, _ = torch.min(x, dim=self.dim, keepdim=True) x_max, _ = torch.max(x, dim=self.dim, keepdim=True) elif isinstance(self.dim, tuple): x_min = x x_max = x for d in self.dim: x_min, _ = torch.min(x_min, dim=d, keepdim=True) x_max, _ = torch.max(x_max, dim=d, keepdim=True) else: # else statement is redundant due to beartype hinting, but better to be safe than sorry raise ValueError( f"UQpy: Invalid dim={dim}. Must be one of None, int, or tuple of ints." ) if torch.any(x_min == x_max): # if x_min equals x_max, a divide by zero error will occur when computing scale raise RuntimeError( "UQpy: RangeNormalizer is not defined if min(x) is equal to max(x) over any dimension to be reduced." ) if torch.any(torch.isnan(x_min)) or torch.any(torch.isinf(x_min)): raise RuntimeError( "UQpy: Invalid value for min(x) in dimension to be reduced. The min cannot be `nan` or `inf`." ) if torch.any(torch.isnan(x_max)) or torch.any(torch.isinf(x_max)): raise RuntimeError( "UQpy: Invalid value for max(x) in dimension to be reduced. The max cannot be `nan` or `inf`." ) self.scale: torch.Tensor = (self.high - self.low) / (x_max - x_min) """Multiplicative factor to rescale range of x to interval width""" self.shift: torch.Tensor = self.low - (self.scale * x_min) """Additive factor to make interval start at ``self.low``"""
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: r"""Scale and shift ``x`` to fall within a new range. If :code:`self.encoding` is :code:`True`, return :math:`(x \times \text{scale}) + \text{shift}`. If :code:`self.encoding` is :code:`False`, return :math:`\frac{x - \text{shift}}{\text{scale}}` :param x: Tensor of any shape :return: Tensor of same shape as ``x`` """ if self.encoding: return (x * self.scale) + self.shift else: return (x - self.shift) / self.scale
[docs] def encode(self, mode: bool = True): """Set the normalizer to scale and shift a tensor to fall within range :math:`[\text{low}, \text{high}]` :param mode: If ``True``, set ``self.encoding`` to ``True``. Default: ``True`` """ self.encoding = mode
[docs] def decode(self, mode: bool = True): """Set the normalizer to restore a tensor to its original range :param mode: If ``True``, set ``self.encoding`` to ``False``. Default: ``True`` """ self.encoding = not mode
def extra_repr(self) -> str: s = "encoding={encoding}" if self.low != 0.0: s += ", low={low}" if self.high != 1.0: s += ", high={high}" if self.dim is not None: s += ", dim={dim}" return s.format(**self.__dict__)