diff --git a/src/gluonts/distribution/box_cox_tranform.py b/src/gluonts/distribution/box_cox_tranform.py index 4c17340b44..29dde30991 100644 --- a/src/gluonts/distribution/box_cox_tranform.py +++ b/src/gluonts/distribution/box_cox_tranform.py @@ -142,6 +142,10 @@ def args(self) -> List: def event_dim(self) -> int: return 0 + @property + def sign(self) -> Tensor: + return 1.0 + def f(self, z: Tensor) -> Tensor: r""" Forward transformation of observations `z` diff --git a/src/gluonts/distribution/distribution.py b/src/gluonts/distribution/distribution.py index 3f1233fe6e..9f672e7d17 100644 --- a/src/gluonts/distribution/distribution.py +++ b/src/gluonts/distribution/distribution.py @@ -170,6 +170,13 @@ def batch_dim(self) -> int: """ return len(self.batch_shape) + @property + def all_dim(self) -> int: + r""" + Number of overall dimensions. + """ + return self.batch_dim + self.event_dim + def sample(self, num_samples: Optional[int] = None) -> Tensor: r""" Draw samples from the distribution. @@ -223,6 +230,29 @@ def cdf(self, x: Tensor) -> Tensor: """ raise NotImplementedError() + def quantile(self, level: Tensor) -> Tensor: + r""" + + Calculates quantiles for the given levels. + + Parameters + ---------- + level + Level values to use for computing the quantiles. + `level` should be a 1d tensor of level values between 0 and 1. + + Returns + ------- + quantiles + Quantile values corresponding to the levels passed. + The return shape is + + (num_levels, ...DISTRIBUTION_SHAPE...), + + where DISTRIBUTION_SHAPE is the shape of the underlying distribution. + """ + raise NotImplementedError() + def _expand_param(p: Tensor, num_samples: Optional[int] = None) -> Tensor: """ diff --git a/src/gluonts/distribution/gaussian.py b/src/gluonts/distribution/gaussian.py index 5100eec788..39ac9c15e2 100644 --- a/src/gluonts/distribution/gaussian.py +++ b/src/gluonts/distribution/gaussian.py @@ -17,7 +17,7 @@ # First-party imports from gluonts.model.common import Tensor -from gluonts.support.util import erf +from gluonts.support.util import erf, erfinv from gluonts.core.component import validated # Relative imports @@ -78,8 +78,8 @@ def stddev(self) -> Tensor: def cdf(self, x): F = self.F - u = self.F.broadcast_div( - self.F.broadcast_minus(x, self.mu), self.sigma * math.sqrt(2.0) + u = F.broadcast_div( + F.broadcast_minus(x, self.mu), self.sigma * math.sqrt(2.0) ) return (erf(F, u) + 1.0) / 2.0 @@ -102,6 +102,20 @@ def s(mu: Tensor, sigma: Tensor) -> Tensor: s, mu=self.mu, sigma=self.sigma, num_samples=num_samples ) + def quantile(self, level: Tensor) -> Tensor: + F = self.F + # we consider level to be an independent axis and so expand it + # to shape (num_levels, 1, 1, ...) + for _ in range(self.all_dim): + level = level.expand_dims(axis=-1) + + return F.broadcast_add( + self.mu, + F.broadcast_mul( + self.sigma, math.sqrt(2.0) * erfinv(F, 2.0 * level - 1.0) + ), + ) + class GaussianOutput(DistributionOutput): args_dim: Dict[str, int] = {"mu": 1, "sigma": 1} diff --git a/src/gluonts/distribution/laplace.py b/src/gluonts/distribution/laplace.py index d62141099f..d9de685255 100644 --- a/src/gluonts/distribution/laplace.py +++ b/src/gluonts/distribution/laplace.py @@ -90,6 +90,16 @@ def s(mu: Tensor, b: Tensor) -> Tensor: s, mu=self.mu, b=self.b, num_samples=num_samples ) + def quantile(self, level: Tensor) -> Tensor: + F = self.F + for _ in range(self.all_dim): + level = level.expand_dims(axis=-1) + + condition = F.broadcast_greater(level, level.zeros_like() + 0.5) + u = F.where(condition, F.log(2.0 * level), -F.log(2.0 - 2.0 * level)) + + return F.broadcast_add(self.mu, F.broadcast_mul(self.b, u)) + class LaplaceOutput(DistributionOutput): args_dim: Dict[str, int] = {"mu": 1, "b": 1} diff --git a/src/gluonts/distribution/piecewise_linear.py b/src/gluonts/distribution/piecewise_linear.py index e9974df558..d28334d578 100644 --- a/src/gluonts/distribution/piecewise_linear.py +++ b/src/gluonts/distribution/piecewise_linear.py @@ -132,7 +132,10 @@ def sample(self, num_samples: Optional[int] = None) -> Tensor: ) ) - sample = self.quantile(u, axis=None if num_samples is None else 0) + sample = self.quantile(u) + + if num_samples is None: + sample = F.squeeze(sample, axis=0) return sample @@ -200,7 +203,7 @@ def cdf(self, x: Tensor) -> Tensor: F = self.F gamma, b, knot_positions = self.gamma, self.b, self.knot_positions - quantiles_at_knots = self.quantile(knot_positions, axis=-2) + quantiles_at_knots = self.quantile_internal(knot_positions, axis=-2) # Mask to nullify the terms corresponding to knots larger than l_0, which is the largest knot # (quantile level) such that the quantile at l_0, s(l_0) < x. @@ -229,7 +232,12 @@ def cdf(self, x: Tensor) -> Tensor: return a_tilde - def quantile(self, x: Tensor, axis: Optional[int] = None) -> Tensor: + def quantile(self, level: Tensor) -> Tensor: + return self.quantile_internal(level, axis=0) + + def quantile_internal( + self, x: Tensor, axis: Optional[int] = None + ) -> Tensor: r""" Evaluates the quantile function at the quantile levels contained in `x`. diff --git a/src/gluonts/distribution/transformed_distribution.py b/src/gluonts/distribution/transformed_distribution.py index eeec26965e..a0a026d734 100644 --- a/src/gluonts/distribution/transformed_distribution.py +++ b/src/gluonts/distribution/transformed_distribution.py @@ -16,6 +16,7 @@ # Third-party imports from mxnet import autograd +import mxnet as mx # First-party imports from gluonts.model.common import Tensor @@ -113,6 +114,26 @@ def cdf(self, y: Tensor) -> Tensor: f = self.base_distribution.cdf(x) return sign * (f - 0.5) + 0.5 + def quantile(self, level: Tensor) -> Tensor: + F = getF(level) + + sign = 1.0 + for t in self.transforms: + sign = sign * t.sign + + if not isinstance(sign, (mx.nd.NDArray, mx.sym.Symbol)): + sign = sign + level.zeros_like() + + cond = F.broadcast_greater(sign, sign.zeros_like()) + level = F.broadcast_mul(cond, level) + F.broadcast_mul( + 1.0 - cond, 1.0 - level + ) + + q = self.base_distribution.quantile(level) + for t in self.transforms: + q = t.f(q) + return q + def sum_trailing_axes(F, x: Tensor, k: int) -> Tensor: for _ in range(k): diff --git a/src/gluonts/distribution/uniform.py b/src/gluonts/distribution/uniform.py index f796e58185..401f4bcaae 100644 --- a/src/gluonts/distribution/uniform.py +++ b/src/gluonts/distribution/uniform.py @@ -92,6 +92,14 @@ def s(low: Tensor, high: Tensor) -> Tensor: def cdf(self, x: Tensor) -> Tensor: return self.F.broadcast_div(x - self.low, self.high - self.low) + def quantile(self, level: Tensor) -> Tensor: + F = self.F + for _ in range(self.all_dim): + level = level.expand_dims(axis=-1) + return F.broadcast_add( + F.broadcast_mul(self.high - self.low, level), self.low + ) + class UniformOutput(DistributionOutput): args_dim: Dict[str, int] = {"low": 1, "width": 1} diff --git a/src/gluonts/support/util.py b/src/gluonts/support/util.py index 0a17e8daed..845afc4d8f 100644 --- a/src/gluonts/support/util.py +++ b/src/gluonts/support/util.py @@ -29,6 +29,7 @@ MXNET_HAS_ERF = hasattr(mx.nd, "erf") +MXNET_HAS_ERFINV = hasattr(mx.nd, "erfinv") class Timer: @@ -479,6 +480,56 @@ def erf(F, x: Tensor): return F.where(F.broadcast_greater_equal(x, zeros), res, -1.0 * res) +def erfinv(F, x: Tensor) -> Tensor: + if MXNET_HAS_ERFINV: + return F.erfinv(x) + + zeros = x.zeros_like() + + w = -F.log(F.broadcast_mul((1.0 - x), (1.0 + x))) + mask_lesser = F.broadcast_lesser(w, zeros + 5.0) + + w = F.where(mask_lesser, w - 2.5, F.sqrt(w) - 3.0) + + coefficients_lesser = [ + 2.81022636e-08, + 3.43273939e-07, + -3.5233877e-06, + -4.39150654e-06, + 0.00021858087, + -0.00125372503, + -0.00417768164, + 0.246640727, + 1.50140941, + ] + + coefficients_greater_equal = [ + -0.000200214257, + 0.000100950558, + 0.00134934322, + -0.00367342844, + 0.00573950773, + -0.0076224613, + 0.00943887047, + 1.00167406, + 2.83297682, + ] + + p = F.where( + mask_lesser, + coefficients_lesser[0] + zeros, + coefficients_greater_equal[0] + zeros, + ) + + for c_l, c_ge in zip( + coefficients_lesser[1:], coefficients_greater_equal[1:] + ): + c = F.where(mask_lesser, c_l + zeros, c_ge + zeros) + p = c + F.broadcast_mul(p, w) + + return F.broadcast_mul(p, x) + + def get_download_path() -> Path: """ diff --git a/test/distribution/test_distribution_inference.py b/test/distribution/test_distribution_inference.py index 26cb1e4804..87bebf9657 100644 --- a/test/distribution/test_distribution_inference.py +++ b/test/distribution/test_distribution_inference.py @@ -509,7 +509,7 @@ def test_piecewise_linear( # Compute quantiles with the estimated parameters quantiles_hat = np.squeeze( - pwl_sqf_hat.quantile( + pwl_sqf_hat.quantile_internal( mx.nd.array(quantile_levels).expand_dims(axis=0), axis=1 ).asnumpy() ) @@ -517,7 +517,7 @@ def test_piecewise_linear( # Compute quantiles with the original parameters # Since params is replicated across samples we take only the first entry quantiles = np.squeeze( - pwl_sqf.quantile( + pwl_sqf.quantile_internal( mx.nd.array(quantile_levels) .expand_dims(axis=0) .repeat(axis=0, repeats=num_samples), diff --git a/test/distribution/test_distribution_sampling.py b/test/distribution/test_distribution_sampling.py index 46c38cf93a..8925424885 100644 --- a/test/distribution/test_distribution_sampling.py +++ b/test/distribution/test_distribution_sampling.py @@ -82,6 +82,7 @@ DISTRIBUTIONS_WITH_CDF = [Gaussian, Uniform, Laplace, Binned] +DISTRIBUTIONS_WITH_QUANTILE_FUNCTION = [Gaussian, Uniform, Laplace] @pytest.mark.parametrize("distr_class, params", test_cases) @@ -111,6 +112,12 @@ def test_sampling(distr_class, params, serialize_fn) -> None: calc_cdf = distr.cdf(mx.nd.array(edges)).asnumpy() assert np.allclose(calc_cdf[1:, :], emp_cdf, atol=1e-2) + if distr_class in DISTRIBUTIONS_WITH_QUANTILE_FUNCTION: + levels = np.linspace(1.0e-3, 1.0 - 1.0e-3, 100) + emp_qfunc = np.percentile(np_samples, levels * 100, axis=0) + calc_qfunc = distr.quantile(mx.nd.array(levels)).asnumpy() + assert np.allclose(calc_qfunc, emp_qfunc, rtol=1e-2) + test_cases_multivariate = [ ( diff --git a/test/distribution/test_distribution_shapes.py b/test/distribution/test_distribution_shapes.py index fc496abe58..1e0c82910c 100644 --- a/test/distribution/test_distribution_shapes.py +++ b/test/distribution/test_distribution_shapes.py @@ -214,3 +214,14 @@ def test_distribution_shapes( x3 = distr.sample(num_samples=3) assert x3.shape == (3,) + distr.batch_shape + distr.event_shape + + def has_quantile(d): + return isinstance(d, (Uniform, Gaussian, Laplace)) + + if ( + has_quantile(distr) + or isinstance(distr, TransformedDistribution) + and has_quantile(distr.base_distribution) + ): + qs1 = distr.quantile(mx.nd.array([0.5])) + assert qs1.shape == (1,) + distr.batch_shape + distr.event_shape diff --git a/test/distribution/test_piecewise_linear.py b/test/distribution/test_piecewise_linear.py index d066dce671..603b5d61c4 100644 --- a/test/distribution/test_piecewise_linear.py +++ b/test/distribution/test_piecewise_linear.py @@ -119,7 +119,7 @@ def test_shapes( assert distr.crps(target).shape == batch_shape # assert that the quantile shape is correct when computing the quantile values at knot positions - used for a_tilde - assert distr.quantile(knot_spacings, axis=-2).shape == ( + assert distr.quantile_internal(knot_spacings, axis=-2).shape == ( *batch_shape, num_pieces, ) @@ -127,9 +127,12 @@ def test_shapes( # assert that the samples and the quantile values shape when num_samples is None is correct samples = distr.sample() assert samples.shape == batch_shape - assert distr.quantile(samples).shape == batch_shape + assert distr.quantile_internal(samples).shape == batch_shape # assert that the samples and the quantile values shape when num_samples is not None is correct samples = distr.sample(num_samples) assert samples.shape == (num_samples, *batch_shape) - assert distr.quantile(samples, axis=0).shape == (num_samples, *batch_shape) + assert distr.quantile_internal(samples, axis=0).shape == ( + num_samples, + *batch_shape, + ) diff --git a/test/distribution/test_transformed_distribution.py b/test/distribution/test_transformed_distribution.py index d62a476828..67ced4b087 100644 --- a/test/distribution/test_transformed_distribution.py +++ b/test/distribution/test_transformed_distribution.py @@ -32,6 +32,10 @@ def exp_cdf(x: np.ndarray) -> np.ndarray: return 1.0 - np.exp(-x) +def exp_quantile(level: np.ndarray) -> np.ndarray: + return -np.log(1.0 - level) + + @pytest.mark.parametrize("serialize_fn", serialize_fn_list) def test_transformed_distribution(serialize_fn) -> None: zero = nd.zeros(1) @@ -51,6 +55,12 @@ def test_transformed_distribution(serialize_fn) -> None: v = np.linspace(0, 5, 101) assert np.allclose(exponential.cdf(nd.array(v)).asnumpy(), exp_cdf(v)) + level = np.linspace(1.0e-5, 1.0 - 1.0e-5, 101) + + qs_calc = exponential.quantile(nd.array(level)).asnumpy()[:, 0] + qs_theo = exp_quantile(level) + assert np.allclose(qs_calc, qs_theo, atol=1.0e-2) + # If Y ~ Exponential(1), then U = 1 - e^{-Y} has Uniform(0, 1) distribution uniform = TransformedDistribution( exponential, @@ -67,3 +77,6 @@ def test_transformed_distribution(serialize_fn) -> None: v = np.linspace(0, 1, 101) assert np.allclose(uniform.cdf(nd.array(v)).asnumpy(), v) + + qs_calc = uniform.quantile(nd.array(level)).asnumpy()[:, 0] + assert np.allclose(qs_calc, level, atol=1.0e-2) diff --git a/test/support/test_util.py b/test/support/test_util.py index 9e557d6647..99fac28c25 100644 --- a/test/support/test_util.py +++ b/test/support/test_util.py @@ -78,3 +78,15 @@ def test_erf() -> None: y_mxnet = util.erf(mx.nd, mx.nd.array(x)).asnumpy() y_scipy = scipy_erf(x) assert np.allclose(y_mxnet, y_scipy) + + +def test_erfinv() -> None: + try: + from scipy.special import erfinv as scipy_erfinv + except: + pytest.skip("scipy not installed skipping test for erf") + + x = np.linspace(-1.0 + 1.0e-4, 1 - 1.0e-4, 11) + y_mxnet = util.erfinv(mx.nd, mx.nd.array(x)).asnumpy() + y_scipy = scipy_erfinv(x) + assert np.allclose(y_mxnet, y_scipy, rtol=1e-3)