-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: batched sampling and log prob methods. #1153
Conversation
This currently won't work due to #1154. |
…rs' into amortizedsample
…from-different-posteriors' into amortizedsample
… reshapes in rejection
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1153 +/- ##
==========================================
+ Coverage 72.91% 72.95% +0.04%
==========================================
Files 93 93
Lines 7394 7459 +65
==========================================
+ Hits 5391 5442 +51
- Misses 2003 2017 +14
Flags with carried forward coverage won't be shown. Click here to find out more.
|
Okey SNPE stuff works now. MCMC stuff still needs to be done and tested. |
ok, merging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great effort! 👏
I added a couple of comments for clarification.
Oh wow thanks for fixing this mess :D. I am not sure what went wrong on the rebase (but probably merging main in should be preferred if there were many commits into main; I will keep that in mind). |
Co-authored-by: Jan <janfb@users.noreply.github.com>
Co-authored-by: Jan <janfb@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me now! 🎉
pytest.param(SNLE_A, marks=pytest.mark.xfail(raises=NotImplementedError)), | ||
pytest.param(SNRE_A, marks=pytest.mark.xfail(raises=NotImplementedError)), | ||
pytest.param(SNRE_B, marks=pytest.mark.xfail(raises=NotImplementedError)), | ||
pytest.param(SNRE_C, marks=pytest.mark.xfail(raises=NotImplementedError)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy with this test - but not sure if it belongs under posterior_nn_test
-> instead should it be under mcmc_test
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this test is fine here (just checks that all posterior produce the same shapes).
I think in MCMC_test we should properly test batched MCMC sampling works as expected (as the current tests do not catch that the MCMC implementation did produce wrong results) i.e. should be part of #1176
there seems to be a failing tests during CD: pytest tests/linearGaussian_snpe_test.py::test_c2st_multi_round_snpe_on_linearGaussian[snpe_a] |
@janfb I have already responded to the reshape (but you can only see it in code review ???, and I cannot answer the new comment you made). In short, This reshape was not introduced in this PR. It was already there; I just moved it a line upwards and added a note (currently, MCMC will only work with a single event_dim anyway). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @manuelgloeckler, thanks!
pytest.param(SNLE_A, marks=pytest.mark.xfail(raises=NotImplementedError)), | ||
pytest.param(SNRE_A, marks=pytest.mark.xfail(raises=NotImplementedError)), | ||
pytest.param(SNRE_B, marks=pytest.mark.xfail(raises=NotImplementedError)), | ||
pytest.param(SNRE_C, marks=pytest.mark.xfail(raises=NotImplementedError)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy with this test - but not sure if it belongs under posterior_nn_test
-> instead should it be under mcmc_test
?
else: | ||
samples = posterior.sample_batched((10,), ones(x_o_batch_dim, num_dim)) | ||
|
||
assert samples.shape == (10, x_o_batch_dim, num_dim), "Sample shape wrong" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
samples = density_estimator.sample((1000,), condition=condition) | ||
samples = density_estimator.sample(sample_shape, condition=condition) | ||
samples = samples.reshape(-1, batch_dim, *input_event_shape) # Flat for comp. | ||
|
||
samples_separate1 = density_estimator.sample( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! I would suggest in a future issue/PR to test this behavior for all batched sample/batched log_prob tests, as in some of them we currently only test the shape.
) | ||
samples = [] | ||
for posterior_index, sample_size in torch.vstack( | ||
posterior_indizes.unique(return_counts=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indices
typo
Okay, the SNPE_A test fails because the SNPE_A posterior switches the "batching" behavior in different rounds
This was because in the first round it sample from the DensityEstimator ( which followed the new convention). And after the first round it sampled using a custom implementation ( which followed the old convention). |
Main is currently failing at tests due to: max_err = np.max(error)
> assert max_err < 0.0027
E assert 0.0033298< 0.0027 This seems like a random fluctation (due to some change). Should I increase the tolerance for a bit to fix this? [Edit: I increase tolerance to 0.005, just for the record] |
No, please keep the tolerance as it was. This will be fixed in #1177 |
This reverts commit 2aac705.
Alright, I will merge after my local GPU/Slow tests are finished. |
Cool! |
Edits by @janfb :
add methods to sample and evaluate given many observations, potentially vectorized.
It uses the capabilities of
ConditionalDensityEstimators
to broadcast sample and evaluate for multiple conditions.not implementing it for MCMC for now, moved to #1176