Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ArviZ integration #542

Closed
sethaxen opened this issue Sep 22, 2021 · 16 comments · Fixed by #607
Closed

Add ArviZ integration #542

sethaxen opened this issue Sep 22, 2021 · 16 comments · Fixed by #607
Labels
enhancement New feature or request hackathon

Comments

@sethaxen
Copy link
Contributor

sethaxen commented Sep 22, 2021

As suggested by @Meteore, the ArviZ package has a large number of MCMC diagnostics, statistics, and visualizations. See for example the gallery. It provides diagnostics/plots to PyMC3 but is PPL-agnostic.

It would be good to include here a converter to an arviz.InferenceData, which would automatically allow users to apply these diagnostics.

@sethaxen
Copy link
Contributor Author

Given that ArviZ already has a converter from Pyro, this would probably be really easy to do: https://arviz-devs.github.io/arviz/api/generated/arviz.from_pyro.html

@alvorithm alvorithm added the enhancement New feature or request label Sep 22, 2021
@michaeldeistler
Copy link
Contributor

I agree, it would be great to support ArviZ

@sethaxen
Copy link
Contributor Author

Okay! I'm happy to open a PR.

@michaeldeistler
Copy link
Contributor

Great, thank you!

@sethaxen
Copy link
Contributor Author

LikelihoodBasedPosterior.sample just returns the samples without any of the sample statistics (e.g. divergences and log probability). Several questions:

  1. Can MCMC be run with multiple chains in parallel?
  2. Is it possible to have the fitted pyro MCMC object returned to the user? This would give the user access to the sampling statistics (and potentially allow them to resume sampling, but I don't know if this is a Pyro-supported feature).
  3. If so, does sbi handle transformations of constrained parameters to an unconstrained space itself, or does it rely on pyro to do that? The latter would be in a sense more convenient, because then IIUC pyro's MCMC object would already return draws in the constrained space.

@michaeldeistler
Copy link
Contributor

michaeldeistler commented Sep 22, 2021

  1. Yes. mcmc_parameters={"num_chains": 10} will run multiple chains. If you use slice sampling (based on np), the chains can also use vectorization with mcmc_parameters={"num_chains": 10, mcmc_method="slice_np_vectorized"}
  2. That would in principle be possible. I would prefer not returning it by default, just to avoid such a major API change (so maybe let's add a flag return_sampler: bool = False to the .sample() method?) Alternatively, we could also save the sampler as an attribute of the posterior class (i.e. self.sampler(...) here).
  3. Unfortunately, pyro does not handle these transformations itself when a potential_fn() is used, it can only infer the transform if a model() is provided. At least I could not figure out how one would do it with a potential_fn(). We do the tranforms ourselves here. I had created an issue on pyro about this here.

Hope this helps!

@sethaxen
Copy link
Contributor Author

sethaxen commented Sep 22, 2021

1. Yes. `mcmc_parameters={"num_chains": 10}` will run multiple chains.

perfect!

2. That would in principle be possible. I would prefer not returning it by default, just to avoid such a major API change (so maybe let's add a flag `return_sampler: bool = False` to the `.sample()` method?) Alternatively, we could also save the sampler as an attribute of the posterior class (i.e. `self.sampler(...)` [here](https://github.com/mackelab/sbi/blob/65b9873d2cab0a1954c0203833a99f8140dbba99/sbi/inference/posteriors/base_posterior.py#L612)).

I agree, the current API should be kept. The latter option is nice (storing the MCMC object as self.sampler). Pyro's MCMC also takes a thinning parameter https://num.pyro.ai/en/stable/mcmc.html#numpyro.infer.mcmc.MCMC. Is there a reason sbi does the thinning itself?

3. Unfortunately, pyro does not handle these transformations itself when a `potential_fn()` is used, it can only infer the transform if a `model()` is provided. At least I could not figure out how one would do it with a `potential_fn()`. We do the tranforms ourselves [here](https://github.com/mackelab/sbi/blob/65b9873d2cab0a1954c0203833a99f8140dbba99/sbi/inference/posteriors/likelihood_based_posterior.py#L187).

Hm, that is unfortunate. I need to look more carefully at pyro.

Hope this helps!

yes, very helpful! Thanks! I'm learning both sbi and pyro at the same time, so there will likely be more questions.

@michaeldeistler
Copy link
Contributor

michaeldeistler commented Sep 22, 2021

  1. Nope, there's no reason to not use pyro's thinning here. Feel free to use it.
  2. Yeah, I still think this is a bug but I don't know pyro well either (which is why I stopped bothering them about this ;) ).

@sethaxen
Copy link
Contributor Author

sethaxen commented Sep 22, 2021

  1. Yeah, I still think this is a bug but I don't know pyro well either (which is why I stopped bothering them about this ;) ).

Perhaps this could be worked around with a helper function for construction a Distribution from the LikelihoodBasePosterior object. Then sbi would define a Pyro model instead of using potential_fn(), passing the transform defined by sbi.

A fringe benefit of this is that in principle one could use the fitted posterior as a prior in a Pyro model for Bayesian updating.

@michaeldeistler
Copy link
Contributor

michaeldeistler commented Sep 22, 2021

Yes I think this could have worked. We added the transforms way after relying on pyro's potential_fn so there's definitely some technical debt here. Your suggestion sounds very reasonable, we might want to do this in the future

@sethaxen
Copy link
Contributor Author

Alright, in the interest of picking the low-hanging fruit first, my proposal is to:

  1. pass the thinning keyword to MCMC
  2. make sampler a stored field of LikelihoodBasedPosterior (at least, maybe it makes sense for other classes to have this)

That should be sufficient for allowing users to use ArviZ to diagnose model problems in the unconstrained space. Then later we could potentially do something like #542 (comment) so that the users get the sampler in the constrained space.

@michaeldeistler
Copy link
Contributor

Sounds good to me. Regarding 2:
If you make in an attribute here, i.e. in the BasePosterior, then all methods will inherit the attribute (which i think is desirable). I'd also set the numpy based samplers as attribute. I.e. here rename the posterior_sampler to self.sampler

@sethaxen
Copy link
Contributor Author

I'd also set the numpy based samplers as attribute. I.e. here rename the posterior_sampler to self.sampler

What are your thoughts on this particular case, where there's not just one sampler but a vector of samplers (one per chain)? I see several ways of handling this:

  1. have self.sampler return either a sampler or vector of samplers (this case)
  2. have self.samplers instead, which would return a vector of length one for all samplers except this one (downside: length of vector is always one or number of chains in only this case)
  3. introduce something like SliceSamplerSerial that has the same interface as SliceSamplerVectorized but internally loops over the chains and calls SliceSampler. Then this code would use SliceSamplerSerial.

To me (3) seems cleanest. What do you think?

@michaeldeistler
Copy link
Contributor

michaeldeistler commented Sep 24, 2021

Just to make sure that I understand proposition 3 correctly:

you would move this loop into the new class SliceSamplerVectorized, right? I like the idea because it would simplify this entire if-else-case.

@sethaxen
Copy link
Contributor Author

you would move this loop into the new class SliceSamplerVectorized, right? I like the idea because it would simplify this entire if-else-case.

SliceSamplerVectorized already exists, and it would be misleading for a so-named class for it to potentially loop internally. My thought was to add a SliceSamplerSerial that would behave similarly to SliceSamplerVectorized but would instead have a loop. Maybe one easy way to do this would be to add something like SliceSamplerMultiChainBase that implements any methods that would be exactly shared by the two classes and have both SliceSamplerSerial and SliceSamplerVectorized subclass this.

Alternatively, one could have a SliceSamplerMultiChain that has the linked if-else statement and deprecate SliceSamplerVectorized, but if the latter class is part of the API, this is not ideal.

@michaeldeistler
Copy link
Contributor

michaeldeistler commented Sep 27, 2021

Sorryyyy I meant SliceSamplerSerial, not SliceSamplerVectorized. So yeah, I completely agree with your original suggestion

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request hackathon
Projects
None yet
4 participants