Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incompatibility of SMCABC log_prob with PyTorch 1.8 #504

Closed
atiyo opened this issue Jun 3, 2021 · 6 comments
Closed

Incompatibility of SMCABC log_prob with PyTorch 1.8 #504

atiyo opened this issue Jun 3, 2021 · 6 comments

Comments

@atiyo
Copy link

atiyo commented Jun 3, 2021

Hello!

I was toying around with SMABC via sbibm and came across an exception while trying to run:

import sbibm
from sbibm.algorithms import smc_abc
task = sbibm.get_task("two_moons")
posterior_samples, _, _ = smc_abc(
    task=task, num_samples=1_000, num_observation=1, num_simulations=10_000
)

which spits out the following (truncated to relevant bits):

~/.pyenv/versions/3.7.4/lib/python3.7/site-packages/sbi/inference/abc/smcabc.py in _sample_and_perturb
(self, particles, weights, num_samples)
    497             parms_perturbed = self.get_new_kernel(parms).sample()
    498
--> 499             is_within_prior = torch.isfinite(self.prior.log_prob(parms_perturbed))
    500             num_accepted += is_within_prior.sum().item()
    501

~/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/distributions/independent.py in log_prob(sel
f, value)
     89
     90     def log_prob(self, value):
---> 91         log_prob = self.base_dist.log_prob(value)
     92         return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)
     93

~/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/distributions/uniform.py in log_prob(self, v
alue)
     71     def log_prob(self, value):
     72         if self._validate_args:
---> 73             self._validate_sample(value)
     74         lb = self.low.le(value).type_as(self.low)
     75         ub = self.high.gt(value).type_as(self.low)

~/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/distributions/distribution.py in _validate_s
ample(self, value)
    275         assert support is not None
    276         if not support.check(value).all():
--> 277             raise ValueError('The value argument must be within the support')
    278
    279     def _get_checked_instance(self, cls, _instance=None):

ValueError: The value argument must be within the support

I believe this comes down to a line in the SMCABC checking whether perturbed parameters are within the prior:

https://github.com/mackelab/sbi/blob/340424ecaacc5c375a4818927f8a3e2742ceb979/sbi/inference/abc/smcabc.py#L499

The problem is that PyTorch 1.8 validates against the support of the distribution before taking a log_prob.

This might be addressed by changing the above line to

is_within_prior = self.prior.support.check(parms_perturbed)

However, this presupposes that the support.check() method exists, which might not be the case for user-defined priors. So a more conservative change might be:

try:
    is_within_prior = self.prior.support.check(parms_perturbed)
except AttributeError:
    is_within_prior = torch.isfinite(self.prior.log_prob(parms_perturbed)) 

I'll go ahead and raise the last change in a PR for the sake of taking a next step, but obviously I'm happy to defer to what the maintainers think is best!

@michaeldeistler
Copy link
Contributor

Hi there,

thanks for creating this issue! What is your version of sbi? If it is not the newest version, please upgrade both sbi and sbibm and see if the error persists.

Best
Michael

@atiyo
Copy link
Author

atiyo commented Jun 4, 2021

Thanks for the quick response!

Just upgraded to ensure that I'm on 0.16.0 for sbi and 1.0.6 for sbibm, which I believe are the latest versions, but the error still persists.

@janfb
Copy link
Contributor

janfb commented Jun 4, 2021

Hey! I think this is indeed a bug. The perturbation can move particles outside of the prior support, which cause the new PyTorch to raise a ValueError.
But I think the fix suggested in #505 will not work because of the same problem. I think we need to use within_support and I actually have the fix on a local branch - just forgot to push it.

@atiyo
Copy link
Author

atiyo commented Jun 4, 2021

Hey! I think this is indeed a bug. The perturbation can move particles outside of the prior support, which cause the new PyTorch to raise a ValueError.
But I think the fix suggested in #505 will not work because of the same problem. I think we need to use within_support and I actually have the fix on a local branch - just forgot to push it.

Ah cool! Thanks for the quick fix. I'll close #505 then :)

@janfb
Copy link
Contributor

janfb commented Jun 4, 2021

The fix is in main now. Is it working for you @atiyo?
The issue can be closed with release 0.16.1.

@atiyo
Copy link
Author

atiyo commented Jun 4, 2021

Seems to work now :) Thank you!

@janfb janfb closed this as completed Jul 15, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants