-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rejection sampling for likelihood-based methods #487
Conversation
Codecov Report
@@ Coverage Diff @@
## main #487 +/- ##
==========================================
- Coverage 69.05% 67.71% -1.34%
==========================================
Files 55 55
Lines 3832 3968 +136
==========================================
+ Hits 2646 2687 +41
- Misses 1186 1281 +95
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
Thanks for tackling this! One quick comment regarding the first point:
Could we keep backwards compatibility, i.e., issue a depreciation warning before switching out the interface completely? |
2899b0b
to
768d830
Compare
11a0c1f
to
10ac215
Compare
595c84e
to
2d1484f
Compare
e84fad6
to
f14b6a8
Compare
@janfb This is ready for review now. I will be testing it more extensively in my own project now -- there might still be small bugs which I will fix in the process of doing that. However, the main code is now in place. Eventually, I decided to keep two functions:
If |
6d954c9
to
ac2d434
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall great work!
@@ -50,8 +55,10 @@ def __init__( | |||
neural_net: nn.Module, | |||
prior, | |||
x_shape: torch.Size, | |||
sample_with: str, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
couldn't we get rid of mcmc_method
by using sample_with="slice"
etc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point, it would definitely work. My only concern is that inexperienced users might prefer setting sample_with='mcmc'
(they might not know what slice sampling is). I think I'd like to keep it as it is for now, just to show explicitly that e.g. slice
is a mcmc
method, not a rejection sampling
method.
I think in the long run, we might want something like this:
from sbi.samplers import MCMC, Rejection, VariationalInference
inference = SNLE(prior)
_ = inference.append_simulations(theta, x).train()
sampler = MCMC(inference, method="slice_np") # can only sample().
posterior = inference.build_posterior(sampler) # can sample(), log_prob(), map()
But this would of course be a huge refactor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK I agree, let's leave like this for now.
As an intermediate step in a future PR we could also use mcmc_slice
, mcmc_hmc
etc, but this could be confusing as well.
@@ -60,6 +67,8 @@ def __init__( | |||
neural_net: A classifier for SNRE, a density estimator for SNPE and SNL. | |||
prior: Prior distribution with `.log_prob()` and `.sample()`. | |||
x_shape: Shape of the simulator data. | |||
sample_with: Method to use for sampling from the posterior. Must be one of | |||
[`mcmc` | `rejection`]. | |||
mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wouldn't it make sense to change the default to slice_np_vectorized
? it is tested and way faster, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, we should at least advertise slice_np_vectorized
(it's never mentioned in docstrings). But I'd prefer keeping this for another PR.
2648cd5
to
becb512
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good to go now, thanks!
becb512
to
8806832
Compare
API change
Old:
New:
Code structure changes
We already had a method that did something similar to rejection sampling:
sample_posterior_within_prior
. I renamed this method torejection_sample_posterior_within_prior()
. In addition, a very similar methodrejection_sample()
now exists. It implements standard rejection sampling. In SNPE, ifrejection_sampling_parameters={"proposal"=...}
is not specified, we userejection_sample_posterior_within_prior()
. If a proposal is specified, we userejection_sample()
. In SNLE and SNRE, we always userejection_sample()
with the prior as proposal.All three methods (SNLE, SNPE, SNRE) use
rejection_sample
ifsample_with='rejection'
.SNLE
andSNRE
use the prior as proposal.SNPE
uses the neural net as proposal.All changes
sample_with_mcmc: bool
in favor ofsample_with: str
. This parameter accepts [mcmc
|rejection
] but it paves the way for further sampling methods. Default ismcmc
rejection_sample()
insbi/mcmc/rejection_sampling.py
binary_rejection_criterion=True
and get rid ofignore_scaling