Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add constant_data into sample_numpyro_nuts() #5807

Merged
merged 3 commits into from
Jun 20, 2022

Conversation

danhphan
Copy link
Member

This PR add constant_data into pymc.sampling_jax.sample_numpyro_nuts(). It addresses the issue #5781

cc @ricardoV94 :)

Let me know if it needs to be changed. Thank you.

@codecov
Copy link

codecov bot commented May 26, 2022

Codecov Report

Merging #5807 (6f50e37) into main (7cc24bc) will decrease coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #5807      +/-   ##
==========================================
- Coverage   89.50%   89.50%   -0.01%     
==========================================
  Files          73       73              
  Lines       13276    13274       -2     
==========================================
- Hits        11883    11881       -2     
  Misses       1393     1393              
Impacted Files Coverage Δ
pymc/backends/arviz.py 90.61% <100.00%> (+0.32%) ⬆️
pymc/sampling_jax.py 96.95% <100.00%> (ø)
pymc/step_methods/hmc/base_hmc.py 89.68% <0.00%> (-0.80%) ⬇️

@danhphan
Copy link
Member Author

It seems that blackjax failed when running on gpu.

On my local machine, when run tests with jax flag on cpu as follow, all the tests in test_sampling_jax.py passed:
JAX_PLATFORM_NAME=cpu pytest pymc/tests/test_sampling_jax.py

However, when run with gpu flag, numpyro worked well, while blackjax failed:

JAX_PLATFORM_NAME=gpu pytest pymc/tests/test_sampling_jax.py
Outputs:

========================================================================= short test summary info =========================================================================
FAILED pymc/tests/test_sampling_jax.py::test_transform_samples[None-sample_blackjax_nuts] - ValueError: compiling computation that requires 2 logical devices, but only ...
FAILED pymc/tests/test_sampling_jax.py::test_transform_samples[cpu-sample_blackjax_nuts] - ValueError: compiling computation that requires 2 logical devices, but only 1...
FAILED pymc/tests/test_sampling_jax.py::test_deterministic_samples[sample_blackjax_nuts] - ValueError: compiling computation that requires 2 logical devices, but only 1...
FAILED pymc/tests/test_sampling_jax.py::test_seeding[2-None-sample_blackjax_nuts] - ValueError: compiling computation that requires 2 logical devices, but only 1 XLA de...
FAILED pymc/tests/test_sampling_jax.py::test_seeding[2-123-sample_blackjax_nuts] - ValueError: compiling computation that requires 2 logical devices, but only 1 XLA dev...
========================================================== 5 failed, 22 passed, 12 warnings in 74.31s (0:01:14) ===========================================================

It seems a down stream error of blackjax?
ValueError: compiling computation that requires 2 logical devices, but only 1 XLA devices are available (num_replicas=2, num_partitions=1)

@danhphan
Copy link
Member Author

Also, any ideas why the pre-commit failed at mypy? I'm still not able to figure it out though :) Thanks

pymc/backends/arviz.py Outdated Show resolved Hide resolved
@ricardoV94
Copy link
Member

If you rebase from main, the failing blackjax tests should go away

@danhphan
Copy link
Member Author

Hi yes, agree. I will refactor the codes.

@danhphan danhphan force-pushed the sampling_jax_constant_data branch from a1fab39 to 034ca36 Compare May 29, 2022 04:30
@danhphan
Copy link
Member Author

Hopefully it is alright now :)

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Left some suggestions to simplify / tidy up the original functions

pymc/backends/arviz.py Outdated Show resolved Hide resolved
pymc/backends/arviz.py Outdated Show resolved Hide resolved
pymc/backends/arviz.py Outdated Show resolved Hide resolved
pymc/backends/arviz.py Outdated Show resolved Hide resolved
@danhphan danhphan force-pushed the sampling_jax_constant_data branch from 034ca36 to 7821d13 Compare May 31, 2022 03:33
@danhphan
Copy link
Member Author

Hi, it has been updated :)

@danhphan
Copy link
Member Author

danhphan commented Jun 4, 2022

Hi @ricardoV94 I think it is alright for this pr to be merged :) Thanks

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, but we need to test that constant_data is showing up in test_sampling_jax.py, probably in the same test that already checks for other InferenceData groups

"""If there are observations available, return them as a dictionary."""
if model is None:
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need the model is none cases. Whomever calls this function should always pass a model. There's no good reason why they would call the function with None

def find_constants(model: Optional["Model"]) -> Dict[str, Var]:
"""If there are constants available, return them as a dictionary."""
# The constant data vars must be either pm.Data or TensorConstant or SharedVariable
if model is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I think it better than the previous codes now. Why we did not remove it in previous review here? But it could be good to have some defensive codes here, since we not 100% sure people always call it with a model. Any way, this is an easy fix, and just a minor thing.

Copy link
Member

@ricardoV94 ricardoV94 Jun 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we did not remove it in previous review here?

We just missed it then.

But it could be good to have some defensive codes here, since we not 100% sure people always call it with a model.

I don't see why you would call this function if not with a model. It just doesn't make sense. And if someone does it that's probably a mistake and the best thing is to let it naturally crash.

Adding special cases like this can make it more difficult to spot errors instead of helping.

@danhphan
Copy link
Member Author

danhphan commented Jun 4, 2022

Looks good, but we need to test that constant_data is showing up in test_sampling_jax.py, probably in the same test that already checks for other InferenceData groups

Hi, I cannot find where these codes though? Can you point it out?

I have already run these following codes and check to make sure the constant_data is shown in idata_jax.

import numpy as np
import pymc as pm
import aesara as ae
import aesara.tensor as at
from pymc.sampling_jax import sample_numpyro_nuts

coords = {
    "obs_id": [0, 1, 2, 3, 4],
}
with pm.Model(coords=coords) as rugby_model:
    item_idx = pm.Data("item_idx", [0, 1, 2, 3, 4], dims="obs_id", mutable=False)
    b = ae.shared(0.1)
    obs = np.random.normal(10, 2, size=100)
    c = ae.shared(obs, borrow=True, name="obs")
    a = pm.Normal("a", 0.0, sigma=10.0, shape=5)

    theta = a[item_idx]
    sigma = pm.HalfCauchy("error", 0.5)

    y = pm.Normal("y", theta, sigma=sigma, observed=[3, 2, 6, 8, 4])

    idata = pm.sample()
    idata_jax = sample_numpyro_nuts(tune=1000, chains=4, target_accept=0.9)

@michaelosthege
Copy link
Member

Hi, I cannot find where these codes though? Can you point it out?

I don't see an existing test that checks pm.Data containers with jax sampling.
Best create a new one right below this one:

def test_idata_kwargs(sampler, idata_kwargs, postprocessing_backend):

But please make sure to set something like chain=1, tune=10, draws=20 so the test runs fast.

@michaelosthege michaelosthege added this to the v4.0.1 milestone Jun 11, 2022
@danhphan
Copy link
Member Author

Yes, I will add the test for both observations and constants, as well as remove the checking model is None soon.

Just need some more time to do that :) Thanks

@michaelosthege
Copy link
Member

@danhphan let us know if someone should take over the rest here. This is one of currently two open items in the 4.0.1 milestone..

@michaelosthege michaelosthege force-pushed the sampling_jax_constant_data branch from 7821d13 to 6f50e37 Compare June 20, 2022 10:58
@danhphan
Copy link
Member Author

Hi @michaelosthege Sorry I'm quite busy with other more urgent stuffs right now, so please feel free to go ahead with other steps. I will not able to do it in early this week. If no-one updates this, I will try to spend sometime this weekend to complete it. Thanks you.

@ricardoV94 ricardoV94 merged commit 403f2d5 into pymc-devs:main Jun 20, 2022
@ricardoV94
Copy link
Member

Thanks @danhphan and @michaelosthege!

@danhphan danhphan deleted the sampling_jax_constant_data branch August 31, 2022 10:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants