-
Notifications
You must be signed in to change notification settings - Fork 107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NUTS performance concerns on GPU #597
Comments
This is not my experience, what is your environment? Also, important note re benchmarking JAX: https://jax.readthedocs.io/en/latest/async_dispatch.html |
@junpenglao I will check back on this later this weekend -- possible that it is an environment problem. I'll get back to you then. |
This is replicable on Colab, so I don't think it's an environment issue.
I think this is more or less expected behavior though when the problem is rather small, and doesn't include operations GPUs are particularly good at. There was a similar discussion for NumPyro here, with the takeaway being that Jax is particularly efficient on CPU and GPU acceleration only makes sense for certain problems. |
Note that NUTS is control-flow heavy which makes its hard to run fast on a GPU. See the CHEES algorithm, implemented in BlackJax, for a NUTS-like sampler that avoids this problem. |
Describe the issue as clearly as possible:
On a trivial example (that of quickstart.md) there appears to be a weird bug I'm experiencing with the NUTS sampler using a GPU.
When I run the script (which I copy below) with a GPU for 200 steps I get
Jax sees these devices: [gpu(id=0)] Starting to run nuts for 200 steps Nuts took 0.050431712468465166 minutes
When I run the script with a GPU for 300 steps I get
Jax sees these devices: [gpu(id=0)] Starting to run nuts for 300 steps Nuts took 0.8048396507898966 minutes
When I run the script with GPU for 500 steps I get
Jax sees these devices: [gpu(id=0)] Starting to run nuts for 500 steps Nuts took 1.2937044938405355 minutes
When I run the script on CPU with 1000 steps I get
Jax sees these devices: [CpuDevice(id=0)] Starting to run nuts for 1000 steps Nuts took 0.06121724049250285 minutes
Steps/code to reproduce the bug:
Expected result:
Error message:
No response
Blackjax/JAX/jaxlib/Python version information:
Context for the issue:
No response
The text was updated successfully, but these errors were encountered: