Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ChiSquareRV JAX implementation #1363

Merged
merged 2 commits into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion aesara/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@ def jax_sample_fn_wald(op):
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)

mean, scale = parameters

key1, key2 = jax.random.split(sampling_key, 2)
Expand All @@ -391,6 +390,21 @@ def sample_fn(rng, size, dtype, *parameters):
return sample_fn


@jax_sample_fn.register(aer.ChiSquareRV)
def jax_sample_fn_chisquare(op):
"""JAX implementation of `ChiSquareRV`"""

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
(df,) = parameters
sample = jax.random.gamma(sampling_key, df / 2, size, dtype) * 2
rng["jax_state"] = rng_key
return (rng, sample)

return sample_fn


@jax_sample_fn.register(aer.GeometricRV)
def jax_sample_fn_geometric(op):
"""JAX implementation of `GeometricRV`."""
Expand Down
21 changes: 16 additions & 5 deletions tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,19 @@ def test_random_updates(rng_ctor):
lambda *args: args,
None,
),
(
aer.chisquare,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
)
],
(2,),
"chi2",
lambda *args: args,
50_000,
),
(
aer.exponential,
[
Expand Down Expand Up @@ -587,12 +600,10 @@ def test_random_concrete_shape_subtensor_tuple():
assert jax_fn(np.ones((2, 3))).shape == (2,)


@pytest.mark.xfail(
reason="`size_at` should be specified as a static argument", strict=True
)
def test_random_concrete_shape_graph_input():
"""JAX cannot JIT-compile random variables whose `size` argument is not static."""
rng = shared(np.random.RandomState(123))
size_at = at.scalar()
out = at.random.normal(0, 1, size=size_at, rng=rng)
jax_fn = function([size_at], out, mode=jax_mode)
assert jax_fn(10).shape == (10,)
with pytest.raises(NotImplementedError, match=r".* concrete values .*"):
function([size_at], out, mode=jax_mode)