Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rejection sampling for likelihood-based methods #487

Merged
merged 8 commits into from
May 19, 2021

Conversation

michaeldeistler
Copy link
Contributor

@michaeldeistler michaeldeistler commented May 11, 2021

API change

Old:

posterior = build_posterior(sample_with_mcmc=True)

New:

posterior = build_posterior(sample_with="mcmc")  # or "rejection"

Code structure changes

We already had a method that did something similar to rejection sampling: sample_posterior_within_prior. I renamed this method to rejection_sample_posterior_within_prior(). In addition, a very similar method rejection_sample() now exists. It implements standard rejection sampling. In SNPE, if rejection_sampling_parameters={"proposal"=...} is not specified, we use rejection_sample_posterior_within_prior(). If a proposal is specified, we use rejection_sample(). In SNLE and SNRE, we always use rejection_sample() with the prior as proposal.

All three methods (SNLE, SNPE, SNRE) use rejection_sample if sample_with='rejection'. SNLE and SNRE use the prior as proposal. SNPE uses the neural net as proposal.

All changes

  • remove sample_with_mcmc: bool in favor of sample_with: str. This parameter accepts [mcmc | rejection] but it paves the way for further sampling methods. Default is mcmc
  • New method rejection_sample() in sbi/mcmc/rejection_sampling.py
  • Use it also for SNPE and SNRE
  • deal with the entire infrastructure (e.g. store as hyperparameters etc...)
  • Test all results
  • Make everything backward-compatible
  • Use rejection sampling in some of our tests? (just because it's super fast)
  • avoid three flow-passes for SNPE via binary_rejection_criterion=True and get rid of ignore_scaling
  • gpu tests

@michaeldeistler michaeldeistler marked this pull request as draft May 11, 2021 16:29
@codecov-commenter
Copy link

codecov-commenter commented May 11, 2021

Codecov Report

Merging #487 (ac31fae) into main (0ba6c31) will decrease coverage by 1.33%.
The diff coverage is 58.36%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #487      +/-   ##
==========================================
- Coverage   69.05%   67.71%   -1.34%     
==========================================
  Files          55       55              
  Lines        3832     3968     +136     
==========================================
+ Hits         2646     2687      +41     
- Misses       1186     1281      +95     
Flag Coverage Δ
unittests 67.71% <58.36%> (-1.34%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
sbi/inference/snle/snle_base.py 97.59% <ø> (ø)
sbi/inference/snpe/snpe_a.py 66.35% <ø> (ø)
sbi/inference/snre/snre_base.py 96.80% <ø> (ø)
sbi/utils/__init__.py 100.00% <ø> (ø)
sbi/utils/sbiutils.py 71.20% <54.45%> (-11.51%) ⬇️
sbi/inference/posteriors/ratio_based_posterior.py 77.27% <55.00%> (-8.07%) ⬇️
...inference/posteriors/likelihood_based_posterior.py 72.97% <57.14%> (-9.00%) ⬇️
sbi/inference/posteriors/base_posterior.py 66.11% <60.65%> (-5.61%) ⬇️
sbi/inference/posteriors/direct_posterior.py 77.77% <63.15%> (-11.34%) ⬇️
sbi/inference/snpe/snpe_base.py 87.12% <100.00%> (+0.40%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 0ba6c31...ac31fae. Read the comment docs.

@jan-matthis
Copy link
Contributor

Thanks for tackling this!

One quick comment regarding the first point:

remove sample_with_mcmc: bool in favor of sample_with: str. This parameter accepts [mcmc | rejection] but it paves the way for further sampling methods. Default is mcmc

Could we keep backwards compatibility, i.e., issue a depreciation warning before switching out the interface completely?

@michaeldeistler michaeldeistler force-pushed the rejection-sampling branch 3 times, most recently from 2899b0b to 768d830 Compare May 17, 2021 08:54
@michaeldeistler michaeldeistler marked this pull request as ready for review May 18, 2021 08:40
@michaeldeistler michaeldeistler force-pushed the rejection-sampling branch 2 times, most recently from 11a0c1f to 10ac215 Compare May 18, 2021 10:03
@michaeldeistler michaeldeistler force-pushed the rejection-sampling branch 2 times, most recently from 595c84e to 2d1484f Compare May 19, 2021 06:32
@michaeldeistler michaeldeistler force-pushed the rejection-sampling branch 2 times, most recently from e84fad6 to f14b6a8 Compare May 19, 2021 07:01
@michaeldeistler michaeldeistler requested a review from janfb May 19, 2021 07:44
@michaeldeistler
Copy link
Contributor Author

michaeldeistler commented May 19, 2021

@janfb This is ready for review now. I will be testing it more extensively in my own project now -- there might still be small bugs which I will fix in the process of doing that. However, the main code is now in place. Eventually, I decided to keep two functions: rejection_sample_posterior_within_prior() and rejection_sample(). They are very related, but I think you were right: it created more code to keep them as a single function. Nonetheless, they underlie the same interface:

posterior = inference.build_posterior(sample_with="rejection")

If rejection_sampling_parameters={"proposal"=...} is not specified, we use rejection_sample_posterior_within_prior(). If a proposal is specified, we use rejection_sample().

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall great work!

@@ -50,8 +55,10 @@ def __init__(
neural_net: nn.Module,
prior,
x_shape: torch.Size,
sample_with: str,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couldn't we get rid of mcmc_method by using sample_with="slice" etc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point, it would definitely work. My only concern is that inexperienced users might prefer setting sample_with='mcmc' (they might not know what slice sampling is). I think I'd like to keep it as it is for now, just to show explicitly that e.g. slice is a mcmc method, not a rejection sampling method.

I think in the long run, we might want something like this:

from sbi.samplers import MCMC, Rejection, VariationalInference

inference = SNLE(prior)
_ = inference.append_simulations(theta, x).train()
sampler = MCMC(inference, method="slice_np")  # can only sample().
posterior = inference.build_posterior(sampler)  # can sample(), log_prob(), map()

But this would of course be a huge refactor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I agree, let's leave like this for now.
As an intermediate step in a future PR we could also use mcmc_slice, mcmc_hmc etc, but this could be confusing as well.

@@ -60,6 +67,8 @@ def __init__(
neural_net: A classifier for SNRE, a density estimator for SNPE and SNL.
prior: Prior distribution with `.log_prob()` and `.sample()`.
x_shape: Shape of the simulator data.
sample_with: Method to use for sampling from the posterior. Must be one of
[`mcmc` | `rejection`].
mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't it make sense to change the default to slice_np_vectorized? it is tested and way faster, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, we should at least advertise slice_np_vectorized (it's never mentioned in docstrings). But I'd prefer keeping this for another PR.

sbi/inference/posteriors/base_posterior.py Show resolved Hide resolved
sbi/inference/posteriors/base_posterior.py Outdated Show resolved Hide resolved
sbi/inference/posteriors/base_posterior.py Show resolved Hide resolved
sbi/utils/sbiutils.py Outdated Show resolved Hide resolved
sbi/utils/sbiutils.py Show resolved Hide resolved
sbi/utils/sbiutils.py Outdated Show resolved Hide resolved
sbi/utils/sbiutils.py Outdated Show resolved Hide resolved
sbi/utils/sbiutils.py Outdated Show resolved Hide resolved
Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good to go now, thanks!

@michaeldeistler michaeldeistler merged commit 3c3a79e into main May 19, 2021
@michaeldeistler michaeldeistler deleted the rejection-sampling branch May 19, 2021 15:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants