-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sample stats for blackjax nuts #6264
Sample stats for blackjax nuts #6264
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #6264 +/- ##
==========================================
+ Coverage 88.15% 94.31% +6.16%
==========================================
Files 111 111
Lines 23820 23855 +35
==========================================
+ Hits 20998 22499 +1501
+ Misses 2822 1356 -1466
|
Your interpretation of the mapping between the different naming is correct here. LGTM. |
Thank you @junpenglao . The failing test succeded locally. Will look into it asap |
The test failed because I had a newer version of blackjax installed, that stores |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good - you will need to rebase the change.
0af8b82
to
32f3da3
Compare
@TimOliverMaier can you add (and fill) the relevant bullet point from the PR template to your top comment? We use those for the auto-generated release notes: https://github.com/pymc-devs/pymc/blob/main/.github/PULL_REQUEST_TEMPLATE.md |
In the last commit I only moved the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this is ready to merge, so let's do it!
* tests for added sample statistics * sample stats test with more draws * redesigned test for older blackjax version * record blackjax sample stats * moved sample stats argument to partial call
What is this PR about?
With this PR I want to address #5718. As of now no stats are saved in InferenceData if one uses
sample_blackjax_nuts
.However I am unsure about the meaning of the attributes
energy
andnum_trajectory_expansions
.Here is what is documented in the blackjax code. I first understood
energy
to be the absolute energy of a state andnum_trajectory_expansions
to be thetree_depth
(That's the current state of the code). But second reading made me skeptical. Any help here is very welcome 😃How it works:
blackjax.mcmc.nuts.NUTSInfo
. The attributes of this instanceare converted to a dictionary
Dict[str,ArrayLike]
mapping the stats with pymc naming conventionsblackjax.mcmc.hmc.HMCState
object. The potential energy is extracted andstored in the above mentioned dictionary.
sample_stats
argument toarviz.from_dict
at the end ofpymc.sampling.jax.sample_blackjax_nuts
Checklist
Bugfixes / New features
pm.sample.jax.sample_blackjax_nuts
returnsInferenceData
withsample_stats
group