-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Support foward sampling via JAX #7348
Comments
#7268 may be a bit of a blocker. Right now we can't do model transformations inside another model context (because I cannot create a distinct model inside another model). I would love to get rid of this behavior, but that means breaking the nested model API. The nested model API is an overkill for basically auto-variable name prefixing... Another point to consider is #7177 . For caching to be useable, users have to have control over when and how to freeze their models, and we shouldn't interfere too much by transforming models under the hood on behalf of users. We could perhaps establish a compromise where we transform if the backend is JAX, and not otherwise. Or perhaps we re-introduce The |
I had in mind that it should be possible to the necessary replacements on the forward sampling graph directly, like model -> forward_graph -> frozen_forward_graph -> sampling_function, without making the round trip model -> fgraph -> frozen_model -> frozen_forward_graph -> sampling_function. Is this not possible? I was thinking we just need to extract all the necessary shape information from the model then apply rewrites that are essentially just Re: For me, #7177 is more about allowing the use of |
I'm afraid of adding special logic inside the forward samplers that is JAX specific. I preferred the RE: Static argnums, we should just try it. Basically we need to replace a vector input by a tuple of scalars all of which are static_argnums in the compiled function (since numpy arrays aren't hashable and accepted as static_argnums). I think that may still be a clean solution that doesn't require us to do anything here. |
I guess I had in mind that this hypothetical graph operation could just replace But I agree that the whole nesting thing is a bit of a PITA ,and it might just be better to attack that somehow. |
#7352 makes it possible to apply The question of why to go back from fgraph to model is something we can start tackling in #7268. Potentially |
I love the idea @jessegrabowski ! That would definitely make that part of the API clearer to users, and I think it's welcome. I wanna make sure I understand what |
It would fail with a shape error since the model was first frozen, then
Actually now that #7352 is merged, you can even do this:
Which is essentially what I envision happening automatically inside |
Should be fixed by pymc-devs/pytensor#1029 |
Before
After
Context for the issue:
For models involving scan or multivariate-normal distributions, you get big speedups by passing
compile_kwargs={'mode':'JAX'}
topm.sample_prior_predictive
andpm.sample_posterior_predictive
. This has already proven useful in statespace modeling (in pymc-experimental, pymc-devs/pymc-extras#346) and instrumental variable modeling (in casualpy, pymc-labs/CausalPy#345). In each of these cases using the JAX backend offers significant speedups, and is a highly desirable feature.This was technically never a supported feature, but it could be made to work by consciously specifying the whole model to be static (e.g. using
pm.ConstantData
and avoidingmutable_kwargs
). After #7047 this is obviously no longer possible. The work-around is to usefreeze_dims_and_data
, but this is somewhat cumbersome, especially with prior predictive sampling, where a typical workflow callspm.sample_prior_predictive
in the model block at construction time. I have also come up with cases wherefreeze_dims_and_data
fails. A trivial example is in predictive modeling usingpm.Flat
dummies -- this adds non-None
entries tomodel.rvs_to_initial_values
, causingmodel_to_fgraph
to fail.My proposal would be to simply add a "freezing" step to
compile_forward_sampling_function
. This would alleviate the need for users to be aware of thefreeze_dims_and_data
helper function, allow JAX forward sampling without breaking out of a single model context, and also support any future backend that requires all shape information to be known.I would also propose to officially support and expose alternative forward sampling backends by promoting
backend=
ormode=
to a kwarg inpm.sample_*_predictive
, rather than hiding it insidecompile_kwargs
.The text was updated successfully, but these errors were encountered: