Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Support foward sampling via JAX #7348

Open
jessegrabowski opened this issue Jun 4, 2024 · 8 comments
Open

ENH: Support foward sampling via JAX #7348

jessegrabowski opened this issue Jun 4, 2024 · 8 comments
Labels

Comments

@jessegrabowski
Copy link
Member

Before

import pymc as pm
from pymc.model.transform.optimization import freeze_dims_and_data

with pm.Model() as m:
    ...

with freeze_dims_and_data(m):
    idata_prior = pm.sample_prior_predictive(compile_kwargs={"mode":"JAX"})

with m:
    idata = pm.sample()

with freeze_dims_and_data(m):
    idata = pm.sample_posterior_predictive(idata, extend_inferencedata=True, compile_kwargs = {"mode":"JAX"})

After

import pymc as pm

with pm.Model() as m:
   ...
   idata_prior = pm.sample_prior_predictive(mode="JAX")
   idata = pm.sample()
   idata = pm.sample_posterior_predictive(idata, mode="JAX", extend_inferencedata=True)

Context for the issue:

For models involving scan or multivariate-normal distributions, you get big speedups by passing compile_kwargs={'mode':'JAX'} to pm.sample_prior_predictive and pm.sample_posterior_predictive. This has already proven useful in statespace modeling (in pymc-experimental, pymc-devs/pymc-extras#346) and instrumental variable modeling (in casualpy, pymc-labs/CausalPy#345). In each of these cases using the JAX backend offers significant speedups, and is a highly desirable feature.

This was technically never a supported feature, but it could be made to work by consciously specifying the whole model to be static (e.g. using pm.ConstantData and avoiding mutable_kwargs). After #7047 this is obviously no longer possible. The work-around is to use freeze_dims_and_data, but this is somewhat cumbersome, especially with prior predictive sampling, where a typical workflow calls pm.sample_prior_predictive in the model block at construction time. I have also come up with cases where freeze_dims_and_data fails. A trivial example is in predictive modeling using pm.Flat dummies -- this adds non-None entries to model.rvs_to_initial_values, causing model_to_fgraph to fail.

My proposal would be to simply add a "freezing" step to compile_forward_sampling_function. This would alleviate the need for users to be aware of the freeze_dims_and_data helper function, allow JAX forward sampling without breaking out of a single model context, and also support any future backend that requires all shape information to be known.

I would also propose to officially support and expose alternative forward sampling backends by promoting backend= or mode= to a kwarg in pm.sample_*_predictive, rather than hiding it inside compile_kwargs.

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 4, 2024

#7268 may be a bit of a blocker.

Right now we can't do model transformations inside another model context (because I cannot create a distinct model inside another model). I would love to get rid of this behavior, but that means breaking the nested model API. The nested model API is an overkill for basically auto-variable name prefixing...

Another point to consider is #7177 . For caching to be useable, users have to have control over when and how to freeze their models, and we shouldn't interfere too much by transforming models under the hood on behalf of users.

We could perhaps establish a compromise where we transform if the backend is JAX, and not otherwise.

Or perhaps we re-introduce pm.ConstantData. I would however not rename pm.Data to MutableData and perhaps not reintroduce any mutable kwargs? So most users don't have to see/think about it. Those that want to sample with JAX can go directly to ConstantData. I agree that it's a bit cumbersome to force users who want that to first define the model and then call freeze_rv_and_dims. How would they find about ConstatData though?

The initval thing... Flat/HalfFlat variables should no longer need custom initvals? If we are still putting those we should remove it. finite_logp_point should works just fine for them. We can support models with custom initvals, I just dislike them at the model level and didn't want to bother representing them in the fgraph format. I still think they should just be passed to pm.sample when a custom initval is needed.

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Jun 4, 2024

I had in mind that it should be possible to the necessary replacements on the forward sampling graph directly, like model -> forward_graph -> frozen_forward_graph -> sampling_function, without making the round trip model -> fgraph -> frozen_model -> frozen_forward_graph -> sampling_function. Is this not possible? I was thinking we just need to extract all the necessary shape information from the model then apply rewrites that are essentially just pt.specify_shape

Re: CosntantData, what we're doing now is actually better, I don't think it's correct to go backwards. It basically forces users to choose between pm.set_data and JAX-forward sampling, which doesn't seem like something they should have to do. There should definitely be a way to have both.

For me, #7177 is more about allowing the use of static_argnums in the JAX functions we generate. Maybe that would solve a lot of our problems? In 99% of uses, the JAX function is completely re-generated each time it is needed anyway, so we essentially lose nothing by specifying inputs with unknown shape as static_argnums. I think we've talked about this before, but your answer didn't stick with me.

@ricardoV94
Copy link
Member

I had in mind that it should be possible to the necessary replacements on the forward sampling graph directly, like model -> forward_graph -> frozen_forward_graph -> sampling_function, without making the round trip model -> fgraph -> frozen_model -> frozen_forward_graph -> sampling_function. Is this not possible? I was thinking we just need to extract all the necessary shape information from the model then apply rewrites that are essentially just pt.specify_shape

I'm afraid of adding special logic inside the forward samplers that is JAX specific. I preferred the freeze_rv_and_dims route because is reusing generic code that has other applications.

RE: Static argnums, we should just try it. Basically we need to replace a vector input by a tuple of scalars all of which are static_argnums in the compiled function (since numpy arrays aren't hashable and accepted as static_argnums). I think that may still be a clean solution that doesn't require us to do anything here.

@jessegrabowski
Copy link
Member Author

I guess I had in mind that this hypothetical graph operation could just replace freeze_rv_and_dims. It's sort of a hack to pass go back to a model representation when we could just operate on the graph (logp or forward sampling) directly. From a rewrite perspective, there shouldn't be any difference between these, no?

But I agree that the whole nesting thing is a bit of a PITA ,and it might just be better to attack that somehow.

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 10, 2024

#7352 makes it possible to apply freeze_rv_and_dims inside a model without worries.

The question of why to go back from fgraph to model is something we can start tackling in #7268. Potentially pm.Model could become a more thin shell that just builds the corresponding fgraph under the hood, but that may take sometime to do properly (i.e., not break everything accidentally).

@AlexAndorra
Copy link
Contributor

I love the idea @jessegrabowski ! That would definitely make that part of the API clearer to users, and I think it's welcome.

I wanna make sure I understand what freeze_dims_and_data does though: will that be a problem when users call pm.set_data with out-of-sample data before calling pm.sample_posterior_predictive?
And is that also needed with Numba mode?

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Jun 21, 2024

freeze_dims_and_data is only required for JAX, because it doesn't allow dynamic array shapes inside JIT functions. This isn't the case for numba jit functions (you can @njit a function then pass whatever shapes you like), so it wouldn't be necessary there. Not sure about pytorch.

freeze_dims_and_data(model) creates a copy of model, with all shape information inserted from available data/coords. So this code would not work with pm.set_data:

with freeze_dims_and_data(model):
    pm.set_data({'X':df_test.values}, coords={'obs_idx':df_test.index})
    idata_pred = pm.sample_posterior_predictive(idata)

It would fail with a shape error since the model was first frozen, then pm.set_data was called. But it would work if you do things in the other order: first update the data that will be used by freeze to determine all the static shapes, then call freeze:

with model:
    pm.set_data({'X':df_test.values}, coords={'obs_idx':df_test.index})
with freeze_dims_and_data(model):
    idata_pred = pm.sample_posterior_predictive(idata)

Actually now that #7352 is merged, you can even do this:

with model:
    pm.set_data({'X':df_test.values}, coords={'obs_idx':df_test.index})
    with freeze_dims_and_data(model):
        idata_pred = pm.sample_posterior_predictive(idata)

Which is essentially what I envision happening automatically inside pm.sample_*_predictive if you pass mode="JAX"

@ricardoV94
Copy link
Member

Should be fixed by pymc-devs/pytensor#1029

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants