From 962bdbefbcc3fe1b9b725143b14fa7ed9e0356bf Mon Sep 17 00:00:00 2001 From: DistraxDev Date: Wed, 30 Aug 2023 12:50:27 -0700 Subject: [PATCH] Add `log_survival_function` for `Laplace` distribution. PiperOrigin-RevId: 561422459 --- distrax/_src/distributions/laplace.py | 18 ++++++++++++++---- distrax/_src/distributions/laplace_test.py | 20 ++++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/distrax/_src/distributions/laplace.py b/distrax/_src/distributions/laplace.py index 5577a9a..7d18ace 100644 --- a/distrax/_src/distributions/laplace.py +++ b/distrax/_src/distributions/laplace.py @@ -32,6 +32,14 @@ EventT = distribution.EventT +def log_cdf_laplace(norm_value: EventT) -> Array: + """Log CDF of a standardized Laplace distribution.""" + lower_value = norm_value - math.log(2.0) + exp_neg_norm_value = jnp.exp(-jnp.abs(norm_value)) + upper_value = jnp.log1p(-0.5 * exp_neg_norm_value) + return jnp.where(jnp.less_equal(norm_value, 0.), lower_value, upper_value) + + class Laplace(distribution.Distribution): """Laplace distribution with location `loc` and `scale` parameters.""" @@ -107,10 +115,12 @@ def _standardize(self, value: Array) -> Array: def log_cdf(self, value: EventT) -> Array: """See `Distribution.log_cdf`.""" norm_value = self._standardize(value) - lower_value = norm_value - math.log(2.) - exp_neg_norm_value = jnp.exp(-jnp.abs(norm_value)) - upper_value = jnp.log1p(-0.5 * exp_neg_norm_value) - return jnp.where(jnp.less_equal(norm_value, 0.), lower_value, upper_value) + return log_cdf_laplace(norm_value) + + def log_survival_function(self, value: EventT) -> Array: + """See `Distribution.log_survival_function`.""" + norm_value = self._standardize(value) + return log_cdf_laplace(-norm_value) def mean(self) -> Array: """Calculates the mean.""" diff --git a/distrax/_src/distributions/laplace_test.py b/distrax/_src/distributions/laplace_test.py index 80baf68..01a71f7 100644 --- a/distrax/_src/distributions/laplace_test.py +++ b/distrax/_src/distributions/laplace_test.py @@ -175,6 +175,26 @@ def test_log_cdf(self, distr_params, value): call_args=(value,), assertion_fn=self.assertion_fn(rtol=2e-2)) + @chex.all_variants + @parameterized.named_parameters( + ('1d dist, 1d value', (0, 1), 1), + ('1d dist, 2d value', (0.5, 0.1), np.array([1, 2])), + ('1d dist, 2d value as list', (0.5, 0.1), [1, 2]), + ('2d dist, 1d value', (0.5 + np.zeros(2), 0.3 * np.ones(2)), 1), + ('2d broadcasted dist, 1d value', (np.zeros(2), 0.8), 1), + ('2d dist, 2d value', ([0.1, -0.5], 0.9 * np.ones(2)), np.array([1, 2])), + ('1d dist, 1d value, edge case', (0, 1), -200), + ) + def test_log_survival_function(self, distr_params, value): + distr_params = (np.asarray(distr_params[0], dtype=np.float32), + np.asarray(distr_params[1], dtype=np.float32)) + value = np.asarray(value, dtype=np.float32) + super()._test_attribute( + attribute_string='log_survival_function', + dist_args=distr_params, + call_args=(value,), + assertion_fn=self.assertion_fn(rtol=2e-2)) + @chex.all_variants(with_pmap=False) @parameterized.named_parameters( ('entropy', ([0., 1., -0.5], [0.5, 1., 1.5]), 'entropy'),