Source code for UQpy.distributions.collection.MultivariateNormal

from typing import Union

import numpy as np
import scipy.stats as stats
from beartype import beartype

from UQpy.distributions.baseclass import DistributionND


[docs]class MultivariateNormal(DistributionND): @beartype def __init__( self, mean: Union[None, np.ndarray, list], cov: Union[None, int, float, np.ndarray, list[list]] = 1.0, ): """ :param mean: mean vector, :class:`numpy.ndarray` of shape :code:`(dimension, )` :param cov: covariance, :any:`float` or :class:`numpy.ndarray` of shape :code:`(dimension, )` or :code:`(dimension, dimension)`. Default is 1. """ if mean is not None and cov is not None: if isinstance(cov, (int, float)): pass else: if not (len(np.array(cov).shape) in [1, 2] and all(sh == len(mean) for sh in np.array(cov).shape)): raise ValueError("Input covariance must be a float or ndarray of appropriate dimensions.") super().__init__(mean=mean, cov=cov, ordered_parameters=["mean", "cov"]) def cdf(self, x): cdf_val = stats.multivariate_normal.cdf(x=x, **self.parameters) return np.atleast_1d(cdf_val) def pdf(self, x): pdf_val = stats.multivariate_normal.pdf(x=x, **self.parameters) return np.atleast_1d(pdf_val) def log_pdf(self, x): logpdf_val = stats.multivariate_normal.logpdf(x=x, **self.parameters) return np.atleast_1d(logpdf_val) def rvs(self, nsamples=1, random_state=None): if not (isinstance(nsamples, int) and nsamples >= 1): raise ValueError("Input nsamples must be an integer > 0.") return stats.multivariate_normal.rvs(size=nsamples, random_state=random_state, **self.parameters ).reshape((nsamples, -1)) def fit(self, data): data = self.check_x_dimension(data) mle_mu, mle_cov = self.parameters["mean"], self.parameters["cov"] if mle_mu is None: mle_mu = np.mean(data, axis=0) if mle_cov is None: mle_cov = np.cov(data, rowvar=False) return {"mean": mle_mu, "cov": mle_cov} def moments(self, moments2return="mv"): if moments2return == "m": return self.get_parameters()["mean"] elif moments2return == "v": return self.get_parameters()["cov"] elif moments2return == "mv": return self.get_parameters()["mean"], self.get_parameters()["cov"] else: raise ValueError('UQpy: moments2return must be "m", "v" or "mv".')