-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: batched mcmc reshaping #1210
Conversation
@gmoss13 I had to change one thing in your reshape-permute magic to make it work for |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1210 +/- ##
==========================================
- Coverage 84.55% 75.86% -8.70%
==========================================
Files 96 97 +1
Lines 7603 7682 +79
==========================================
- Hits 6429 5828 -601
- Misses 1174 1854 +680
Flags with carried forward coverage won't be shown. Click here to find out more.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @janfb, I think these changes fix the issues we discussed! I added some suggestions, happy to discuss further (also happy to implement them myself if you agree).
62c24db
to
295ff42
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Jan! Added a couple of very minor comments.
295ff42
to
7e76605
Compare
when the
num_chains
was higher or not a multiple of the number of samples to be generated with MCMCsample_batched
, the samples were reshaped incorrectly: The remaining samples were pushed into the last dimension here:sbi/sbi/inference/posteriors/mcmc_posterior.py
Line 467 in b275448
This is a fix where we first collect all the generated samples from the chains (possibly more than needed), then select as many as we need and then reshape into the desired
(*sample_shape, batch_size, input_shape)
.I also added tests that cover all the cases, which makes them quite slow.