From abb21757233743f5b02d87c6f6a172fc0fedf0d9 Mon Sep 17 00:00:00 2001 From: Dan Manela Date: Wed, 27 Sep 2023 18:10:12 +0100 Subject: [PATCH 1/2] Add ICDF and CDF to BernoulliProbs --- numpyro/distributions/discrete.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index e5fbb2f88..633091969 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -102,6 +102,13 @@ def mean(self): def variance(self): return self.probs * (1 - self.probs) + def cdf(self, value): + return ((1 - self.probs) * jnp.heaviside(value, 1) + + self.probs * jnp.heaviside(value - 1, 1)) + + def icdf(self, q): + return jnp.heaviside(q - (1 - self.probs), 1) + def enumerate_support(self, expand=True): values = jnp.arange(2).reshape((-1,) + (1,) * len(self.batch_shape)) if expand: From 3ce5ea42b2c534536142dad50d7f51865f28ec2a Mon Sep 17 00:00:00 2001 From: Dan Manela Date: Wed, 27 Sep 2023 18:50:11 +0100 Subject: [PATCH 2/2] Add tests for BernoulliProbs ICDF and CDF --- test/test_distributions.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index 390880cac..a7fd91d4f 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1387,7 +1387,7 @@ def test_mixture_log_prob(): @pytest.mark.parametrize( "jax_dist, sp_dist, params", # TODO: add more complete pattern for Discrete.cdf - CONTINUOUS + [T(dist.Poisson, 2.0), T(dist.Poisson, np.array([2.0, 3.0, 5.0]))], + CONTINUOUS + [T(dist.Poisson, 2.0), T(dist.Poisson, np.array([2.0, 3.0, 5.0])), T(dist.BernoulliProbs, 0.2)], ) @pytest.mark.filterwarnings("ignore:overflow encountered:RuntimeWarning") def test_cdf_and_icdf(jax_dist, sp_dist, params): @@ -1411,8 +1411,15 @@ def test_cdf_and_icdf(jax_dist, sp_dist, params): atol=1e-5, rtol=rtol, ) - assert_allclose(d.cdf(d.icdf(quantiles)), quantiles, atol=1e-5, rtol=1e-5) - assert_allclose(d.icdf(d.cdf(samples)), samples, atol=1e-5, rtol=rtol) + if jax_dist is dist.BernoulliProbs: + assert pytest.approx(d.icdf(quantiles).mean(), abs=0.1) == d.probs + prop_of_ones = (d.cdf(d.icdf(quantiles)) == 1).mean() + prop_of_zeros = (d.cdf(d.icdf(quantiles)) == (1 - d.probs)).mean() + assert pytest.approx(prop_of_ones, abs=0.1) == d.probs + assert pytest.approx(prop_of_zeros, abs=0.1) == (1 - d.probs) + else: + assert_allclose(d.cdf(d.icdf(quantiles)), quantiles, atol=1e-5, rtol=1e-5) + assert_allclose(d.icdf(d.cdf(samples)), samples, atol=1e-5, rtol=rtol) except NotImplementedError: pass