Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing nested sampling #1871

Merged
merged 7 commits into from
Oct 16, 2024
Merged

Conversation

renecotyfanboy
Copy link
Contributor

There are some import discrepancies between jaxns<2.6 and jaxns 2.6.*. This small fix seems to do the trick!

@fehiepsi
Copy link
Member

could you update jaxns version in setup and docs/requirements

"jaxns==2.4.8",

docs/requirements.txt Outdated Show resolved Hide resolved
@renecotyfanboy
Copy link
Contributor Author

@fehiepsi any idea on how to enforce double precision in the test base for this specific test ? I don't know how JAX and pytest work together and if it is possible to enable x64 for a single test

@fehiepsi
Copy link
Member

fehiepsi commented Oct 1, 2024

You can add a flag like this one

JAX_ENABLE_X64=1 pytest -vs test/infer/test_mcmc.py -k x64

@renecotyfanboy renecotyfanboy force-pushed the fixing-nested-sampling branch 2 times, most recently from f5785e1 to 6c21533 Compare October 1, 2024 10:02
@renecotyfanboy renecotyfanboy force-pushed the fixing-nested-sampling branch from 6c21533 to 0c42f8e Compare October 1, 2024 12:44
@renecotyfanboy
Copy link
Contributor Author

Okay I hope it works now, I had a lot of trouble since latest versions of jaxns enforces double precision to summarize, I had to :

  • Makes a special case for pytest if double precision is not enabled to skip NS related tests
  • Add an extra invoke for dedicated jaxns tests
  • Change the expected result of a doctest in affine transform in the documentation, as the output is jnp.float64 instead of float32. All the other tests ran smoothly on my computer
  • Updated a bit the jaxns contributed code. To be fair, at some point I should take some time to update it properly with the new interfaces proposed by the package

examples/gaussian_shells.py Outdated Show resolved Hide resolved
test/contrib/test_nested_sampling.py Show resolved Hide resolved
@renecotyfanboy
Copy link
Contributor Author

Okay @fehiepsi this should work now. For you to know, most of the struggles came from the fact that importing jaxns automatically enable double precision. So top-level imports in pytests triggered the double precision even if not required. I don't import it if the double precision is disabled, and the related test suite is automatically skipped in this situation.

test/contrib/test_nested_sampling.py Show resolved Hide resolved
docs/requirements.txt Outdated Show resolved Hide resolved
.github/workflows/ci.yml Show resolved Hide resolved
examples/gaussian_shells.py Outdated Show resolved Hide resolved
numpyro/distributions/transforms.py Outdated Show resolved Hide resolved
@renecotyfanboy
Copy link
Contributor Author

I think it's ready for review

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woohoo, thanks @renecotyfanboy!

@fehiepsi fehiepsi merged commit d18bcfb into pyro-ppl:master Oct 16, 2024
4 checks passed
@renecotyfanboy renecotyfanboy deleted the fixing-nested-sampling branch October 16, 2024 18:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants