Skip to content

Commit a064bcb

Browse files
theorashidrlouf
authored andcommitted
Add HalfNormalRV JAX implementation (aesara-devs#1362)
1 parent 2434cb4 commit a064bcb

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

aesara/link/jax/dispatch/random.py

+22
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,28 @@ def sample_fn(rng, size, dtype, *parameters):
251251
return sample_fn
252252

253253

254+
@jax_sample_fn.register(aer.HalfNormalRV)
255+
def jax_sample_fn_halfnormal(op):
256+
"""JAX implementation of `HalfNormalRV`."""
257+
258+
def sample_fn(rng, size, dtype, *parameters):
259+
rng_key = rng["jax_state"]
260+
rng_key, sampling_key = jax.random.split(rng_key, 2)
261+
(
262+
loc,
263+
scale,
264+
) = parameters
265+
sample = (
266+
loc
267+
+ jax.random.truncated_normal(sampling_key, 0.0, jax.numpy.inf, size, dtype)
268+
* scale
269+
)
270+
rng["jax_state"] = rng_key
271+
return (rng, sample)
272+
273+
return sample_fn
274+
275+
254276
@jax_sample_fn.register(aer.ChoiceRV)
255277
def jax_funcify_choice(op):
256278
"""JAX implementation of `ChoiceRV`."""

tests/link/jax/test_random.py

+16
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,22 @@ def test_random_updates(rng_ctor):
280280
"uniform",
281281
lambda *args: args,
282282
),
283+
(
284+
aer.halfnormal,
285+
[
286+
set_test_value(
287+
at.dvector(),
288+
np.array([-1.0, 2.0], dtype=np.float64),
289+
),
290+
set_test_value(
291+
at.dscalar(),
292+
np.array(1000.0, dtype=np.float64),
293+
),
294+
],
295+
(2,),
296+
"halfnorm",
297+
lambda *args: args,
298+
),
283299
],
284300
)
285301
def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_conv):

0 commit comments

Comments
 (0)