Skip to content

Commit 9796f9a

Browse files
pscemama-mitrebrandonwillard
authored andcommitted
Add an InvGamma JAX implementation
1 parent 4a687c0 commit 9796f9a

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

aesara/link/jax/dispatch/random.py

+21
Original file line numberDiff line numberDiff line change
@@ -453,3 +453,24 @@ def sample_fn(rng, size, dtype, *parameters):
453453
return (rng, samples)
454454

455455
return sample_fn
456+
457+
458+
@jax_sample_fn.register(aer.InvGammaRV)
459+
def jax_sample_fn_invgamma(op):
460+
"""JAX implementation of `InvGamma`."""
461+
462+
def sample_fn(rng, size, dtype, *parameters):
463+
rng_key = rng["jax_state"]
464+
rng_key, sampling_key = jax.random.split(rng_key, 2)
465+
466+
(
467+
shape,
468+
scale,
469+
) = parameters
470+
# InvGamma[shape, scale] <-> 1 / Gamma[shape, 1 / scale]
471+
samples = 1 / (jax.random.gamma(sampling_key, shape, size, dtype) / scale)
472+
473+
rng["jax_state"] = rng_key
474+
return (rng, samples)
475+
476+
return sample_fn

tests/link/jax/test_random.py

+13
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,19 @@ def test_random_dirichlet(parameter, size):
467467
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)
468468

469469

470+
@pytest.mark.parametrize(
471+
"shape, scale",
472+
[(3, 3), (2, 1), (2, 5)],
473+
)
474+
def test_random_invgamma(shape, scale):
475+
rng = shared(np.random.RandomState(123))
476+
g = at.random.invgamma(shape, scale, size=(100000,), rng=rng)
477+
g_fn = function([], g, mode=jax_mode)
478+
samples = g_fn()
479+
# mean = scale / (shape - 1) only exists for shape > 1
480+
np.testing.assert_allclose(samples.mean(), scale / (shape - 1), rtol=1e-01)
481+
482+
470483
def test_random_choice():
471484
# Elements are picked at equal frequency
472485
num_samples = 10000

0 commit comments

Comments
 (0)