Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX/NUMPYRO does not include constant_data #5781

Closed
jbh1128d1 opened this issue May 18, 2022 · 8 comments
Closed

JAX/NUMPYRO does not include constant_data #5781

jbh1128d1 opened this issue May 18, 2022 · 8 comments

Comments

@jbh1128d1
Copy link

When I sample larger datasets using jax (which will only sample this without error), the constant_data is not included in the inference_data output. See below.

example.

coords = {
    "obs_id": [0,1,2,3,4],
}
with pm.Model(coords=coords) as rugby_model:
    item_idx = pm.Data('item_idx',[0,1,2,3,4], dims="obs_id", mutable=False)
    a = pm.Normal("a", 0.0, sigma=10.0, shape=5)

    theta = a[item_idx]
    sigma = pm.HalfCauchy("error", 0.5)

    y = pm.Normal("y", theta, sigma=sigma, observed=[3,2,6,8,4])

    idata= pm.sample()
    idata_jax = pymc.sampling_jax.sample_numpyro_nuts(tune=1000, chains = 4, target_accept=0.9)

image

Complete error traceback
[The complete error output here]

Please provide any additional information below.

Versions and main components

@ricardoV94
Copy link
Member

ricardoV94 commented May 18, 2022

Jeesh how many InferenceData groups are there xD? I confirmed that this is still missing in main. Thanks for reporting

@OriolAbril
Copy link
Member

@ricardoV94 worst case scenario below 😜

imatge

@twiecki
Copy link
Member

twiecki commented May 18, 2022

@ricardoV94 Can you add some more info on what needs to be done here?

@ricardoV94
Copy link
Member

@OriolAbril might have a better idea, I am not very familiar with the InferenceData creation from PyMC

@OriolAbril
Copy link
Member

The pm.Data variables are included in the generated InferenceData object in the constant_data group, similarty to the variables passed as observed to any variables in the model are added to the observed_data group. I don't know how the jax sampler wrappers work (at any step, not even inferencedata creation) so I don't know what specific steps need to be taken to fix this, and it also is strange to have observed_data but not constant_data 🤔.

@danhphan
Copy link
Member

Hi, I think this issue is due to no constant_data is included in az_trace in pymc.sampling_jax.sample_numpyro_nuts(). Similar things also happen in pymc.sampling_jax.sample_blackjax_nuts().

#5189 has done a great job to add log_likelihood, observed_data, and sample_stats to numpyro sampler.

So, we can add constant_data into the sample_numpyro_nuts function by creating a function find_constant_data(model), similar to find_observations(model) for observed_data.

In a couple of days, if no-one take this. I am happy to do this task :)

@ricardoV94
Copy link
Member

@danhphan is this one resolved?

@danhphan
Copy link
Member

Hi @ricardoV94 , yes, I think it was fixed on this PR #5807

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants