-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of DCC inference algorithm #1715
Merged
fehiepsi
merged 7 commits into
pyro-ppl:master
from
treigerm:initial_dcc_implementation
Feb 22, 2024
Merged
Changes from 2 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
2ceb88d
Initial bare bones implementation of DCC
treigerm 0450340
Add tests and documentation
treigerm d1a435b
Make scale in Normal proposal configurable
treigerm 9a57c4a
Merge branch 'master' into initial_dcc_implementation
treigerm e60215f
Run linter
treigerm b09a7f7
Add __init__.py file and allow parallel inference in tests
treigerm 00068d0
Move DCC tests to 'test chains' group
treigerm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from collections import OrderedDict, namedtuple | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from jax import random | ||
|
||
import numpyro.distributions as dist | ||
from numpyro.handlers import condition, seed, trace | ||
from numpyro.infer import MCMC, NUTS | ||
from numpyro.infer.autoguide import AutoNormal | ||
from numpyro.infer.util import init_to_value, log_density | ||
|
||
DCCResult = namedtuple("DCCResult", ["samples", "slp_weights"]) | ||
|
||
|
||
class DCC: | ||
""" | ||
Implements the Divide, Conquer, and Combine (DCC) algorithm for models with | ||
stochastic support from [1]. | ||
|
||
.. note:: This implementation assumes that all stochastic branching is done based on the | ||
outcomes of discrete sampling sites that are annotated with `infer={"branching": True}`. | ||
For example, | ||
|
||
.. code-block:: python | ||
|
||
def model(): | ||
model1 = numpyro.sample("model1", dist.Bernoulli(0.5), infer={"branching": True}) | ||
if model1 == 0: | ||
mean = numpyro.sample("a1", dist.Normal(0.0, 1.0)) | ||
else: | ||
mean = numpyro.sample("a2", dist.Normal(1.0, 1.0)) | ||
numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2) | ||
|
||
|
||
|
||
**References:** | ||
|
||
1. *Divide, Conquer, and Combine: a New Inference Strategy for Probabilistic Programs with Stochastic Support*, | ||
Yuan Zhou, Hongseok Yang, Yee Whye Teh, Tom Rainforth | ||
|
||
:param model: Python callable containing Pyro primitives :mod:`~numpyro.primitives`. | ||
:param dict mcmc_kwargs: Dictionary of arguments passed to :data:`~numpyro.infer.MCMC`. | ||
:param numpyro.infer.mcmc.MCMCKernel kernel_cls: MCMC kernel class that is used for | ||
local inference. Defaults to :class:`~numpyro.infer.NUTS`. | ||
:param int num_slp_samples: Number of samples to draw from the prior to discover the | ||
straight-line programs (SLPs). | ||
:param int max_slps: Maximum number of SLPs to discover. DCC will not run inference | ||
on more than `max_slps`. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model, | ||
mcmc_kwargs, | ||
kernel_cls=NUTS, | ||
num_slp_samples=1000, | ||
max_slps=124, | ||
): | ||
self.model = model | ||
self.kernel_cls = kernel_cls | ||
self.mcmc_kwargs = mcmc_kwargs | ||
|
||
self.num_slp_samples = num_slp_samples | ||
self.max_slps = max_slps | ||
|
||
def _find_slps(self, rng_key, *args, **kwargs): | ||
""" | ||
Discover the straight-line programs (SLPs) in the model by sampling from the prior. | ||
This implementation assumes that all branching is done via discrete sampling sites | ||
that are annotated with `infer={"branching": True}`. | ||
""" | ||
branching_traces = {} | ||
for _ in range(self.num_slp_samples): | ||
rng_key, subkey = random.split(rng_key) | ||
tr = trace(seed(self.model, subkey)).get_trace(*args, **kwargs) | ||
btr = self._get_branching_trace(tr) | ||
btr_str = ",".join(str(x) for x in btr.values()) | ||
if btr_str not in branching_traces: | ||
branching_traces[btr_str] = btr | ||
if len(branching_traces) >= self.max_slps: | ||
break | ||
|
||
return branching_traces | ||
|
||
def _get_branching_trace(self, tr): | ||
""" | ||
Extract the sites from the trace that are annotated with `infer={"branching": True}`. | ||
""" | ||
branching_trace = OrderedDict() | ||
for site in tr.values(): | ||
if site["type"] == "sample" and site["infer"].get("branching", False): | ||
if ( | ||
not isinstance(site["fn"], dist.Distribution) | ||
or not site["fn"].support.is_discrete | ||
): | ||
raise RuntimeError( | ||
"Branching is only supported for discrete sampling sites." | ||
) | ||
# It is essential that we convert the value to a Python int. If it remains | ||
# a JAX Array, then during JIT compilation it will be treated as an AbstractArray | ||
# which means branching will raise in an error. | ||
# Reference: (https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit) | ||
branching_trace[site["name"]] = int(site["value"]) | ||
return branching_trace | ||
|
||
def _run_mcmc(self, rng_key, branching_trace, *args, **kwargs): | ||
""" | ||
Run MCMC on the model conditioned on the given branching trace. | ||
""" | ||
slp_model = condition(self.model, data=branching_trace) | ||
kernel = self.kernel_cls(slp_model) | ||
mcmc = MCMC(kernel, **self.mcmc_kwargs) | ||
mcmc.run(rng_key, *args, **kwargs) | ||
|
||
return mcmc.get_samples() | ||
|
||
def _combine_samples(self, rng_key, samples, branching_traces, *args, **kwargs): | ||
""" | ||
Weight each SLP proportional to its estimated normalization constant. | ||
The normalization constants are estimated using importance sampling with | ||
the proposal centered on the MCMC samples. | ||
""" | ||
|
||
def log_weight(rng_key, i, slp_model, slp_samples): | ||
trace = {k: v[i] for k, v in slp_samples.items()} | ||
guide = AutoNormal( | ||
slp_model, | ||
init_loc_fn=init_to_value(values=trace), | ||
init_scale=1.0, | ||
) | ||
rng_key, subkey = random.split(rng_key) | ||
guide_trace = seed(guide, subkey)(*args, **kwargs) | ||
guide_log_density, _ = log_density(guide, args, kwargs, guide_trace) | ||
model_log_density, _ = log_density(slp_model, args, kwargs, guide_trace) | ||
return model_log_density - guide_log_density | ||
|
||
log_weights = jax.vmap(log_weight, in_axes=(None, 0, None, None)) | ||
|
||
log_Zs = {} | ||
for bt, slp_samples in samples.items(): | ||
num_samples = slp_samples[next(iter(slp_samples))].shape[0] | ||
slp_model = condition(self.model, data=branching_traces[bt]) | ||
lws = log_weights(rng_key, jnp.arange(num_samples), slp_model, slp_samples) | ||
log_Zs[bt] = jax.scipy.special.logsumexp(lws) - jnp.log(num_samples) | ||
|
||
normalizer = jax.scipy.special.logsumexp(jnp.array(list(log_Zs.values()))) | ||
slp_weights = {k: jnp.exp(v - normalizer) for k, v in log_Zs.items()} | ||
return DCCResult(samples, slp_weights) | ||
|
||
def run(self, rng_key, *args, **kwargs): | ||
""" | ||
Run DCC and collect samples for all SLPs. | ||
|
||
:param jax.random.PRNGKey rng_key: Random number generator key. | ||
:param args: Arguments to the model. | ||
:param kwargs: Keyword arguments to the model. | ||
""" | ||
rng_key, subkey = random.split(rng_key) | ||
branching_traces = self._find_slps(subkey, *args, **kwargs) | ||
|
||
samples = dict() | ||
for key, bt in branching_traces.items(): | ||
rng_key, subkey = random.split(rng_key) | ||
samples[key] = self._run_mcmc(subkey, bt, *args, **kwargs) | ||
|
||
rng_key, subkey = random.split(rng_key) | ||
return self._combine_samples(subkey, samples, branching_traces, *args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
import math | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import pytest | ||
from jax import random | ||
from numpy.testing import assert_allclose | ||
|
||
import numpyro | ||
import numpyro.distributions as dist | ||
from numpyro.contrib.stochastic_support.dcc import DCC | ||
from numpyro.infer import HMC, NUTS, SA, BarkerMH | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"branch_dist", | ||
[dist.Normal(0, 1), dist.Gamma(1, 1)], | ||
) | ||
@pytest.mark.xfail(raises=RuntimeError) | ||
def test_continuous_branching(branch_dist): | ||
rng_key = random.PRNGKey(0) | ||
|
||
def model(): | ||
model1 = numpyro.sample("model1", branch_dist, infer={"branching": True}) | ||
mean = 1.0 if model1 == 0 else 2.0 | ||
numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2) | ||
|
||
mcmc_kwargs = dict( | ||
num_warmup=500, | ||
num_samples=1000, | ||
num_chains=1, | ||
) | ||
|
||
dcc = DCC(model, mcmc_kwargs=mcmc_kwargs) | ||
rng_key, subkey = random.split(rng_key) | ||
dcc.run(subkey) | ||
|
||
|
||
def test_different_address_path(): | ||
rng_key = random.PRNGKey(0) | ||
|
||
def model(): | ||
model1 = numpyro.sample( | ||
"model1", dist.Bernoulli(0.5), infer={"branching": True} | ||
) | ||
if model1 == 0: | ||
numpyro.sample("a1", dist.Normal(9.0, 1.0)) | ||
else: | ||
numpyro.sample("a2", dist.Normal(9.0, 1.0)) | ||
numpyro.sample("a3", dist.Normal(9.0, 1.0)) | ||
mean = 1.0 if model1 == 0 else 2.0 | ||
numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2) | ||
|
||
mcmc_kwargs = dict( | ||
num_warmup=50, | ||
num_samples=50, | ||
num_chains=1, | ||
progress_bar=False, | ||
) | ||
|
||
dcc = DCC(model, mcmc_kwargs=mcmc_kwargs) | ||
rng_key, subkey = random.split(rng_key) | ||
dcc.run(subkey) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"chain_method", | ||
["sequential", "parallel", "vectorized"], | ||
) | ||
@pytest.mark.parametrize("kernel_cls", [NUTS, HMC, SA, BarkerMH]) | ||
def test_kernels(chain_method, kernel_cls): | ||
if chain_method == "vectorized" and kernel_cls in [SA, BarkerMH]: | ||
# These methods do not support vectorized execution. | ||
return | ||
|
||
def model(y): | ||
z = numpyro.sample("z", dist.Normal(0.0, 1.0)) | ||
model1 = numpyro.sample( | ||
"model1", dist.Bernoulli(0.5), infer={"branching": True} | ||
) | ||
sigma = 1.0 if model1 == 0 else 2.0 | ||
with numpyro.plate("data", y.shape[0]): | ||
numpyro.sample("obs", dist.Normal(z, sigma), obs=y) | ||
|
||
rng_key = random.PRNGKey(0) | ||
|
||
rng_key, subkey = random.split(rng_key) | ||
y_train = dist.Normal(0, 1).sample(subkey, (200,)) | ||
|
||
mcmc_kwargs = dict( | ||
num_warmup=50, | ||
num_samples=50, | ||
num_chains=2, | ||
chain_method=chain_method, | ||
progress_bar=False, | ||
) | ||
|
||
dcc = DCC(model, mcmc_kwargs=mcmc_kwargs, kernel_cls=kernel_cls) | ||
rng_key, subkey = random.split(rng_key) | ||
dcc.run(subkey, y_train) | ||
|
||
|
||
def test_weight_convergence(): | ||
PRIOR_MEAN, PRIOR_STD = 0.0, 1.0 | ||
LIKELIHOOD1_STD = 2.0 | ||
LIKELIHOOD2_STD = 0.62177 | ||
|
||
def log_marginal_likelihood(data, likelihood_std, prior_mean, prior_std): | ||
""" | ||
Calculate the marginal likelihood of a model with Normal likelihood, unknown mean, | ||
and Normal prior. | ||
|
||
Taken from Section 2.5 at https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf. | ||
""" | ||
num_data = data.shape[0] | ||
likelihood_var = jnp.square(likelihood_std) | ||
prior_var = jnp.square(prior_std) | ||
|
||
first_term = ( | ||
jnp.log(likelihood_std) | ||
- num_data * jnp.log(jnp.sqrt(2 * math.pi) * likelihood_std) | ||
+ 0.5 * jnp.log(num_data * prior_var + likelihood_var) | ||
) | ||
second_term = -(jnp.sum(jnp.square(data)) / (2 * likelihood_var)) - ( | ||
jnp.square(prior_mean) / (2 * prior_var) | ||
) | ||
third_term = ( | ||
( | ||
prior_var | ||
* jnp.square(num_data) | ||
* jnp.square(jnp.mean(data)) | ||
/ likelihood_var | ||
) | ||
+ (likelihood_var * jnp.square(prior_mean) / prior_var) | ||
+ 2 * num_data * jnp.mean(data) * prior_mean | ||
) / (2 * (num_data * prior_var + likelihood_var)) | ||
return first_term + second_term + third_term | ||
|
||
def model(y): | ||
z = numpyro.sample("z", dist.Normal(PRIOR_MEAN, PRIOR_STD)) | ||
model1 = numpyro.sample( | ||
"model1", dist.Bernoulli(0.5), infer={"branching": True} | ||
) | ||
sigma = LIKELIHOOD1_STD if model1 == 0 else LIKELIHOOD2_STD | ||
with numpyro.plate("data", y.shape[0]): | ||
numpyro.sample("obs", dist.Normal(z, sigma), obs=y) | ||
|
||
rng_key = random.PRNGKey(0) | ||
|
||
rng_key, subkey = random.split(rng_key) | ||
y_train = dist.Normal(0, 1).sample(subkey, (200,)) | ||
|
||
mcmc_kwargs = dict( | ||
num_warmup=500, | ||
num_samples=1000, | ||
num_chains=1, | ||
) | ||
|
||
dcc = DCC(model, mcmc_kwargs=mcmc_kwargs) | ||
rng_key, subkey = random.split(rng_key) | ||
dcc_result = dcc.run(subkey, y_train) | ||
slp_weights = jnp.array([dcc_result.slp_weights["0"], dcc_result.slp_weights["1"]]) | ||
assert_allclose(1.0, jnp.sum(slp_weights)) | ||
|
||
slp1_lml = log_marginal_likelihood(y_train, LIKELIHOOD1_STD, PRIOR_MEAN, PRIOR_STD) | ||
slp2_lml = log_marginal_likelihood(y_train, LIKELIHOOD2_STD, PRIOR_MEAN, PRIOR_STD) | ||
lmls = jnp.array([slp1_lml, slp2_lml]) | ||
analytic_weights = jnp.exp(lmls - jax.scipy.special.logsumexp(lmls)) | ||
assert_allclose(analytic_weights, slp_weights, rtol=1e-5, atol=1e-8) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this standard Gaussian a good practical choice for the proposal? Looking at the paper, it seems that the authors used a metropolis-within-gibbs sampler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for spending the time reviewing the PR!
Note that the Gaussian is centered around the MCMC samples (more precisely, each MCMC sample gives rise to a single proposal distribution). As long as the MCMC chain(s) are well-mixed this generally leads to good proposals. This is also what the paper describes (and also what the author's implementation does which I have received upon request).
Actually, the metropoylis-within-gibbs sampler is only used for the local inference tasks. For many models it isn't a very efficient inference algorithm because it only updates one variable at a time. Because the implementation here assumes that the branching is only done based on the outcomes of discrete sampling statements, it can use more efficient algorithms for local inference (like HMC or NUTS).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks. How about using sample variance instead of unit variance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have just updated the PR to make the scale in the proposal a parameter that can be set by the user. I agree that unit variance is probably not always desirable. However, I'm not sure whether sample variance is desirable either. The main idea behind the algorithm for estimating the normalization constant (which in more detailed is described in another paper Layered Adaptive Importance Sampling) is that the proposal on top of each sample leads to fairly local proposals. If there are multiple modes in the posterior then using the sample variance could result in lots of proposed samples in the low density regions between the modes.
I'd still be open to adding the option to use the sample variance to the implementation but this is currently complicated by the way the
AutoNormal
guide is implemented. You would want to compute the sample variance for each individual variable in the program but as far as I can tell there is no way to set variable specific variances in theAutoNormal
guide (it's possible to set variable specific means though). So this might be a feature that would be reasonable to add at a later time?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late response! Thank you for the insights - I don't have a strong opinion on whether it's helpful to expose
init_scale
in AutoNormal. There is another way to substitute sample variance, likesubstitute(guide, data={"auto_foo_scale": ...})
but we need to be careful at the domain (needs to be unconstrained) of such afoo
variable.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I seem, I guess this would require knowledge about all the variable names in the program though (but that could be extracted automatically)? For now I would lean towards leaving the implementation as is to keep it simple, if that is okay.