Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of DCC inference algorithm #1715

Merged
merged 7 commits into from
Feb 22, 2024

Conversation

treigerm
Copy link
Contributor

@treigerm treigerm commented Jan 8, 2024

This is an initial bare bones implementation of the Divide, Conquer, Combine inference algorithm for programs with stochastic support as discussed in #1697. I have also included a simple example to show how to use the interface. This is only a bery basic implementation of the algorithm to keep the size of the PR reasonable. In its current form the algorithm assumes that the branching inside the program is done based on the outcomes of discrete sampling sites which are annotated with infer={"branching": True}.

The algorithm then proceeds as follows:

  1. Discover different branches, a.k.a. straight-line programs (SLPs), by sampling from the prior.
  2. Run inference on each discovered branch separately.
  3. Combine the inference results by weighting each branch proportional to its marginal likelihood estimate.

This is just a draft PR for now to see whether the general approach for the implementation makes sense. If the approach is sensible then I will go and add a more detailed example, some tests and documentation. So the main questions that need to be answered at the moment are:

  • Does the general interface for the algorithm seem sensible?
  • Did I place the new implementation in the correct location (I located it in the contrib folder for now)?
  • Am I using the Numpyro primitives correctly?

@treigerm treigerm marked this pull request as draft January 8, 2024 08:38
@fehiepsi
Copy link
Member

Thanks for the contribution, @treigerm! The draft looks great.

Re location: yeah, it makes sense to put the algorithm in contrib. We can move it to infer in the future when the api is solid.
Re primitives: yes, the usage looks correct to me.

@treigerm treigerm marked this pull request as ready for review January 17, 2024 10:01
@treigerm
Copy link
Contributor Author

I have now added tests and documentation. Please let me know if you think anything is missing!

After this PR is merged my plan would be to add the SDVI algorithm and after that is added I can write a tutorial about how to use these two inference algorithms and their respective trade-offs.

@treigerm treigerm changed the title [WIP] Implementation of DCC inference algorithm Implementation of DCC inference algorithm Jan 17, 2024
"""
Weight each SLP proportional to its estimated normalization constant.
The normalization constants are estimated using importance sampling with
the proposal centered on the MCMC samples.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this standard Gaussian a good practical choice for the proposal? Looking at the paper, it seems that the authors used a metropolis-within-gibbs sampler.

Copy link
Contributor Author

@treigerm treigerm Jan 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for spending the time reviewing the PR!

Note that the Gaussian is centered around the MCMC samples (more precisely, each MCMC sample gives rise to a single proposal distribution). As long as the MCMC chain(s) are well-mixed this generally leads to good proposals. This is also what the paper describes (and also what the author's implementation does which I have received upon request).

Looking at the paper, it seems that the authors used a metropolis-within-gibbs sampler.

Actually, the metropoylis-within-gibbs sampler is only used for the local inference tasks. For many models it isn't a very efficient inference algorithm because it only updates one variable at a time. Because the implementation here assumes that the branching is only done based on the outcomes of discrete sampling statements, it can use more efficient algorithms for local inference (like HMC or NUTS).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks. How about using sample variance instead of unit variance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have just updated the PR to make the scale in the proposal a parameter that can be set by the user. I agree that unit variance is probably not always desirable. However, I'm not sure whether sample variance is desirable either. The main idea behind the algorithm for estimating the normalization constant (which in more detailed is described in another paper Layered Adaptive Importance Sampling) is that the proposal on top of each sample leads to fairly local proposals. If there are multiple modes in the posterior then using the sample variance could result in lots of proposed samples in the low density regions between the modes.

I'd still be open to adding the option to use the sample variance to the implementation but this is currently complicated by the way the AutoNormal guide is implemented. You would want to compute the sample variance for each individual variable in the program but as far as I can tell there is no way to set variable specific variances in the AutoNormal guide (it's possible to set variable specific means though). So this might be a feature that would be reasonable to add at a later time?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late response! Thank you for the insights - I don't have a strong opinion on whether it's helpful to expose init_scale in AutoNormal. There is another way to substitute sample variance, like substitute(guide, data={"auto_foo_scale": ...}) but we need to be careful at the domain (needs to be unconstrained) of such a foo variable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I seem, I guess this would require knowledge about all the variable names in the program though (but that could be extracted automatically)? For now I would lean towards leaving the implementation as is to keep it simple, if that is okay.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending linting issues. Thanks for the great contribution, @treigerm!

@fehiepsi fehiepsi added this to the 0.14 milestone Feb 16, 2024
@treigerm
Copy link
Contributor Author

Thanks, @fehiepsi ! I'm currently unsure why the tests are failing, it seems to be an import error that it is not able to find the new numpyro.contrib.stochastic_support.dcc module. For me, locally the test are passing if I run XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/stochastic_support/test_dcc.py, so not quite sure where the issue is coming from.

@fehiepsi
Copy link
Member

How about adding __init__ file and exposing DCC etc. there? Then you can import

from numpyro.contrib.stochastic_support import DCC

@treigerm
Copy link
Contributor Author

Ah yes of course! I have added an __init__ file now. I also added the XLA_FLAGS="--xla_force_host_platform_device_count=2" flag to the contrib module tests to make the parallel chain sampling method pass. Tests are passing locally for me.

@@ -102,7 +102,7 @@ jobs:
run: |
pytest -vs --durations=20 test/infer/test_mcmc.py
pytest -vs --durations=20 test/infer --ignore=test/infer/test_mcmc.py
pytest -vs --durations=20 test/contrib
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs --durations=20 test/contrib
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you ignore dcc test here and move dcc to the test chain below instead? thanks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@fehiepsi fehiepsi merged commit c4ca3d8 into pyro-ppl:master Feb 22, 2024
4 checks passed
OlaRonning pushed a commit to aleatory-science/numpyro that referenced this pull request Feb 27, 2024
* Initial bare bones implementation of DCC

* Add tests and documentation

* Make scale in Normal proposal configurable

* Run linter

* Add __init__.py file and allow parallel inference in tests

* Move DCC tests to 'test chains' group
OlaRonning pushed a commit to aleatory-science/numpyro that referenced this pull request May 6, 2024
* Initial bare bones implementation of DCC

* Add tests and documentation

* Make scale in Normal proposal configurable

* Run linter

* Add __init__.py file and allow parallel inference in tests

* Move DCC tests to 'test chains' group
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants