Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sample stats for blackjax nuts #6264

Merged
merged 5 commits into from
Nov 14, 2022

Conversation

TimOliverMaier
Copy link
Contributor

@TimOliverMaier TimOliverMaier commented Nov 3, 2022

What is this PR about?
With this PR I want to address #5718. As of now no stats are saved in InferenceData if one uses sample_blackjax_nuts.
However I am unsure about the meaning of the attributes energy and num_trajectory_expansions.
Here is what is documented in the blackjax code. I first understood energy to be the absolute energy of a state and num_trajectory_expansions to be the tree_depth (That's the current state of the code). But second reading made me skeptical. Any help here is very welcome 😃

How it works:

  • blackjax returns sample statistics as instance of blackjax.mcmc.nuts.NUTSInfo. The attributes of this instance
    are converted to a dictionary Dict[str,ArrayLike] mapping the stats with pymc naming conventions
  • potential energy is stored alongside states in a blackjax.mcmc.hmc.HMCState object. The potential energy is extracted and
    stored in the above mentioned dictionary.
  • the dictionary is passed as sample_stats argument to arviz.from_dict at the end of pymc.sampling.jax.sample_blackjax_nuts

Checklist

Bugfixes / New features

  • pm.sample.jax.sample_blackjax_nuts returns InferenceData with sample_stats group

@codecov
Copy link

codecov bot commented Nov 3, 2022

Codecov Report

Merging #6264 (5f02ad9) into main (b5db350) will increase coverage by 6.16%.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6264      +/-   ##
==========================================
+ Coverage   88.15%   94.31%   +6.16%     
==========================================
  Files         111      111              
  Lines       23820    23855      +35     
==========================================
+ Hits        20998    22499    +1501     
+ Misses       2822     1356    -1466     
Impacted Files Coverage Δ
pymc/sampling/jax.py 97.34% <100.00%> (+0.14%) ⬆️
pymc/tests/sampling/test_jax.py 100.00% <100.00%> (ø)
pymc/model.py 89.51% <0.00%> (+0.27%) ⬆️
pymc/distributions/shape_utils.py 97.85% <0.00%> (+0.42%) ⬆️
pymc/backends/arviz.py 90.08% <0.00%> (+2.89%) ⬆️
pymc/distributions/logprob.py 97.26% <0.00%> (+6.16%) ⬆️
pymc/step_methods/hmc/quadpotential.py 80.69% <0.00%> (+6.93%) ⬆️
pymc/model_graph.py 78.82% <0.00%> (+12.35%) ⬆️
pymc/variational/updates.py 92.11% <0.00%> (+54.18%) ⬆️
pymc/distributions/timeseries.py 94.45% <0.00%> (+65.99%) ⬆️
... and 4 more

@TimOliverMaier TimOliverMaier marked this pull request as ready for review November 3, 2022 16:50
@junpenglao
Copy link
Member

Your interpretation of the mapping between the different naming is correct here. LGTM.

@TimOliverMaier
Copy link
Contributor Author

Thank you @junpenglao . The failing test succeded locally. Will look into it asap

@TimOliverMaier
Copy link
Contributor Author

The test failed because I had a newer version of blackjax installed, that stores n_steps (called num_integration_steps there) . The test now no longer asserts the existence of a DataArray n_steps in sample_stats. The code however would store n_steps if a user has a newer version of blackjax installed.

Copy link
Member

@junpenglao junpenglao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good - you will need to rebase the change.

@ricardoV94
Copy link
Member

@TimOliverMaier can you add (and fill) the relevant bullet point from the PR template to your top comment? We use those for the auto-generated release notes: https://github.com/pymc-devs/pymc/blob/main/.github/PULL_REQUEST_TEMPLATE.md

@TimOliverMaier
Copy link
Contributor Author

TimOliverMaier commented Nov 8, 2022

In the last commit I only moved the sample_stats argument into the call of partial. That's where other groups are passed and that's how it is done in the numpyro nuts function.

@michaelosthege michaelosthege changed the title sample stats for blackjax nuts Sample stats for blackjax nuts Nov 14, 2022
Copy link
Member

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is ready to merge, so let's do it!

@michaelosthege michaelosthege merged commit 3f9d2e2 into pymc-devs:main Nov 14, 2022
wrongu pushed a commit to wrongu/pymc that referenced this pull request Dec 1, 2022
* tests for added sample statistics
* sample stats test with more draws
* redesigned test for older blackjax version
* record blackjax sample stats
* moved sample stats argument to partial call
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants