Source code for UQpy.sampling.stratified_sampling.refinement.baseclass.Refinement

from abc import ABC, abstractmethod
import numpy as np

from UQpy.utilities.ValidationTypes import RandomStateType


[docs]class Refinement(ABC): """ Baseclass of all available strata refinement methods. Provides the methods that each existing and new refinement algorithm must implement in order to be used in the :class:`.RefinedStratifiedSampling` class. """
[docs] @abstractmethod def update_samples( self, nsamples: int, samples_per_iteration: int, random_state: RandomStateType, index: int, dimension: int, samples_u01: np.ndarray, training_points: np.ndarray, ): """ Method that need to be overridden in case of new :class:`.Refinement` techniques. :param nsamples: Total number of samples to be drawn :param samples_per_iteration: New samples to be drawn at each refinement iteration. :param random_state: Random seed used to initialize the pseudo-random number generator. Default is :class:`None`. If an :any:`int` is provided, this sets the seed for an object of :class:`numpy.random.RandomState`. Otherwise, the object itself can be passed directly. :param index: Iteration index :param dimension: Number of dimension of the sampling problem. :param samples_u01: Existing samples drawn at the unit hypercube space. :param training_points: Training points required in case of advanced refinement techniques. """ pass
def initialize(self, samples_number, training_points, samples): pass def finalize(self, samples, samples_per_iteration): pass @staticmethod def identify_bins(strata_metrics, points_to_add, random_state): bins2break = np.array([]) points_left = points_to_add while np.where(strata_metrics == strata_metrics.max())[0].shape[0] < points_left: bin = np.where(strata_metrics == strata_metrics.max())[0] bins2break = np.hstack([bins2break, bin]) strata_metrics[bin] = 0 points_left -= bin.shape[0] bin_for_remaining_points = random_state.choice( np.where(strata_metrics == strata_metrics.max())[0], points_left, replace=False,) bins2break = np.hstack([bins2break, bin_for_remaining_points]) bins2break = list(map(int, bins2break)) return bins2break @abstractmethod def update_strata(self, samplesU01): pass