-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement gamma sampler using core.Primitive interface #1790
Conversation
@@ -616,7 +618,6 @@ def beta(key, a, b, shape=None, dtype=onp.float64): | |||
dtype = dtypes.canonicalize_dtype(dtype) | |||
return _beta(key, a, b, shape, dtype) | |||
|
|||
@partial(jit, static_argnums=(3, 4)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to jit this because we don't need to compile gamma_b.
@mattjj Can we have this for the next JAX release? We need this fix to unblock some issues related to gamma sampler in NumPyro. |
Thanks for the ping. We’ve been slow the last few weeks because of NeurIPS then the holidays. I promise to follow up on this on Monday. And yes we’ll get it in the next release! Sorry for blocking you. |
Please take your time! I saw a new release of JAX so I just wanted to make sure that you get notified about this PR. :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, great work, and thanks!
My only suggestion is that I think we can get rid of random_gamma_p.multiple_results = True
and then remove some singleton-tuplification in a few places. Let me know what you think of that, then let's merge!
I suppose we also need to resolve conflicts with master before being able to merge. Let me know if you need any help with that! |
@fehiepsi just uploaded jax==0.1.57 to pypi with this change! |
Thanks a lot, @mattjj !! |
This PR tries to address #1789. While working on a fix, I realize that (compiling time)
lax.map
outperformsvmap
for most usage cases. So I believe that the tests for gamma/beta/dirichlet will be much faster than before. With this, probably #552 and #1188 is not necessary to be addressed anymore.However, I am getting trouble with translation rule. With the script
jit(f)(np.ones(3))
throws an error at printing functionwhile
jit(f)(np.ones(()))
throws the errorI guess the output with translation rule is a tuple and I need to pack it into an array, but I am not sure how to process. I tried to mimic
custom_transform
implementation to use tree_flatten, flatten_fun_nokwargs... but don't know how to write a compatible batching rule. Several months ago, the following code worksbut those utilities
jaxpr_computation
,trace_unwrapped_to_jaxpr
are no longer available and I am not sure if a similar code will work now.I appreciate any help!
cc @mattjj @shoyer @jekbradbury
Updated: thanks to @jekbradbury, I am able to make this work. This PR resolves #1789, #1875 (comment), #1188, #552 (actually, the fourth one is not relevant, but with the change in this PR, it seems fast to get samples from gamma sampler now).
@mattjj It would be great to have this for the next jax release. :)