Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JAX implementation for MultinomialRV #1326

Open
Tracked by #1425
rlouf opened this issue Dec 3, 2022 · 6 comments
Open
Tracked by #1425

Add JAX implementation for MultinomialRV #1326

rlouf opened this issue Dec 3, 2022 · 6 comments
Labels
enhancement New feature or request good first issue Good for newcomers help wanted Extra attention is needed JAX Involves JAX transpilation random variables Involves random variables and/or sampling

Comments

@rlouf
Copy link
Member

rlouf commented Dec 3, 2022

No description provided.

@rlouf rlouf added JAX Involves JAX transpilation enhancement New feature or request good first issue Good for newcomers help wanted Extra attention is needed random variables Involves random variables and/or sampling labels Dec 3, 2022
@GStechschulte
Copy link

Hey @rlouf and others, I will give this PR a go. I am following #1335 and #1284 for more explanations on the implementations. Likewise, I will comment here on progress and questions.

@brandonwillard
Copy link
Member

Much appreciated, @GStechschulte!

@rlouf
Copy link
Member Author

rlouf commented Dec 9, 2022

Here is an explanation of how to go about adding an implementation: #1335 (comment)

@GStechschulte
Copy link

After using NumPyro, I remembered that they have a JAX implementation of the Multinomial distribution, albeit following the design of the PyTorch distributions module. Therefore, I adapted the code to align with the parameters argument passed into the respective RV's __call__ function in this file.

I still need to add the tests. As this is my first time contributing, how should I/we handle "using" a code snippet from another library? In this case, I have used 3 functions from the NumPyro library to meet the needs of this PR. My idea is to add a reference in a doc string? Thanks!

Here is the link to my branch of the JAX implementation of MultinomialRV.

@brandonwillard
Copy link
Member

Here is the link to my branch of the JAX implementation of MultinomialRV.

Feel free to create a PR for that branch. If you're still working on it, no worries; you can make it a draft PR.

@brandonwillard
Copy link
Member

I still need to add the tests. As this is my first time contributing, how should I/we handle "using" a code snippet from another library? In this case, I have used 3 functions from the NumPyro library to meet the needs of this PR. My idea is to add a reference in a doc string? Thanks!

We'll need to make sure that the licenses are compatible and see what they require.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers help wanted Extra attention is needed JAX Involves JAX transpilation random variables Involves random variables and/or sampling
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants