Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: batched sampling and log prob methods. #1153

Merged
merged 76 commits into from
Jun 18, 2024
Merged

Conversation

michaeldeistler
Copy link
Contributor

@michaeldeistler michaeldeistler commented May 3, 2024

Edits by @janfb :
add methods to sample and evaluate given many observations, potentially vectorized.

It uses the capabilities of ConditionalDensityEstimators to broadcast sample and evaluate for multiple conditions.

not implementing it for MCMC for now, moved to #1176

@manuelgloeckler
Copy link
Contributor

This currently won't work due to #1154.

@manuelgloeckler manuelgloeckler added the blocked Something is in the way of fixing this. Refer to it in the issue label May 7, 2024
@manuelgloeckler manuelgloeckler removed the blocked Something is in the way of fixing this. Refer to it in the issue label May 7, 2024
Copy link

codecov bot commented May 7, 2024

Codecov Report

Attention: Patch coverage is 75.94937% with 19 lines in your changes missing coverage. Please review.

Project coverage is 72.95%. Comparing base (3f722e3) to head (26444f7).
Report is 6 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1153      +/-   ##
==========================================
+ Coverage   72.91%   72.95%   +0.04%     
==========================================
  Files          93       93              
  Lines        7394     7459      +65     
==========================================
+ Hits         5391     5442      +51     
- Misses       2003     2017      +14     
Flag Coverage Δ
unittests 72.95% <75.94%> (+0.04%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
sbi/inference/posteriors/direct_posterior.py 98.78% <100.00%> (+0.41%) ⬆️
sbi/inference/posteriors/mcmc_posterior.py 85.23% <100.00%> (+0.21%) ⬆️
sbi/neural_nets/density_estimators/nflows_flow.py 63.46% <ø> (ø)
sbi/inference/posteriors/base_posterior.py 86.04% <66.66%> (+1.70%) ⬆️
sbi/inference/posteriors/importance_posterior.py 55.38% <50.00%> (-0.18%) ⬇️
sbi/inference/posteriors/rejection_posterior.py 80.95% <50.00%> (-1.55%) ⬇️
sbi/inference/posteriors/vi_posterior.py 80.86% <50.00%> (-0.39%) ⬇️
sbi/samplers/rejection/rejection.py 88.00% <92.85%> (+0.94%) ⬆️
sbi/inference/snpe/snpe_a.py 62.60% <0.00%> (-0.95%) ⬇️
sbi/inference/posteriors/ensemble_posterior.py 50.00% <11.11%> (-3.54%) ⬇️

... and 5 files with indirect coverage changes

@manuelgloeckler
Copy link
Contributor

Okey SNPE stuff works now.

MCMC stuff still needs to be done and tested.

@janfb
Copy link
Contributor

janfb commented Jun 13, 2024

ok, merging main into amortizedsample fixed the conflicts and removed all the tracked changes from the other PRs. 👍

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great effort! 👏
I added a couple of comments for clarification.

sbi/inference/posteriors/direct_posterior.py Outdated Show resolved Hide resolved
sbi/inference/posteriors/importance_posterior.py Outdated Show resolved Hide resolved
sbi/inference/posteriors/mcmc_posterior.py Show resolved Hide resolved
sbi/inference/posteriors/mcmc_posterior.py Outdated Show resolved Hide resolved
sbi/samplers/rejection/rejection.py Show resolved Hide resolved
tests/density_estimator_test.py Outdated Show resolved Hide resolved
tests/density_estimator_test.py Outdated Show resolved Hide resolved
tests/posterior_nn_test.py Outdated Show resolved Hide resolved
tests/posterior_nn_test.py Outdated Show resolved Hide resolved
@janfb janfb changed the title feat: sample method for batched conditions feat: batched sampling and log prob methods. Jun 13, 2024
@manuelgloeckler
Copy link
Contributor

Oh wow thanks for fixing this mess :D. I am not sure what went wrong on the rebase (but probably merging main in should be preferred if there were many commits into main; I will keep that in mind).

@manuelgloeckler manuelgloeckler requested review from gmoss13 and janfb June 18, 2024 08:55
Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me now! 🎉

sbi/inference/posteriors/mcmc_posterior.py Show resolved Hide resolved
tests/density_estimator_test.py Show resolved Hide resolved
Comment on lines +91 to +94
pytest.param(SNLE_A, marks=pytest.mark.xfail(raises=NotImplementedError)),
pytest.param(SNRE_A, marks=pytest.mark.xfail(raises=NotImplementedError)),
pytest.param(SNRE_B, marks=pytest.mark.xfail(raises=NotImplementedError)),
pytest.param(SNRE_C, marks=pytest.mark.xfail(raises=NotImplementedError)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy with this test - but not sure if it belongs under posterior_nn_test -> instead should it be under mcmc_test?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this test is fine here (just checks that all posterior produce the same shapes).

I think in MCMC_test we should properly test batched MCMC sampling works as expected (as the current tests do not catch that the MCMC implementation did produce wrong results) i.e. should be part of #1176

@janfb
Copy link
Contributor

janfb commented Jun 18, 2024

there seems to be a failing tests during CD:

pytest tests/linearGaussian_snpe_test.py::test_c2st_multi_round_snpe_on_linearGaussian[snpe_a]

@manuelgloeckler
Copy link
Contributor

@janfb I have already responded to the reshape (but you can only see it in code review ???, and I cannot answer the new comment you made).

In short, This reshape was not introduced in this PR. It was already there; I just moved it a line upwards and added a note (currently, MCMC will only work with a single event_dim anyway).

Copy link
Contributor

@gmoss13 gmoss13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @manuelgloeckler, thanks!

Comment on lines +91 to +94
pytest.param(SNLE_A, marks=pytest.mark.xfail(raises=NotImplementedError)),
pytest.param(SNRE_A, marks=pytest.mark.xfail(raises=NotImplementedError)),
pytest.param(SNRE_B, marks=pytest.mark.xfail(raises=NotImplementedError)),
pytest.param(SNRE_C, marks=pytest.mark.xfail(raises=NotImplementedError)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy with this test - but not sure if it belongs under posterior_nn_test -> instead should it be under mcmc_test?

else:
samples = posterior.sample_batched((10,), ones(x_o_batch_dim, num_dim))

assert samples.shape == (10, x_o_batch_dim, num_dim), "Sample shape wrong"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

samples = density_estimator.sample((1000,), condition=condition)
samples = density_estimator.sample(sample_shape, condition=condition)
samples = samples.reshape(-1, batch_dim, *input_event_shape) # Flat for comp.

samples_separate1 = density_estimator.sample(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! I would suggest in a future issue/PR to test this behavior for all batched sample/batched log_prob tests, as in some of them we currently only test the shape.

)
samples = []
for posterior_index, sample_size in torch.vstack(
posterior_indizes.unique(return_counts=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indices typo

sbi/inference/posteriors/mcmc_posterior.py Show resolved Hide resolved
@manuelgloeckler
Copy link
Contributor

Okay, the SNPE_A test fails because the SNPE_A posterior switches the "batching" behavior in different rounds

  • After first round > Sample from the proposal are of shape (500, 1, 2), which is correct as x_o has batch_dim 1
  • After second round > Sample from the proposal are of shape (1,500,2) which is incorrec.

This was because in the first round it sample from the DensityEstimator ( which followed the new convention). And after the first round it sampled using a custom implementation ( which followed the old convention).

@manuelgloeckler
Copy link
Contributor

manuelgloeckler commented Jun 18, 2024

Main is currently failing at tests due to:

max_err = np.max(error)
>       assert max_err < 0.0027
E       assert 0.0033298< 0.0027

This seems like a random fluctation (due to some change).

Should I increase the tolerance for a bit to fix this? [Edit: I increase tolerance to 0.005, just for the record]

@janfb
Copy link
Contributor

janfb commented Jun 18, 2024

Main is currently failing at tests due to:

max_err = np.max(error)
>       assert max_err < 0.0027
E       assert 0.0033298< 0.0027

This seems like a random fluctation (due to some change).

Should I increase the tolerance for a bit to fix this? [Edit: I increase tolerance to 0.005, just for the record]

No, please keep the tolerance as it was. This will be fixed in #1177

This reverts commit 2aac705.
@manuelgloeckler
Copy link
Contributor

Alright, I will merge after my local GPU/Slow tests are finished.

@janfb
Copy link
Contributor

janfb commented Jun 18, 2024

Alright, I will merge after my local GPU/Slow tests are finished.

Cool!

@manuelgloeckler manuelgloeckler merged commit 4951439 into main Jun 18, 2024
6 checks passed
@janfb janfb deleted the amortizedsample branch June 20, 2024 10:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants