-
-
Notifications
You must be signed in to change notification settings - Fork 152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add JAX implementation for HalfNormalRV
#1335
Comments
Hey Rémi, I'd like to give this a shot, but I'm not sure where to start with the codebase as I'm new to aesara. I've done a search for @_numba_funcify.register(aer.HalfNormalRV)
def numba_funcify_HalfNormalRV(op, node, **kwargs):
def body_fn(a, b):
return f" return {a} + {b} * abs(np.random.normal(0, 1))"
return create_numba_random_fn(op, node, body_fn) and a scipy.stats version. If there's a template or some steps I can follow, I'd be happy to implement this and a few others RVs, but if it's quite involved and requires a really deep knowledge of aesara, then I don't think I have that yet. |
Thank you for your interest! I'll come back to you here with an explanation once #1284 is merged, it should then be possible to add the implementation without deep knowledge of the internals. |
Now that #1284 is merged it should be a lot easier to contribute JAX implementations for random variables. You would need to add the implementation in this file, for instance the implementation for the Student T random variable: @jax_sample_fn.register(aer.StudentTRV) # tells Aesara this is the implementation of `StudentRV`
def jax_sample_fn_t(op):
"""JAX implementation of `StudentTRV`."""
def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
(
df,
loc,
scale,
) = parameters
# here we're lucky, already implemented in JAX. This is where your code
# would go.
sample = loc + jax.random.t(rng_key, df, size, dtype) * scale
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
return (rng, sample)
return sample_fn The content of the You will also need to add a test in this file. It should be able to fit in Everything else should be self-explanatory, but don't hesitate if you have any question. |
No description provided.
The text was updated successfully, but these errors were encountered: