Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pyro v1.4.0 compatibility #286

Closed
AlexSauer opened this issue Aug 3, 2020 · 9 comments
Closed

Pyro v1.4.0 compatibility #286

AlexSauer opened this issue Aug 3, 2020 · 9 comments
Assignees
Labels
bug Something isn't working

Comments

@AlexSauer
Copy link

AlexSauer commented Aug 3, 2020

Hi,

I was experimenting with SNRE and encountered the following problem:

num_dim = 3
prior = utils.BoxUniform(low=-2*torch.ones(num_dim), high=2*torch.ones(num_dim))

def simulator(parameter_set):
    return 1.0 + parameter_set + torch.randn(parameter_set.shape) * 0.1

posterior = infer(simulator, prior, method='SNRE', num_simulations=1000)

observation = torch.zeros(3)
samples = posterior.sample((100,), x=observation, mcmc_method = 'hmc')

This throws an error in the method pyro_potential of snre_base which seems to happen in the very last iteration. The same holds true for mcmc_method = 'nuts'.

@michaeldeistler
Copy link
Contributor

michaeldeistler commented Aug 3, 2020

Hi Alex,

thanks for reaching out! I guess there are two potential reasons for your problem:

  1. In our latest release, we fixed a bug in SNRE. Can you see if
import sbi
sbi.__version__

gives 0.11.1? If your version is older, please update sbi and see if the error persists.

  1. Your likelihood is quite narrow (standard deviation only 0.1). This gives a very narrow posterior, which can lead to problems in the MCMC step needed in SNLE and SNRE. This problem can be alleviated by using a different initialization strategy for the mcmc chain (called Sequential-Importance-Resampling, SIR):
samples = posterior.sample((100,), x=observation, mcmc_method = 'hmc', mcmc_parameters={'init_strategy': 'sir'})

With this change and using v0.11.1, your code finishes successfully for me.

Let me know if this worked :)

Michael

@AlexSauer
Copy link
Author

AlexSauer commented Aug 3, 2020

Thanks for the fast reply!
However, unfortunately, it doesn't fix the problem...
So, I am running this in a notebook in Google Colab ( https://colab.research.google.com/drive/1TmzrnRSg2XxO1XbiFik0XEyzhTEiWuOT?usp=sharing )

I first install your package using

%pip install sbi

and then execute

import torch
import sbi
import sbi.utils as utils
from sbi.inference.base import infer

print('Version: ', sbi.__version__)

num_dim = 3
prior = utils.BoxUniform(low=-2*torch.ones(num_dim), high=2*torch.ones(num_dim))
def simulator(parameter_set):
    return 1.0 + parameter_set + torch.randn(parameter_set.shape) * 0.1

posterior = infer(simulator, prior, method='SNRE', num_simulations=1000)
observation = torch.zeros(3)
samples = posterior.sample((100,), x=observation, mcmc_method = 'hmc', mcmc_parameters={'init_strategy': 'sir'})

Which gives the following output and stack of error messages. It looks to me more like it tries to call the method one more time even the sampling is already finished?

Version:  0.11.1
Running 1000 simulations.: 100%
1000/1000 [00:06<00:00, 150.79it/s]

Neural network successfully converged after 58 epochs.
Sample: 100%|██████████| 1021/1021 [00:14, 70.64it/s, step size=8.83e-01, acc. prob=0.567] 
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-8-3a4e9afaaa1c> in <module>()
     13 posterior = infer(simulator, prior, method='SNRE', num_simulations=1000)
     14 observation = torch.zeros(3)
---> 15 samples = posterior.sample((100,), x=observation, mcmc_method = 'hmc', mcmc_parameters={'init_strategy': 'sir'})

6 frames
/usr/local/lib/python3.6/dist-packages/sbi/inference/posterior.py in sample(self, sample_shape, x, show_progress_bars, sample_with_mcmc, mcmc_method, mcmc_parameters)
    449                 show_progress_bars=show_progress_bars,
    450                 mcmc_method=mcmc_method,
--> 451                 **mcmc_parameters,
    452             )
    453         elif self._method_family == "snpe":

/usr/local/lib/python3.6/dist-packages/sbi/inference/posterior.py in _sample_posterior_mcmc(self, num_samples, x, mcmc_method, thin, warmup_steps, num_chains, init_strategy, init_strategy_num_candidates, show_progress_bars)
    560                     warmup_steps=warmup_steps,
    561                     num_chains=num_chains,
--> 562                     show_progress_bars=show_progress_bars,
    563                 ).detach()
    564             else:

/usr/local/lib/python3.6/dist-packages/sbi/inference/posterior.py in _pyro_mcmc(self, num_samples, potential_function, mcmc_method, thin, warmup_steps, num_chains, show_progress_bars)
    657             disable_progbar=not show_progress_bars,
    658         )
--> 659         sampler.run()
    660         samples = next(iter(sampler.get_samples().values())).reshape(
    661             -1, len(self._prior.mean)  # len(prior.mean) = dim of theta

/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      9 def _context_wrap(context, fn, *args, **kwargs):
     10     with context:
---> 11         return fn(*args, **kwargs)
     12 
     13 

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
    412             if getattr(self.kernel, "transforms", None) is None:
    413                 warmup_steps = 0
--> 414                 self.kernel.setup(warmup_steps, *args, **kwargs)
    415             # Use `kernel.transforms` when available
    416             if getattr(self.kernel, "transforms", None) is not None:

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/hmc.py in setup(self, warmup_steps, *args, **kwargs)
    306             z_grads, potential_energy = potential_grad(self.potential_fn, z)
    307         else:
--> 308             z_grads, potential_energy = {}, self.potential_fn(self.initial_params)
    309         self._cache(self.initial_params, potential_energy, z_grads)
    310         if self.initial_params:

/usr/local/lib/python3.6/dist-packages/sbi/inference/snre/snre_base.py in pyro_potential(self, theta)
    416         """
    417 
--> 418         theta = next(iter(theta.values()))
    419 
    420         # Theta and x should have shape (1, dim).

AttributeError: 'NoneType' object has no attribute 'values'

Not sure why it finishes for you but not here (and I just tried it on my own laptop and I got the same error message)...
(I was actually experimenting with a different prior and simulator when I encountered the problem and then just used your problem form the tutorial to see if the error persists)
Thanks for your help! :)

@michaeldeistler
Copy link
Contributor

It seems that this is because of a recent update in pyro to v1.4.0. I do not yet know what exactly breaks our code, but a quick fix for you might be to

pip install pyro-ppl==1.3.1

Let me know this worked :)

Michael

@AlexSauer
Copy link
Author

Yes, great!
Thank you again! :)

@michaeldeistler
Copy link
Contributor

Glad to hear that! :)

I'll leave this issue open until we have found a more stable fix for this.

@jan-matthis jan-matthis changed the title Sampling for SNRE and MCMC using Pyro Pyro v1.4.0 compatibility Aug 4, 2020
@jan-matthis jan-matthis added the bug Something isn't working label Aug 4, 2020
@jan-matthis
Copy link
Contributor

f7444cc pins pyro to v1.3.1 as a temporary solution

@michaeldeistler michaeldeistler linked a pull request Aug 4, 2020 that will close this issue
@michaeldeistler
Copy link
Contributor

I would rather not close until we got rid of transforms={}

@jan-matthis
Copy link
Contributor

Sure, we can keep it open

For context: #339 introduced transforms={} as a workaround for compatibility with 1.4.0. We might take it out at a later point (at which point this issue will be closed)

@jan-matthis
Copy link
Contributor

Pyro 1.5.0 including the bugfix was released. We decided to keep the workaround in place nevertheless, in order not to have to change the pyro requirement in setup.py and force everyone to upgrade.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants