From f5bd186adda3a3436c21192a7b649cc4dd42e8e9 Mon Sep 17 00:00:00 2001 From: Eike Petersen <1774207+e-pet@users.noreply.github.com> Date: Fri, 10 Nov 2023 18:05:39 +0100 Subject: [PATCH] Kumaraswamy distribution bug fixes (#1675) * Minor Kumaraswamy dist bug fixes * Removing intermediates from Kuma log_prob again because no longer necessary --------- Co-authored-by: Eike Petersen --- numpyro/distributions/continuous.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index f5b7b81e6..e8cff6ca6 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -765,7 +765,7 @@ def icdf(self, q): return self.loc - self.scale * jnp.log(-jnp.log(q)) -class Kumaraswamy(TransformedDistribution): +class Kumaraswamy(Distribution): arg_constraints = { "concentration1": constraints.positive, "concentration0": constraints.positive, @@ -786,13 +786,7 @@ def __init__(self, concentration1, concentration0, *, validate_args=None): batch_shape = lax.broadcast_shapes( jnp.shape(concentration1), jnp.shape(concentration0) ) - base_dist = Uniform(0, 1).expand(batch_shape) - transforms = [ - PowerTransform(1 / concentration0), - AffineTransform(1, -1), - PowerTransform(1 / concentration1), - ] - super().__init__(base_dist, transforms, validate_args=validate_args) + super().__init__(batch_shape=batch_shape, validate_args=validate_args) def sample(self, key, sample_shape=()): assert is_prng_key(key)