diff --git a/docs/source/conf.py b/docs/source/conf.py index 4f0aa83a2..adf47cdd4 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -9,7 +9,6 @@ import nbsphinx import sphinx_rtd_theme - # import pkg_resources # -*- coding: utf-8 -*- @@ -33,6 +32,7 @@ # HACK: This is to ensure that local functions are documented by sphinx. from numpyro.infer.hmc import hmc # noqa: E402 + hmc(None, None) # -- Project information ----------------------------------------------------- diff --git a/examples/covtype.py b/examples/covtype.py index 1dcff36e7..702b0d151 100644 --- a/examples/covtype.py +++ b/examples/covtype.py @@ -4,13 +4,18 @@ import argparse import time +import matplotlib.pyplot as plt + from jax import random import jax.numpy as jnp import numpyro import numpyro.distributions as dist from numpyro.examples.datasets import COVTYPE, load_dataset -from numpyro.infer import MCMC, NUTS +from numpyro.infer import HMC, HMCECS, MCMC, NUTS, SA, SVI, Trace_ELBO, init_to_value +from numpyro.infer.autoguide import AutoBNAFNormal +from numpyro.infer.hmc_gibbs import taylor_proxy +from numpyro.infer.reparam import NeuTraReparam def _load_dataset(): @@ -33,22 +38,76 @@ def _load_dataset(): return features, labels -def model(data, labels): +def model(data, labels, subsample_size=None): dim = data.shape[1] coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim))) - logits = jnp.dot(data, coefs) - return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) + with numpyro.plate("N", data.shape[0], subsample_size=subsample_size) as idx: + logits = jnp.dot(data[idx], coefs) + return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels[idx]) def benchmark_hmc(args, features, labels): - step_size = jnp.sqrt(0.5 / features.shape[0]) - trajectory_length = step_size * args.num_steps rng_key = random.PRNGKey(1) start = time.time() - kernel = NUTS(model, trajectory_length=trajectory_length) - mcmc = MCMC(kernel, 0, args.num_samples) - mcmc.run(rng_key, features, labels) - mcmc.print_summary() + # a MAP estimate at the following source + # https://github.com/google/edward2/blob/master/examples/no_u_turn_sampler/logistic_regression.py#L117 + ref_params = {"coefs": jnp.array([ + +2.03420663e+00, -3.53567265e-02, -1.49223924e-01, -3.07049364e-01, + -1.00028366e-01, -1.46827862e-01, -1.64167881e-01, -4.20344204e-01, + +9.47479829e-02, -1.12681836e-02, +2.64442056e-01, -1.22087866e-01, + -6.00568838e-02, -3.79419506e-01, -1.06668741e-01, -2.97053963e-01, + -2.05253899e-01, -4.69537191e-02, -2.78072730e-02, -1.43250525e-01, + -6.77954629e-02, -4.34899796e-03, +5.90927452e-02, +7.23133609e-02, + +1.38526391e-02, -1.24497898e-01, -1.50733739e-02, -2.68872194e-02, + -1.80925727e-02, +3.47936489e-02, +4.03552800e-02, -9.98773426e-03, + +6.20188080e-02, +1.15002751e-01, +1.32145107e-01, +2.69109547e-01, + +2.45785132e-01, +1.19035013e-01, -2.59744357e-02, +9.94279515e-04, + +3.39266285e-02, -1.44057125e-02, -6.95222765e-02, -7.52013028e-02, + +1.21171586e-01, +2.29205526e-02, +1.47308692e-01, -8.34354162e-02, + -9.34122875e-02, -2.97472421e-02, -3.03937674e-01, -1.70958012e-01, + -1.59496680e-01, -1.88516974e-01, -1.20889175e+00])} + if args.algo == "HMC": + step_size = jnp.sqrt(0.5 / features.shape[0]) + trajectory_length = step_size * args.num_steps + kernel = HMC(model, step_size=step_size, trajectory_length=trajectory_length, adapt_step_size=False, + dense_mass=args.dense_mass) + subsample_size = None + elif args.algo == "NUTS": + kernel = NUTS(model, dense_mass=args.dense_mass) + subsample_size = None + elif args.algo == "HMCECS": + subsample_size = 1000 + inner_kernel = NUTS(model, init_strategy=init_to_value(values=ref_params), + dense_mass=args.dense_mass) + # note: if num_blocks=100, we'll update 10 index at each MCMC step + # so it took 50000 MCMC steps to iterative the whole dataset + kernel = HMCECS(inner_kernel, num_blocks=100, proxy=taylor_proxy(ref_params)) + elif args.algo == "SA": + # NB: this kernel requires large num_warmup and num_samples + # and running on GPU is much faster than on CPU + kernel = SA(model, adapt_state_size=1000, init_strategy=init_to_value(values=ref_params)) + subsample_size = None + elif args.algo == "FlowHMCECS": + subsample_size = 1000 + guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8]) + svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO()) + params, losses = svi.run(random.PRNGKey(2), 2000, features, labels) + plt.plot(losses) + plt.show() + + neutra = NeuTraReparam(guide, params) + neutra_model = neutra.reparam(model) + neutra_ref_params = {"auto_shared_latent": jnp.zeros(55)} + # no need to adapt mass matrix if the flow does a good job + inner_kernel = NUTS(neutra_model, init_strategy=init_to_value(values=neutra_ref_params), + adapt_mass_matrix=False) + kernel = HMCECS(inner_kernel, num_blocks=100, proxy=taylor_proxy(neutra_ref_params)) + else: + raise ValueError("Invalid algorithm, either 'HMC', 'NUTS', or 'HMCECS'.") + mcmc = MCMC(kernel, args.num_warmup, args.num_samples) + mcmc.run(rng_key, features, labels, subsample_size, extra_fields=("accept_prob",)) + print("Mean accept prob:", jnp.mean(mcmc.get_extra_fields()["accept_prob"])) + mcmc.print_summary(exclude_deterministic=False) print('\nMCMC elapsed time:', time.time() - start) @@ -60,14 +119,20 @@ def main(args): if __name__ == '__main__': assert numpyro.__version__.startswith('0.5.0') parser = argparse.ArgumentParser(description="parse args") - parser.add_argument('-n', '--num-samples', default=100, type=int, help='number of samples') + parser.add_argument('-n', '--num-samples', default=1000, type=int, help='number of samples') + parser.add_argument('--num-warmup', default=1000, type=int, help='number of warmup steps') parser.add_argument('--num-steps', default=10, type=int, help='number of steps (for "HMC")') parser.add_argument('--num-chains', nargs='?', default=1, type=int) - parser.add_argument('--algo', default='NUTS', type=str, help='whether to run "HMC" or "NUTS"') + parser.add_argument('--algo', default='HMCECS', type=str, + help='whether to run "HMCECS", "NUTS", "HMCECS", "SA" or "FlowHMCECS"') + parser.add_argument('--dense-mass', action="store_true") + parser.add_argument('--x64', action="store_true") parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".') args = parser.parse_args() numpyro.set_platform(args.device) numpyro.set_host_device_count(args.num_chains) + if args.x64: + numpyro.enable_x64() main(args) diff --git a/numpyro/diagnostics.py b/numpyro/diagnostics.py index f85ee471d..ab2b9996d 100644 --- a/numpyro/diagnostics.py +++ b/numpyro/diagnostics.py @@ -161,6 +161,7 @@ def effective_sample_size(x): :return: effective sample size of ``x``. :rtype: numpy.ndarray """ + assert x.ndim >= 2 assert x.shape[1] >= 2 diff --git a/numpyro/examples/datasets.py b/numpyro/examples/datasets.py index 846e93892..18ace29e3 100644 --- a/numpyro/examples/datasets.py +++ b/numpyro/examples/datasets.py @@ -4,11 +4,13 @@ from collections import namedtuple import csv import gzip +import io import os import pickle import struct from urllib.parse import urlparse from urllib.request import urlretrieve +import warnings import zipfile import numpy as np @@ -23,25 +25,20 @@ '.data')) os.makedirs(DATA_DIR, exist_ok=True) - dset = namedtuple('dset', ['name', 'urls']) - BASEBALL = dset('baseball', [ 'https://d2hg8soec8ck9v.cloudfront.net/datasets/EfronMorrisBB.txt', ]) - COVTYPE = dset('covtype', [ 'https://d2hg8soec8ck9v.cloudfront.net/datasets/covtype.zip', ]) - DIPPER_VOLE = dset('dipper_vole', [ 'https://github.com/pyro-ppl/datasets/blob/master/dipper_vole.zip?raw=true', ]) - MNIST = dset('mnist', [ 'https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/train-images-idx3-ubyte.gz', 'https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/train-labels-idx1-ubyte.gz', @@ -49,26 +46,26 @@ 'https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/t10k-labels-idx1-ubyte.gz', ]) - SP500 = dset('SP500', [ 'https://d2hg8soec8ck9v.cloudfront.net/datasets/SP500.csv', ]) - UCBADMIT = dset('ucbadmit', [ 'https://d2hg8soec8ck9v.cloudfront.net/datasets/UCBadmit.csv', ]) - LYNXHARE = dset('lynxhare', [ 'https://d2hg8soec8ck9v.cloudfront.net/datasets/LynxHare.txt', ]) - JSB_CHORALES = dset('jsb_chorales', [ 'https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/jsb_chorales.pickle', ]) +HIGGS = dset("higgs", [ + "https://archive.ics.uci.edu/ml/machine-learning-databases/00280/HIGGS.csv.gz", +]) + def _download(dset): for url in dset.urls: @@ -240,6 +237,21 @@ def _load_jsb_chorales(): return processed_dataset +def _load_higgs(): + warnings.warn("Higgs is a 2.6 GB dataset") + _download(HIGGS) + + file_path = os.path.join(DATA_DIR, 'HIGGS.csv.gz') + with io.TextIOWrapper(gzip.open(file_path, 'rb')) as f: + csv_reader = csv.reader(f, delimiter=',', quoting=csv.QUOTE_NONE) + obs = [] + data = [] + for row in csv_reader: + obs.append(row[0]) + data.append(row[1:]) + return np.stack(obs), np.stack(data) + + def _load(dset): if dset == BASEBALL: return _load_baseball() @@ -257,6 +269,8 @@ def _load(dset): return _load_lynxhare() elif dset == JSB_CHORALES: return _load_jsb_chorales() + elif dset == HIGGS: + return _load_higgs() raise ValueError('Dataset - {} not found.'.format(dset.name)) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index d97c96f95..a40572504 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -1,6 +1,5 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 - """ This provides a small set of effect handlers in NumPyro that are modeled after Pyro's `poutine `_ module. @@ -136,6 +135,7 @@ class trace(Messenger): 'type': 'sample', 'value': DeviceArray(-0.20584235, dtype=float32)})]) """ + def __enter__(self): super(trace, self).__enter__() self.trace = OrderedDict() @@ -146,7 +146,7 @@ def postprocess_message(self, msg): # skip recording helper messages e.g. `control_flow`, `to_data`, `to_funsor` # which has no name return - assert not(msg['type'] == 'sample' and msg['name'] in self.trace), \ + assert not (msg['type'] == 'sample' and msg['name'] in self.trace), \ 'all sites must have unique names but got `{}` duplicated'.format(msg['name']) self.trace[msg['name']] = msg.copy() @@ -191,6 +191,7 @@ class replay(Messenger): -0.20584235 >>> assert replayed_trace['a']['value'] == exec_trace['a']['value'] """ + def __init__(self, fn=None, guide_trace=None): assert guide_trace is not None self.guide_trace = guide_trace @@ -234,6 +235,7 @@ class block(Messenger): >>> assert 'a' not in trace_block_a >>> assert 'b' in trace_block_a """ + def __init__(self, fn=None, hide_fn=None, hide=None): if hide_fn is not None: self.hide_fn = hide_fn @@ -350,6 +352,7 @@ class condition(Messenger): >>> assert exec_trace['a']['value'] == -1 >>> assert exec_trace['a']['is_observed'] """ + def __init__(self, fn=None, data=None, condition_fn=None): self.condition_fn = condition_fn self.data = data @@ -386,6 +389,7 @@ class infer_config(Messenger): :param fn: a stochastic function (callable containing NumPyro primitive calls) :param config_fn: a callable taking a site and returning an infer dict """ + def __init__(self, fn=None, config_fn=None): super().__init__(fn) self.config_fn = config_fn @@ -470,6 +474,7 @@ class mask(Messenger): :param mask: a boolean or a boolean-valued array for masking elementwise log probability of sample sites (`True` includes a site, `False` excludes a site). """ + def __init__(self, fn=None, mask=True): if lax.dtype(mask) != 'bool': raise ValueError("`mask` should be a bool array.") @@ -506,6 +511,7 @@ class reparam(Messenger): :class:`~numpyro.infer.reparam.Reparam` or None. :type config: dict or callable """ + def __init__(self, fn=None, config=None): assert isinstance(config, dict) or callable(config) self.config = config @@ -550,6 +556,7 @@ class scale(Messenger): of log probability. :type scale: float or numpy.ndarray """ + def __init__(self, fn=None, scale=1.): if not_jax_tracer(scale): if np.any(np.less_equal(scale, 0)): @@ -587,6 +594,7 @@ class scope(Messenger): :param str prefix: a string to prepend to sample names :param str divider: a string to join the prefix and sample name; default to `'/'` """ + def __init__(self, fn=None, prefix='', divider='/'): self.prefix = prefix self.divider = divider @@ -638,6 +646,7 @@ class seed(Messenger): >>> y = handlers.seed(model, rng_seed=1)() >>> assert x == y """ + def __init__(self, fn=None, rng_seed=None): if isinstance(rng_seed, int) or (isinstance(rng_seed, jnp.ndarray) and not jnp.shape(rng_seed)): rng_seed = random.PRNGKey(rng_seed) @@ -647,10 +656,10 @@ def __init__(self, fn=None, rng_seed=None): super(seed, self).__init__(fn) def process_message(self, msg): - if (msg['type'] == 'sample' and not msg['is_observed'] and - msg['kwargs']['rng_key'] is None) or msg['type'] in ['prng_key', 'plate', 'control_flow']: - # no need to create a new key when value is available + if (msg['type'] == 'sample' and not msg['is_observed'] and msg['kwargs']['rng_key'] is None) \ + or msg['type'] in ['prng_key', 'plate', 'control_flow']: if msg['value'] is not None: + # no need to create a new key when value is available return self.rng_key, rng_key_sample = random.split(self.rng_key) msg['kwargs']['rng_key'] = rng_key_sample @@ -691,6 +700,7 @@ class substitute(Messenger): >>> exec_trace = trace(substitute(model, {'a': -1})).get_trace() >>> assert exec_trace['a']['value'] == -1 """ + def __init__(self, fn=None, data=None, substitute_fn=None): self.substitute_fn = substitute_fn self.data = data @@ -760,6 +770,7 @@ class do(Messenger): >>> assert not exec_trace['z'].get('stop', None) >>> assert z_square == 1 """ + def __init__(self, fn=None, data=None): self.data = data self._intervener_id = str(id(self)) diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index c42bb3098..bd4966bdd 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -214,19 +214,14 @@ def __call__(self, *args, **kwargs): event_dim=event_dim) site_fn = dist.Normal(site_loc, site_scale).to_event(event_dim) - if site["fn"].support in [constraints.real, constraints.real_vector]: + if site["fn"].support is constraints.real \ + or (isinstance(site["fn"].support, constraints.independent) and + site["fn"].support is constraints.real): result[name] = numpyro.sample(name, site_fn) else: - unconstrained_value = numpyro.sample("{}_unconstrained".format(name), site_fn, - infer={"is_auxiliary": True}) - transform = biject_to(site['fn'].support) - value = transform(unconstrained_value) - log_density = - transform.log_abs_det_jacobian(unconstrained_value, value) - log_density = sum_rightmost(log_density, - jnp.ndim(log_density) - jnp.ndim(value) + site["fn"].event_dim) - delta_dist = dist.Delta(value, log_density=log_density, event_dim=site["fn"].event_dim) - result[name] = numpyro.sample(name, delta_dist) + guide_dist = dist.TransformedDistribution(site_fn, transform) + result[name] = numpyro.sample(name, guide_dist) return result diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 89f1c89d3..ef6881123 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -1,17 +1,21 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import namedtuple +from collections import defaultdict, namedtuple import copy from functools import partial +import warnings -from jax import device_put, grad, jacfwd, ops, random, value_and_grad +from jax import device_put, grad, hessian, jacfwd, jacobian, lax, ops, random, value_and_grad import jax.numpy as jnp from jax.scipy.special import expit -from numpyro.handlers import condition, seed, substitute, trace +import numpyro +from numpyro.distributions.transforms import biject_to +from numpyro.handlers import block, condition, seed, substitute, trace from numpyro.infer.hmc import HMC from numpyro.infer.mcmc import MCMCKernel +from numpyro.infer.util import _unconstrain_reparam from numpyro.util import cond, fori_loop, identity, ravel_pytree HMCGibbsState = namedtuple("HMCGibbsState", "z, hmc_state, rng_key") @@ -247,7 +251,6 @@ def _discrete_modified_rw_proposal(rng_key, z_discrete, pe, potential_fn, idx, s def _discrete_gibbs_fn(potential_fn, support_sizes, proposal_fn): - def gibbs_fn(rng_key, gibbs_sites, hmc_sites, pe): # get support_sizes of gibbs_sites support_sizes_flat, _ = ravel_pytree({k: support_sizes[k] for k in gibbs_sites}) @@ -397,31 +400,52 @@ def potential_fn(z_gibbs, z_hmc): return HMCGibbsState(z, hmc_state, rng_key) -def _subsample_gibbs_fn(potential_fn, plate_sizes, num_blocks=1): +def _update_block(rng_key, num_blocks, subsample_idx, plate_size): + size, subsample_size = plate_size + rng_key, subkey, block_key = random.split(rng_key, 3) + block_size = (subsample_size - 1) // num_blocks + 1 + pad = block_size - (subsample_size - 1) % block_size - 1 - def gibbs_fn(rng_key, gibbs_sites, hmc_sites, pe): - assert set(gibbs_sites) == set(plate_sizes) - u_new = {} - for name in gibbs_sites: - size, subsample_size = plate_sizes[name] - rng_key, subkey, block_key = random.split(rng_key, 3) - block_size = subsample_size // num_blocks + chosen_block = random.randint(block_key, shape=(), minval=0, maxval=num_blocks) + new_idx = random.randint(subkey, minval=0, maxval=size, shape=(block_size,)) + subsample_idx_padded = jnp.pad(subsample_idx, (0, pad)) + start = chosen_block * block_size + subsample_idx_padded = lax.dynamic_update_slice_in_dim( + subsample_idx_padded, new_idx, start, 0) + return rng_key, subsample_idx_padded[:subsample_size], pad, new_idx, start - chosen_block = random.randint(block_key, shape=(), minval=0, maxval=num_blocks) - new_idx = random.randint(subkey, minval=0, maxval=size, shape=(subsample_size,)) - block_mask = jnp.arange(subsample_size) // block_size == chosen_block - u_new[name] = jnp.where(block_mask, new_idx, gibbs_sites[name]) +def _block_update(plate_sizes, num_blocks, rng_key, gibbs_sites, gibbs_state): + u_new = {} + for name, subsample_idx in gibbs_sites.items(): + rng_key, u_new[name], *_ = _update_block(rng_key, num_blocks, subsample_idx, plate_sizes[name]) + return u_new, gibbs_state - # given a fixed hmc_sites, pe_new - pe_curr = loglik_new - loglik_curr - pe_new = potential_fn(u_new, hmc_sites) - accept_prob = jnp.clip(jnp.exp(pe - pe_new), a_max=1.0) - gibbs_sites, pe = cond(random.bernoulli(rng_key, accept_prob), - (u_new, pe_new), identity, - (gibbs_sites, pe), identity) - return gibbs_sites, pe - return gibbs_fn +def _block_update_proxy(num_blocks, rng_key, gibbs_sites, plate_sizes): + u_new = {} + pads = {} + new_idxs = {} + starts = {} + for name, subsample_idx in gibbs_sites.items(): + rng_key, u_new[name], pads[name], new_idxs[name], starts[name] = _update_block(rng_key, num_blocks, + subsample_idx, plate_sizes[name]) + return u_new, pads, new_idxs, starts + + +HMCECSState = namedtuple("HMCECSState", "z, hmc_state, rng_key, gibbs_state, accept_prob") +TaylorProxyState = namedtuple("TaylorProxyState", "ref_subsample_log_liks, " + "ref_subsample_log_lik_grads, ref_subsample_log_lik_hessians") + + +def _wrap_gibbs_state(model): + def wrapped_fn(*args, **kwargs): + # this is to let estimate_likelihood handler knows what is the current gibbs_state + msg = {"type": "_gibbs_state", "value": kwargs.pop("_gibbs_state", ())} + numpyro.primitives.apply_stack(msg) + return model(*args, **kwargs) + + return wrapped_fn class HMCECS(HMCGibbs): @@ -448,9 +472,13 @@ class HMCECS(HMCGibbs): Quiroz, M., Kohn, R., Villani, M., & Tran, M. N. (2018) 3. *The Block Pseudo-Margional Sampler*, Tran, M.-N., Kohn, R., Quiroz, M. Villani, M. (2017) + 4. *The Fundamental Incompatibility of Scalable Hamiltonian Monte Carlo and Naive Data Subsampling* + Betancourt, M. (2015) :param inner_kernel: One of :class:`~numpyro.infer.hmc.HMC` or :class:`~numpyro.infer.hmc.NUTS`. :param int num_blocks: Number of blocks to partition subsample into. + :param proxy: Either :function `~numpyro.infer.hmc_gibbs.taylor_proxy` for likelihood estimation, + or, None for naive (in-between trajectory) subsampling as outlined in [4]. **Example** @@ -476,48 +504,285 @@ class HMCECS(HMCGibbs): >>> assert abs(jnp.mean(samples) - 1.) < 0.1 """ - def __init__(self, inner_kernel, *, num_blocks=1): + + def __init__(self, inner_kernel, *, num_blocks=1, proxy=None): super().__init__(inner_kernel, lambda *args: None, None) + + self.inner_kernel._model = _wrap_gibbs_state(self.inner_kernel._model) self._num_blocks = num_blocks + self._proxy = proxy + + def postprocess_fn(self, args, kwargs): + def fn(z): + model_kwargs = {} if kwargs is None else kwargs.copy() + hmc_sites = {k: v for k, v in z.items() if k not in self._gibbs_sites} + gibbs_sites = {k: v for k, v in z.items() if k in self._gibbs_sites} + model_kwargs["_gibbs_sites"] = gibbs_sites + hmc_sites = self.inner_kernel.postprocess_fn(args, model_kwargs)(hmc_sites) + return hmc_sites + + return fn def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): model_kwargs = {} if model_kwargs is None else model_kwargs.copy() rng_key, key_u = random.split(rng_key) self._prototype_trace = trace(seed(self.model, key_u)).get_trace(*model_args, **model_kwargs) - self._plate_sizes = { + self._subsample_plate_sizes = { name: site["args"] for name, site in self._prototype_trace.items() if site["type"] == "plate" and site["args"][0] > site["args"][1] # i.e. size > subsample_size } - self._gibbs_sites = list(self._plate_sizes.keys()) - return super().init(rng_key, num_warmup, init_params, model_args, model_kwargs) + self._gibbs_sites = list(self._subsample_plate_sizes.keys()) + if self._proxy is not None: + proxy_fn, gibbs_init, self._gibbs_update = self._proxy(self._prototype_trace, + self._subsample_plate_sizes, + self.model, + model_args, + model_kwargs.copy(), + num_blocks=self._num_blocks) + method = perturbed_method(self._subsample_plate_sizes, proxy_fn) + self.inner_kernel._model = estimate_likelihood(self.inner_kernel._model, method) + + z_gibbs = {name: site["value"] for name, site in self._prototype_trace.items() if name in self._gibbs_sites} + rng_key, rng_state = random.split(rng_key) + gibbs_state = gibbs_init(rng_state, z_gibbs) + else: + self._gibbs_update = partial(_block_update, self._subsample_plate_sizes, self._num_blocks) + gibbs_state = () + + model_kwargs["_gibbs_state"] = gibbs_state + state = super().init(rng_key, num_warmup, init_params, model_args, model_kwargs) + return HMCECSState(state.z, state.hmc_state, state.rng_key, gibbs_state, jnp.array(0.)) def sample(self, state, model_args, model_kwargs): - model_kwargs = {} if model_kwargs is None else model_kwargs + model_kwargs = {} if model_kwargs is None else model_kwargs.copy() rng_key, rng_gibbs = random.split(state.rng_key) - def potential_fn(z_gibbs, z_hmc): + def potential_fn(z_gibbs, gibbs_state, z_hmc): return self.inner_kernel._potential_fn_gen( - *model_args, _gibbs_sites=z_gibbs, **model_kwargs)(z_hmc) + *model_args, _gibbs_sites=z_gibbs, _gibbs_state=gibbs_state, **model_kwargs)(z_hmc) z_gibbs = {k: v for k, v in state.z.items() if k not in state.hmc_state.z} - z_hmc = {k: v for k, v in state.z.items() if k in state.hmc_state.z} - model_kwargs_ = model_kwargs.copy() - model_kwargs_["_gibbs_sites"] = z_gibbs + z_gibbs_new, gibbs_state_new = self._gibbs_update(rng_key, z_gibbs, state.gibbs_state) - gibbs_fn = _subsample_gibbs_fn(potential_fn, self._plate_sizes, self._num_blocks) - z_gibbs, pe = gibbs_fn(rng_key=rng_gibbs, gibbs_sites=z_gibbs, hmc_sites=z_hmc, - pe=state.hmc_state.potential_energy) + # given a fixed hmc_sites, pe_new - pe_curr = loglik_new - loglik_curr + pe = state.hmc_state.potential_energy + pe_new = potential_fn(z_gibbs_new, gibbs_state_new, state.hmc_state.z) + accept_prob = jnp.clip(jnp.exp(pe - pe_new), a_max=1.0) + transition = random.bernoulli(rng_key, accept_prob) + grad_ = jacfwd if self.inner_kernel._forward_mode_differentiation else grad + z_gibbs, gibbs_state, pe, z_grad = cond(transition, + (z_gibbs_new, gibbs_state_new, pe_new), + lambda vals: vals + (grad_(partial(potential_fn, + vals[0], + vals[1]))(state.hmc_state.z),), + (z_gibbs, state.gibbs_state, pe, state.hmc_state.z_grad), identity) - if self.inner_kernel._forward_mode_differentiation: - z_grad = jacfwd(partial(potential_fn, z_gibbs))(state.hmc_state.z) - else: - z_grad = grad(partial(potential_fn, z_gibbs))(state.hmc_state.z) hmc_state = state.hmc_state._replace(z_grad=z_grad, potential_energy=pe) - model_kwargs_["_gibbs_sites"] = z_gibbs - hmc_state = self.inner_kernel.sample(hmc_state, model_args, model_kwargs_) + model_kwargs["_gibbs_sites"] = z_gibbs + model_kwargs["_gibbs_state"] = gibbs_state + hmc_state = self.inner_kernel.sample(hmc_state, model_args, model_kwargs) z = {**z_gibbs, **hmc_state.z} + return HMCECSState(z, hmc_state, rng_key, gibbs_state, accept_prob) - return HMCGibbsState(z, hmc_state, rng_key) + +def perturbed_method(subsample_plate_sizes, proxy_fn): + def estimator(likelihoods, params, gibbs_state): + subsample_log_liks = defaultdict(float) + for (fn, value, name, subsample_dim) in likelihoods.values(): + subsample_log_liks[name] += _sum_all_except_at_dim(fn.log_prob(value), subsample_dim) + + log_lik_sum = 0. + + proxy_value_all, proxy_value_subsample = proxy_fn(params, subsample_log_liks.keys(), gibbs_state) + + for name, subsample_log_lik in subsample_log_liks.items(): # loop over all subsample sites + n, m = subsample_plate_sizes[name] + + diff = subsample_log_lik - proxy_value_subsample[name] + + unbiased_log_lik = proxy_value_all[name] + n * jnp.mean(diff) + variance = n ** 2 / m * jnp.var(diff) + log_lik_sum += unbiased_log_lik - 0.5 * variance + return log_lik_sum + + return estimator + + +def taylor_proxy(reference_params): + """ Control variate for unbiased log likelihood estimation using a Taylor expansion around a reference + parameter. Suggest for subsampling in [1]. + + :param dict reference_params: Model parameterization at MLE or MAP-estimate. + + ** References: ** + + [1] Towards scaling up Markov chainMonte Carlo: an adaptive subsampling approach + Bardenet., R., Doucet, A., Holmes, C. (2014) + """ + + def construct_proxy_fn(prototype_trace, subsample_plate_sizes, model, model_args, model_kwargs, num_blocks=1): + ref_params = {name: biject_to(prototype_trace[name]["fn"].support).inv(value) + for name, value in reference_params.items()} + + ref_params_flat, unravel_fn = ravel_pytree(ref_params) + + def log_likelihood(params_flat, subsample_indices=None): + if subsample_indices is None: + subsample_indices = {k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items()} + params = unravel_fn(params_flat) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + params = {name: biject_to(prototype_trace[name]["fn"].support)(value) for name, value in params.items()} + with block(), trace() as tr, substitute(data=subsample_indices), substitute(data=params): + model(*model_args, **model_kwargs) + + log_lik = {} + for site in tr.values(): + if site["type"] == "sample" and site["is_observed"]: + for frame in site["cond_indep_stack"]: + if frame.name in log_lik: + log_lik[frame.name] += _sum_all_except_at_dim( + site["fn"].log_prob(site["value"]), frame.dim) + else: + log_lik[frame.name] = _sum_all_except_at_dim( + site["fn"].log_prob(site["value"]), frame.dim) + return log_lik + + def log_likelihood_sum(params_flat, subsample_indices=None): + return {k: v.sum() for k, v in log_likelihood(params_flat, subsample_indices).items()} + + # those stats are dict keyed by subsample names + ref_log_likelihoods_sum = log_likelihood_sum(ref_params_flat) + ref_log_likelihood_grads_sum = jacobian(log_likelihood_sum)(ref_params_flat) + ref_log_likelihood_hessians_sum = hessian(log_likelihood_sum)(ref_params_flat) + + def gibbs_init(rng_key, gibbs_sites): + ref_subsample_log_liks = log_likelihood(ref_params_flat, gibbs_sites) + ref_subsample_log_lik_grads = jacfwd(log_likelihood)(ref_params_flat, gibbs_sites) + ref_subsample_log_lik_hessians = jacfwd(jacfwd(log_likelihood))(ref_params_flat, gibbs_sites) + return TaylorProxyState(ref_subsample_log_liks, ref_subsample_log_lik_grads, ref_subsample_log_lik_hessians) + + def gibbs_update(rng_key, gibbs_sites, gibbs_state): + u_new, pads, new_idxs, starts = _block_update_proxy(num_blocks, rng_key, gibbs_sites, subsample_plate_sizes) + + new_states = defaultdict(dict) + ref_subsample_log_liks = log_likelihood(ref_params_flat, new_idxs) + ref_subsample_log_lik_grads = jacfwd(log_likelihood)(ref_params_flat, new_idxs) + ref_subsample_log_lik_hessians = jacfwd(jacfwd(log_likelihood))(ref_params_flat, new_idxs) + for stat, new_block_values, last_values in zip( + ["log_liks", "grads", "hessians"], + [ref_subsample_log_liks, + ref_subsample_log_lik_grads, + ref_subsample_log_lik_hessians], + [gibbs_state.ref_subsample_log_liks, + gibbs_state.ref_subsample_log_lik_grads, + gibbs_state.ref_subsample_log_lik_hessians]): + for name, subsample_idx in gibbs_sites.items(): + size, subsample_size = subsample_plate_sizes[name] + pad, start = pads[name], starts[name] + new_value = jnp.pad(last_values[name], [(0, pad)] + [(0, 0)] * (jnp.ndim(last_values[name]) - 1)) + new_value = lax.dynamic_update_slice_in_dim( + new_value, new_block_values[name], start, 0) + new_states[stat][name] = new_value[:subsample_size] + gibbs_state = TaylorProxyState(new_states["log_liks"], new_states["grads"], new_states["hessians"]) + return u_new, gibbs_state + + def proxy_fn(params, subsample_lik_sites, gibbs_state): + params_flat, _ = ravel_pytree(params) + params_diff = params_flat - ref_params_flat + + ref_subsample_log_liks = gibbs_state.ref_subsample_log_liks + ref_subsample_log_lik_grads = gibbs_state.ref_subsample_log_lik_grads + ref_subsample_log_lik_hessians = gibbs_state.ref_subsample_log_lik_hessians + + proxy_sum = defaultdict(float) + proxy_subsample = defaultdict(float) + for name in subsample_lik_sites: + proxy_subsample[name] = (ref_subsample_log_liks[name] + + jnp.dot(ref_subsample_log_lik_grads[name], params_diff) + + 0.5 * jnp.dot(jnp.dot(ref_subsample_log_lik_hessians[name], params_diff), + params_diff)) + + proxy_sum[name] = (ref_log_likelihoods_sum[name] + + jnp.dot(ref_log_likelihood_grads_sum[name], params_diff) + + 0.5 * jnp.dot(jnp.dot(ref_log_likelihood_hessians_sum[name], params_diff), + params_diff)) + return proxy_sum, proxy_subsample + + return proxy_fn, gibbs_init, gibbs_update + + return construct_proxy_fn + + +def _sum_all_except_at_dim(x, dim): + x = x.reshape((-1,) + x.shape[dim:]).sum(0) + return x.reshape(x.shape[:1] + (-1,)).sum(-1) + + +class estimate_likelihood(numpyro.primitives.Messenger): + def __init__(self, fn=None, method=None): + # estimate_likelihood: accept likelihood tuple (fn, value, subsample_name, subsample_dim) + # and current unconstrained params + # and returns log of the bias-corrected likelihood + assert method is not None + super().__init__(fn) + self.method = method + self.params = None + self.likelihoods = {} + self.subsample_plates = {} + self.gibbs_state = None + + def __enter__(self): + for handler in numpyro.primitives._PYRO_STACK[::-1]: + # the potential_fn in HMC makes the PYRO_STACK nested like trace(...); so we can extract the + # unconstrained_params from the _unconstrain_reparam substitute_fn + if isinstance(handler, substitute) and isinstance(handler.substitute_fn, partial) \ + and handler.substitute_fn.func is _unconstrain_reparam: + self.params = handler.substitute_fn.args[0] + break + return super().__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + # make sure exit trackback is nice if an error happens + super().__exit__(exc_type, exc_value, traceback) + if exc_type is not None: + return + + if self.params is None: + return + + # add numpyro.factor; ideally, we will want to skip this computation when making prediction + # see: https://github.com/pyro-ppl/pyro/issues/2744 + numpyro.factor("_biased_corrected_log_likelihood", + self.method(self.likelihoods, self.params, self.gibbs_state)) + + # clean up + self.params = None + self.likelihoods = {} + self.subsample_plates = {} + self.gibbs_state = None + + def process_message(self, msg): + if self.params is None: + return + + if msg["type"] == "_gibbs_state": + self.gibbs_state = msg["value"] + return + + if msg["type"] == "sample" and msg["is_observed"]: + assert msg["name"] not in self.params + # store the likelihood for the estimator + for frame in msg["cond_indep_stack"]: + if frame.name in self.subsample_plates: + if msg["name"] in self.likelihoods: + raise RuntimeError(f"Multiple subsample plates at site {msg['name']} " + "are not allowed. Please reshape your data.") + self.likelihoods[msg["name"]] = (msg["fn"], msg["value"], frame.name, frame.dim) + # mask the current likelihood + msg["fn"] = msg["fn"].mask(False) + elif msg["type"] == "plate" and msg["args"][0] > msg["args"][1]: + self.subsample_plates[msg["name"]] = msg["value"] diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index d71c50a5c..8b4227819 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -432,7 +432,7 @@ def initialize_model(rng_key, model, # substitute param sites from model_trace to model so # we don't need to generate again parameters of `numpyro.module` model = substitute(model, data={k: site["value"] for k, site in model_trace.items() - if site["type"] in ["param", "plate"]}) + if site["type"] in ["param"]}) constrained_values = {k: v['value'] for k, v in model_trace.items() if v['type'] == 'sample' and not v['is_observed'] and not v['fn'].is_discrete} @@ -460,7 +460,8 @@ def initialize_model(rng_key, model, init_strategy = _init_to_unconstrained_value(values=unconstrained_values) prototype_params = transform_fn(inv_transforms, constrained_values, invert=True) (init_params, pe, grad), is_valid = find_valid_initial_params( - rng_key, model, + rng_key, substitute(model, data={k: site["value"] for k, site in model_trace.items() + if site["type"] in ["plate"]}), init_strategy=init_strategy, enum=has_enumerate_support, model_args=model_args, @@ -482,7 +483,7 @@ def initialize_model(rng_key, model, for w in ws: # at site information to the warning message w.message.args = ("Site {}: {}".format(site["name"], w.message.args[0]),) \ - + w.message.args[1:] + + w.message.args[1:] warnings.showwarning(w.message, w.category, w.filename, w.lineno, file=w.file, line=w.line) raise RuntimeError("Cannot find valid initial parameters. Please check your model again.") @@ -491,7 +492,6 @@ def initialize_model(rng_key, model, def _predictive(rng_key, model, posterior_samples, batch_shape, return_sites=None, parallel=True, model_args=(), model_kwargs={}): - def single_prediction(val): rng_key, samples = val model_trace = trace(seed(substitute(model, samples), rng_key)).get_trace( diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 1092abf4a..3353f4966 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -15,7 +15,6 @@ _PYRO_STACK = [] - CondIndepStackFrame = namedtuple('CondIndepStackFrame', ['name', 'dim', 'size']) @@ -38,7 +37,7 @@ def apply_stack(msg): # A Messenger that sets msg["stop"] == True also prevents application # of postprocess_message by Messengers above it on the stack # via the pointer variable from the process_message loop - for handler in _PYRO_STACK[-pointer-1:]: + for handler in _PYRO_STACK[-pointer - 1:]: handler.postprocess_message(msg) return msg @@ -277,6 +276,7 @@ class plate(Messenger): is used as the plate dim. If `None` (default), the leftmost available dim is allocated. """ + def __init__(self, name, size, subsample_size=None, dim=None): self.name = name self.size = size @@ -304,10 +304,11 @@ def _subsample(name, size, subsample_size, dim): } apply_stack(msg) subsample = msg['value'] + subsample_size = msg['args'][1] if subsample_size is not None and subsample_size != subsample.shape[0]: warnings.warn("subsample_size does not match len(subsample), {} vs {}.".format( subsample_size, len(subsample)) + - " Did you accidentally use different subsample_size in the model and guide?") + " Did you accidentally use different subsample_size in the model and guide?") cond_indep_stack = msg['cond_indep_stack'] occupied_dims = {f.dim for f in cond_indep_stack} if dim is None: @@ -355,7 +356,7 @@ def process_message(self, msg): msg['fn'] = msg['fn'].expand(batch_shape) if self.size != self.subsample_size: scale = 1. if msg['scale'] is None else msg['scale'] - msg['scale'] = scale * self.size / self.subsample_size + msg['scale'] = scale * (self.size / self.subsample_size if self.subsample_size else 1) def postprocess_message(self, msg): if msg["type"] in ("subsample", "param") and self.dim is not None: diff --git a/setup.py b/setup.py index bd20fd1b0..dcad69e77 100644 --- a/setup.py +++ b/setup.py @@ -33,9 +33,9 @@ author='Uber AI Labs', install_requires=[ # TODO: pin to a specific version for the release (until JAX's API becomes stable) - 'jax==0.2.8', + 'jax>=0.2.8', # check min version here: https://github.com/google/jax/blob/master/jax/lib/__init__.py#L26 - 'jaxlib==0.1.59', + 'jaxlib>=0.1.59', 'tqdm', ], extras_require={ diff --git a/test/test_hmc_gibbs.py b/test/test_hmc_gibbs.py index 9b0737599..e15eb47e0 100644 --- a/test/test_hmc_gibbs.py +++ b/test/test_hmc_gibbs.py @@ -1,20 +1,20 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 - from functools import partial import numpy as np from numpy.testing import assert_allclose import pytest -from jax import random +from jax import hessian, jacrev, random, vmap import jax.numpy as jnp from jax.scipy.linalg import cho_factor, cho_solve, inv, solve_triangular import numpyro import numpyro.distributions as dist -from numpyro.handlers import plate from numpyro.infer import HMC, HMCECS, MCMC, NUTS, DiscreteHMCGibbs, HMCGibbs +from numpyro.infer.hmc_gibbs import taylor_proxy +from numpyro.infer.util import log_density def _linear_regression_gibbs_fn(X, XX, XY, Y, rng_key, gibbs_sites, hmc_sites): @@ -206,28 +206,25 @@ def model(probs, locs): assert_allclose(jnp.var(samples["c"]), 1.03, atol=0.1) -@pytest.mark.parametrize('kernel_cls', [HMC, NUTS]) @pytest.mark.parametrize('num_blocks', [1, 2, 50, 100]) -def test_subsample_gibbs_partitioning(kernel_cls, num_blocks): - def model(obs): - with plate('N', obs.shape[0], subsample_size=100) as idx: - numpyro.sample('x', dist.Normal(0, 1), obs=obs[idx]) - - obs = random.normal(random.PRNGKey(0), (10000,)) / 100 - kernel = HMCECS(kernel_cls(model), num_blocks=num_blocks) - state = kernel.init(random.PRNGKey(1), 10, None, model_args=(obs,), model_kwargs=None) - gibbs_sites = {'N': jnp.arange(100)} - - def potential_fn(z_gibbs, z_hmc): - return kernel.inner_kernel._potential_fn_gen(obs, _gibbs_sites=z_gibbs)(z_hmc) - - gibbs_fn = numpyro.infer.hmc_gibbs._subsample_gibbs_fn(potential_fn, kernel._plate_sizes, num_blocks) - new_gibbs_sites, _ = gibbs_fn(random.PRNGKey(2), gibbs_sites, state.hmc_state.z, - state.hmc_state.potential_energy) # accept_prob > .999 +def test_block_update_partitioning(num_blocks): + plate_size = 10000, 100 + + plate_sizes = {'N': plate_size} + gibbs_sites = {'N': jnp.arange(plate_size[1])} + gibbs_state = {} + + new_gibbs_sites, new_gibbs_state = numpyro.infer.hmc_gibbs._block_update(plate_sizes, + num_blocks, + random.PRNGKey(2), + gibbs_sites, + gibbs_state) block_size = 100 // num_blocks for name in gibbs_sites: assert block_size == jnp.not_equal(gibbs_sites[name], new_gibbs_sites[name]).sum() + assert gibbs_state == new_gibbs_state + def test_enum_subsample_smoke(): def model(data): @@ -240,3 +237,104 @@ def model(data): kernel = HMCECS(NUTS(model), num_blocks=10) mcmc = MCMC(kernel, 10, 10) mcmc.run(random.PRNGKey(0), data) + + +@pytest.mark.parametrize('kernel_cls', [HMC, NUTS]) +@pytest.mark.parametrize('num_block', [1, 2, 50]) +@pytest.mark.parametrize('subsample_size', [50, 150]) +def test_hmcecs_normal_normal(kernel_cls, num_block, subsample_size): + true_loc = jnp.array([0.3, 0.1, 0.9]) + num_warmup, num_samples = 200, 200 + data = true_loc + dist.Normal(jnp.zeros(3, ), jnp.ones(3, )).sample(random.PRNGKey(1), (10000,)) + + def model(data, subsample_size): + mean = numpyro.sample('mean', dist.Normal().expand((3,)).to_event(1)) + with numpyro.plate('batch', data.shape[0], dim=-2, subsample_size=subsample_size): + sub_data = numpyro.subsample(data, 0) + numpyro.sample("obs", dist.Normal(mean, 1), obs=sub_data) + + ref_params = {'mean': true_loc + dist.Normal(true_loc, 5e-2).sample(random.PRNGKey(0))} + proxy_fn = taylor_proxy(ref_params) + + kernel = HMCECS(kernel_cls(model), proxy=proxy_fn) + mcmc = MCMC(kernel, num_warmup, num_samples) + mcmc.run(random.PRNGKey(0), data, subsample_size) + + samples = mcmc.get_samples() + assert_allclose(np.mean(mcmc.get_samples()['mean'], axis=0), true_loc, atol=0.1) + assert len(samples['mean']) == num_samples + + +@pytest.mark.parametrize('subsample_size', [5, 10, 15]) +def test_taylor_proxy_norm(subsample_size): + data_key, tr_key, rng_key = random.split(random.PRNGKey(0), 3) + ref_params = jnp.array([0.1, 0.5, -0.2]) + sigma = .1 + + data = ref_params + dist.Normal(jnp.zeros(3), jnp.ones(3)).sample(data_key, (100,)) + n, _ = data.shape + + def model(data, subsample_size): + mean = numpyro.sample('mean', dist.Normal(ref_params, jnp.ones_like(ref_params))) + with numpyro.plate('data', data.shape[0], subsample_size=subsample_size, dim=-2) as idx: + numpyro.sample('obs', dist.Normal(mean, sigma), obs=data[idx]) + + def log_prob_fn(params): + return vmap(dist.Normal(params, sigma).log_prob)(data).sum(-1) + + log_prob = log_prob_fn(ref_params) + log_norm_jac = jacrev(log_prob_fn)(ref_params) + log_norm_hessian = hessian(log_prob_fn)(ref_params) + + tr = numpyro.handlers.trace(numpyro.handlers.seed(model, tr_key)).get_trace(data, subsample_size) + plate_sizes = {'data': (n, subsample_size)} + + proxy_constructor = taylor_proxy({'mean': ref_params}) + proxy_fn, gibbs_init, gibbs_update = proxy_constructor(tr, plate_sizes, model, (data, subsample_size), {}) + + def taylor_expand_2nd_order(idx, pos): + return log_prob[idx] + (log_norm_jac[idx] @ pos) + .5 * (pos @ log_norm_hessian[idx]) @ pos + + def taylor_expand_2nd_order_sum(pos): + return log_prob.sum() + log_norm_jac.sum(0) @ pos + .5 * pos @ log_norm_hessian.sum(0) @ pos + + for _ in range(5): + split_key, perturbe_key, rng_key = random.split(rng_key, 3) + perturbe_params = ref_params + dist.Normal(.1, 0.1).sample(perturbe_key, ref_params.shape) + subsample_idx = random.randint(rng_key, (subsample_size,), 0, n) + gibbs_site = {'data': subsample_idx} + proxy_state = gibbs_init(None, gibbs_site) + actual_proxy_sum, actual_proxy_sub = proxy_fn({'data': perturbe_params}, ['data'], proxy_state) + assert_allclose(actual_proxy_sub['data'], + taylor_expand_2nd_order(subsample_idx, perturbe_params - ref_params), rtol=1e-5) + assert_allclose(actual_proxy_sum['data'], taylor_expand_2nd_order_sum(perturbe_params - ref_params), rtol=1e-5) + + +@pytest.mark.filterwarnings('ignore::UserWarning') +@pytest.mark.parametrize('kernel_cls', [HMC, NUTS]) +def test_estimate_likelihood(kernel_cls): + data_key, tr_key, sub_key, rng_key = random.split(random.PRNGKey(0), 4) + ref_params = jnp.array([0.1, 0.5, -0.2]) + sigma = .1 + data = ref_params + dist.Normal(jnp.zeros(3), jnp.ones(3)).sample(data_key, (10_000,)) + n, _ = data.shape + num_warmup = 200 + num_samples = 200 + num_blocks = 20 + + def model(data): + mean = numpyro.sample('mean', dist.Normal(ref_params, jnp.ones_like(ref_params))) + with numpyro.plate('N', data.shape[0], subsample_size=100, dim=-2) as idx: + numpyro.sample('obs', dist.Normal(mean, sigma), obs=data[idx]) + + proxy_fn = taylor_proxy({'mean': ref_params}) + kernel = HMCECS(kernel_cls(model), proxy=proxy_fn, num_blocks=num_blocks) + mcmc = MCMC(kernel, num_warmup, num_samples) + + mcmc.run(random.PRNGKey(0), data, extra_fields=['hmc_state.potential_energy']) + + pes = mcmc.get_extra_fields()['hmc_state.potential_energy'] + samples = mcmc.get_samples() + pes_full = vmap(lambda sample: log_density(model, (data,), {}, {**sample, **{'N': jnp.arange(n)}})[0])(samples) + + assert jnp.var(jnp.exp(-pes - pes_full)) < 1.