-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
BUG: Regression in jax translation from 5.12 -> 5.13 #7263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Does using this helper first, fix the problem? pymc/pymc/model/transform/optimization.py Line 23 in 4bc8439
|
How would one use it? |
with pm.Model() as model:
...
frozen_model = freeze_dims_and_data(model)
with frozen_model:
pm.sample(nuts_sampler="numpyro") |
Yes that works. Both for the toy example as well as the real model. |
No |
Well there seems to be one. It is now throwing errors if I add initvals to the models. Any workarounds for that? |
Can you provide a minimum working example? |
throws
|
Right we don't support custom initial values on the model transformations. You should be able to specify them when calling |
Or specify them after freezing the model, with |
ok. Maybe it makes sense to deprecate initval parameter on the RVs then? |
I think there was resistance to that, but it would be my preference |
Um. Resistance to removing something that no longer works? Or do I misunderstand something? |
It works, just not for model transformations like freeze_dims_and_data |
That's why it's a |
and jax sampler only works for frozen dims if we want more complex pytensor manipulations? |
Most models should sample fine in JAX without frozen dims, but we expect some hiccups like the one you found. Hence why that helper was added. That helper does not work with custom initvals, but you can pass custom initvals directly to sampler anyway. It's not a final solution, but everyone should be able to do their thing right now. |
Your model wouldn't have worked before the changes with mutable dims anyway. You can change this line: pt.zeros( (len(ns),model.dim_lengths['mw']) ) To: pt.zeros( (len(ns), len(model.coords['mw'])) That will freeze the second dimension, instead of linking it to the mutable |
Ok. I guess I have my answers, and you are right, I have all the tools needed to make it work. Thank you for your thorough answers @ricardoV94 |
The out of sample model pattern breaks in a JAX workflow due to the issues reported here. For example: import pymc as pm
from pymc.model.transform.optimization import freeze_dims_and_data
with pm.Model(coords={'a':[1]}) as m:
mu = pm.Normal('mu', dims=['a'])
x = pm.Normal('x', mu=mu, sigma=1, dims=['a'])
idata = pm.sample(nuts_sampler='numpyro')
with pm.Model(coords={'a':[1]}) as new_m:
mu = pm.Flat('mu', dims=['a'])
x = pm.Normal('x', mu, sigma=2, dims=['a'])
frozen_new_m = freeze_dims_and_data(new_m)
with frozen_new_m:
idata_pred = pm.sample_posterior_predictive(idata, var_names=['x'],
predictions=True,
compile_kwargs={'mode':'JAX'}) I guess there's some automatic initivals being silently set when we use |
Describe the issue:
I have a few models where I have to do some rather complex tensor manipulation, and moving from 5.12 to 5.13 quite a few of them broke down with JAX errors.
As the models themselves are big and unwieldy, I have tried to re-create the same issue with a toy example. As you can see, it needs to be quite convoluted to illicit the error (requiring a model dimension, a call to pt.concatenate and pt.set_subtensor), but I do run into it with more complex actual use cases as well.
I have managed to work around it i some cases by avoiding pt.concatenate and instead just creating an empty tensor and setting it's parts via set_subtensor, but I have one model where even that runs into issues. So it would be very nice if it worked like it used to before :)
The facts of the case:
Reproduceable code example:
Error message:
PyMC version information:
Fails on 5.13.1
Context for the issue:
No response
The text was updated successfully, but these errors were encountered: