Skip to content

Commit

Permalink
include quantile function in distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
vafl committed Sep 16, 2019
1 parent c2fbb47 commit adcffaf
Show file tree
Hide file tree
Showing 14 changed files with 203 additions and 11 deletions.
4 changes: 4 additions & 0 deletions src/gluonts/distribution/box_cox_tranform.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ def args(self) -> List:
def event_dim(self) -> int:
return 0

@property
def sign(self) -> Tensor:
return 1.0

def f(self, z: Tensor) -> Tensor:
r"""
Forward transformation of observations `z`
Expand Down
30 changes: 30 additions & 0 deletions src/gluonts/distribution/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@ def batch_dim(self) -> int:
"""
return len(self.batch_shape)

@property
def all_dim(self) -> int:
r"""
Number of overall dimensions.
"""
return self.batch_dim + self.event_dim

def sample(self, num_samples: Optional[int] = None) -> Tensor:
r"""
Draw samples from the distribution.
Expand Down Expand Up @@ -223,6 +230,29 @@ def cdf(self, x: Tensor) -> Tensor:
"""
raise NotImplementedError()

def quantile(self, level: Tensor) -> Tensor:
r"""
Calculates quantiles for the given levels.
Parameters
----------
level
Level values to use for computing the quantiles.
`level` should be a 1d tensor of level values between 0 and 1.
Returns
-------
quantiles
Quantile values corresponding to the levels passed.
The return shape is
(num_levels, ...DISTRIBUTION_SHAPE...),
where DISTRIBUTION_SHAPE is the shape of the underlying distribution.
"""
raise NotImplementedError()


def _expand_param(p: Tensor, num_samples: Optional[int] = None) -> Tensor:
"""
Expand Down
20 changes: 17 additions & 3 deletions src/gluonts/distribution/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

# First-party imports
from gluonts.model.common import Tensor
from gluonts.support.util import erf
from gluonts.support.util import erf, erfinv
from gluonts.core.component import validated

# Relative imports
Expand Down Expand Up @@ -78,8 +78,8 @@ def stddev(self) -> Tensor:

def cdf(self, x):
F = self.F
u = self.F.broadcast_div(
self.F.broadcast_minus(x, self.mu), self.sigma * math.sqrt(2.0)
u = F.broadcast_div(
F.broadcast_minus(x, self.mu), self.sigma * math.sqrt(2.0)
)
return (erf(F, u) + 1.0) / 2.0

Expand All @@ -102,6 +102,20 @@ def s(mu: Tensor, sigma: Tensor) -> Tensor:
s, mu=self.mu, sigma=self.sigma, num_samples=num_samples
)

def quantile(self, level: Tensor) -> Tensor:
F = self.F
# we consider level to be an independent axis and so expand it
# to shape (num_levels, 1, 1, ...)
for _ in range(self.all_dim):
level = level.expand_dims(axis=-1)

return F.broadcast_add(
self.mu,
F.broadcast_mul(
self.sigma, math.sqrt(2.0) * erfinv(F, 2.0 * level - 1.0)
),
)


class GaussianOutput(DistributionOutput):
args_dim: Dict[str, int] = {"mu": 1, "sigma": 1}
Expand Down
10 changes: 10 additions & 0 deletions src/gluonts/distribution/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ def s(mu: Tensor, b: Tensor) -> Tensor:
s, mu=self.mu, b=self.b, num_samples=num_samples
)

def quantile(self, level: Tensor) -> Tensor:
F = self.F
for _ in range(self.all_dim):
level = level.expand_dims(axis=-1)

condition = F.broadcast_greater(level, level.zeros_like() + 0.5)
u = F.where(condition, F.log(2.0 * level), -F.log(2.0 - 2.0 * level))

return F.broadcast_add(self.mu, F.broadcast_mul(self.b, u))


class LaplaceOutput(DistributionOutput):
args_dim: Dict[str, int] = {"mu": 1, "b": 1}
Expand Down
14 changes: 11 additions & 3 deletions src/gluonts/distribution/piecewise_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ def sample(self, num_samples: Optional[int] = None) -> Tensor:
)
)

sample = self.quantile(u, axis=None if num_samples is None else 0)
sample = self.quantile(u)

if num_samples is None:
sample = F.squeeze(sample, axis=0)

return sample

Expand Down Expand Up @@ -200,7 +203,7 @@ def cdf(self, x: Tensor) -> Tensor:
F = self.F
gamma, b, knot_positions = self.gamma, self.b, self.knot_positions

quantiles_at_knots = self.quantile(knot_positions, axis=-2)
quantiles_at_knots = self.quantile_internal(knot_positions, axis=-2)

# Mask to nullify the terms corresponding to knots larger than l_0, which is the largest knot
# (quantile level) such that the quantile at l_0, s(l_0) < x.
Expand Down Expand Up @@ -229,7 +232,12 @@ def cdf(self, x: Tensor) -> Tensor:

return a_tilde

def quantile(self, x: Tensor, axis: Optional[int] = None) -> Tensor:
def quantile(self, level: Tensor) -> Tensor:
return self.quantile_internal(level, axis=0)

def quantile_internal(
self, x: Tensor, axis: Optional[int] = None
) -> Tensor:
r"""
Evaluates the quantile function at the quantile levels contained in `x`.
Expand Down
21 changes: 21 additions & 0 deletions src/gluonts/distribution/transformed_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

# Third-party imports
from mxnet import autograd
import mxnet as mx

# First-party imports
from gluonts.model.common import Tensor
Expand Down Expand Up @@ -113,6 +114,26 @@ def cdf(self, y: Tensor) -> Tensor:
f = self.base_distribution.cdf(x)
return sign * (f - 0.5) + 0.5

def quantile(self, level: Tensor) -> Tensor:
F = getF(level)

sign = 1.0
for t in self.transforms:
sign = sign * t.sign

if not isinstance(sign, (mx.nd.NDArray, mx.sym.Symbol)):
sign = sign + level.zeros_like()

cond = F.broadcast_greater(sign, sign.zeros_like())
level = F.broadcast_mul(cond, level) + F.broadcast_mul(
1.0 - cond, 1.0 - level
)

q = self.base_distribution.quantile(level)
for t in self.transforms:
q = t.f(q)
return q


def sum_trailing_axes(F, x: Tensor, k: int) -> Tensor:
for _ in range(k):
Expand Down
8 changes: 8 additions & 0 deletions src/gluonts/distribution/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ def s(low: Tensor, high: Tensor) -> Tensor:
def cdf(self, x: Tensor) -> Tensor:
return self.F.broadcast_div(x - self.low, self.high - self.low)

def quantile(self, level: Tensor) -> Tensor:
F = self.F
for _ in range(self.all_dim):
level = level.expand_dims(axis=-1)
return F.broadcast_add(
F.broadcast_mul(self.high - self.low, level), self.low
)


class UniformOutput(DistributionOutput):
args_dim: Dict[str, int] = {"low": 1, "width": 1}
Expand Down
51 changes: 51 additions & 0 deletions src/gluonts/support/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@


MXNET_HAS_ERF = hasattr(mx.nd, "erf")
MXNET_HAS_ERFINV = hasattr(mx.nd, "erfinv")


class Timer:
Expand Down Expand Up @@ -479,6 +480,56 @@ def erf(F, x: Tensor):
return F.where(F.broadcast_greater_equal(x, zeros), res, -1.0 * res)


def erfinv(F, x: Tensor) -> Tensor:
if MXNET_HAS_ERFINV:
return F.erfinv(x)

zeros = x.zeros_like()

w = -F.log(F.broadcast_mul((1.0 - x), (1.0 + x)))
mask_lesser = F.broadcast_lesser(w, zeros + 5.0)

w = F.where(mask_lesser, w - 2.5, F.sqrt(w) - 3.0)

coefficients_lesser = [
2.81022636e-08,
3.43273939e-07,
-3.5233877e-06,
-4.39150654e-06,
0.00021858087,
-0.00125372503,
-0.00417768164,
0.246640727,
1.50140941,
]

coefficients_greater_equal = [
-0.000200214257,
0.000100950558,
0.00134934322,
-0.00367342844,
0.00573950773,
-0.0076224613,
0.00943887047,
1.00167406,
2.83297682,
]

p = F.where(
mask_lesser,
coefficients_lesser[0] + zeros,
coefficients_greater_equal[0] + zeros,
)

for c_l, c_ge in zip(
coefficients_lesser[1:], coefficients_greater_equal[1:]
):
c = F.where(mask_lesser, c_l + zeros, c_ge + zeros)
p = c + F.broadcast_mul(p, w)

return F.broadcast_mul(p, x)


def get_download_path() -> Path:
"""
Expand Down
4 changes: 2 additions & 2 deletions test/distribution/test_distribution_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,15 +509,15 @@ def test_piecewise_linear(

# Compute quantiles with the estimated parameters
quantiles_hat = np.squeeze(
pwl_sqf_hat.quantile(
pwl_sqf_hat.quantile_internal(
mx.nd.array(quantile_levels).expand_dims(axis=0), axis=1
).asnumpy()
)

# Compute quantiles with the original parameters
# Since params is replicated across samples we take only the first entry
quantiles = np.squeeze(
pwl_sqf.quantile(
pwl_sqf.quantile_internal(
mx.nd.array(quantile_levels)
.expand_dims(axis=0)
.repeat(axis=0, repeats=num_samples),
Expand Down
7 changes: 7 additions & 0 deletions test/distribution/test_distribution_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@


DISTRIBUTIONS_WITH_CDF = [Gaussian, Uniform, Laplace, Binned]
DISTRIBUTIONS_WITH_QUANTILE_FUNCTION = [Gaussian, Uniform, Laplace]


@pytest.mark.parametrize("distr_class, params", test_cases)
Expand Down Expand Up @@ -111,6 +112,12 @@ def test_sampling(distr_class, params, serialize_fn) -> None:
calc_cdf = distr.cdf(mx.nd.array(edges)).asnumpy()
assert np.allclose(calc_cdf[1:, :], emp_cdf, atol=1e-2)

if distr_class in DISTRIBUTIONS_WITH_QUANTILE_FUNCTION:
levels = np.linspace(1.0e-3, 1.0 - 1.0e-3, 100)
emp_qfunc = np.percentile(np_samples, levels * 100, axis=0)
calc_qfunc = distr.quantile(mx.nd.array(levels)).asnumpy()
assert np.allclose(calc_qfunc, emp_qfunc, rtol=1e-2)


test_cases_multivariate = [
(
Expand Down
11 changes: 11 additions & 0 deletions test/distribution/test_distribution_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,14 @@ def test_distribution_shapes(
x3 = distr.sample(num_samples=3)

assert x3.shape == (3,) + distr.batch_shape + distr.event_shape

def has_quantile(d):
return isinstance(d, (Uniform, Gaussian, Laplace))

if (
has_quantile(distr)
or isinstance(distr, TransformedDistribution)
and has_quantile(distr.base_distribution)
):
qs1 = distr.quantile(mx.nd.array([0.5]))
assert qs1.shape == (1,) + distr.batch_shape + distr.event_shape
9 changes: 6 additions & 3 deletions test/distribution/test_piecewise_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,17 +119,20 @@ def test_shapes(
assert distr.crps(target).shape == batch_shape

# assert that the quantile shape is correct when computing the quantile values at knot positions - used for a_tilde
assert distr.quantile(knot_spacings, axis=-2).shape == (
assert distr.quantile_internal(knot_spacings, axis=-2).shape == (
*batch_shape,
num_pieces,
)

# assert that the samples and the quantile values shape when num_samples is None is correct
samples = distr.sample()
assert samples.shape == batch_shape
assert distr.quantile(samples).shape == batch_shape
assert distr.quantile_internal(samples).shape == batch_shape

# assert that the samples and the quantile values shape when num_samples is not None is correct
samples = distr.sample(num_samples)
assert samples.shape == (num_samples, *batch_shape)
assert distr.quantile(samples, axis=0).shape == (num_samples, *batch_shape)
assert distr.quantile_internal(samples, axis=0).shape == (
num_samples,
*batch_shape,
)
13 changes: 13 additions & 0 deletions test/distribution/test_transformed_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ def exp_cdf(x: np.ndarray) -> np.ndarray:
return 1.0 - np.exp(-x)


def exp_quantile(level: np.ndarray) -> np.ndarray:
return -np.log(1.0 - level)


@pytest.mark.parametrize("serialize_fn", serialize_fn_list)
def test_transformed_distribution(serialize_fn) -> None:
zero = nd.zeros(1)
Expand All @@ -51,6 +55,12 @@ def test_transformed_distribution(serialize_fn) -> None:
v = np.linspace(0, 5, 101)
assert np.allclose(exponential.cdf(nd.array(v)).asnumpy(), exp_cdf(v))

level = np.linspace(1.0e-5, 1.0 - 1.0e-5, 101)

qs_calc = exponential.quantile(nd.array(level)).asnumpy()[:, 0]
qs_theo = exp_quantile(level)
assert np.allclose(qs_calc, qs_theo, atol=1.0e-2)

# If Y ~ Exponential(1), then U = 1 - e^{-Y} has Uniform(0, 1) distribution
uniform = TransformedDistribution(
exponential,
Expand All @@ -67,3 +77,6 @@ def test_transformed_distribution(serialize_fn) -> None:

v = np.linspace(0, 1, 101)
assert np.allclose(uniform.cdf(nd.array(v)).asnumpy(), v)

qs_calc = uniform.quantile(nd.array(level)).asnumpy()[:, 0]
assert np.allclose(qs_calc, level, atol=1.0e-2)
12 changes: 12 additions & 0 deletions test/support/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,15 @@ def test_erf() -> None:
y_mxnet = util.erf(mx.nd, mx.nd.array(x)).asnumpy()
y_scipy = scipy_erf(x)
assert np.allclose(y_mxnet, y_scipy)


def test_erfinv() -> None:
try:
from scipy.special import erfinv as scipy_erfinv
except:
pytest.skip("scipy not installed skipping test for erf")

x = np.linspace(-1.0 + 1.0e-4, 1 - 1.0e-4, 11)
y_mxnet = util.erfinv(mx.nd, mx.nd.array(x)).asnumpy()
y_scipy = scipy_erfinv(x)
assert np.allclose(y_mxnet, y_scipy, rtol=1e-3)

0 comments on commit adcffaf

Please sign in to comment.