Skip to content

Commit

Permalink
Add log_survival_function for Laplace distribution.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 561422459
  • Loading branch information
DistraxDev authored and DistraxDev committed Sep 4, 2023
1 parent 93c54a8 commit 962bdbe
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
18 changes: 14 additions & 4 deletions distrax/_src/distributions/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@
EventT = distribution.EventT


def log_cdf_laplace(norm_value: EventT) -> Array:
"""Log CDF of a standardized Laplace distribution."""
lower_value = norm_value - math.log(2.0)
exp_neg_norm_value = jnp.exp(-jnp.abs(norm_value))
upper_value = jnp.log1p(-0.5 * exp_neg_norm_value)
return jnp.where(jnp.less_equal(norm_value, 0.), lower_value, upper_value)


class Laplace(distribution.Distribution):
"""Laplace distribution with location `loc` and `scale` parameters."""

Expand Down Expand Up @@ -107,10 +115,12 @@ def _standardize(self, value: Array) -> Array:
def log_cdf(self, value: EventT) -> Array:
"""See `Distribution.log_cdf`."""
norm_value = self._standardize(value)
lower_value = norm_value - math.log(2.)
exp_neg_norm_value = jnp.exp(-jnp.abs(norm_value))
upper_value = jnp.log1p(-0.5 * exp_neg_norm_value)
return jnp.where(jnp.less_equal(norm_value, 0.), lower_value, upper_value)
return log_cdf_laplace(norm_value)

def log_survival_function(self, value: EventT) -> Array:
"""See `Distribution.log_survival_function`."""
norm_value = self._standardize(value)
return log_cdf_laplace(-norm_value)

def mean(self) -> Array:
"""Calculates the mean."""
Expand Down
20 changes: 20 additions & 0 deletions distrax/_src/distributions/laplace_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,26 @@ def test_log_cdf(self, distr_params, value):
call_args=(value,),
assertion_fn=self.assertion_fn(rtol=2e-2))

@chex.all_variants
@parameterized.named_parameters(
('1d dist, 1d value', (0, 1), 1),
('1d dist, 2d value', (0.5, 0.1), np.array([1, 2])),
('1d dist, 2d value as list', (0.5, 0.1), [1, 2]),
('2d dist, 1d value', (0.5 + np.zeros(2), 0.3 * np.ones(2)), 1),
('2d broadcasted dist, 1d value', (np.zeros(2), 0.8), 1),
('2d dist, 2d value', ([0.1, -0.5], 0.9 * np.ones(2)), np.array([1, 2])),
('1d dist, 1d value, edge case', (0, 1), -200),
)
def test_log_survival_function(self, distr_params, value):
distr_params = (np.asarray(distr_params[0], dtype=np.float32),
np.asarray(distr_params[1], dtype=np.float32))
value = np.asarray(value, dtype=np.float32)
super()._test_attribute(
attribute_string='log_survival_function',
dist_args=distr_params,
call_args=(value,),
assertion_fn=self.assertion_fn(rtol=2e-2))

@chex.all_variants(with_pmap=False)
@parameterized.named_parameters(
('entropy', ([0., 1., -0.5], [0.5, 1., 1.5]), 'entropy'),
Expand Down

0 comments on commit 962bdbe

Please sign in to comment.