Skip to content

Commit bfcfe4b

Browse files
theorashidrlouf
authored andcommitted
Add HalfNormalRV JAX implementation (#1362)
1 parent 94f3f32 commit bfcfe4b

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

aesara/link/jax/dispatch/random.py

+22
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,28 @@ def sample_fn(rng, size, dtype, *parameters):
260260
return sample_fn
261261

262262

263+
@jax_sample_fn.register(aer.HalfNormalRV)
264+
def jax_sample_fn_halfnormal(op):
265+
"""JAX implementation of `HalfNormalRV`."""
266+
267+
def sample_fn(rng, size, dtype, *parameters):
268+
rng_key = rng["jax_state"]
269+
rng_key, sampling_key = jax.random.split(rng_key, 2)
270+
(
271+
loc,
272+
scale,
273+
) = parameters
274+
sample = (
275+
loc
276+
+ jax.random.truncated_normal(sampling_key, 0.0, jax.numpy.inf, size, dtype)
277+
* scale
278+
)
279+
rng["jax_state"] = rng_key
280+
return (rng, sample)
281+
282+
return sample_fn
283+
284+
263285
@jax_sample_fn.register(aer.ChoiceRV)
264286
def jax_funcify_choice(op):
265287
"""JAX implementation of `ChoiceRV`."""

tests/link/jax/test_random.py

+16
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,22 @@ def test_random_updates(rng_ctor):
280280
"uniform",
281281
lambda *args: args,
282282
),
283+
(
284+
aer.halfnormal,
285+
[
286+
set_test_value(
287+
at.dvector(),
288+
np.array([-1.0, 2.0], dtype=np.float64),
289+
),
290+
set_test_value(
291+
at.dscalar(),
292+
np.array(1000.0, dtype=np.float64),
293+
),
294+
],
295+
(2,),
296+
"halfnorm",
297+
lambda *args: args,
298+
),
283299
],
284300
)
285301
def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_conv):

0 commit comments

Comments
 (0)