Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX: Implement geometric RV sampling #1444

Merged
merged 2 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions aesara/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,20 @@ def sample_fn(rng, size, dtype, *parameters):
return (rng, samples)

return sample_fn


@jax_sample_fn.register(aer.GeometricRV)
def jax_sample_fn_geometric(op):
"""JAX implementation of `GeometricRV`."""

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
p = parameters[0]
sample_num = jax.numpy.log(jax.random.uniform(sampling_key, size))
sample = sample_num / jax.numpy.log1p(-p)
sample_ceil = jax.numpy.ceil(sample)
Comment on lines +402 to +404
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rng["jax_state"] = rng_key
return (rng, sample_ceil)

return sample_fn
15 changes: 15 additions & 0 deletions tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,21 @@ def test_random_bernoulli(size):
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)


@pytest.mark.parametrize(
"p, size",
[
(0.6, ()),
(0.2, (4,)),
],
)
def test_random_geometric(p, size):
rng = shared(np.random.RandomState(123))
g = at.random.geometric(p, size=(1000,) + size, rng=rng)
g_fn = function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(), 1 / p, atol=0.1)


def test_random_mvnormal():
rng = shared(np.random.RandomState(123))

Expand Down