From eda2f5053b139ec4b7002dedb85d9199be69deba Mon Sep 17 00:00:00 2001 From: Paul Scemama Date: Sun, 19 Mar 2023 17:12:03 -0400 Subject: [PATCH] Add an InvGamma JAX implementation --- aesara/link/jax/dispatch/random.py | 21 +++++++++++++++++++++ tests/link/jax/test_random.py | 13 +++++++++++++ 2 files changed, 34 insertions(+) diff --git a/aesara/link/jax/dispatch/random.py b/aesara/link/jax/dispatch/random.py index d4868641fa..0cabbd9cce 100644 --- a/aesara/link/jax/dispatch/random.py +++ b/aesara/link/jax/dispatch/random.py @@ -453,3 +453,24 @@ def sample_fn(rng, size, dtype, *parameters): return (rng, samples) return sample_fn + + +@jax_sample_fn.register(aer.InvGammaRV) +def jax_sample_fn_invgamma(op): + """JAX implementation of `InvGammaRV`.""" + + def sample_fn(rng, size, dtype, *parameters): + rng_key = rng["jax_state"] + rng_key, sampling_key = jax.random.split(rng_key, 2) + + ( + shape, + scale, + ) = parameters + # InvGamma[shape, scale] <-> 1 / Gamma[shape, 1 / scale] + samples = 1 / (jax.random.gamma(sampling_key, shape, size, dtype) / scale) + + rng["jax_state"] = rng_key + return (rng, samples) + + return sample_fn diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 5df9fce07c..d7d7cb190c 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -467,6 +467,19 @@ def test_random_dirichlet(parameter, size): np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1) +@pytest.mark.parametrize( + "shape, scale", + [(3, 3), (2, 1), (2, 5)], +) +def test_random_invgamma(shape, scale): + rng = shared(np.random.RandomState(123)) + g = at.random.invgamma(shape, scale, size=(100000,), rng=rng) + g_fn = function([], g, mode=jax_mode) + samples = g_fn() + # mean = scale / (shape - 1) only exists for shape > 1 + np.testing.assert_allclose(samples.mean(), scale / (shape - 1), rtol=1e-01) + + def test_random_choice(): # Elements are picked at equal frequency num_samples = 10000