-
-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add JAX implementation for MultinomialRV
#1326
Comments
Much appreciated, @GStechschulte! |
Here is an explanation of how to go about adding an implementation: #1335 (comment) |
After using NumPyro, I remembered that they have a JAX implementation of the Multinomial distribution, albeit following the design of the PyTorch distributions module. Therefore, I adapted the code to align with the parameters argument passed into the respective RV's I still need to add the tests. As this is my first time contributing, how should I/we handle "using" a code snippet from another library? In this case, I have used 3 functions from the NumPyro library to meet the needs of this PR. My idea is to add a reference in a doc string? Thanks! Here is the link to my branch of the JAX implementation of MultinomialRV. |
Feel free to create a PR for that branch. If you're still working on it, no worries; you can make it a draft PR. |
We'll need to make sure that the licenses are compatible and see what they require. |
No description provided.
The text was updated successfully, but these errors were encountered: