From 8d3089d776ee5e3a1b2fa6e34c71e39a4dc07cf9 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Fri, 9 Dec 2022 20:52:42 +0100 Subject: [PATCH] jax lognormal --- aesara/link/jax/dispatch/random.py | 15 +++++++++++++++ tests/link/jax/test_random.py | 16 ++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/aesara/link/jax/dispatch/random.py b/aesara/link/jax/dispatch/random.py index d7d5b8c515..ccafdff129 100644 --- a/aesara/link/jax/dispatch/random.py +++ b/aesara/link/jax/dispatch/random.py @@ -277,3 +277,18 @@ def sample_fn(rng, size, dtype, *parameters): return (rng, sample) return sample_fn + + +@jax_sample_fn.register(aer.LogNormalRV) +def jax_sample_fn_lognormal(op): + """JAX implementation of `LogNormalRV`.""" + + def sample_fn(rng, size, dtype, *parameters): + rng_key = rng["jax_state"] + loc, scale = parameters + sample = loc + jax.random.normal(rng_key, size, dtype) * scale + sample_exp = jax.numpy.exp(sample) + rng["jax_state"] = jax.random.split(rng_key, num=1)[0] + return (rng, sample_exp) + + return sample_fn diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 1a87ddb196..9934b7ae24 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -165,6 +165,22 @@ def test_random_updates(rng_ctor): "logistic", lambda *args: args, ), + ( + aer.lognormal, + [ + set_test_value( + at.lvector(), + np.array([0, 0], dtype=np.int64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + (2,), + "lognorm", + lambda *args: args, + ), ( aer.normal, [