Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HalfNormalRV JAX implementation #1362

Merged
merged 1 commit into from
Dec 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions aesara/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,28 @@ def sample_fn(rng, size, dtype, *parameters):
return sample_fn


@jax_sample_fn.register(aer.HalfNormalRV)
def jax_sample_fn_halfnormal(op):
"""JAX implementation of `HalfNormalRV`."""

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
(
loc,
scale,
) = parameters
sample = (
loc
+ jax.random.truncated_normal(sampling_key, 0.0, jax.numpy.inf, size, dtype)
* scale
)
rng["jax_state"] = rng_key
return (rng, sample)

return sample_fn


@jax_sample_fn.register(aer.ChoiceRV)
def jax_funcify_choice(op):
"""JAX implementation of `ChoiceRV`."""
Expand Down
16 changes: 16 additions & 0 deletions tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,22 @@ def test_random_updates(rng_ctor):
"uniform",
lambda *args: args,
),
(
aer.halfnormal,
[
set_test_value(
at.dvector(),
np.array([-1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1000.0, dtype=np.float64),
),
],
(2,),
"halfnorm",
lambda *args: args,
),
],
)
def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_conv):
Expand Down