Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JAX implementation for HalfNormalRV #1335

Closed
rlouf opened this issue Dec 3, 2022 · 3 comments · Fixed by #1362
Closed

Add JAX implementation for HalfNormalRV #1335

rlouf opened this issue Dec 3, 2022 · 3 comments · Fixed by #1362
Labels
enhancement New feature or request good first issue Good for newcomers help wanted Extra attention is needed JAX Involves JAX transpilation random variables Involves random variables and/or sampling

Comments

@rlouf
Copy link
Member

rlouf commented Dec 3, 2022

No description provided.

@rlouf rlouf added JAX Involves JAX transpilation enhancement New feature or request good first issue Good for newcomers help wanted Extra attention is needed random variables Involves random variables and/or sampling labels Dec 3, 2022
@theorashid
Copy link
Contributor

Hey Rémi, I'd like to give this a shot, but I'm not sure where to start with the codebase as I'm new to aesara.

I've done a search for HalfNormalRV in the codebase, and I can see there's a numba version

@_numba_funcify.register(aer.HalfNormalRV)
def numba_funcify_HalfNormalRV(op, node, **kwargs):
    def body_fn(a, b):
        return f"    return {a} + {b} * abs(np.random.normal(0, 1))"

    return create_numba_random_fn(op, node, body_fn)

and a scipy.stats version.

If there's a template or some steps I can follow, I'd be happy to implement this and a few others RVs, but if it's quite involved and requires a really deep knowledge of aesara, then I don't think I have that yet.

@rlouf
Copy link
Member Author

rlouf commented Dec 7, 2022

Thank you for your interest! I'll come back to you here with an explanation once #1284 is merged, it should then be possible to add the implementation without deep knowledge of the internals.

@rlouf
Copy link
Member Author

rlouf commented Dec 9, 2022

Now that #1284 is merged it should be a lot easier to contribute JAX implementations for random variables. You would need to add the implementation in this file, for instance the implementation for the Student T random variable:

@jax_sample_fn.register(aer.StudentTRV)  # tells Aesara this is the implementation of `StudentRV`
def jax_sample_fn_t(op):
    """JAX implementation of `StudentTRV`."""

    def sample_fn(rng, size, dtype, *parameters):
        rng_key = rng["jax_state"]
        (
            df,
            loc,
            scale,
        ) = parameters
        # here we're lucky, already implemented in JAX. This is where your code
        # would go.
        sample = loc + jax.random.t(rng_key, df, size, dtype) * scale 
        rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
        return (rng, sample)

    return sample_fn

The content of the parameters argument can be determined by looking at the implementation of the RandomVariable Op in this file by looking at the parameters passed to the `call function.

You will also need to add a test in this file. It should be able to fit in test_random_RandomVariable if SciPy has an implementation for this distribution.

Everything else should be self-explanatory, but don't hesitate if you have any question.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers help wanted Extra attention is needed JAX Involves JAX transpilation random variables Involves random variables and/or sampling
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants