diff --git a/aesara/link/jax/dispatch/random.py b/aesara/link/jax/dispatch/random.py index 6fb9c28c84..2212662227 100644 --- a/aesara/link/jax/dispatch/random.py +++ b/aesara/link/jax/dispatch/random.py @@ -370,7 +370,6 @@ def jax_sample_fn_wald(op): def sample_fn(rng, size, dtype, *parameters): rng_key = rng["jax_state"] rng_key, sampling_key = jax.random.split(rng_key, 2) - mean, scale = parameters key1, key2 = jax.random.split(sampling_key, 2) @@ -391,6 +390,21 @@ def sample_fn(rng, size, dtype, *parameters): return sample_fn +@jax_sample_fn.register(aer.ChiSquareRV) +def jax_sample_fn_chisquare(op): + """JAX implementation of `ChiSquareRV`""" + + def sample_fn(rng, size, dtype, *parameters): + rng_key = rng["jax_state"] + rng_key, sampling_key = jax.random.split(rng_key, 2) + (df,) = parameters + sample = jax.random.gamma(sampling_key, df / 2, size, dtype) * 2 + rng["jax_state"] = rng_key + return (rng, sample) + + return sample_fn + + @jax_sample_fn.register(aer.GeometricRV) def jax_sample_fn_geometric(op): """JAX implementation of `GeometricRV`.""" diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 4adc7f2241..ade208754d 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -97,6 +97,19 @@ def test_random_updates(rng_ctor): lambda *args: args, None, ), + ( + aer.chisquare, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ) + ], + (2,), + "chi2", + lambda *args: args, + None, + ), ( aer.exponential, [