-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BUG: Jax-based samplers crash at transformation stage #6744
Comments
Setting |
I think this was solved by switching to scan as the default |
I'm still getting out of memory crashes after sampling even when using v5.10. Is it still possible to set |
The options are now scan or vmap, scan is the default which is more memory conscious: Line 188 in c53277b
|
Yeah, I saw that. I still get crashes post-processing on GPU for large models (even with |
This looks like it might help, though it is not implemented in Jax yet. We should probably keep the option for using |
We are already using Scan by default, so I don't think it would help |
I'm running into the same OOM issue in post-processing with the default |
IIRC postprocessing_chunks is just using scan under the hood anyway, so it shouldn't help. Can you check it actually helps in your case? We need an example to investigate this issue, but if you see a difference we can consider temporarily reverting while we figure it out |
Describe the issue:
The Jax-based samplers crash after sampling, following the "Transforming variables..." message on medium-to-large models (thousands of rows, hundreds of parameters). This occurs both on GPU and CPU systems, and using either the numpyro or blackjax samplers. The failure on GPU returns a backtrace that isolates the issue at the
vmap
in_postprocess_samples
. On a CPU (MacBook Pro M1), the process is simply killed without any error messages. I have tried running the GPU model with thepostprocessing_backend="cpu"
argument for the numpyro sampler, but this does not seem to make a difference. Should it be using vmap when the postprocessing backend is CPU?Reproduceable code example:
Will add example when I can come up with one
Error message:
PyMC version information:
PyMC 5.3.0
PyTensor 2.11.1
Context for the issue:
The numpyro sampler is currently unusable for moderate-sized models due to this issue.
The text was updated successfully, but these errors were encountered: