Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of DCC inference algorithm #1715

Merged
merged 7 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,11 @@ SteinVI Kernels
.. autoclass:: numpyro.contrib.einstein.stein_kernels.ProbabilityProductKernel


Stochastic Support
~~~~~~~~~~~~~~~~~~

.. autoclass:: numpyro.contrib.stochastic_support.dcc.DCC
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
171 changes: 171 additions & 0 deletions numpyro/contrib/stochastic_support/dcc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict, namedtuple

import jax
import jax.numpy as jnp
from jax import random

import numpyro.distributions as dist
from numpyro.handlers import condition, seed, trace
from numpyro.infer import MCMC, NUTS
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.util import init_to_value, log_density

DCCResult = namedtuple("DCCResult", ["samples", "slp_weights"])


class DCC:
"""
Implements the Divide, Conquer, and Combine (DCC) algorithm for models with
stochastic support from [1].

.. note:: This implementation assumes that all stochastic branching is done based on the
outcomes of discrete sampling sites that are annotated with `infer={"branching": True}`.
For example,

.. code-block:: python

def model():
model1 = numpyro.sample("model1", dist.Bernoulli(0.5), infer={"branching": True})
if model1 == 0:
mean = numpyro.sample("a1", dist.Normal(0.0, 1.0))
else:
mean = numpyro.sample("a2", dist.Normal(1.0, 1.0))
numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2)



**References:**

1. *Divide, Conquer, and Combine: a New Inference Strategy for Probabilistic Programs with Stochastic Support*,
Yuan Zhou, Hongseok Yang, Yee Whye Teh, Tom Rainforth

:param model: Python callable containing Pyro primitives :mod:`~numpyro.primitives`.
:param dict mcmc_kwargs: Dictionary of arguments passed to :data:`~numpyro.infer.MCMC`.
:param numpyro.infer.mcmc.MCMCKernel kernel_cls: MCMC kernel class that is used for
local inference. Defaults to :class:`~numpyro.infer.NUTS`.
:param int num_slp_samples: Number of samples to draw from the prior to discover the
straight-line programs (SLPs).
:param int max_slps: Maximum number of SLPs to discover. DCC will not run inference
on more than `max_slps`.
"""

def __init__(
self,
model,
mcmc_kwargs,
kernel_cls=NUTS,
num_slp_samples=1000,
max_slps=124,
):
self.model = model
self.kernel_cls = kernel_cls
self.mcmc_kwargs = mcmc_kwargs

self.num_slp_samples = num_slp_samples
self.max_slps = max_slps

def _find_slps(self, rng_key, *args, **kwargs):
"""
Discover the straight-line programs (SLPs) in the model by sampling from the prior.
This implementation assumes that all branching is done via discrete sampling sites
that are annotated with `infer={"branching": True}`.
"""
branching_traces = {}
for _ in range(self.num_slp_samples):
rng_key, subkey = random.split(rng_key)
tr = trace(seed(self.model, subkey)).get_trace(*args, **kwargs)
btr = self._get_branching_trace(tr)
btr_str = ",".join(str(x) for x in btr.values())
if btr_str not in branching_traces:
branching_traces[btr_str] = btr
if len(branching_traces) >= self.max_slps:
break

return branching_traces

def _get_branching_trace(self, tr):
"""
Extract the sites from the trace that are annotated with `infer={"branching": True}`.
"""
branching_trace = OrderedDict()
for site in tr.values():
if site["type"] == "sample" and site["infer"].get("branching", False):
if (
not isinstance(site["fn"], dist.Distribution)
or not site["fn"].support.is_discrete
):
raise RuntimeError(
"Branching is only supported for discrete sampling sites."
)
# It is essential that we convert the value to a Python int. If it remains
# a JAX Array, then during JIT compilation it will be treated as an AbstractArray
# which means branching will raise in an error.
# Reference: (https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit)
branching_trace[site["name"]] = int(site["value"])
return branching_trace

def _run_mcmc(self, rng_key, branching_trace, *args, **kwargs):
"""
Run MCMC on the model conditioned on the given branching trace.
"""
slp_model = condition(self.model, data=branching_trace)
kernel = self.kernel_cls(slp_model)
mcmc = MCMC(kernel, **self.mcmc_kwargs)
mcmc.run(rng_key, *args, **kwargs)

return mcmc.get_samples()

def _combine_samples(self, rng_key, samples, branching_traces, *args, **kwargs):
"""
Weight each SLP proportional to its estimated normalization constant.
The normalization constants are estimated using importance sampling with
the proposal centered on the MCMC samples.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this standard Gaussian a good practical choice for the proposal? Looking at the paper, it seems that the authors used a metropolis-within-gibbs sampler.

Copy link
Contributor Author

@treigerm treigerm Jan 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for spending the time reviewing the PR!

Note that the Gaussian is centered around the MCMC samples (more precisely, each MCMC sample gives rise to a single proposal distribution). As long as the MCMC chain(s) are well-mixed this generally leads to good proposals. This is also what the paper describes (and also what the author's implementation does which I have received upon request).

Looking at the paper, it seems that the authors used a metropolis-within-gibbs sampler.

Actually, the metropoylis-within-gibbs sampler is only used for the local inference tasks. For many models it isn't a very efficient inference algorithm because it only updates one variable at a time. Because the implementation here assumes that the branching is only done based on the outcomes of discrete sampling statements, it can use more efficient algorithms for local inference (like HMC or NUTS).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks. How about using sample variance instead of unit variance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have just updated the PR to make the scale in the proposal a parameter that can be set by the user. I agree that unit variance is probably not always desirable. However, I'm not sure whether sample variance is desirable either. The main idea behind the algorithm for estimating the normalization constant (which in more detailed is described in another paper Layered Adaptive Importance Sampling) is that the proposal on top of each sample leads to fairly local proposals. If there are multiple modes in the posterior then using the sample variance could result in lots of proposed samples in the low density regions between the modes.

I'd still be open to adding the option to use the sample variance to the implementation but this is currently complicated by the way the AutoNormal guide is implemented. You would want to compute the sample variance for each individual variable in the program but as far as I can tell there is no way to set variable specific variances in the AutoNormal guide (it's possible to set variable specific means though). So this might be a feature that would be reasonable to add at a later time?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late response! Thank you for the insights - I don't have a strong opinion on whether it's helpful to expose init_scale in AutoNormal. There is another way to substitute sample variance, like substitute(guide, data={"auto_foo_scale": ...}) but we need to be careful at the domain (needs to be unconstrained) of such a foo variable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I seem, I guess this would require knowledge about all the variable names in the program though (but that could be extracted automatically)? For now I would lean towards leaving the implementation as is to keep it simple, if that is okay.

"""

def log_weight(rng_key, i, slp_model, slp_samples):
trace = {k: v[i] for k, v in slp_samples.items()}
guide = AutoNormal(
slp_model,
init_loc_fn=init_to_value(values=trace),
init_scale=1.0,
)
rng_key, subkey = random.split(rng_key)
guide_trace = seed(guide, subkey)(*args, **kwargs)
guide_log_density, _ = log_density(guide, args, kwargs, guide_trace)
model_log_density, _ = log_density(slp_model, args, kwargs, guide_trace)
return model_log_density - guide_log_density

log_weights = jax.vmap(log_weight, in_axes=(None, 0, None, None))

log_Zs = {}
for bt, slp_samples in samples.items():
num_samples = slp_samples[next(iter(slp_samples))].shape[0]
slp_model = condition(self.model, data=branching_traces[bt])
lws = log_weights(rng_key, jnp.arange(num_samples), slp_model, slp_samples)
log_Zs[bt] = jax.scipy.special.logsumexp(lws) - jnp.log(num_samples)

normalizer = jax.scipy.special.logsumexp(jnp.array(list(log_Zs.values())))
slp_weights = {k: jnp.exp(v - normalizer) for k, v in log_Zs.items()}
return DCCResult(samples, slp_weights)

def run(self, rng_key, *args, **kwargs):
"""
Run DCC and collect samples for all SLPs.

:param jax.random.PRNGKey rng_key: Random number generator key.
:param args: Arguments to the model.
:param kwargs: Keyword arguments to the model.
"""
rng_key, subkey = random.split(rng_key)
branching_traces = self._find_slps(subkey, *args, **kwargs)

samples = dict()
for key, bt in branching_traces.items():
rng_key, subkey = random.split(rng_key)
samples[key] = self._run_mcmc(subkey, bt, *args, **kwargs)

rng_key, subkey = random.split(rng_key)
return self._combine_samples(subkey, samples, branching_traces, *args, **kwargs)
169 changes: 169 additions & 0 deletions test/contrib/stochastic_support/test_dcc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import math

import jax
import jax.numpy as jnp
import pytest
from jax import random
from numpy.testing import assert_allclose

import numpyro
import numpyro.distributions as dist
from numpyro.contrib.stochastic_support.dcc import DCC
from numpyro.infer import HMC, NUTS, SA, BarkerMH


@pytest.mark.parametrize(
"branch_dist",
[dist.Normal(0, 1), dist.Gamma(1, 1)],
)
@pytest.mark.xfail(raises=RuntimeError)
def test_continuous_branching(branch_dist):
rng_key = random.PRNGKey(0)

def model():
model1 = numpyro.sample("model1", branch_dist, infer={"branching": True})
mean = 1.0 if model1 == 0 else 2.0
numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2)

mcmc_kwargs = dict(
num_warmup=500,
num_samples=1000,
num_chains=1,
)

dcc = DCC(model, mcmc_kwargs=mcmc_kwargs)
rng_key, subkey = random.split(rng_key)
dcc.run(subkey)


def test_different_address_path():
rng_key = random.PRNGKey(0)

def model():
model1 = numpyro.sample(
"model1", dist.Bernoulli(0.5), infer={"branching": True}
)
if model1 == 0:
numpyro.sample("a1", dist.Normal(9.0, 1.0))
else:
numpyro.sample("a2", dist.Normal(9.0, 1.0))
numpyro.sample("a3", dist.Normal(9.0, 1.0))
mean = 1.0 if model1 == 0 else 2.0
numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2)

mcmc_kwargs = dict(
num_warmup=50,
num_samples=50,
num_chains=1,
progress_bar=False,
)

dcc = DCC(model, mcmc_kwargs=mcmc_kwargs)
rng_key, subkey = random.split(rng_key)
dcc.run(subkey)


@pytest.mark.parametrize(
"chain_method",
["sequential", "parallel", "vectorized"],
)
@pytest.mark.parametrize("kernel_cls", [NUTS, HMC, SA, BarkerMH])
def test_kernels(chain_method, kernel_cls):
if chain_method == "vectorized" and kernel_cls in [SA, BarkerMH]:
# These methods do not support vectorized execution.
return

def model(y):
z = numpyro.sample("z", dist.Normal(0.0, 1.0))
model1 = numpyro.sample(
"model1", dist.Bernoulli(0.5), infer={"branching": True}
)
sigma = 1.0 if model1 == 0 else 2.0
with numpyro.plate("data", y.shape[0]):
numpyro.sample("obs", dist.Normal(z, sigma), obs=y)

rng_key = random.PRNGKey(0)

rng_key, subkey = random.split(rng_key)
y_train = dist.Normal(0, 1).sample(subkey, (200,))

mcmc_kwargs = dict(
num_warmup=50,
num_samples=50,
num_chains=2,
chain_method=chain_method,
progress_bar=False,
)

dcc = DCC(model, mcmc_kwargs=mcmc_kwargs, kernel_cls=kernel_cls)
rng_key, subkey = random.split(rng_key)
dcc.run(subkey, y_train)


def test_weight_convergence():
PRIOR_MEAN, PRIOR_STD = 0.0, 1.0
LIKELIHOOD1_STD = 2.0
LIKELIHOOD2_STD = 0.62177

def log_marginal_likelihood(data, likelihood_std, prior_mean, prior_std):
"""
Calculate the marginal likelihood of a model with Normal likelihood, unknown mean,
and Normal prior.

Taken from Section 2.5 at https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf.
"""
num_data = data.shape[0]
likelihood_var = jnp.square(likelihood_std)
prior_var = jnp.square(prior_std)

first_term = (
jnp.log(likelihood_std)
- num_data * jnp.log(jnp.sqrt(2 * math.pi) * likelihood_std)
+ 0.5 * jnp.log(num_data * prior_var + likelihood_var)
)
second_term = -(jnp.sum(jnp.square(data)) / (2 * likelihood_var)) - (
jnp.square(prior_mean) / (2 * prior_var)
)
third_term = (
(
prior_var
* jnp.square(num_data)
* jnp.square(jnp.mean(data))
/ likelihood_var
)
+ (likelihood_var * jnp.square(prior_mean) / prior_var)
+ 2 * num_data * jnp.mean(data) * prior_mean
) / (2 * (num_data * prior_var + likelihood_var))
return first_term + second_term + third_term

def model(y):
z = numpyro.sample("z", dist.Normal(PRIOR_MEAN, PRIOR_STD))
model1 = numpyro.sample(
"model1", dist.Bernoulli(0.5), infer={"branching": True}
)
sigma = LIKELIHOOD1_STD if model1 == 0 else LIKELIHOOD2_STD
with numpyro.plate("data", y.shape[0]):
numpyro.sample("obs", dist.Normal(z, sigma), obs=y)

rng_key = random.PRNGKey(0)

rng_key, subkey = random.split(rng_key)
y_train = dist.Normal(0, 1).sample(subkey, (200,))

mcmc_kwargs = dict(
num_warmup=500,
num_samples=1000,
num_chains=1,
)

dcc = DCC(model, mcmc_kwargs=mcmc_kwargs)
rng_key, subkey = random.split(rng_key)
dcc_result = dcc.run(subkey, y_train)
slp_weights = jnp.array([dcc_result.slp_weights["0"], dcc_result.slp_weights["1"]])
assert_allclose(1.0, jnp.sum(slp_weights))

slp1_lml = log_marginal_likelihood(y_train, LIKELIHOOD1_STD, PRIOR_MEAN, PRIOR_STD)
slp2_lml = log_marginal_likelihood(y_train, LIKELIHOOD2_STD, PRIOR_MEAN, PRIOR_STD)
lmls = jnp.array([slp1_lml, slp2_lml])
analytic_weights = jnp.exp(lmls - jax.scipy.special.logsumexp(lmls))
assert_allclose(analytic_weights, slp_weights, rtol=1e-5, atol=1e-8)