From 4321595cda9a343c5556656a6092b1de6ed26331 Mon Sep 17 00:00:00 2001 From: Lys Date: Tue, 8 Sep 2020 13:32:28 +0200 Subject: [PATCH 01/93] start --- examples/logistic_hmcecs.py | 0 numpyro/contrib/hmcecs.py | 571 ++++++++++++++++++++++++++++++++++++ 2 files changed, 571 insertions(+) create mode 100644 examples/logistic_hmcecs.py create mode 100644 numpyro/contrib/hmcecs.py diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py new file mode 100644 index 000000000..e69de29bb diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py new file mode 100644 index 000000000..0a3f89b29 --- /dev/null +++ b/numpyro/contrib/hmcecs.py @@ -0,0 +1,571 @@ +"""Contributed code for HMC and NUTS energy conserving""" + +from collections import namedtuple +import math +import os +import warnings + +from jax import device_put, lax, partial, random, vmap +from jax.dtypes import canonicalize_dtype +from jax.flatten_util import ravel_pytree +import jax.numpy as jnp + +from numpyro.infer.hmc_util import ( + IntegratorState, + build_tree, + euclidean_kinetic_energy, + find_reasonable_step_size, + velocity_verlet, + warmup_adapter +) +from numpyro.infer.mcmc import MCMCKernel +from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model +from numpyro.util import cond, fori_loop, identity + +HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob', + 'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key']) +""" +A :func:`~collections.namedtuple` consisting of the following fields: + + - **i** - iteration. This is reset to 0 after warmup. + - **z** - Python collection representing values (unconstrained samples from + the posterior) at latent sites. + - **z_grad** - Gradient of potential energy w.r.t. latent sample sites. + - **potential_energy** - Potential energy computed at the given value of ``z``. + - **energy** - Sum of potential energy and kinetic energy of the current state. + - **num_steps** - Number of steps in the Hamiltonian trajectory (for diagnostics). + - **accept_prob** - Acceptance probability of the proposal. Note that ``z`` + does not correspond to the proposal if it is rejected. + - **mean_accept_prob** - Mean acceptance probability until current iteration + during warmup adaptation or sampling (for diagnostics). + - **diverging** - A boolean value to indicate whether the current trajectory is diverging. + - **adapt_state** - A ``HMCAdaptState`` namedtuple which contains adaptation information + during warmup: + + + **step_size** - Step size to be used by the integrator in the next iteration. + + **inverse_mass_matrix** - The inverse mass matrix to be used for the next + iteration. + + **mass_matrix_sqrt** - The square root of mass matrix to be used for the next + iteration. In case of dense mass, this is the Cholesky factorization of the + mass matrix. + + - **rng_key** - random number generator seed used for the iteration. +""" + + +def _get_num_steps(step_size, trajectory_length): + num_steps = jnp.clip(trajectory_length / step_size, a_min=1) + # NB: casting to jnp.int64 does not take effect (returns jnp.int32 instead) + # if jax_enable_x64 is False + return num_steps.astype(canonicalize_dtype(jnp.int64)) + + +def momentum_generator(prototype_r, mass_matrix_sqrt, rng_key): + _, unpack_fn = ravel_pytree(prototype_r) + eps = random.normal(rng_key, jnp.shape(mass_matrix_sqrt)[:1]) + if mass_matrix_sqrt.ndim == 1: + r = jnp.multiply(mass_matrix_sqrt, eps) + return unpack_fn(r) + elif mass_matrix_sqrt.ndim == 2: + r = jnp.dot(mass_matrix_sqrt, eps) + return unpack_fn(r) + else: + raise ValueError("Mass matrix has incorrect number of dims.") + + +def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='NUTS'): + r""" + Hamiltonian Monte Carlo inference, using either fixed number of + steps or the No U-Turn Sampler (NUTS) with adaptive path length. + + **References:** + + 1. *MCMC Using Hamiltonian Dynamics*, + Radford M. Neal + 2. *The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo*, + Matthew D. Hoffman, and Andrew Gelman. + 3. *A Conceptual Introduction to Hamiltonian Monte Carlo`*, + Michael Betancourt + + :param potential_fn: Python callable that computes the potential energy + given input parameters. The input parameters to `potential_fn` can be + any python collection type, provided that `init_params` argument to + `init_kernel` has the same type. + :param potential_fn_gen: Python callable that when provided with model + arguments / keyword arguments returns `potential_fn`. This + may be provided to do inference on the same model with changing data. + If the data shape remains the same, we can compile `sample_kernel` + once, and use the same for multiple inference runs. + :param kinetic_fn: Python callable that returns the kinetic energy given + inverse mass matrix and momentum. If not provided, the default is + euclidean kinetic energy. + :param str algo: Whether to run ``HMC`` with fixed number of steps or ``NUTS`` + with adaptive path length. Default is ``NUTS``. + :return: a tuple of callables (`init_kernel`, `sample_kernel`), the first + one to initialize the sampler, and the second one to generate samples + given an existing one. + + .. warning:: + Instead of using this interface directly, we would highly recommend you + to use the higher level :class:`numpyro.infer.MCMC` API instead. + + **Example** + + .. doctest:: + + >>> import jax + >>> from jax import random + >>> import jax.numpy as jnp + >>> import numpyro + >>> import numpyro.distributions as dist + >>> from numpyro.infer.hmc import hmc + >>> from numpyro.infer.util import initialize_model + >>> from numpyro.util import fori_collect + + >>> true_coefs = jnp.array([1., 2., 3.]) + >>> data = random.normal(random.PRNGKey(2), (2000, 3)) + >>> dim = 3 + >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3)) + >>> + >>> def model(data, labels): + ... coefs_mean = jnp.zeros(dim) + ... coefs = numpyro.sample('beta', dist.Normal(coefs_mean, jnp.ones(3))) + ... intercept = numpyro.sample('intercept', dist.Normal(0., 10.)) + ... return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels) + >>> + >>> model_info = initialize_model(random.PRNGKey(0), model, model_args=(data, labels,)) + >>> init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS') + >>> hmc_state = init_kernel(model_info.param_info, + ... trajectory_length=10, + ... num_warmup=300) + >>> samples = fori_collect(0, 500, sample_kernel, hmc_state, + ... transform=lambda state: model_info.postprocess_fn(state.z)) + >>> print(jnp.mean(samples['beta'], axis=0)) # doctest: +SKIP + [0.9153987 2.0754058 2.9621222] + """ + if kinetic_fn is None: + kinetic_fn = euclidean_kinetic_energy + vv_update = None + trajectory_len = None + max_treedepth = None + wa_update = None + wa_steps = None + max_delta_energy = 1000. + if algo not in {'HMC', 'NUTS'}: + raise ValueError('`algo` must be one of `HMC` or `NUTS`.') + + def init_kernel(init_params, + num_warmup, + step_size=1.0, + inverse_mass_matrix=None, + adapt_step_size=True, + adapt_mass_matrix=True, + dense_mass=False, + target_accept_prob=0.8, + trajectory_length=2*math.pi, + max_tree_depth=10, + find_heuristic_step_size=False, + model_args=(), + model_kwargs=None, + rng_key=random.PRNGKey(0)): + """ + Initializes the HMC sampler. + + :param init_params: Initial parameters to begin sampling. The type must + be consistent with the input type to `potential_fn`. + :param int num_warmup: Number of warmup steps; samples generated + during warmup are discarded. + :param float step_size: Determines the size of a single step taken by the + verlet integrator while computing the trajectory using Hamiltonian + dynamics. If not specified, it will be set to 1. + :param numpy.ndarray inverse_mass_matrix: Initial value for inverse mass matrix. + This may be adapted during warmup if adapt_mass_matrix = True. + If no value is specified, then it is initialized to the identity matrix. + :param bool adapt_step_size: A flag to decide if we want to adapt step_size + during warm-up phase using Dual Averaging scheme. + :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass + matrix during warm-up phase using Welford scheme. + :param bool dense_mass: A flag to decide if mass matrix is dense or + diagonal (default when ``dense_mass=False``) + :param float target_accept_prob: Target acceptance probability for step size + adaptation using Dual Averaging. Increasing this value will lead to a smaller + step size, hence the sampling will be slower but more robust. Default to 0.8. + :param float trajectory_length: Length of a MCMC trajectory for HMC. Default + value is :math:`2\\pi`. + :param int max_tree_depth: Max depth of the binary tree created during the doubling + scheme of NUTS sampler. Defaults to 10. + :param bool find_heuristic_step_size: whether to a heuristic function to adjust the + step size at the beginning of each adaptation window. Defaults to False. + :param tuple model_args: Model arguments if `potential_fn_gen` is specified. + :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified. + :param jax.random.PRNGKey rng_key: random key to be used as the source of + randomness. + + """ + step_size = lax.convert_element_type(step_size, canonicalize_dtype(jnp.float64)) + nonlocal wa_update, trajectory_len, max_treedepth, vv_update, wa_steps + wa_steps = num_warmup + trajectory_len = trajectory_length + max_treedepth = max_tree_depth + if isinstance(init_params, ParamInfo): + z, pe, z_grad = init_params + else: + z, pe, z_grad = init_params, None, None + pe_fn = potential_fn + if potential_fn_gen: + if pe_fn is not None: + raise ValueError('Only one of `potential_fn` or `potential_fn_gen` must be provided.') + else: + kwargs = {} if model_kwargs is None else model_kwargs + pe_fn = potential_fn_gen(*model_args, **kwargs) + + find_reasonable_ss = None + if find_heuristic_step_size: + find_reasonable_ss = partial(find_reasonable_step_size, + pe_fn, + kinetic_fn, + momentum_generator) + + wa_init, wa_update = warmup_adapter(num_warmup, + adapt_step_size=adapt_step_size, + adapt_mass_matrix=adapt_mass_matrix, + dense_mass=dense_mass, + target_accept_prob=target_accept_prob, + find_reasonable_step_size=find_reasonable_ss) + + rng_key_hmc, rng_key_wa, rng_key_momentum = random.split(rng_key, 3) + z_info = IntegratorState(z=z, potential_energy=pe, z_grad=z_grad) + wa_state = wa_init(z_info, rng_key_wa, step_size, + inverse_mass_matrix=inverse_mass_matrix, + mass_matrix_size=jnp.size(ravel_pytree(z)[0])) + r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum) + vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn) + vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad) + energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r) + hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, + 0, 0., 0., False, wa_state, rng_key_hmc) + return device_put(hmc_state) + + def _hmc_next(step_size, inverse_mass_matrix, vv_state, + model_args, model_kwargs, rng_key): + if potential_fn_gen: + nonlocal vv_update + pe_fn = potential_fn_gen(*model_args, **model_kwargs) + _, vv_update = velocity_verlet(pe_fn, kinetic_fn) + + num_steps = _get_num_steps(step_size, trajectory_len) + vv_state_new = fori_loop(0, num_steps, + lambda i, val: vv_update(step_size, inverse_mass_matrix, val), + vv_state) + energy_old = vv_state.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state.r) + energy_new = vv_state_new.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state_new.r) + delta_energy = energy_new - energy_old + delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy) + accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0) + diverging = delta_energy > max_delta_energy + transition = random.bernoulli(rng_key, accept_prob) + vv_state, energy = cond(transition, + (vv_state_new, energy_new), identity, + (vv_state, energy_old), identity) + return vv_state, energy, num_steps, accept_prob, diverging + + def _nuts_next(step_size, inverse_mass_matrix, vv_state, + model_args, model_kwargs, rng_key): + if potential_fn_gen: + nonlocal vv_update + pe_fn = potential_fn_gen(*model_args, **model_kwargs) + _, vv_update = velocity_verlet(pe_fn, kinetic_fn) + + binary_tree = build_tree(vv_update, kinetic_fn, vv_state, + inverse_mass_matrix, step_size, rng_key, + max_delta_energy=max_delta_energy, + max_tree_depth=max_treedepth) + accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals + num_steps = binary_tree.num_proposals + vv_state = IntegratorState(z=binary_tree.z_proposal, + r=vv_state.r, + potential_energy=binary_tree.z_proposal_pe, + z_grad=binary_tree.z_proposal_grad) + return vv_state, binary_tree.z_proposal_energy, num_steps, accept_prob, binary_tree.diverging + + _next = _nuts_next if algo == 'NUTS' else _hmc_next + + def sample_kernel(hmc_state, model_args=(), model_kwargs=None): + """ + Given an existing :data:`~numpyro.infer.mcmc.HMCState`, run HMC with fixed (possibly adapted) + step size and return a new :data:`~numpyro.infer.mcmc.HMCState`. + + :param hmc_state: Current sample (and associated state). + :param tuple model_args: Model arguments if `potential_fn_gen` is specified. + :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified. + :return: new proposed :data:`~numpyro.infer.mcmc.HMCState` from simulating + Hamiltonian dynamics given existing state. + + """ + model_kwargs = {} if model_kwargs is None else model_kwargs + rng_key, rng_key_momentum, rng_key_transition = random.split(hmc_state.rng_key, 3) + r = momentum_generator(hmc_state.z, hmc_state.adapt_state.mass_matrix_sqrt, rng_key_momentum) + vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad) + vv_state, energy, num_steps, accept_prob, diverging = _next(hmc_state.adapt_state.step_size, + hmc_state.adapt_state.inverse_mass_matrix, + vv_state, + model_args, + model_kwargs, + rng_key_transition) + # not update adapt_state after warmup phase + adapt_state = cond(hmc_state.i < wa_steps, + (hmc_state.i, accept_prob, vv_state, hmc_state.adapt_state), + lambda args: wa_update(*args), + hmc_state.adapt_state, + identity) + + itr = hmc_state.i + 1 + n = jnp.where(hmc_state.i < wa_steps, itr, itr - wa_steps) + mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n + + return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps, + accept_prob, mean_accept_prob, diverging, adapt_state, rng_key) + + # Make `init_kernel` and `sample_kernel` visible from the global scope once + # `hmc` is called for sphinx doc generation. + if 'SPHINX_BUILD' in os.environ: + hmc.init_kernel = init_kernel + hmc.sample_kernel = sample_kernel + + return init_kernel, sample_kernel + + +class HMC(MCMCKernel): + """ + Hamiltonian Monte Carlo inference, using fixed trajectory length, with + provision for step size and mass matrix adaptation. + + **References:** + + 1. *MCMC Using Hamiltonian Dynamics*, + Radford M. Neal + + :param model: Python callable containing Pyro :mod:`~numpyro.primitives`. + If model is provided, `potential_fn` will be inferred using the model. + :param potential_fn: Python callable that computes the potential energy + given input parameters. The input parameters to `potential_fn` can be + any python collection type, provided that `init_params` argument to + `init_kernel` has the same type. + :param kinetic_fn: Python callable that returns the kinetic energy given + inverse mass matrix and momentum. If not provided, the default is + euclidean kinetic energy. + :param float step_size: Determines the size of a single step taken by the + verlet integrator while computing the trajectory using Hamiltonian + dynamics. If not specified, it will be set to 1. + :param bool adapt_step_size: A flag to decide if we want to adapt step_size + during warm-up phase using Dual Averaging scheme. + :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass + matrix during warm-up phase using Welford scheme. + :param bool dense_mass: A flag to decide if mass matrix is dense or + diagonal (default when ``dense_mass=False``) + :param float target_accept_prob: Target acceptance probability for step size + adaptation using Dual Averaging. Increasing this value will lead to a smaller + step size, hence the sampling will be slower but more robust. Default to 0.8. + :param float trajectory_length: Length of a MCMC trajectory for HMC. Default + value is :math:`2\\pi`. + :param callable init_strategy: a per-site initialization function. + See :ref:`init_strategy` section for available functions. + :param bool find_heuristic_step_size: whether to a heuristic function to adjust the + step size at the beginning of each adaptation window. Defaults to False. + """ + def __init__(self, + model=None, + potential_fn=None, + kinetic_fn=None, + step_size=1.0, + adapt_step_size=True, + adapt_mass_matrix=True, + dense_mass=False, + target_accept_prob=0.8, + trajectory_length=2 * math.pi, + init_strategy=init_to_uniform, + find_heuristic_step_size=False): + if not (model is None) ^ (potential_fn is None): + raise ValueError('Only one of `model` or `potential_fn` must be specified.') + self._model = model + self._potential_fn = potential_fn + self._kinetic_fn = kinetic_fn if kinetic_fn is not None else euclidean_kinetic_energy + self._step_size = step_size + self._adapt_step_size = adapt_step_size + self._adapt_mass_matrix = adapt_mass_matrix + self._dense_mass = dense_mass + self._target_accept_prob = target_accept_prob + self._trajectory_length = trajectory_length + self._algo = 'HMC' + self._max_tree_depth = 10 + self._init_strategy = init_strategy + self._find_heuristic_step_size = find_heuristic_step_size + # Set on first call to init + self._init_fn = None + self._postprocess_fn = None + self._sample_fn = None + + def _init_state(self, rng_key, model_args, model_kwargs, init_params): + if self._model is not None: + init_params, potential_fn, postprocess_fn, model_trace = initialize_model( + rng_key, + self._model, + dynamic_args=True, + model_args=model_args, + model_kwargs=model_kwargs) + if any(v['type'] == 'param' for v in model_trace.values()): + warnings.warn("'param' sites will be treated as constants during inference. To define " + "an improper variable, please use a 'sample' site with log probability " + "masked out. For example, `sample('x', dist.LogNormal(0, 1).mask(False)` " + "means that `x` has improper distribution over the positive domain.") + if self._init_fn is None: + self._init_fn, self._sample_fn = hmc(potential_fn_gen=potential_fn, + kinetic_fn=self._kinetic_fn, + algo=self._algo) + self._postprocess_fn = postprocess_fn + elif self._init_fn is None: + self._init_fn, self._sample_fn = hmc(potential_fn=self._potential_fn, + kinetic_fn=self._kinetic_fn, + algo=self._algo) + + return init_params + + @property + def model(self): + return self._model + + @property + def sample_field(self): + return 'z' + + @property + def default_fields(self): + return ('z', 'diverging') + + def get_diagnostics_str(self, state): + return '{} steps of size {:.2e}. acc. prob={:.2f}'.format(state.num_steps, + state.adapt_state.step_size, + state.mean_accept_prob) + + def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}): + # non-vectorized + if rng_key.ndim == 1: + rng_key, rng_key_init_model = random.split(rng_key) + # vectorized + else: + rng_key, rng_key_init_model = jnp.swapaxes(vmap(random.split)(rng_key), 0, 1) + init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params) + if self._potential_fn and init_params is None: + raise ValueError('Valid value of `init_params` must be provided with' + ' `potential_fn`.') + + hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731 + init_params, + num_warmup=num_warmup, + step_size=self._step_size, + adapt_step_size=self._adapt_step_size, + adapt_mass_matrix=self._adapt_mass_matrix, + dense_mass=self._dense_mass, + target_accept_prob=self._target_accept_prob, + trajectory_length=self._trajectory_length, + max_tree_depth=self._max_tree_depth, + find_heuristic_step_size=self._find_heuristic_step_size, + model_args=model_args, + model_kwargs=model_kwargs, + rng_key=rng_key, + ) + if rng_key.ndim == 1: + init_state = hmc_init_fn(init_params, rng_key) + else: + # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some + # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth, + # wa_steps because those variables do not depend on traced args: init_params, rng_key. + init_state = vmap(hmc_init_fn)(init_params, rng_key) + sample_fn = vmap(self._sample_fn, in_axes=(0, None, None)) + self._sample_fn = sample_fn + return init_state + + def postprocess_fn(self, args, kwargs): + if self._postprocess_fn is None: + return identity + return self._postprocess_fn(*args, **kwargs) + + def sample(self, state, model_args, model_kwargs): + """ + Run HMC from the given :data:`~numpyro.infer.hmc.HMCState` and return the resulting + :data:`~numpyro.infer.hmc.HMCState`. + + :param HMCState state: Represents the current state. + :param model_args: Arguments provided to the model. + :param model_kwargs: Keyword arguments provided to the model. + :return: Next `state` after running HMC. + """ + return self._sample_fn(state, model_args, model_kwargs) + + +class NUTS(HMC): + """ + Hamiltonian Monte Carlo inference, using the No U-Turn Sampler (NUTS) + with adaptive path length and mass matrix adaptation. + + **References:** + + 1. *MCMC Using Hamiltonian Dynamics*, + Radford M. Neal + 2. *The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo*, + Matthew D. Hoffman, and Andrew Gelman. + 3. *A Conceptual Introduction to Hamiltonian Monte Carlo`*, + Michael Betancourt + + :param model: Python callable containing Pyro :mod:`~numpyro.primitives`. + If model is provided, `potential_fn` will be inferred using the model. + :param potential_fn: Python callable that computes the potential energy + given input parameters. The input parameters to `potential_fn` can be + any python collection type, provided that `init_params` argument to + `init_kernel` has the same type. + :param kinetic_fn: Python callable that returns the kinetic energy given + inverse mass matrix and momentum. If not provided, the default is + euclidean kinetic energy. + :param float step_size: Determines the size of a single step taken by the + verlet integrator while computing the trajectory using Hamiltonian + dynamics. If not specified, it will be set to 1. + :param bool adapt_step_size: A flag to decide if we want to adapt step_size + during warm-up phase using Dual Averaging scheme. + :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass + matrix during warm-up phase using Welford scheme. + :param bool dense_mass: A flag to decide if mass matrix is dense or + diagonal (default when ``dense_mass=False``) + :param float target_accept_prob: Target acceptance probability for step size + adaptation using Dual Averaging. Increasing this value will lead to a smaller + step size, hence the sampling will be slower but more robust. Default to 0.8. + :param float trajectory_length: Length of a MCMC trajectory for HMC. This arg has + no effect in NUTS sampler. + :param int max_tree_depth: Max depth of the binary tree created during the doubling + scheme of NUTS sampler. Defaults to 10. + :param callable init_strategy: a per-site initialization function. + See :ref:`init_strategy` section for available functions. + :param bool find_heuristic_step_size: whether to a heuristic function to adjust the + step size at the beginning of each adaptation window. Defaults to False. + """ + def __init__(self, + model=None, + potential_fn=None, + kinetic_fn=None, + step_size=1.0, + adapt_step_size=True, + adapt_mass_matrix=True, + dense_mass=False, + target_accept_prob=0.8, + trajectory_length=None, + max_tree_depth=10, + init_strategy=init_to_uniform, + find_heuristic_step_size=False): + super(NUTS, self).__init__(potential_fn=potential_fn, model=model, kinetic_fn=kinetic_fn, + step_size=step_size, adapt_step_size=adapt_step_size, + adapt_mass_matrix=adapt_mass_matrix, dense_mass=dense_mass, + target_accept_prob=target_accept_prob, + trajectory_length=trajectory_length, + init_strategy=init_strategy, + find_heuristic_step_size=find_heuristic_step_size) + self._max_tree_depth = max_tree_depth + self._algo = 'NUTS' From b8001a9d9f2bfc75741f4a90c8b0181223f5a184 Mon Sep 17 00:00:00 2001 From: Lys Date: Tue, 8 Sep 2020 19:50:13 +0200 Subject: [PATCH 02/93] start hmcecs two --- examples/logistic_hmcecs.py | 76 ++++++++++++++++ numpyro/contrib/hmcecs.py | 167 +++++++++++++++++++++++++++++++++--- 2 files changed, 230 insertions(+), 13 deletions(-) diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index e69de29bb..d9649f7c2 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -0,0 +1,76 @@ +""" Logistic regression model as implemetned in https://arxiv.org/pdf/1708.00955.pdf with Higgs Dataset """ +import jax +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist +from numpyro.infer import NUTS, MCMC, Predictive +from numpyro.contrib.hmcecs import HMC +from sklearn.datasets import load_breast_cancer + + +# TODO: import Higgs data! ---> http://archive.ics.uci.edu/ml/machine-learning-databases/00280/ +# https://towardsdatascience.com/identifying-higgs-bosons-from-background-noise-pyspark-d7983234207e + +def model(feats, obs): + """ Logistic regression model + + """ + n, m = feats.shape + precision = numpyro.sample('precision', dist.continuous.Uniform(0, 4)) + theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), precision * jnp.ones(m))) + + numpyro.sample('obs', dist.Bernoulli(logits=jnp.matmul(feats, theta)), obs=obs) + + +def infer_nuts(rng_key, feats, obs, samples=10, warmup=5, ): + kernel = NUTS(model=model) + mcmc = MCMC(kernel, num_warmup=warmup, num_samples=samples) + mcmc.run(rng_key, feats, obs) + # mcmc.print_summary() + return mcmc.get_samples() + + +def infer_hmcecs(rng_key, feats, obs, g=2, samples=10, warmup=5, ): + hmcecs_key, map_key = jax.random.split(rng_key) + n, _ = feats.shape + model_args = (feats, obs) + print("Running Nuts for map estimation") + z_map = {key: value.mean(0) for key, value in infer_nuts(map_key, feats, obs).items()} + print("Running MCMC subsampling") + kernel = HMC(model=model) + mcmc = MCMC(kernel,num_warmup=warmup,num_samples=samples) + mcmc.run(rng_key,observations=obs,features=feats,subsample_method="perturb",m=4,g=2,z_map = z_map) + return mcmc.get_samples() + + + +def breast_cancer_data(): + dataset = load_breast_cancer() + feats = dataset.data + feats = (feats - feats.mean(0)) / feats.std(0) + feats = jnp.hstack((feats, jnp.ones((feats.shape[0], 1)))) + return feats, dataset.target.reshape((-1, 1)) + + +def higgs_data(): + return + + +if __name__ == '__main__': + rng_key = jax.random.PRNGKey(37) + rng_key, feat_key, obs_key = jax.random.split(rng_key, 3) + n = 100 + m = 10 + + feats, obs = breast_cancer_data() + + from jax.config import config + + config.update('jax_disable_jit', True) + est_posterior = infer_hmcecs(rng_key, feats=feats, obs=obs) + + exit() + predictions = Predictive(model, posterior_samples=est_posterior)(rng_key, feats, None)['obs'] + + # for i, y in enumerate(obs): + # print(i, y[0], jnp.sum(predictions[i]) > 50) \ No newline at end of file diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index 0a3f89b29..e23646c9c 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -1,11 +1,11 @@ -"""Contributed code for HMC and NUTS energy conserving""" +"""Contributed code for HMC and NUTS energy conserving sampling from """ from collections import namedtuple import math import os import warnings -from jax import device_put, lax, partial, random, vmap +from jax import device_put, lax, partial, random, vmap,jacfwd, hessian,jit,ops from jax.dtypes import canonicalize_dtype from jax.flatten_util import ravel_pytree import jax.numpy as jnp @@ -19,11 +19,13 @@ warmup_adapter ) from numpyro.infer.mcmc import MCMCKernel -from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model +from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model, log_density from numpyro.util import cond, fori_loop, identity HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob', - 'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key']) + 'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key', + 'u','blocks', 'hmc_state', 'z_map', 'll_map', 'jac_map', 'hess_map', + 'control_variates', 'll_u']) """ A :func:`~collections.namedtuple` consisting of the following fields: @@ -50,6 +52,14 @@ mass matrix. - **rng_key** - random number generator seed used for the iteration. + - **u** - Subsample + - **blocks** - blocks in which the subsample is divided + - **z_map** - MAP estimation of the model parameters to initialize the subsampling. + - **ll_map** - Log likelihood of the map estimated parameters. + - **jac_map** - Jacobian vector from the map estimated parameters. + - **hess_map** - Hessian matrix from the map estimated parameters + - **Control variates** - Log likelihood correction + - **ll_u** - Log likelihood of the subsample """ @@ -72,6 +82,22 @@ def momentum_generator(prototype_r, mass_matrix_sqrt, rng_key): else: raise ValueError("Mass matrix has incorrect number of dims.") +@partial(jit, static_argnums=(2, 3, 4)) +def _update_block(rng_key, u, n, m, g): + """Returns updated indexes for the subsample""" + rng_key_block, rng_key_index = random.split(rng_key) + + # uniformly choose block to update + chosen_block = random.randint(rng_key, (), 0, g + 1) + + idxs_new = random.randint(rng_key_index, (m // g,), 0, n) + + u_new = jnp.zeros(m, jnp.dtype(u)) + for i in range(m): + u_new = ops.index_add(u_new, i, + lax.cond(i // g == chosen_block, i, lambda _: idxs_new[i % (m // g)], i, lambda _: u[i])) + return u_new + def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='NUTS'): r""" @@ -167,7 +193,17 @@ def init_kernel(init_params, find_heuristic_step_size=False, model_args=(), model_kwargs=None, - rng_key=random.PRNGKey(0)): + rng_key=random.PRNGKey(0), + u=None, + blocks=None, + hmc_state = None, + z_map=None, + ll_map = None, + jac_map = None, + hess_map= None, + control_variates= None, + ll_u=None + ): """ Initializes the HMC sampler. @@ -204,6 +240,7 @@ def init_kernel(init_params, """ step_size = lax.convert_element_type(step_size, canonicalize_dtype(jnp.float64)) nonlocal wa_update, trajectory_len, max_treedepth, vv_update, wa_steps + #nonlocal n,m,g #TODO: This needs to be activated wa_steps = num_warmup trajectory_len = trajectory_length max_treedepth = max_tree_depth @@ -243,7 +280,10 @@ def init_kernel(init_params, vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad) energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r) hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, - 0, 0., 0., False, wa_state, rng_key_hmc) + 0, 0., 0., False, wa_state, rng_key_hmc, + u, blocks, hmc_state, z_map, ll_map, jac_map, hess_map, + control_variates, ll_u + ) return device_put(hmc_state) def _hmc_next(step_size, inverse_mass_matrix, vv_state, @@ -324,7 +364,8 @@ def sample_kernel(hmc_state, model_args=(), model_kwargs=None): mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps, - accept_prob, mean_accept_prob, diverging, adapt_state, rng_key) + accept_prob, mean_accept_prob, diverging, adapt_state, rng_key,u, blocks, hmc_state, z_map, ll_map, jac_map, hess_map, + control_variates, ll_u) # Make `init_kernel` and `sample_kernel` visible from the global scope once # `hmc` is called for sphinx doc generation. @@ -334,6 +375,50 @@ def sample_kernel(hmc_state, model_args=(), model_kwargs=None): return init_kernel, sample_kernel +def _log_prob(trace): + """ Compute probability of each observation """ + node = trace['observations'] + return jnp.sum(node['fn'].log_prob(node['value']), 1) + +def _hmcecs_potential(model, model_args, u, control_variates, jac_map, z, z_map, hess_map, n, m): + """Estimate the potential dynamic energy for the HMC ECS implementation. The calculation follows section 7.2.1 in https://jmlr.org/papers/volume18/15-205/15-205.pdf + The computation has a complexity of O(1) and it's highly dependant on the quality of the map estimate""" + ratio_pop_sub = n / m # ratio of population size to subsample + z_flat, _ = ravel_pytree(z) + zmap_flat, _ = ravel_pytree(z_map) + + _, trace = log_density(model, model_args, {}, z) # log likelihood for subsample + z_diff = z_flat - zmap_flat + + control_variates += jac_map.T @ z_diff + .5 * z_diff.T @ hess_map @ (z_flat - 2 * zmap_flat) + + lq_sub = _log_prob(trace) - control_variates[u] #correction of the likelihood based on the difference between the estimation and the map estimate + + d_hat = ratio_pop_sub * jnp.sum(lq_sub) # assume uniform distribution for subsample! + l_hat = d_hat + jnp.sum(control_variates) + + lq_sub_mean = jnp.mean(lq_sub) + sigma = ratio_pop_sub ** 2 * jnp.sum(lq_sub - lq_sub_mean) + return l_hat - .5 * sigma, control_variates, lq_sub + +def _grad_hmcecs_potential(model,model_args, model_kwargs,u, z, z_map, n, m, jac_map, hess_map, lq_sub): + ratio_pop_sub = n / m # ratio of population size to subsample + z_flat, treedef = ravel_pytree(z) + zmap_flat, _ = ravel_pytree(z_map) + + grad_cv = jac_map + hess_map @ (z_flat - zmap_flat) + + grad_lsub, _ = ravel_pytree(jacfwd(lambda args: partial(log_density, model, model_args, model_kwargs)(args)[0])(z)) #jacobian + grad_lhat = jnp.sum(jac_map, 0) + jnp.sum(hess_map, 0) + ratio_pop_sub * jnp.sum(grad_lsub - grad_cv) + + lq_sub_mean = jnp.mean(lq_sub) + grad_dhat = grad_lhat - grad_cv - jnp.mean(grad_lhat - grad_cv) + + # Note: factor 2 cancels with 1/2 from grad(L_hat) = grad_lhat - .5 * 2 * ratio_pop_sub**2 * ... + grad_sigma = ratio_pop_sub ** 2 * (jnp.sum(lq_sub) * grad_dhat - lq_sub_mean * jnp.sum( + grad_dhat)) # TODO: figure out lq_sub (20,) @ grad_dhat (z.shape) + + return treedef(grad_lhat - grad_sigma) #unflatten tree class HMC(MCMCKernel): """ @@ -372,6 +457,8 @@ class HMC(MCMCKernel): See :ref:`init_strategy` section for available functions. :param bool find_heuristic_step_size: whether to a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False. + :param subsample If "perturb" is provided, the "potential_fn" function will be calculated + using the equations from section 7.2.1 in https://jmlr.org/papers/volume18/15-205/15-205.pdf """ def __init__(self, model=None, @@ -384,9 +471,14 @@ def __init__(self, target_accept_prob=0.8, trajectory_length=2 * math.pi, init_strategy=init_to_uniform, - find_heuristic_step_size=False): + find_heuristic_step_size=False, + subsample_method = None, + m= None, + g = None + ): if not (model is None) ^ (potential_fn is None): raise ValueError('Only one of `model` or `potential_fn` must be specified.') + self._model = model self._potential_fn = potential_fn self._kinetic_fn = kinetic_fn if kinetic_fn is not None else euclidean_kinetic_energy @@ -400,12 +492,23 @@ def __init__(self, self._max_tree_depth = 10 self._init_strategy = init_strategy self._find_heuristic_step_size = find_heuristic_step_size + self._subsample_method = 'perturbed' + self._m = m if m is not None else 4 + self._g = g if g is not None else 2 # Set on first call to init self._init_fn = None self._postprocess_fn = None self._sample_fn = None - def _init_state(self, rng_key, model_args, model_kwargs, init_params): + def _init_state(self, rng_key, model_args, model_kwargs, init_params,z_map): + if self._subsample_method is not None: + warnings.warn("Assumption that the observations have a shape of (n_elements,)") + assert z_map is not None + n = model_kwargs["observations"].shape[0] + m = self._m + g = self._g + u = random.randint(rng_key, (self._m,), 0, n) + model_kwargs = self.model_kwargs_sub(u, model_kwargs) if self._model is not None: init_params, potential_fn, postprocess_fn, model_trace = initialize_model( rng_key, @@ -428,7 +531,8 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): kinetic_fn=self._kinetic_fn, algo=self._algo) - return init_params + return init_params #TODO: Return subsample state? + @property def model(self): @@ -446,18 +550,48 @@ def get_diagnostics_str(self, state): return '{} steps of size {:.2e}. acc. prob={:.2f}'.format(state.num_steps, state.adapt_state.step_size, state.mean_accept_prob) + def model_kwargs_sub(self,u, kwargs): + """Subsample observations and features""" + for key_arg, val_arg in kwargs.items(): + if key_arg == "observations" or key_arg == "features": + kwargs[key_arg] = jnp.take(val_arg, u, axis=0) + return kwargs + + def _block_indices(self,size, num_blocks): + a = jnp.repeat(jnp.arange(num_blocks - 1), size // num_blocks) + b = jnp.repeat(num_blocks - 1, size - len(jnp.repeat(jnp.arange(num_blocks - 1), size // num_blocks))) + return jnp.hstack((a, b)) + - def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}): + def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={},z_map=None): # non-vectorized if rng_key.ndim == 1: rng_key, rng_key_init_model = random.split(rng_key) # vectorized else: rng_key, rng_key_init_model = jnp.swapaxes(vmap(random.split)(rng_key), 0, 1) - init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params) + init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params,z_map) if self._potential_fn and init_params is None: raise ValueError('Valid value of `init_params` must be provided with' ' `potential_fn`.') + if self._subsample_method is not None: + #TODO: Does this make sense to repeat? + rng_key_subsample, rng_key_model, rng_key_hmc_init, rng_key_potential, rng_key = random.split(rng_key, 5) + n = model_kwargs["observations"] + u = random.randint(rng_key, (self._m,), 0, n) + blocks = self._block_indices(self._m, self._g) + model_kwargs = self.model_kwargs_sub(u,model_kwargs) + ld_fn = lambda args: partial(log_density, self._model,(model_args, model_kwargs),{})(model_args)[0] #TODO: I changed args to model_args, still got detected + + ll_map = ld_fn(z_map) + jac_map, _ = ravel_pytree(jacfwd(ld_fn)(z_map)) + hess_map, _ = ravel_pytree(hessian(ld_fn)(z_map)) + hess_map = jnp.reshape(hess_map, (jac_map.shape[0], jac_map.shape[0])) + _, tr = log_density(self._model,model_args, model_kwargs,z_map) + obs_node = tr['observations'] + control_variates = jnp.sum(obs_node['fn'].log_prob(obs_node['value']), 1) + + init_params, _, postprocess_fn, _ = initialize_model(rng_key_init_model, self._model,model_args, model_kwargs) hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731 init_params, @@ -476,6 +610,13 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg ) if rng_key.ndim == 1: init_state = hmc_init_fn(init_params, rng_key) + return init_state + elif self._subsample_method is not None: + #TODO: No f***** clue --> Return subsample state? + init_state = vmap(hmc_init_fn)(init_params, rng_key) + sample_fn = vmap(self._sample_fn, in_axes=(0, None, None)) + self._sample_fn = sample_fn + return init_state else: # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth, @@ -483,7 +624,7 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg init_state = vmap(hmc_init_fn)(init_params, rng_key) sample_fn = vmap(self._sample_fn, in_axes=(0, None, None)) self._sample_fn = sample_fn - return init_state + return init_state def postprocess_fn(self, args, kwargs): if self._postprocess_fn is None: From 26219ce57e9e12c9b4ac75bfdf143190895aad87 Mon Sep 17 00:00:00 2001 From: Lys Date: Mon, 14 Sep 2020 19:17:01 +0200 Subject: [PATCH 03/93] structuring --- examples/logistic_hmcecs.py | 8 +- numpyro/contrib/hmcecs.py | 189 ++++++++++++++++++++++---------- numpyro/contrib/hmcecs_utils.py | 108 ++++++++++++++++++ 3 files changed, 245 insertions(+), 60 deletions(-) create mode 100644 numpyro/contrib/hmcecs_utils.py diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index d9649f7c2..536ae3058 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -33,13 +33,17 @@ def infer_nuts(rng_key, feats, obs, samples=10, warmup=5, ): def infer_hmcecs(rng_key, feats, obs, g=2, samples=10, warmup=5, ): hmcecs_key, map_key = jax.random.split(rng_key) n, _ = feats.shape + model_args = (feats, obs) print("Running Nuts for map estimation") z_map = {key: value.mean(0) for key, value in infer_nuts(map_key, feats, obs).items()} + + #Observations = (569,1) + #Features = (569,31) print("Running MCMC subsampling") - kernel = HMC(model=model) + kernel = HMC(model=model,z_ref=z_map,m=50,g=10) #,subsample_method="perturb") mcmc = MCMC(kernel,num_warmup=warmup,num_samples=samples) - mcmc.run(rng_key,observations=obs,features=feats,subsample_method="perturb",m=4,g=2,z_map = z_map) + mcmc.run(rng_key,feats,obs) return mcmc.get_samples() diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index e23646c9c..9dc9daa16 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -1,4 +1,4 @@ -"""Contributed code for HMC and NUTS energy conserving sampling from """ +"""Contributed code for HMC and NUTS energy conserving sampling adapted from """ from collections import namedtuple import math @@ -21,11 +21,10 @@ from numpyro.infer.mcmc import MCMCKernel from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model, log_density from numpyro.util import cond, fori_loop, identity - +from numpyro.contrib.hmcecs_utils import grad_potential,potential_est,log_density_hmcecs HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob', - 'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key', - 'u','blocks', 'hmc_state', 'z_map', 'll_map', 'jac_map', 'hess_map', - 'control_variates', 'll_u']) + 'mean_accept_prob', 'diverging', 'adapt_state','rng_key']) +HMCECSState = namedtuple("HMCECState",["u","hmc_state","z_ref","ll_ref","jac_all","hess_all","ll_u"]) """ A :func:`~collections.namedtuple` consisting of the following fields: @@ -54,7 +53,7 @@ - **rng_key** - random number generator seed used for the iteration. - **u** - Subsample - **blocks** - blocks in which the subsample is divided - - **z_map** - MAP estimation of the model parameters to initialize the subsampling. + - **z_ref** - MAP estimation of the model parameters to initialize the subsampling. - **ll_map** - Log likelihood of the map estimated parameters. - **jac_map** - Jacobian vector from the map estimated parameters. - **hess_map** - Hessian matrix from the map estimated parameters @@ -84,22 +83,30 @@ def momentum_generator(prototype_r, mass_matrix_sqrt, rng_key): @partial(jit, static_argnums=(2, 3, 4)) def _update_block(rng_key, u, n, m, g): - """Returns updated indexes for the subsample""" + """Returns the indexes from the subsample that will be updated, there is replacement. + The number of indexes to be updated depend on the block size, higher block size more correlation among elements in the subsample. + :param rng_key + :param u subsample + :param n total number of data + :param m subsample size + :param g block size: subsample subdivision""" rng_key_block, rng_key_index = random.split(rng_key) # uniformly choose block to update - chosen_block = random.randint(rng_key, (), 0, g + 1) + chosen_block = random.randint(rng_key, shape=(), minval= 0, maxval=g + 1) #TODO: assertions for g values? why minval=0?division by 0 - idxs_new = random.randint(rng_key_index, (m // g,), 0, n) + idxs_new = random.randint(rng_key_index, shape=(m // g,), minval=0, maxval=n) #chose block within the subsample to update - u_new = jnp.zeros(m, jnp.dtype(u)) + u_new = jnp.zeros(m, jnp.dtype(u)) #empty array with size m for i in range(m): + #if index in the subsample // g = chosen block : pick new indexes from the subsample size + #else not update: keep the same indexes u_new = ops.index_add(u_new, i, lax.cond(i // g == chosen_block, i, lambda _: idxs_new[i % (m // g)], i, lambda _: u[i])) return u_new -def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='NUTS'): +def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, grad_potential_fn_gen=None,algo='NUTS'): r""" Hamiltonian Monte Carlo inference, using either fixed number of steps or the No U-Turn Sampler (NUTS) with adaptive path length. @@ -197,7 +204,7 @@ def init_kernel(init_params, u=None, blocks=None, hmc_state = None, - z_map=None, + z_ref=None, ll_map = None, jac_map = None, hess_map= None, @@ -280,10 +287,7 @@ def init_kernel(init_params, vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad) energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r) hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, - 0, 0., 0., False, wa_state, rng_key_hmc, - u, blocks, hmc_state, z_map, ll_map, jac_map, hess_map, - control_variates, ll_u - ) + 0, 0., 0., False, wa_state,rng_key_hmc) return device_put(hmc_state) def _hmc_next(step_size, inverse_mass_matrix, vv_state, @@ -364,8 +368,7 @@ def sample_kernel(hmc_state, model_args=(), model_kwargs=None): mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps, - accept_prob, mean_accept_prob, diverging, adapt_state, rng_key,u, blocks, hmc_state, z_map, ll_map, jac_map, hess_map, - control_variates, ll_u) + accept_prob, mean_accept_prob, diverging, adapt_state,rng_key) # Make `init_kernel` and `sample_kernel` visible from the global scope once # `hmc` is called for sphinx doc generation. @@ -380,7 +383,7 @@ def _log_prob(trace): node = trace['observations'] return jnp.sum(node['fn'].log_prob(node['value']), 1) -def _hmcecs_potential(model, model_args, u, control_variates, jac_map, z, z_map, hess_map, n, m): +def _hmcecs_potential(model, model_args, u, control_variates, jac_map, z, z_ref, hess_map, n, m): """Estimate the potential dynamic energy for the HMC ECS implementation. The calculation follows section 7.2.1 in https://jmlr.org/papers/volume18/15-205/15-205.pdf The computation has a complexity of O(1) and it's highly dependant on the quality of the map estimate""" ratio_pop_sub = n / m # ratio of population size to subsample @@ -401,10 +404,10 @@ def _hmcecs_potential(model, model_args, u, control_variates, jac_map, z, z_map, sigma = ratio_pop_sub ** 2 * jnp.sum(lq_sub - lq_sub_mean) return l_hat - .5 * sigma, control_variates, lq_sub -def _grad_hmcecs_potential(model,model_args, model_kwargs,u, z, z_map, n, m, jac_map, hess_map, lq_sub): +def _grad_hmcecs_potential(model,model_args, model_kwargs,u, z, z_ref, n, m, jac_map, hess_map, lq_sub): ratio_pop_sub = n / m # ratio of population size to subsample z_flat, treedef = ravel_pytree(z) - zmap_flat, _ = ravel_pytree(z_map) + zmap_flat, _ = ravel_pytree(z_ref) grad_cv = jac_map + hess_map @ (z_flat - zmap_flat) @@ -457,12 +460,16 @@ class HMC(MCMCKernel): See :ref:`init_strategy` section for available functions. :param bool find_heuristic_step_size: whether to a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False. - :param subsample If "perturb" is provided, the "potential_fn" function will be calculated + :param subsample_method If "perturb" is provided, the "potential_fn" function will be calculated using the equations from section 7.2.1 in https://jmlr.org/papers/volume18/15-205/15-205.pdf + :param m subsample size + :param g block size + :param z_ref MAP estimate of the parameters """ def __init__(self, model=None, potential_fn=None, + grad_potential = None, kinetic_fn=None, step_size=1.0, adapt_step_size=True, @@ -474,13 +481,15 @@ def __init__(self, find_heuristic_step_size=False, subsample_method = None, m= None, - g = None + g = None, + z_ref= None ): if not (model is None) ^ (potential_fn is None): raise ValueError('Only one of `model` or `potential_fn` must be specified.') self._model = model self._potential_fn = potential_fn + self._grad_potential = grad_potential self._kinetic_fn = kinetic_fn if kinetic_fn is not None else euclidean_kinetic_energy self._step_size = step_size self._adapt_step_size = adapt_step_size @@ -492,23 +501,55 @@ def __init__(self, self._max_tree_depth = 10 self._init_strategy = init_strategy self._find_heuristic_step_size = find_heuristic_step_size - self._subsample_method = 'perturbed' + self._subsample_method = subsample_method self._m = m if m is not None else 4 self._g = g if g is not None else 2 + self._z_ref = z_ref + self._n = None # Set on first call to init self._init_fn = None self._postprocess_fn = None self._sample_fn = None + self._subsample_fn = None + + def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ref): + self._n = model_kwargs["observations"].shape[0] + + u = random.randint(rng_key, (self._m,), 0, self._n) + model_kwargs = self.model_kwargs_sub(u, model_kwargs) + + rng_key_subsample, rng_key_model, rng_key_hmc_init, rng_key_potential, rng_key = random.split(rng_key, 5) + + ld_fn = lambda args: partial(log_density_hmcecs, self._model, model_args, model_kwargs, prior=False)(args)[0] + + ll_ref = ld_fn(z_ref) + jac_all, _ = ravel_pytree(jacfwd(ld_fn)(z_ref)) + hess_all, _ = ravel_pytree(hessian(ld_fn)(z_ref)) + + jac_all = jac_all.reshape(self._n, -1).sum(0) + k, = jac_all.shape + hess_all = hess_all.reshape(self._n, k, k).sum(0) + + self._potential_fn = lambda model, args, ll_ref, jac_all, z_ref, hess_all, n, m: \ + lambda z: potential_est(model=model, model_args=args, model_kwargs=model_kwargs, ll_ref=ll_ref, + jac_all=jac_all, + z=z, z_ref=z_ref, hess_all=hess_all, n=n, m=m) + self._grad_potential = lambda model, args, ll_ref, jac_all, z_ref, hess_all, n, m: \ + lambda z: grad_potential(model=model, model_args=args, model_kwargs=model_kwargs, jac_all=jac_all, z=z, + z_ref=z_ref, + hess_all=hess_all, n=n, m=m) + - def _init_state(self, rng_key, model_args, model_kwargs, init_params,z_map): + def _init_state(self, rng_key, model_args, model_kwargs, init_params): if self._subsample_method is not None: - warnings.warn("Assumption that the observations have a shape of (n_elements,)") - assert z_map is not None - n = model_kwargs["observations"].shape[0] - m = self._m - g = self._g - u = random.randint(rng_key, (self._m,), 0, n) - model_kwargs = self.model_kwargs_sub(u, model_kwargs) + print(self._z_ref) + assert self._z_ref is not None, "Please provide a (i.e map) estimate for the parameters" + self._init_subsample_state(rng_key, model_args, model_kwargs, init_params,self._z_ref) + self._init_fn, self._subsample_fn = hmc(potential_fn_gen=self._potential_fn, + kinetic_fn=euclidean_kinetic_energy, + grad_potential_fn_gen=self._grad_potential, + algo='HMC') # no need to be returned here to be used for sampling, because they are init sampler and subsampler are updated... + if self._model is not None: init_params, potential_fn, postprocess_fn, model_trace = initialize_model( rng_key, @@ -531,7 +572,7 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params,z_map): kinetic_fn=self._kinetic_fn, algo=self._algo) - return init_params #TODO: Return subsample state? + return init_params @property @@ -563,35 +604,19 @@ def _block_indices(self,size, num_blocks): return jnp.hstack((a, b)) - def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={},z_map=None): + def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}): # non-vectorized if rng_key.ndim == 1: rng_key, rng_key_init_model = random.split(rng_key) # vectorized else: rng_key, rng_key_init_model = jnp.swapaxes(vmap(random.split)(rng_key), 0, 1) - init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params,z_map) + init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params) #should work for all cases + if self._potential_fn and init_params is None: raise ValueError('Valid value of `init_params` must be provided with' ' `potential_fn`.') - if self._subsample_method is not None: - #TODO: Does this make sense to repeat? - rng_key_subsample, rng_key_model, rng_key_hmc_init, rng_key_potential, rng_key = random.split(rng_key, 5) - n = model_kwargs["observations"] - u = random.randint(rng_key, (self._m,), 0, n) - blocks = self._block_indices(self._m, self._g) - model_kwargs = self.model_kwargs_sub(u,model_kwargs) - ld_fn = lambda args: partial(log_density, self._model,(model_args, model_kwargs),{})(model_args)[0] #TODO: I changed args to model_args, still got detected - - ll_map = ld_fn(z_map) - jac_map, _ = ravel_pytree(jacfwd(ld_fn)(z_map)) - hess_map, _ = ravel_pytree(hessian(ld_fn)(z_map)) - hess_map = jnp.reshape(hess_map, (jac_map.shape[0], jac_map.shape[0])) - _, tr = log_density(self._model,model_args, model_kwargs,z_map) - obs_node = tr['observations'] - control_variates = jnp.sum(obs_node['fn'].log_prob(obs_node['value']), 1) - - init_params, _, postprocess_fn, _ = initialize_model(rng_key_init_model, self._model,model_args, model_kwargs) + hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731 init_params, @@ -611,12 +636,12 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg if rng_key.ndim == 1: init_state = hmc_init_fn(init_params, rng_key) return init_state - elif self._subsample_method is not None: - #TODO: No f***** clue --> Return subsample state? + elif self._subsample_method: init_state = vmap(hmc_init_fn)(init_params, rng_key) - sample_fn = vmap(self._sample_fn, in_axes=(0, None, None)) - self._sample_fn = sample_fn + subsample_fn = vmap(self._subsample_fn, in_axes=(0, None, None)) + self._subsample_fn = subsample_fn return init_state + else: # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth, @@ -642,7 +667,55 @@ def sample(self, state, model_args, model_kwargs): :return: Next `state` after running HMC. """ return self._sample_fn(state, model_args, model_kwargs) - + def subsample(self,subsamplestate,model_args,model_kwargs): + """ + Run HMC from the given :data:`~numpyro.infer.hmc.HMCECSState` and return the resulting + :data:`~numpyro.infer.hmc.HMCECSState`. + + :param HMCECSState state: Represents the current state. + :param model_args: Arguments provided to the model. + :param model_kwargs: Keyword arguments provided to the model. + :return: Next `subsample state` after running HMC. + """ + + rng_key_subsample, rng_key_transition, rng_key_likelihood, rng_key = random.split(subsamplestate.hmc_state.rng_key,4) + + u_new = _update_block(rng_key_subsample, subsamplestate.u, self._n, self._m, self._g) + + # estimate likelihood of subsample with single block updated + llu_new = potential_est(model=self._model, + model_args=model_args, + model_kwargs=model_kwargs, + jac_all=subsamplestate.jac_all, + hess_all=subsamplestate.hess_all, + ll_ref=subsamplestate.ll_ref, + z=subsamplestate.hmc_state.z, + z_ref=subsamplestate.z_ref, + n=self._n, m=self._m) + + # accept new subsample with probability min(1,L^{hat}_{u_new}(z) - L^{hat}_{u}(z)) + # NOTE: latent variables (z aka theta) same, subsample indices (u) different by one block. + accept_prob = jnp.clip(jnp.exp(-llu_new + subsamplestate.ll_u), a_max=1.) + transition = random.bernoulli(rng_key_transition, accept_prob) + u, ll_u = cond(transition, + (u_new, llu_new), identity, + (subsamplestate.u, subsamplestate.ll_u), identity) + + ######## UPDATE PARAMETERS ########## + + + hmc_subsamplestate= HMCECSState(u=u, hmc_state=subsamplestate.hmc_state, + z_ref=subsamplestate.z_ref, + ll_u=ll_u,ll_ref=subsamplestate.ll_ref, + jac_all=subsamplestate.jac_all, + hess_all=subsamplestate.hess_all) + + return self._subsample_fn(hmc_subsamplestate,model_args=(self._model, + model_args, + subsamplestate.ll_ref, + subsamplestate.jac_all, + subsamplestate.z_ref, + subsamplestate.hess_all, self._n, self._m),model_kwargs=model_kwargs) class NUTS(HMC): """ diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py new file mode 100644 index 000000000..7906098ba --- /dev/null +++ b/numpyro/contrib/hmcecs_utils.py @@ -0,0 +1,108 @@ +from functools import partial + +import jax +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist +from numpyro.distributions.util import is_identically_one +from numpyro.handlers import substitute, trace +from numpyro.util import ravel_pytree + +def log_density_hmcecs(model, model_args, model_kwargs, params, prior=True): + """ + (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given + latent values ``params``. If prior == False, the log probability of the prior probability + over the parameters is not computed, solely the log probability of the observations + + :param model: Python callable containing NumPyro primitives. + :param tuple model_args: args provided to the model. + :param dict model_kwargs: kwargs provided to the model. + :param dict params: dictionary of current parameter values keyed by site + name. + :return: log of joint density and a corresponding model trace + """ + model = substitute(model, data=params) + model_trace = trace(model).get_trace(*model_args, **model_kwargs) + log_joint = jnp.array(0.) + for site in model_trace.values(): + if site['type'] == 'sample' and not isinstance(site['fn'], dist.PRNGIdentity) and not site['is_observed']: + value = site['value'] + intermediates = site['intermediates'] + scale = site['scale'] + if intermediates: + log_prob = site['fn'].log_prob(value, intermediates) + else: + log_prob = site['fn'].log_prob(value) + + if (scale is not None) and (not is_identically_one(scale)): + log_prob = scale * log_prob + + if prior: + log_prob = jnp.sum(log_prob) + log_joint = log_joint + log_prob + return log_joint, model_trace + + + +def grad_potential(model, model_args, model_kwargs,z, z_ref, jac_all, hess_all, n, m, *args, **kwargs): + + k, = jac_all.shape + z_flat, treedef = ravel_pytree(z) + zref_flat, _ = ravel_pytree(z_ref) + z_diff = z_flat - zref_flat + + ld_fn = lambda args: partial(log_density, model, model_args, model_kwargs, prior = False)(args)[0] + + jac_ref, _ = ravel_pytree(jax.jacfwd(ld_fn)(z_ref)) + hess_ref, _ = ravel_pytree(jax.hessian(ld_fn)(z_ref)) + + jac_ref = jac_ref.reshape(m, k) + hess_ref = hess_ref.reshape(m, k, k) + + grad_sum = jac_all + hess_all.dot(z_diff) + jac_sub, _ = ravel_pytree(jax.jacfwd(ld_fn)(z)) + + ll_sub, _ = log_density(model, model_args, model_kwargs, z,prior=False) # log likelihood for subsample with current theta + ll_ref, _ = log_density(model, model_args, model_kwargs, z_ref,prior=False) # log likelihood for subsample with reference theta + + diff = ll_sub - (ll_ref + jac_ref @ z_diff + .5 * z_diff @ hess_ref @ z_diff.T) + + jac_sub = jac_sub.reshape(jac_ref.shape) - jac_ref + + grad_d_k = jac_sub - z_diff.dot(hess_ref) + + gradll = -(grad_sum + n / m * (jac_sub.sum(0) - hess_ref.sum(0).dot(z_diff))) + n ** 2 / (m ** 2) * ( + diff - diff.mean(0)).T.dot(grad_d_k - grad_d_k.mean(0)) + + ld_fn = lambda args: partial(log_density, model, model_args, model_kwargs,prior=True)(args)[0] + jac_sub, _ = ravel_pytree(jax.jacfwd(ld_fn)(z)) + + return treedef(gradll - jac_sub) + + +def potential_est(model, model_args, model_kwargs,ll_ref, jac_all, hess_all, z, z_ref, n, m): + # Agrees with reference upto constant factor on prior + k, = jac_all.shape # number of features + z_flat, _ = ravel_pytree(z) + zref_flat, _ = ravel_pytree(z_ref) + + z_diff = z_flat - zref_flat + + ld_fn = lambda args: partial(log_density, model, model_args, model_kwargs,prior=False)(args)[0] + + jac_sub, _ = ravel_pytree(jax.jacfwd(ld_fn)(z_ref)) + hess_sub, _ = ravel_pytree(jax.hessian(ld_fn)(z_ref)) + + proxy = jnp.sum(ll_ref) + jac_all.T @ z_diff + .5 * z_diff.T @ hess_all @ z_diff + + ll_sub, _ = log_density(model, model_args, model_kwargs, z,prior=False) # log likelihood for subsample with current theta + ll_ref, _ = log_density(model, model_args, model_kwargs, z_ref,prior=False) # log likelihood for subsample with reference theta + + diff = ll_sub - (ll_ref + jac_sub.reshape((m, k)) @ z_diff + .5 * z_diff @ hess_sub.reshape((m, k, k)) @ z_diff.T) + l_hat = proxy + n / m * jnp.sum(diff) + + sigma = n ** 2 / m * jnp.var(diff) + + ll_prior, _ = log_density(model, model_args, model_kwargs, z,prior=True) + + return (-l_hat + .5 * sigma) - ll_prior \ No newline at end of file From 9a326db9f0cd248f4797857a58564eb63fe040b6 Mon Sep 17 00:00:00 2001 From: Lys Date: Mon, 14 Sep 2020 20:01:34 +0200 Subject: [PATCH 04/93] small fix --- examples/logistic_hmcecs.py | 4 ++-- numpyro/contrib/hmcecs.py | 22 ++++++++++++++++++---- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index 536ae3058..f2447fcd5 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -35,13 +35,13 @@ def infer_hmcecs(rng_key, feats, obs, g=2, samples=10, warmup=5, ): n, _ = feats.shape model_args = (feats, obs) - print("Running Nuts for map estimation") + print("Running NUTS for map estimation") z_map = {key: value.mean(0) for key, value in infer_nuts(map_key, feats, obs).items()} #Observations = (569,1) #Features = (569,31) print("Running MCMC subsampling") - kernel = HMC(model=model,z_ref=z_map,m=50,g=10) #,subsample_method="perturb") + kernel = HMC(model=model,z_ref=z_map,m=50,g=10,subsample_method="perturb") mcmc = MCMC(kernel,num_warmup=warmup,num_samples=samples) mcmc.run(rng_key,feats,obs) return mcmc.get_samples() diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index 9dc9daa16..2ae4cdd50 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -513,19 +513,25 @@ def __init__(self, self._subsample_fn = None def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ref): - self._n = model_kwargs["observations"].shape[0] + + self._n = model_args[0].shape[0] u = random.randint(rng_key, (self._m,), 0, self._n) + + model_args = self.model_args_sub(u,model_args) model_kwargs = self.model_kwargs_sub(u, model_kwargs) rng_key_subsample, rng_key_model, rng_key_hmc_init, rng_key_potential, rng_key = random.split(rng_key, 5) ld_fn = lambda args: partial(log_density_hmcecs, self._model, model_args, model_kwargs, prior=False)(args)[0] - + print(z_ref["theta"].shape) + print(z_ref.keys()) + exit() ll_ref = ld_fn(z_ref) jac_all, _ = ravel_pytree(jacfwd(ld_fn)(z_ref)) hess_all, _ = ravel_pytree(hessian(ld_fn)(z_ref)) - + print(jac_all.shape) + exit() jac_all = jac_all.reshape(self._n, -1).sum(0) k, = jac_all.shape hess_all = hess_all.reshape(self._n, k, k).sum(0) @@ -542,7 +548,6 @@ def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ def _init_state(self, rng_key, model_args, model_kwargs, init_params): if self._subsample_method is not None: - print(self._z_ref) assert self._z_ref is not None, "Please provide a (i.e map) estimate for the parameters" self._init_subsample_state(rng_key, model_args, model_kwargs, init_params,self._z_ref) self._init_fn, self._subsample_fn = hmc(potential_fn_gen=self._potential_fn, @@ -591,6 +596,15 @@ def get_diagnostics_str(self, state): return '{} steps of size {:.2e}. acc. prob={:.2f}'.format(state.num_steps, state.adapt_state.step_size, state.mean_accept_prob) + def model_args_sub(self,u,model_args): + """Subsample observations and features""" + args = [] + for arg in model_args: + if isinstance(arg, jnp.ndarray): + args.append(jnp.take(arg, u, axis=0)) + else: + args.append(arg) + return args def model_kwargs_sub(self,u, kwargs): """Subsample observations and features""" for key_arg, val_arg in kwargs.items(): From 44d2fc199cd7380c5e265d6f3e91d88b444cacb9 Mon Sep 17 00:00:00 2001 From: Lys Date: Wed, 16 Sep 2020 17:49:29 +0200 Subject: [PATCH 05/93] ADDED: verlet, new log density --- examples/logistic_hmcecs.py | 2 +- numpyro/contrib/hmcecs.py | 81 +++++++------------- numpyro/contrib/hmcecs_utils.py | 128 ++++++++++++++++++++++++-------- 3 files changed, 128 insertions(+), 83 deletions(-) diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index f2447fcd5..646f20e15 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -17,6 +17,7 @@ def model(feats, obs): """ n, m = feats.shape precision = numpyro.sample('precision', dist.continuous.Uniform(0, 4)) + #precision = 0.5 theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), precision * jnp.ones(m))) numpyro.sample('obs', dist.Bernoulli(logits=jnp.matmul(feats, theta)), obs=obs) @@ -34,7 +35,6 @@ def infer_hmcecs(rng_key, feats, obs, g=2, samples=10, warmup=5, ): hmcecs_key, map_key = jax.random.split(rng_key) n, _ = feats.shape - model_args = (feats, obs) print("Running NUTS for map estimation") z_map = {key: value.mean(0) for key, value in infer_nuts(map_key, feats, obs).items()} diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index 2ae4cdd50..9b39f459c 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -21,7 +21,7 @@ from numpyro.infer.mcmc import MCMCKernel from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model, log_density from numpyro.util import cond, fori_loop, identity -from numpyro.contrib.hmcecs_utils import grad_potential,potential_est,log_density_hmcecs +from numpyro.contrib.hmcecs_utils import grad_potential,potential_est,log_density_hmcecs, velocity_verlet_hmcecs HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob', 'mean_accept_prob', 'diverging', 'adapt_state','rng_key']) HMCECSState = namedtuple("HMCECState",["u","hmc_state","z_ref","ll_ref","jac_all","hess_all","ll_u"]) @@ -90,6 +90,8 @@ def _update_block(rng_key, u, n, m, g): :param n total number of data :param m subsample size :param g block size: subsample subdivision""" + if not (g > m) or (g < 1): + raise ValueError('Block size (g) needs to = or > than 1 and smaller than the subsample size {}'.format(m)) rng_key_block, rng_key_index = random.split(rng_key) # uniformly choose block to update @@ -247,7 +249,6 @@ def init_kernel(init_params, """ step_size = lax.convert_element_type(step_size, canonicalize_dtype(jnp.float64)) nonlocal wa_update, trajectory_len, max_treedepth, vv_update, wa_steps - #nonlocal n,m,g #TODO: This needs to be activated wa_steps = num_warmup trajectory_len = trajectory_length max_treedepth = max_tree_depth @@ -262,6 +263,11 @@ def init_kernel(init_params, else: kwargs = {} if model_kwargs is None else model_kwargs pe_fn = potential_fn_gen(*model_args, **kwargs) + if grad_potential_fn_gen: + kwargs = {} if model_kwargs is None else model_kwargs + gpe_fn = grad_potential_fn_gen(*model_args, **kwargs) + else: + gpe_fn = None find_reasonable_ss = None if find_heuristic_step_size: @@ -283,7 +289,7 @@ def init_kernel(init_params, inverse_mass_matrix=inverse_mass_matrix, mass_matrix_size=jnp.size(ravel_pytree(z)[0])) r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum) - vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn) + vv_init, vv_update = velocity_verlet_hmcecs(pe_fn, kinetic_fn,grad_potential_fn=gpe_fn) vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad) energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r) hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, @@ -293,9 +299,14 @@ def init_kernel(init_params, def _hmc_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key): if potential_fn_gen: + if grad_potential_fn_gen: + kwargs = {} if model_kwargs is None else model_kwargs + gpe_fn = grad_potential_fn_gen(*model_args, **kwargs) + else: + gpe_fn = None nonlocal vv_update pe_fn = potential_fn_gen(*model_args, **model_kwargs) - _, vv_update = velocity_verlet(pe_fn, kinetic_fn) + _, vv_update = velocity_verlet_hmcecs(pe_fn, kinetic_fn,gpe_fn) num_steps = _get_num_steps(step_size, trajectory_len) vv_state_new = fori_loop(0, num_steps, @@ -318,7 +329,12 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, if potential_fn_gen: nonlocal vv_update pe_fn = potential_fn_gen(*model_args, **model_kwargs) - _, vv_update = velocity_verlet(pe_fn, kinetic_fn) + if grad_potential_fn_gen: + kwargs = {} if model_kwargs is None else model_kwargs + gpe_fn = grad_potential_fn_gen(*model_args, **kwargs) + else: + gpe_fn = None + _, vv_update = velocity_verlet_hmcecs(pe_fn, kinetic_fn,gpe_fn) binary_tree = build_tree(vv_update, kinetic_fn, vv_state, inverse_mass_matrix, step_size, rng_key, @@ -383,45 +399,6 @@ def _log_prob(trace): node = trace['observations'] return jnp.sum(node['fn'].log_prob(node['value']), 1) -def _hmcecs_potential(model, model_args, u, control_variates, jac_map, z, z_ref, hess_map, n, m): - """Estimate the potential dynamic energy for the HMC ECS implementation. The calculation follows section 7.2.1 in https://jmlr.org/papers/volume18/15-205/15-205.pdf - The computation has a complexity of O(1) and it's highly dependant on the quality of the map estimate""" - ratio_pop_sub = n / m # ratio of population size to subsample - z_flat, _ = ravel_pytree(z) - zmap_flat, _ = ravel_pytree(z_map) - - _, trace = log_density(model, model_args, {}, z) # log likelihood for subsample - z_diff = z_flat - zmap_flat - - control_variates += jac_map.T @ z_diff + .5 * z_diff.T @ hess_map @ (z_flat - 2 * zmap_flat) - - lq_sub = _log_prob(trace) - control_variates[u] #correction of the likelihood based on the difference between the estimation and the map estimate - - d_hat = ratio_pop_sub * jnp.sum(lq_sub) # assume uniform distribution for subsample! - l_hat = d_hat + jnp.sum(control_variates) - - lq_sub_mean = jnp.mean(lq_sub) - sigma = ratio_pop_sub ** 2 * jnp.sum(lq_sub - lq_sub_mean) - return l_hat - .5 * sigma, control_variates, lq_sub - -def _grad_hmcecs_potential(model,model_args, model_kwargs,u, z, z_ref, n, m, jac_map, hess_map, lq_sub): - ratio_pop_sub = n / m # ratio of population size to subsample - z_flat, treedef = ravel_pytree(z) - zmap_flat, _ = ravel_pytree(z_ref) - - grad_cv = jac_map + hess_map @ (z_flat - zmap_flat) - - grad_lsub, _ = ravel_pytree(jacfwd(lambda args: partial(log_density, model, model_args, model_kwargs)(args)[0])(z)) #jacobian - grad_lhat = jnp.sum(jac_map, 0) + jnp.sum(hess_map, 0) + ratio_pop_sub * jnp.sum(grad_lsub - grad_cv) - - lq_sub_mean = jnp.mean(lq_sub) - grad_dhat = grad_lhat - grad_cv - jnp.mean(grad_lhat - grad_cv) - - # Note: factor 2 cancels with 1/2 from grad(L_hat) = grad_lhat - .5 * 2 * ratio_pop_sub**2 * ... - grad_sigma = ratio_pop_sub ** 2 * (jnp.sum(lq_sub) * grad_dhat - lq_sub_mean * jnp.sum( - grad_dhat)) # TODO: figure out lq_sub (20,) @ grad_dhat (z.shape) - - return treedef(grad_lhat - grad_sigma) #unflatten tree class HMC(MCMCKernel): """ @@ -522,20 +499,18 @@ def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ model_kwargs = self.model_kwargs_sub(u, model_kwargs) rng_key_subsample, rng_key_model, rng_key_hmc_init, rng_key_potential, rng_key = random.split(rng_key, 5) + ld_fn = lambda args: partial(log_density_hmcecs, self._model, model_args, model_kwargs,prior=False)(args)[0] - ld_fn = lambda args: partial(log_density_hmcecs, self._model, model_args, model_kwargs, prior=False)(args)[0] - print(z_ref["theta"].shape) - print(z_ref.keys()) - exit() - ll_ref = ld_fn(z_ref) - jac_all, _ = ravel_pytree(jacfwd(ld_fn)(z_ref)) - hess_all, _ = ravel_pytree(hessian(ld_fn)(z_ref)) + ll_ref = ld_fn(z_ref) #loglikelihood of the reference parameter estimates (u,u) under the subsample + + jac_all, jac_all_unflat = ravel_pytree(jacfwd(ld_fn)(z_ref)) #contains the jacobian for the non observed parameters "theta":(u,u,features), "precision" : (u,u) + hess_all, hess_all_unflat = ravel_pytree(hessian(ld_fn)(z_ref)) print(jac_all.shape) exit() jac_all = jac_all.reshape(self._n, -1).sum(0) k, = jac_all.shape hess_all = hess_all.reshape(self._n, k, k).sum(0) - + exit() self._potential_fn = lambda model, args, ll_ref, jac_all, z_ref, hess_all, n, m: \ lambda z: potential_est(model=model, model_args=args, model_kwargs=model_kwargs, ll_ref=ll_ref, jac_all=jac_all, @@ -597,7 +572,7 @@ def get_diagnostics_str(self, state): state.adapt_state.step_size, state.mean_accept_prob) def model_args_sub(self,u,model_args): - """Subsample observations and features""" + """Subsample observations and features according to u subsample indexes""" args = [] for arg in model_args: if isinstance(arg, jnp.ndarray): diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py index 7906098ba..51f05f686 100644 --- a/numpyro/contrib/hmcecs_utils.py +++ b/numpyro/contrib/hmcecs_utils.py @@ -2,13 +2,20 @@ import jax import jax.numpy as jnp +from jax import grad, value_and_grad +from jax.tree_util import tree_multimap import numpyro import numpyro.distributions as dist from numpyro.distributions.util import is_identically_one from numpyro.handlers import substitute, trace from numpyro.util import ravel_pytree +from collections import namedtuple -def log_density_hmcecs(model, model_args, model_kwargs, params, prior=True): +IntegratorState = namedtuple('IntegratorState', ['z', 'r', 'potential_energy', 'z_grad']) +IntegratorState.__new__.__defaults__ = (None,) * len(IntegratorState._fields) + + +def log_density_hmcecs(model, model_args, model_kwargs, params,prior=False): """ (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given latent values ``params``. If prior == False, the log probability of the prior probability @@ -23,35 +30,49 @@ def log_density_hmcecs(model, model_args, model_kwargs, params, prior=True): """ model = substitute(model, data=params) model_trace = trace(model).get_trace(*model_args, **model_kwargs) + log_joint = jnp.array(0.) - for site in model_trace.values(): - if site['type'] == 'sample' and not isinstance(site['fn'], dist.PRNGIdentity) and not site['is_observed']: - value = site['value'] - intermediates = site['intermediates'] - scale = site['scale'] - if intermediates: - log_prob = site['fn'].log_prob(value, intermediates) - else: - log_prob = site['fn'].log_prob(value) - - if (scale is not None) and (not is_identically_one(scale)): - log_prob = scale * log_prob - - if prior: + if not prior: + for site in model_trace.values(): + if site['type'] == 'sample' and site['is_observed'] and not isinstance(site['fn'], dist.PRNGIdentity): + value = site['value'] + intermediates = site['intermediates'] + scale = site['scale'] + if intermediates: + log_prob = site['fn'].log_prob(value, intermediates) + else: + log_prob = site['fn'].log_prob(value) #TODO: The shape here is duplicated + + if (scale is not None) and (not is_identically_one(scale)): + log_prob = scale * log_prob + + return log_prob, model_trace + else: + for site in model_trace.values(): + if site['type'] == 'sample' and not isinstance(site['fn'], dist.PRNGIdentity) and not site['is_observed']: #Prior prob + value = site['value'] + intermediates = site['intermediates'] + scale = site['scale'] + if intermediates: + log_prob = site['fn'].log_prob(value, intermediates) + else: + log_prob = site['fn'].log_prob(value) + + if (scale is not None) and (not is_identically_one(scale)): + log_prob = scale * log_prob + log_prob = jnp.sum(log_prob) log_joint = log_joint + log_prob - return log_joint, model_trace - - + return log_joint, model_trace def grad_potential(model, model_args, model_kwargs,z, z_ref, jac_all, hess_all, n, m, *args, **kwargs): - + """Calculate the gradient of the potential energy function""" k, = jac_all.shape z_flat, treedef = ravel_pytree(z) zref_flat, _ = ravel_pytree(z_ref) z_diff = z_flat - zref_flat - ld_fn = lambda args: partial(log_density, model, model_args, model_kwargs, prior = False)(args)[0] + ld_fn = lambda args: partial(log_density_hmcecs, model, model_args, model_kwargs,prior=False)(args)[0] jac_ref, _ = ravel_pytree(jax.jacfwd(ld_fn)(z_ref)) hess_ref, _ = ravel_pytree(jax.hessian(ld_fn)(z_ref)) @@ -62,8 +83,8 @@ def grad_potential(model, model_args, model_kwargs,z, z_ref, jac_all, hess_all, grad_sum = jac_all + hess_all.dot(z_diff) jac_sub, _ = ravel_pytree(jax.jacfwd(ld_fn)(z)) - ll_sub, _ = log_density(model, model_args, model_kwargs, z,prior=False) # log likelihood for subsample with current theta - ll_ref, _ = log_density(model, model_args, model_kwargs, z_ref,prior=False) # log likelihood for subsample with reference theta + ll_sub, _ = log_density_hmcecs(model, model_args, model_kwargs, z,prior=False) # log likelihood for subsample with current theta + ll_ref, _ = log_density_hmcecs(model, model_args, model_kwargs, z_ref,prior=False) # log likelihood for subsample with reference theta diff = ll_sub - (ll_ref + jac_ref @ z_diff + .5 * z_diff @ hess_ref @ z_diff.T) @@ -74,13 +95,14 @@ def grad_potential(model, model_args, model_kwargs,z, z_ref, jac_all, hess_all, gradll = -(grad_sum + n / m * (jac_sub.sum(0) - hess_ref.sum(0).dot(z_diff))) + n ** 2 / (m ** 2) * ( diff - diff.mean(0)).T.dot(grad_d_k - grad_d_k.mean(0)) - ld_fn = lambda args: partial(log_density, model, model_args, model_kwargs,prior=True)(args)[0] + ld_fn = lambda args: partial(log_density_hmcecs, model, model_args, model_kwargs,prior=True)(args)[0] jac_sub, _ = ravel_pytree(jax.jacfwd(ld_fn)(z)) return treedef(gradll - jac_sub) - def potential_est(model, model_args, model_kwargs,ll_ref, jac_all, hess_all, z, z_ref, n, m): + """Estimate the potential dynamic energy for the HMC ECS implementation. The calculation follows section 7.2.1 in https://jmlr.org/papers/volume18/15-205/15-205.pdf + The computation has a complexity of O(1) and it's highly dependant on the quality of the map estimate""" # Agrees with reference upto constant factor on prior k, = jac_all.shape # number of features z_flat, _ = ravel_pytree(z) @@ -88,21 +110,69 @@ def potential_est(model, model_args, model_kwargs,ll_ref, jac_all, hess_all, z, z_diff = z_flat - zref_flat - ld_fn = lambda args: partial(log_density, model, model_args, model_kwargs,prior=False)(args)[0] + ld_fn = lambda args: partial(log_density_hmcecs, model, model_args, model_kwargs,prior=False)(args)[0] jac_sub, _ = ravel_pytree(jax.jacfwd(ld_fn)(z_ref)) hess_sub, _ = ravel_pytree(jax.hessian(ld_fn)(z_ref)) proxy = jnp.sum(ll_ref) + jac_all.T @ z_diff + .5 * z_diff.T @ hess_all @ z_diff - ll_sub, _ = log_density(model, model_args, model_kwargs, z,prior=False) # log likelihood for subsample with current theta - ll_ref, _ = log_density(model, model_args, model_kwargs, z_ref,prior=False) # log likelihood for subsample with reference theta + ll_sub, _ = log_density_hmcecs(model, model_args, model_kwargs, z,prior=False) # log likelihood for subsample with current theta + ll_ref, _ = log_density_hmcecs(model, model_args, model_kwargs, z_ref,prior=False) # log likelihood for subsample with reference theta diff = ll_sub - (ll_ref + jac_sub.reshape((m, k)) @ z_diff + .5 * z_diff @ hess_sub.reshape((m, k, k)) @ z_diff.T) l_hat = proxy + n / m * jnp.sum(diff) sigma = n ** 2 / m * jnp.var(diff) - ll_prior, _ = log_density(model, model_args, model_kwargs, z,prior=True) + ll_prior, _ = log_density_hmcecs(model, model_args, model_kwargs, z,prior=True) + + return (-l_hat + .5 * sigma) - ll_prior + +def velocity_verlet_hmcecs(potential_fn, kinetic_fn, grad_potential_fn=None): + r""" + Second order symplectic integrator that uses the velocity verlet algorithm + for position `z` and momentum `r`. + + :param potential_fn: Python callable that computes the potential energy + given input parameters. The input parameters to `potential_fn` can be + any python collection type. If HMCECS is used the gradient of the potential + energy funtion is calculated + :param kinetic_fn: Python callable that returns the kinetic energy given + inverse mass matrix and momentum. + :return: a pair of (`init_fn`, `update_fn`). + """ + compute_value_grad = value_and_grad(potential_fn) if grad_potential_fn is None \ + else lambda z: (potential_fn(z), grad_potential_fn(z)) + + def init_fn(z, r, potential_energy=None, z_grad=None): + """ + :param z: Position of the particle. + :param r: Momentum of the particle. + :param potential_energy: Potential energy at `z`. + :param z_grad: gradient of potential energy at `z`. + :return: initial state for the integrator. + """ + if potential_energy is None or z_grad is None: + potential_energy, z_grad = compute_value_grad(z) + + return IntegratorState(z, r, potential_energy, z_grad) + + def update_fn(step_size, inverse_mass_matrix, state): + """ + :param float step_size: Size of a single step. + :param inverse_mass_matrix: Inverse of mass matrix, which is used to + calculate kinetic energy. + :param state: Current state of the integrator. + :return: new state for the integrator. + """ + z, r, _, z_grad = state + r = tree_multimap(lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad) # r(n+1/2) + r_grad = grad(kinetic_fn, argnums=1)(inverse_mass_matrix, r) + z = tree_multimap(lambda z, r_grad: z + step_size * r_grad, z, r_grad) # z(n+1) + potential_energy, z_grad = compute_value_grad(z) + r = tree_multimap(lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad) # r(n+1) + return IntegratorState(z, r, potential_energy, z_grad) + + return init_fn, update_fn - return (-l_hat + .5 * sigma) - ll_prior \ No newline at end of file From 4eeb1f00979ef20f937959d06491a4147bdb6cfd Mon Sep 17 00:00:00 2001 From: Lys Date: Fri, 18 Sep 2020 18:27:45 +0200 Subject: [PATCH 06/93] FIXED: initialization model parameters --- examples/logistic_hmcecs.py | 11 +- numpyro/contrib/hmcecs.py | 197 ++++++++++++++++++++------------ numpyro/contrib/hmcecs_utils.py | 11 +- 3 files changed, 143 insertions(+), 76 deletions(-) diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index 646f20e15..c0dc6955b 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -6,7 +6,7 @@ from numpyro.infer import NUTS, MCMC, Predictive from numpyro.contrib.hmcecs import HMC from sklearn.datasets import load_breast_cancer - +numpyro.set_platform("cpu") # TODO: import Higgs data! ---> http://archive.ics.uci.edu/ml/machine-learning-databases/00280/ # https://towardsdatascience.com/identifying-higgs-bosons-from-background-noise-pyspark-d7983234207e @@ -23,7 +23,7 @@ def model(feats, obs): numpyro.sample('obs', dist.Bernoulli(logits=jnp.matmul(feats, theta)), obs=obs) -def infer_nuts(rng_key, feats, obs, samples=10, warmup=5, ): +def infer_nuts(rng_key, feats, obs, samples=5, warmup=5, ): kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup=warmup, num_samples=samples) mcmc.run(rng_key, feats, obs) @@ -35,13 +35,15 @@ def infer_hmcecs(rng_key, feats, obs, g=2, samples=10, warmup=5, ): hmcecs_key, map_key = jax.random.split(rng_key) n, _ = feats.shape + + print("Running NUTS for map estimation") z_map = {key: value.mean(0) for key, value in infer_nuts(map_key, feats, obs).items()} #Observations = (569,1) #Features = (569,31) print("Running MCMC subsampling") - kernel = HMC(model=model,z_ref=z_map,m=50,g=10,subsample_method="perturb") + kernel = HMC(model=model,z_ref=z_map,m=5,g=2,subsample_method="perturb") mcmc = MCMC(kernel,num_warmup=warmup,num_samples=samples) mcmc.run(rng_key,feats,obs) return mcmc.get_samples() @@ -53,7 +55,8 @@ def breast_cancer_data(): feats = dataset.data feats = (feats - feats.mean(0)) / feats.std(0) feats = jnp.hstack((feats, jnp.ones((feats.shape[0], 1)))) - return feats, dataset.target.reshape((-1, 1)) + + return feats[:10], dataset.target[:10] def higgs_data(): diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index 9b39f459c..280a587a9 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -478,11 +478,17 @@ def __init__(self, self._max_tree_depth = 10 self._init_strategy = init_strategy self._find_heuristic_step_size = find_heuristic_step_size - self._subsample_method = subsample_method - self._m = m if m is not None else 4 - self._g = g if g is not None else 2 - self._z_ref = z_ref + #HMCECS parameters + self.subsample_method = subsample_method + self.m = m if m is not None else 4 + self.g = g if g is not None else 2 + self.z_ref = z_ref self._n = None + self._ll_ref = None + self._jac_all = None + self._hess_all = None + self._ll_u = None + self._u = None # Set on first call to init self._init_fn = None self._postprocess_fn = None @@ -491,44 +497,52 @@ def __init__(self, def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ref): - self._n = model_args[0].shape[0] - u = random.randint(rng_key, (self._m,), 0, self._n) - - model_args = self.model_args_sub(u,model_args) - model_kwargs = self.model_kwargs_sub(u, model_kwargs) - - rng_key_subsample, rng_key_model, rng_key_hmc_init, rng_key_potential, rng_key = random.split(rng_key, 5) - ld_fn = lambda args: partial(log_density_hmcecs, self._model, model_args, model_kwargs,prior=False)(args)[0] + rng_key_subsample, rng_key_model, rng_key_hmc_init, rng_key_potential, rng_key,rng_key_init_model = random.split(rng_key, 6) + ld_fn = lambda args: log_density_hmcecs( model=self._model,model_args=model_args,model_kwargs= model_kwargs,params=args,prior=False)[0] ll_ref = ld_fn(z_ref) #loglikelihood of the reference parameter estimates (u,u) under the subsample - - jac_all, jac_all_unflat = ravel_pytree(jacfwd(ld_fn)(z_ref)) #contains the jacobian for the non observed parameters "theta":(u,u,features), "precision" : (u,u) + jac_all, jac_all_unflat = ravel_pytree(jacfwd(ld_fn)(z_ref)) #contains the jacobian for the non observed parameters "theta":(u,features), "precision" : (u,u) hess_all, hess_all_unflat = ravel_pytree(hessian(ld_fn)(z_ref)) - print(jac_all.shape) - exit() - jac_all = jac_all.reshape(self._n, -1).sum(0) - k, = jac_all.shape - hess_all = hess_all.reshape(self._n, k, k).sum(0) - exit() - self._potential_fn = lambda model, args, ll_ref, jac_all, z_ref, hess_all, n, m: \ - lambda z: potential_est(model=model, model_args=args, model_kwargs=model_kwargs, ll_ref=ll_ref, - jac_all=jac_all, - z=z, z_ref=z_ref, hess_all=hess_all, n=n, m=m) - self._grad_potential = lambda model, args, ll_ref, jac_all, z_ref, hess_all, n, m: \ - lambda z: grad_potential(model=model, model_args=args, model_kwargs=model_kwargs, jac_all=jac_all, z=z, - z_ref=z_ref, - hess_all=hess_all, n=n, m=m) - - + self._jac_all = jac_all.reshape(self._n, -1).sum(0) + k, = self._jac_all.shape + self._hess_all = hess_all.reshape(self._n, k, k).sum(0) + self._u = random.randint(rng_key, (self.m,), 0, self._n) + + init_params, _, postprocess_fn, _ = initialize_model(rng_key_init_model, self._model, + model_args=self.model_args_sub(u=self._u,model_args=model_args)) # TODO: fix init strategy! + + def potential_fn(self,model, model_args,model_kwargs, ll_ref, jac_all, z_ref, hess_all, n, m): + return lambda z: potential_est(model=model, model_args=model_args, model_kwargs=model_kwargs,ll_ref=ll_ref, jac_all=jac_all, + z=z, z_ref=z_ref, hess_all=hess_all, n=n, m=m) + def grad_potencial_fn(self,model, model_args, model_kwargs, jac_all,z_ref,hess_all, n, m): + return lambda z: grad_potential(model=model, model_args=model_args, model_kwargs=model_kwargs, jac_all=jac_all, z=z, + z_ref=z_ref,hess_all=hess_all, n=n, m=m) def _init_state(self, rng_key, model_args, model_kwargs, init_params): - if self._subsample_method is not None: - assert self._z_ref is not None, "Please provide a (i.e map) estimate for the parameters" - self._init_subsample_state(rng_key, model_args, model_kwargs, init_params,self._z_ref) + if self.subsample_method is not None: + assert self.z_ref is not None, "Please provide a (i.e map) estimate for the parameters" + self._n = model_args[0].shape[0] + + #Initialize the subsample state + self._init_subsample_state(rng_key, model_args, model_kwargs, init_params,self.z_ref) + # + # #Initialize the potential and its gradient + self._potential_fn = self.potential_fn(model=self._model, model_args=model_args, model_kwargs=model_kwargs, ll_ref=self._ll_ref, + jac_all=self._jac_all, z_ref=self.z_ref, hess_all=self._hess_all, n=self._n, m=self.m) + self._grad_potential = self.grad_potencial_fn(model=self._model,model_args = model_args,model_kwargs=model_kwargs, + jac_all=self._jac_all,z_ref=self.z_ref,hess_all=self._hess_all,n=self._n,m=self.m) + #Initialize the model parameters + init_params, potential_fn, postprocess_fn, model_trace = initialize_model( + rng_key, + self._model, + dynamic_args=True, + model_args=model_args, + model_kwargs=model_kwargs) #TODO: review, change initialization + #Initialize the hmc sampler self._init_fn, self._subsample_fn = hmc(potential_fn_gen=self._potential_fn, kinetic_fn=euclidean_kinetic_energy, grad_potential_fn_gen=self._grad_potential, - algo='HMC') # no need to be returned here to be used for sampling, because they are init sampler and subsampler are updated... + algo='HMC') if self._model is not None: init_params, potential_fn, postprocess_fn, model_trace = initialize_model( @@ -579,7 +593,7 @@ def model_args_sub(self,u,model_args): args.append(jnp.take(arg, u, axis=0)) else: args.append(arg) - return args + return tuple(args) def model_kwargs_sub(self,u, kwargs): """Subsample observations and features""" for key_arg, val_arg in kwargs.items(): @@ -594,6 +608,7 @@ def _block_indices(self,size, num_blocks): def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}): + """Initialize sampling algorithms""" # non-vectorized if rng_key.ndim == 1: rng_key, rng_key_init_model = random.split(rng_key) @@ -606,39 +621,81 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg raise ValueError('Valid value of `init_params` must be provided with' ' `potential_fn`.') - - hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731 - init_params, - num_warmup=num_warmup, - step_size=self._step_size, - adapt_step_size=self._adapt_step_size, - adapt_mass_matrix=self._adapt_mass_matrix, - dense_mass=self._dense_mass, - target_accept_prob=self._target_accept_prob, - trajectory_length=self._trajectory_length, - max_tree_depth=self._max_tree_depth, - find_heuristic_step_size=self._find_heuristic_step_size, - model_args=model_args, - model_kwargs=model_kwargs, - rng_key=rng_key, - ) - if rng_key.ndim == 1: - init_state = hmc_init_fn(init_params, rng_key) - return init_state - elif self._subsample_method: - init_state = vmap(hmc_init_fn)(init_params, rng_key) - subsample_fn = vmap(self._subsample_fn, in_axes=(0, None, None)) - self._subsample_fn = subsample_fn - return init_state + if self.subsample_method == "perturb": + print("Initializing sampler") + hmc_init_fn = lambda init_params,rng_key: self._init_fn(init_params=init_params, + num_warmup = num_warmup, + step_size = self._step_size, + adapt_step_size = self._adapt_step_size, + adapt_mass_matrix = self._adapt_mass_matrix, + dense_mass = self._dense_mass, + target_accept_prob = self._target_accept_prob, + trajectory_length=self._trajectory_length, + max_tree_depth=self._max_tree_depth, + find_heuristic_step_size=self._find_heuristic_step_size, + model_args=( + self._model, self.model_args_sub(self._u, model_args),self._ll_ref, self._jac_all, self.z_ref, self._hess_all, + self._n, self.m)) + if rng_key.ndim ==1: + print(init_params) + exit() + init_state = hmc_init_fn(init_params, rng_key) + + + self._ll_u = self._potential_fn(self._model, self.model_args_sub(self._u, model_args), self._ll_ref, + self._jac_all, self._hess_all, + init_state.z, self.z_ref, self._n, self.m) + + init_subsample_state = HMCECSState(u=self._u, hmc_state=init_state, z_ref=self.z_ref, ll_u=self._ll_u, + jac_all=self._jac_all, + hess_all=self._hess_all, ll_ref=self._ll_ref) + return init_state,device_put(init_subsample_state) + else: + # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some + # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth, + # wa_steps because those variables do not depend on traced args: init_params, rng_key. + init_state = vmap(hmc_init_fn)(init_params, rng_key) + self._ll_u = self._potential_fn(self._model, self.model_args_sub(self._u, model_args), self._ll_ref, + self._jac_all, self._hess_all, + init_state.z, self.z_ref, self._n, self.m) + + init_subsample_state = HMCECSState(u=self._u, hmc_state=init_state, z_ref=self.z_ref, ll_u=self._ll_u, + jac_all=self._jac_all, + hess_all=self._hess_all, ll_ref=self._ll_ref) + sample_fn = vmap(self._sample_fn, in_axes=(0, None, None)) + subsample = vmap(self._subsample_fn, in_axes=(0,None,None)) + self._sample_fn = sample_fn + self._subsample_fn = subsample + return init_state, device_put(init_subsample_state) + exit() else: - # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some - # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth, - # wa_steps because those variables do not depend on traced args: init_params, rng_key. - init_state = vmap(hmc_init_fn)(init_params, rng_key) - sample_fn = vmap(self._sample_fn, in_axes=(0, None, None)) - self._sample_fn = sample_fn - return init_state + hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731 + init_params, + num_warmup=num_warmup, + step_size=self._step_size, + adapt_step_size=self._adapt_step_size, + adapt_mass_matrix=self._adapt_mass_matrix, + dense_mass=self._dense_mass, + target_accept_prob=self._target_accept_prob, + trajectory_length=self._trajectory_length, + max_tree_depth=self._max_tree_depth, + find_heuristic_step_size=self._find_heuristic_step_size, + model_args=model_args, + model_kwargs=model_kwargs, + rng_key=rng_key, + ) + if rng_key.ndim == 1: + init_state = hmc_init_fn(init_params, rng_key) + return init_state + else: + # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some + # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth, + # wa_steps because those variables do not depend on traced args: init_params, rng_key. + init_state = vmap(hmc_init_fn)(init_params, rng_key) + sample_fn = vmap(self._sample_fn, in_axes=(0, None, None)) + self._sample_fn = sample_fn + return init_state def postprocess_fn(self, args, kwargs): if self._postprocess_fn is None: @@ -669,10 +726,10 @@ def subsample(self,subsamplestate,model_args,model_kwargs): rng_key_subsample, rng_key_transition, rng_key_likelihood, rng_key = random.split(subsamplestate.hmc_state.rng_key,4) - u_new = _update_block(rng_key_subsample, subsamplestate.u, self._n, self._m, self._g) + u_new = _update_block(rng_key_subsample, subsamplestate.u, self._n, self.m, self.g) # estimate likelihood of subsample with single block updated - llu_new = potential_est(model=self._model, + llu_new = self._potential_fn(model=self._model, model_args=model_args, model_kwargs=model_kwargs, jac_all=subsamplestate.jac_all, @@ -680,7 +737,7 @@ def subsample(self,subsamplestate,model_args,model_kwargs): ll_ref=subsamplestate.ll_ref, z=subsamplestate.hmc_state.z, z_ref=subsamplestate.z_ref, - n=self._n, m=self._m) + n=self._n, m=self.m) # accept new subsample with probability min(1,L^{hat}_{u_new}(z) - L^{hat}_{u}(z)) # NOTE: latent variables (z aka theta) same, subsample indices (u) different by one block. @@ -704,7 +761,7 @@ def subsample(self,subsamplestate,model_args,model_kwargs): subsamplestate.ll_ref, subsamplestate.jac_all, subsamplestate.z_ref, - subsamplestate.hess_all, self._n, self._m),model_kwargs=model_kwargs) + subsamplestate.hess_all, self._n, self.m),model_kwargs=model_kwargs) class NUTS(HMC): """ diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py index 51f05f686..2603a405f 100644 --- a/numpyro/contrib/hmcecs_utils.py +++ b/numpyro/contrib/hmcecs_utils.py @@ -28,9 +28,9 @@ def log_density_hmcecs(model, model_args, model_kwargs, params,prior=False): name. :return: log of joint density and a corresponding model trace """ + model = substitute(model, data=params) model_trace = trace(model).get_trace(*model_args, **model_kwargs) - log_joint = jnp.array(0.) if not prior: for site in model_trace.values(): @@ -41,8 +41,12 @@ def log_density_hmcecs(model, model_args, model_kwargs, params,prior=False): if intermediates: log_prob = site['fn'].log_prob(value, intermediates) else: + #print(site["name"]) + #print("value shape") + #print(value.shape) log_prob = site['fn'].log_prob(value) #TODO: The shape here is duplicated - + #print("Log prob shape") + #print(log_prob.shape) if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob @@ -176,3 +180,6 @@ def update_fn(step_size, inverse_mass_matrix, state): return init_fn, update_fn +def initialize_model_hmcecs(): + pass + From ca9decefe7fdbc97cced534cb4adc53785981eb5 Mon Sep 17 00:00:00 2001 From: Lys Date: Mon, 21 Sep 2020 18:54:52 +0200 Subject: [PATCH 07/93] FIXED: Arguments potential function --- examples/logistic_hmcecs.py | 2 +- numpyro/contrib/hmcecs.py | 107 ++++++++++++++++---------------- numpyro/contrib/hmcecs_utils.py | 20 +++++- 3 files changed, 71 insertions(+), 58 deletions(-) diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index c0dc6955b..91782367a 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -23,7 +23,7 @@ def model(feats, obs): numpyro.sample('obs', dist.Bernoulli(logits=jnp.matmul(feats, theta)), obs=obs) -def infer_nuts(rng_key, feats, obs, samples=5, warmup=5, ): +def infer_nuts(rng_key, feats, obs, samples=5, warmup=0, ): kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup=warmup, num_samples=samples) mcmc.run(rng_key, feats, obs) diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index 280a587a9..f234895f4 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -21,7 +21,7 @@ from numpyro.infer.mcmc import MCMCKernel from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model, log_density from numpyro.util import cond, fori_loop, identity -from numpyro.contrib.hmcecs_utils import grad_potential,potential_est,log_density_hmcecs, velocity_verlet_hmcecs +from numpyro.contrib.hmcecs_utils import grad_potential,potential_est,log_density_hmcecs, velocity_verlet_hmcecs, init_near_values HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob', 'mean_accept_prob', 'diverging', 'adapt_state','rng_key']) HMCECSState = namedtuple("HMCECState",["u","hmc_state","z_ref","ll_ref","jac_all","hess_all","ll_u"]) @@ -202,17 +202,7 @@ def init_kernel(init_params, find_heuristic_step_size=False, model_args=(), model_kwargs=None, - rng_key=random.PRNGKey(0), - u=None, - blocks=None, - hmc_state = None, - z_ref=None, - ll_map = None, - jac_map = None, - hess_map= None, - control_variates= None, - ll_u=None - ): + rng_key=random.PRNGKey(0)): """ Initializes the HMC sampler. @@ -262,6 +252,8 @@ def init_kernel(init_params, raise ValueError('Only one of `potential_fn` or `potential_fn_gen` must be provided.') else: kwargs = {} if model_kwargs is None else model_kwargs + # print(potential_fn_gen.__code__.co_varnames) + # exit() pe_fn = potential_fn_gen(*model_args, **kwargs) if grad_potential_fn_gen: kwargs = {} if model_kwargs is None else model_kwargs @@ -499,51 +491,61 @@ def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ rng_key_subsample, rng_key_model, rng_key_hmc_init, rng_key_potential, rng_key,rng_key_init_model = random.split(rng_key, 6) - - ld_fn = lambda args: log_density_hmcecs( model=self._model,model_args=model_args,model_kwargs= model_kwargs,params=args,prior=False)[0] - ll_ref = ld_fn(z_ref) #loglikelihood of the reference parameter estimates (u,u) under the subsample - jac_all, jac_all_unflat = ravel_pytree(jacfwd(ld_fn)(z_ref)) #contains the jacobian for the non observed parameters "theta":(u,features), "precision" : (u,u) - hess_all, hess_all_unflat = ravel_pytree(hessian(ld_fn)(z_ref)) - self._jac_all = jac_all.reshape(self._n, -1).sum(0) - k, = self._jac_all.shape - self._hess_all = hess_all.reshape(self._n, k, k).sum(0) + self._n = model_args[0].shape[0] self._u = random.randint(rng_key, (self.m,), 0, self._n) - init_params, _, postprocess_fn, _ = initialize_model(rng_key_init_model, self._model, - model_args=self.model_args_sub(u=self._u,model_args=model_args)) # TODO: fix init strategy! + ld_fn = lambda args: jnp.sum(partial(log_density_hmcecs, self._model, model_args, {},prior=False)(args)[0]) + + self._ll_ref = ld_fn(z_ref) + self._jac_all, _ = ravel_pytree(jacfwd(ld_fn)(z_ref)) + hess_all, _ = ravel_pytree(hessian(ld_fn)(z_ref)) - def potential_fn(self,model, model_args,model_kwargs, ll_ref, jac_all, z_ref, hess_all, n, m): - return lambda z: potential_est(model=model, model_args=model_args, model_kwargs=model_kwargs,ll_ref=ll_ref, jac_all=jac_all, - z=z, z_ref=z_ref, hess_all=hess_all, n=n, m=m) - def grad_potencial_fn(self,model, model_args, model_kwargs, jac_all,z_ref,hess_all, n, m): - return lambda z: grad_potential(model=model, model_args=model_args, model_kwargs=model_kwargs, jac_all=jac_all, z=z, + k, = self._jac_all.shape + self._hess_all = hess_all.reshape((k, k)) + # Initialize the potential and its gradient + self._potential_fn = self.potential_fn_hmcecs(model=self._model, model_args=model_args, model_kwargs=model_kwargs, + ll_ref=self._ll_ref, + jac_all=self._jac_all, z_ref=self.z_ref, hess_all=self._hess_all, + n=self._n, m=self.m) + + self._grad_potential = self.grad_potencial_fn_hmcecs(model=self._model, model_args=model_args, + model_kwargs=model_kwargs, + jac_all=self._jac_all, z_ref=self.z_ref, hess_all=self._hess_all, + n=self._n, m=self.m) + # Initialize the model parameters + init_params, potential_fn, postprocess_fn, model_trace = initialize_model( + rng_key, + self._model, + init_strategy=partial(init_near_values, values=self.z_ref), + dynamic_args=True, + model_args=self.model_args_sub(self._u, model_args), + model_kwargs=model_kwargs) + + return init_params, potential_fn, postprocess_fn, model_trace + + def potential_fn_hmcecs(self,model, model_args,model_kwargs, ll_ref, jac_all, z_ref, hess_all, n, m): + + return lambda model, args, kwargs,ll_ref, jac_all, z_ref, hess_all, n, m: \ + lambda z: potential_est(model=model, model_args=args, model_kwargs=model_kwargs,ll_ref=ll_ref, jac_all=jac_all, + z=z, z_ref=z_ref, hess_all=hess_all, n=n, m=m) + + def grad_potencial_fn_hmcecs(self,model, model_args, model_kwargs, jac_all,z_ref,hess_all, n, m): + return lambda model, args,kwargs ,ll_ref, jac_all, z_ref, hess_all, n, m:\ + lambda z: grad_potential(model=model, model_args=model_args, model_kwargs=model_kwargs, jac_all=jac_all, z=z, z_ref=z_ref,hess_all=hess_all, n=n, m=m) + + def _init_state(self, rng_key, model_args, model_kwargs, init_params): if self.subsample_method is not None: assert self.z_ref is not None, "Please provide a (i.e map) estimate for the parameters" - self._n = model_args[0].shape[0] - #Initialize the subsample state - self._init_subsample_state(rng_key, model_args, model_kwargs, init_params,self.z_ref) - # - # #Initialize the potential and its gradient - self._potential_fn = self.potential_fn(model=self._model, model_args=model_args, model_kwargs=model_kwargs, ll_ref=self._ll_ref, - jac_all=self._jac_all, z_ref=self.z_ref, hess_all=self._hess_all, n=self._n, m=self.m) - self._grad_potential = self.grad_potencial_fn(model=self._model,model_args = model_args,model_kwargs=model_kwargs, - jac_all=self._jac_all,z_ref=self.z_ref,hess_all=self._hess_all,n=self._n,m=self.m) - #Initialize the model parameters - init_params, potential_fn, postprocess_fn, model_trace = initialize_model( - rng_key, - self._model, - dynamic_args=True, - model_args=model_args, - model_kwargs=model_kwargs) #TODO: review, change initialization - #Initialize the hmc sampler + self._algo = "HMC" + init_params, potential_fn, postprocess_fn, model_trace=self._init_subsample_state(rng_key, model_args, model_kwargs, init_params,self.z_ref) + # Initialize the hmc sampler self._init_fn, self._subsample_fn = hmc(potential_fn_gen=self._potential_fn, kinetic_fn=euclidean_kinetic_energy, grad_potential_fn_gen=self._grad_potential, - algo='HMC') - + algo=self._algo) if self._model is not None: init_params, potential_fn, postprocess_fn, model_trace = initialize_model( rng_key, @@ -622,7 +624,7 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg ' `potential_fn`.') if self.subsample_method == "perturb": - print("Initializing sampler") + hmc_init_fn = lambda init_params,rng_key: self._init_fn(init_params=init_params, num_warmup = num_warmup, step_size = self._step_size, @@ -633,15 +635,11 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg trajectory_length=self._trajectory_length, max_tree_depth=self._max_tree_depth, find_heuristic_step_size=self._find_heuristic_step_size, - model_args=( - self._model, self.model_args_sub(self._u, model_args),self._ll_ref, self._jac_all, self.z_ref, self._hess_all, - self._n, self.m)) + model_args=(self._model,self.model_args_sub(self._u,model_args),model_kwargs,self._ll_ref,self._jac_all,self.z_ref,self._hess_all,self._n,self.m), + model_kwargs=model_kwargs) if rng_key.ndim ==1: - print(init_params) - exit() - init_state = hmc_init_fn(init_params, rng_key) - + init_state = hmc_init_fn(init_params, rng_key) self._ll_u = self._potential_fn(self._model, self.model_args_sub(self._u, model_args), self._ll_ref, self._jac_all, self._hess_all, init_state.z, self.z_ref, self._n, self.m) @@ -667,7 +665,6 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg self._sample_fn = sample_fn self._subsample_fn = subsample return init_state, device_put(init_subsample_state) - exit() else: hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731 diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py index 2603a405f..d487894a9 100644 --- a/numpyro/contrib/hmcecs_utils.py +++ b/numpyro/contrib/hmcecs_utils.py @@ -180,6 +180,22 @@ def update_fn(step_size, inverse_mass_matrix, state): return init_fn, update_fn -def initialize_model_hmcecs(): - pass +def init_near_values(site=None, values={}): + """Initialize the sampling to a noisy map estimate of the parameters""" + from functools import partial + + from numpyro.distributions.continuous import Normal + from numpyro.infer.initialization import init_to_uniform + + if site is None: + return partial(init_near_values(values=values)) + + if site['type'] == 'sample' and not site['is_observed']: + if site['name'] in values: + try: + rng_key = site['kwargs'].get('rng_key') + sample_shape = site['kwargs'].get('sample_shape') + return values[site['name']] + Normal(0., 1e-3).sample(rng_key, sample_shape) + except: + return init_to_uniform(site) From f01c027a8798cc24b86c172160a65f5780acd9cf Mon Sep 17 00:00:00 2001 From: Lys Date: Tue, 22 Sep 2020 18:34:45 +0200 Subject: [PATCH 08/93] FIXED: Arguments mess --- numpyro/contrib/hmcecs.py | 285 ++++++++++++++++++++------------ numpyro/contrib/hmcecs_utils.py | 18 +- numpyro/infer/mcmc.py | 1 + 3 files changed, 192 insertions(+), 112 deletions(-) diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index f234895f4..99778e549 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -21,7 +21,7 @@ from numpyro.infer.mcmc import MCMCKernel from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model, log_density from numpyro.util import cond, fori_loop, identity -from numpyro.contrib.hmcecs_utils import grad_potential,potential_est,log_density_hmcecs, velocity_verlet_hmcecs, init_near_values +from numpyro.contrib.hmcecs_utils import grad_potential,potential_est,log_density_hmcecs, velocity_verlet_hmcecs, init_near_values,tuplemerge HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob', 'mean_accept_prob', 'diverging', 'adapt_state','rng_key']) HMCECSState = namedtuple("HMCECState",["u","hmc_state","z_ref","ll_ref","jac_all","hess_all","ll_u"]) @@ -90,8 +90,9 @@ def _update_block(rng_key, u, n, m, g): :param n total number of data :param m subsample size :param g block size: subsample subdivision""" - if not (g > m) or (g < 1): - raise ValueError('Block size (g) needs to = or > than 1 and smaller than the subsample size {}'.format(m)) + + if (g > m) or (g < 1): + raise ValueError('Block size (g) = {} needs to = or > than 1 and smaller than the subsample size {}'.format(g,m)) rng_key_block, rng_key_index = random.split(rng_key) # uniformly choose block to update @@ -202,7 +203,15 @@ def init_kernel(init_params, find_heuristic_step_size=False, model_args=(), model_kwargs=None, - rng_key=random.PRNGKey(0)): + model = None, + ll_ref=None, + jac_all=None, + z_ref= None, + hess_all=None, + n = None, + m = None, + rng_key=random.PRNGKey(0), + subsample_method=None): """ Initializes the HMC sampler. @@ -242,26 +251,35 @@ def init_kernel(init_params, wa_steps = num_warmup trajectory_len = trajectory_length max_treedepth = max_tree_depth + if isinstance(init_params, ParamInfo): z, pe, z_grad = init_params else: z, pe, z_grad = init_params, None, None + pe_fn = potential_fn if potential_fn_gen: if pe_fn is not None: raise ValueError('Only one of `potential_fn` or `potential_fn_gen` must be provided.') else: - kwargs = {} if model_kwargs is None else model_kwargs - # print(potential_fn_gen.__code__.co_varnames) - # exit() - pe_fn = potential_fn_gen(*model_args, **kwargs) + if subsample_method == "perturb": + gpe_fn = grad_potential_fn_gen(model, model_args, model_kwargs,ll_ref,jac_all,z, z_ref, hess_all, n, m) + pe_fn = potential_fn_gen(model, model_args, model_kwargs,ll_ref, jac_all,z,z_ref ,hess_all, n, m) + else: + kwargs = {} if model_kwargs is None else model_kwargs + pe_fn = potential_fn_gen(*model_args, **kwargs) + gpe_fn = grad_potential_fn_gen(*model_args, **kwargs,) if grad_potential_fn_gen: - kwargs = {} if model_kwargs is None else model_kwargs - gpe_fn = grad_potential_fn_gen(*model_args, **kwargs) + if subsample_method == "perturb": + gpe_fn = grad_potential_fn_gen(model, model_args, model_kwargs, ll_ref, jac_all, z, z_ref, hess_all,n, m) + else: + kwargs = {} if model_kwargs is None else model_kwargs + gpe_fn = grad_potential_fn_gen(*model_args, **kwargs) else: gpe_fn = None find_reasonable_ss = None + if find_heuristic_step_size: find_reasonable_ss = partial(find_reasonable_step_size, pe_fn, @@ -280,27 +298,39 @@ def init_kernel(init_params, wa_state = wa_init(z_info, rng_key_wa, step_size, inverse_mass_matrix=inverse_mass_matrix, mass_matrix_size=jnp.size(ravel_pytree(z)[0])) + r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum) vv_init, vv_update = velocity_verlet_hmcecs(pe_fn, kinetic_fn,grad_potential_fn=gpe_fn) vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad) energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r) + hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, 0, 0., 0., False, wa_state,rng_key_hmc) return device_put(hmc_state) def _hmc_next(step_size, inverse_mass_matrix, vv_state, - model_args, model_kwargs, rng_key): + model_args, model_kwargs, rng_key,subsample_method, + model,ll_ref,jac_all,z,z_ref,hess_all,n,m): if potential_fn_gen: if grad_potential_fn_gen: - kwargs = {} if model_kwargs is None else model_kwargs - gpe_fn = grad_potential_fn_gen(*model_args, **kwargs) + if subsample_method == "perturb": + gpe_fn = grad_potential_fn_gen(model, model_args, model_kwargs, ll_ref,jac_all, vv_state.z, z_ref, hess_all, n,m) + pe_fn = potential_fn_gen(model, model_args, model_kwargs, ll_ref,jac_all, vv_state.z, z_ref, hess_all, n,m) + else: + kwargs = {} if model_kwargs is None else model_kwargs + gpe_fn = grad_potential_fn_gen(*model_args, **kwargs,) + pe_fn = potential_fn_gen(*model_args, **model_kwargs) else: gpe_fn = None + pe_fn = potential_fn_gen(*model_args, **model_kwargs) nonlocal vv_update - pe_fn = potential_fn_gen(*model_args, **model_kwargs) + #pe_fn = potential_fn_gen(*model_args, **model_kwargs) _, vv_update = velocity_verlet_hmcecs(pe_fn, kinetic_fn,gpe_fn) num_steps = _get_num_steps(step_size, trajectory_len) + #TODO: the verlet update function is taking the full model_args, instead of model_args_sub + print(model_args[0].shape) + print("####################################################") vv_state_new = fori_loop(0, num_steps, lambda i, val: vv_update(step_size, inverse_mass_matrix, val), vv_state) @@ -317,13 +347,18 @@ def _hmc_next(step_size, inverse_mass_matrix, vv_state, return vv_state, energy, num_steps, accept_prob, diverging def _nuts_next(step_size, inverse_mass_matrix, vv_state, - model_args, model_kwargs, rng_key): + model_args, model_kwargs, rng_key,subsample_method): if potential_fn_gen: nonlocal vv_update pe_fn = potential_fn_gen(*model_args, **model_kwargs) if grad_potential_fn_gen: - kwargs = {} if model_kwargs is None else model_kwargs - gpe_fn = grad_potential_fn_gen(*model_args, **kwargs) + if subsample_method == "perturbed": + model, model_args_sub, model_kwargs, ll_ref, jac_all, z, z_ref, hess_all, n, m = model_args + gpe_fn = grad_potential_fn_gen(model, model_args_sub, model_kwargs,ll_ref,jac_all, z, z_ref, n, + m) + else: + kwargs = {} if model_kwargs is None else model_kwargs + gpe_fn = grad_potential_fn_gen(*model_args, **kwargs, ) else: gpe_fn = None _, vv_update = velocity_verlet_hmcecs(pe_fn, kinetic_fn,gpe_fn) @@ -342,7 +377,7 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, _next = _nuts_next if algo == 'NUTS' else _hmc_next - def sample_kernel(hmc_state, model_args=(), model_kwargs=None): + def sample_kernel(hmc_state,model,ll_ref,jac_all,z,z_ref,hess_all,n,m,model_args=(),model_kwargs=None,subsample_method=None,): """ Given an existing :data:`~numpyro.infer.mcmc.HMCState`, run HMC with fixed (possibly adapted) step size and return a new :data:`~numpyro.infer.mcmc.HMCState`. @@ -354,16 +389,21 @@ def sample_kernel(hmc_state, model_args=(), model_kwargs=None): Hamiltonian dynamics given existing state. """ + model_kwargs = {} if model_kwargs is None else model_kwargs rng_key, rng_key_momentum, rng_key_transition = random.split(hmc_state.rng_key, 3) r = momentum_generator(hmc_state.z, hmc_state.adapt_state.mass_matrix_sqrt, rng_key_momentum) vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad) + + vv_state, energy, num_steps, accept_prob, diverging = _next(hmc_state.adapt_state.step_size, hmc_state.adapt_state.inverse_mass_matrix, vv_state, model_args, model_kwargs, - rng_key_transition) + rng_key_transition, + subsample_method, + model,ll_ref,jac_all,z,z_ref,hess_all,n,m) # not update adapt_state after warmup phase adapt_state = cond(hmc_state.i < wa_steps, (hmc_state.i, accept_prob, vv_state, hmc_state.adapt_state), @@ -488,7 +528,7 @@ def __init__(self, self._subsample_fn = None def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ref): - + "Compute the jacobian, hessian and gradient for all the data" rng_key_subsample, rng_key_model, rng_key_hmc_init, rng_key_potential, rng_key,rng_key_init_model = random.split(rng_key, 6) self._n = model_args[0].shape[0] @@ -502,16 +542,6 @@ def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ k, = self._jac_all.shape self._hess_all = hess_all.reshape((k, k)) - # Initialize the potential and its gradient - self._potential_fn = self.potential_fn_hmcecs(model=self._model, model_args=model_args, model_kwargs=model_kwargs, - ll_ref=self._ll_ref, - jac_all=self._jac_all, z_ref=self.z_ref, hess_all=self._hess_all, - n=self._n, m=self.m) - - self._grad_potential = self.grad_potencial_fn_hmcecs(model=self._model, model_args=model_args, - model_kwargs=model_kwargs, - jac_all=self._jac_all, z_ref=self.z_ref, hess_all=self._hess_all, - n=self._n, m=self.m) # Initialize the model parameters init_params, potential_fn, postprocess_fn, model_trace = initialize_model( rng_key, @@ -521,18 +551,9 @@ def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ model_args=self.model_args_sub(self._u, model_args), model_kwargs=model_kwargs) - return init_params, potential_fn, postprocess_fn, model_trace - - def potential_fn_hmcecs(self,model, model_args,model_kwargs, ll_ref, jac_all, z_ref, hess_all, n, m): - return lambda model, args, kwargs,ll_ref, jac_all, z_ref, hess_all, n, m: \ - lambda z: potential_est(model=model, model_args=args, model_kwargs=model_kwargs,ll_ref=ll_ref, jac_all=jac_all, - z=z, z_ref=z_ref, hess_all=hess_all, n=n, m=m) - def grad_potencial_fn_hmcecs(self,model, model_args, model_kwargs, jac_all,z_ref,hess_all, n, m): - return lambda model, args,kwargs ,ll_ref, jac_all, z_ref, hess_all, n, m:\ - lambda z: grad_potential(model=model, model_args=model_args, model_kwargs=model_kwargs, jac_all=jac_all, z=z, - z_ref=z_ref,hess_all=hess_all, n=n, m=m) + return init_params, potential_fn, postprocess_fn, model_trace def _init_state(self, rng_key, model_args, model_kwargs, init_params): @@ -540,12 +561,25 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): assert self.z_ref is not None, "Please provide a (i.e map) estimate for the parameters" #Initialize the subsample state self._algo = "HMC" - init_params, potential_fn, postprocess_fn, model_trace=self._init_subsample_state(rng_key, model_args, model_kwargs, init_params,self.z_ref) - # Initialize the hmc sampler - self._init_fn, self._subsample_fn = hmc(potential_fn_gen=self._potential_fn, + # Initialize the potential and its gradient + self._potential_fn = lambda model, args, kwargs, ll_ref, jac_all,z, z_ref, hess_all, n, m: \ + lambda z: potential_est(model=self._model, model_args=model_args, model_kwargs=model_kwargs, + ll_ref=self._ll_ref, + jac_all=self._jac_all, z=z, z_ref=z_ref, hess_all=hess_all, n=self._n, m=self.m) + self._grad_potential = lambda model, args, kwargs,ll_ref, jac_all,z, z_ref, hess_all, n, m:\ + lambda z: grad_potential(model=self._model, model_args=model_args, + model_kwargs=model_kwargs, + jac_all=self._jac_all,z=z, + z_ref=self.z_ref, hess_all=self._hess_all, + n=self._n, m=self.m) + # Initialize the hmc sampler: sample_fn = sample_kernel + self._init_fn, self._sample_fn = hmc(potential_fn_gen=self._potential_fn, kinetic_fn=euclidean_kinetic_energy, grad_potential_fn_gen=self._grad_potential, algo=self._algo) + + init_params, potential_fn, postprocess_fn, model_trace=self._init_subsample_state(rng_key, model_args, model_kwargs, init_params,self.z_ref) + if self._model is not None: init_params, potential_fn, postprocess_fn, model_trace = initialize_model( rng_key, @@ -608,7 +642,6 @@ def _block_indices(self,size, num_blocks): b = jnp.repeat(num_blocks - 1, size - len(jnp.repeat(jnp.arange(num_blocks - 1), size // num_blocks))) return jnp.hstack((a, b)) - def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}): """Initialize sampling algorithms""" # non-vectorized @@ -624,7 +657,6 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg ' `potential_fn`.') if self.subsample_method == "perturb": - hmc_init_fn = lambda init_params,rng_key: self._init_fn(init_params=init_params, num_warmup = num_warmup, step_size = self._step_size, @@ -635,36 +667,64 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg trajectory_length=self._trajectory_length, max_tree_depth=self._max_tree_depth, find_heuristic_step_size=self._find_heuristic_step_size, - model_args=(self._model,self.model_args_sub(self._u,model_args),model_kwargs,self._ll_ref,self._jac_all,self.z_ref,self._hess_all,self._n,self.m), - model_kwargs=model_kwargs) - if rng_key.ndim ==1: - - init_state = hmc_init_fn(init_params, rng_key) - self._ll_u = self._potential_fn(self._model, self.model_args_sub(self._u, model_args), self._ll_ref, - self._jac_all, self._hess_all, - init_state.z, self.z_ref, self._n, self.m) + model_args=self.model_args_sub(self._u,model_args), + model_kwargs=model_kwargs, + subsample_method= self.subsample_method, + model=self._model, + ll_ref =self._ll_ref, + jac_all=self._jac_all, + z_ref=self.z_ref, + hess_all = self._hess_all, + n=self._n,m=self.m) - init_subsample_state = HMCECSState(u=self._u, hmc_state=init_state, z_ref=self.z_ref, ll_u=self._ll_u, - jac_all=self._jac_all, - hess_all=self._hess_all, ll_ref=self._ll_ref) - return init_state,device_put(init_subsample_state) + if rng_key.ndim ==1: + init_state = hmc_init_fn(init_params, rng_key) #HMCState + self._ll_u = potential_est(self._model, + self.model_args_sub(self._u, model_args), + model_kwargs, + self._ll_ref, + self._jac_all, + self._hess_all, + init_state.z, + self.z_ref, + self._n, + self.m) + hmc_init_sub_fn = lambda init_params, rng_key: HMCECSState(u=self._u, + hmc_state=init_state, + z_ref=self.z_ref, + ll_u=self._ll_u, + jac_all=self._jac_all, + hess_all=self._hess_all, + ll_ref=self._ll_ref) + + init_sub_state = hmc_init_sub_fn(init_params,rng_key) #HMCState + + HMCCombinedState = tuplemerge(init_state._asdict(),init_sub_state._asdict()) + + + return HMCCombinedState else: + #For more than 2 chains # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth, # wa_steps because those variables do not depend on traced args: init_params, rng_key. init_state = vmap(hmc_init_fn)(init_params, rng_key) - self._ll_u = self._potential_fn(self._model, self.model_args_sub(self._u, model_args), self._ll_ref, + self._ll_u = potential_est(self._model, self.model_args_sub(self._u, model_args), model_kwargs,self._ll_ref, self._jac_all, self._hess_all, init_state.z, self.z_ref, self._n, self.m) - init_subsample_state = HMCECSState(u=self._u, hmc_state=init_state, z_ref=self.z_ref, ll_u=self._ll_u, + hmc_init_sub_fn = lambda init_params, rng_key: HMCECSState(u=self._u, hmc_state=init_state, z_ref=self.z_ref, ll_u=self._ll_u, jac_all=self._jac_all, hess_all=self._hess_all, ll_ref=self._ll_ref) + + init_subsample_state = vmap(hmc_init_sub_fn)(init_params,rng_key) + sample_fn = vmap(self._sample_fn, in_axes=(0, None, None)) subsample = vmap(self._subsample_fn, in_axes=(0,None,None)) + HMCCombinedState = tuplemerge(init_state._asdict,init_subsample_state._asdict()) self._sample_fn = sample_fn self._subsample_fn = subsample - return init_state, device_put(init_subsample_state) + return HMCCombinedState else: hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731 @@ -709,56 +769,61 @@ def sample(self, state, model_args, model_kwargs): :param model_kwargs: Keyword arguments provided to the model. :return: Next `state` after running HMC. """ - return self._sample_fn(state, model_args, model_kwargs) - def subsample(self,subsamplestate,model_args,model_kwargs): - """ - Run HMC from the given :data:`~numpyro.infer.hmc.HMCECSState` and return the resulting - :data:`~numpyro.infer.hmc.HMCECSState`. - - :param HMCECSState state: Represents the current state. - :param model_args: Arguments provided to the model. - :param model_kwargs: Keyword arguments provided to the model. - :return: Next `subsample state` after running HMC. - """ - - rng_key_subsample, rng_key_transition, rng_key_likelihood, rng_key = random.split(subsamplestate.hmc_state.rng_key,4) - - u_new = _update_block(rng_key_subsample, subsamplestate.u, self._n, self.m, self.g) - - # estimate likelihood of subsample with single block updated - llu_new = self._potential_fn(model=self._model, - model_args=model_args, - model_kwargs=model_kwargs, - jac_all=subsamplestate.jac_all, - hess_all=subsamplestate.hess_all, - ll_ref=subsamplestate.ll_ref, - z=subsamplestate.hmc_state.z, - z_ref=subsamplestate.z_ref, - n=self._n, m=self.m) - - # accept new subsample with probability min(1,L^{hat}_{u_new}(z) - L^{hat}_{u}(z)) - # NOTE: latent variables (z aka theta) same, subsample indices (u) different by one block. - accept_prob = jnp.clip(jnp.exp(-llu_new + subsamplestate.ll_u), a_max=1.) - transition = random.bernoulli(rng_key_transition, accept_prob) - u, ll_u = cond(transition, - (u_new, llu_new), identity, - (subsamplestate.u, subsamplestate.ll_u), identity) - - ######## UPDATE PARAMETERS ########## - - - hmc_subsamplestate= HMCECSState(u=u, hmc_state=subsamplestate.hmc_state, - z_ref=subsamplestate.z_ref, - ll_u=ll_u,ll_ref=subsamplestate.ll_ref, - jac_all=subsamplestate.jac_all, - hess_all=subsamplestate.hess_all) - - return self._subsample_fn(hmc_subsamplestate,model_args=(self._model, - model_args, - subsamplestate.ll_ref, - subsamplestate.jac_all, - subsamplestate.z_ref, - subsamplestate.hess_all, self._n, self.m),model_kwargs=model_kwargs) + + if self.subsample_method == "perturb": + rng_key_subsample, rng_key_transition, rng_key_likelihood, rng_key = random.split( + state.hmc_state.rng_key, 4) + + u_new = _update_block(rng_key_subsample, state.u, self._n, self.m, self.g) + + # estimate likelihood of subsample with single block updated + llu_new = potential_est(model=self._model, + model_args=self.model_args_sub(u_new,model_args), + model_kwargs=model_kwargs, + ll_ref = state.ll_ref, + jac_all=state.jac_all, + hess_all= state.hess_all, + z=state.hmc_state.z, + z_ref=state.z_ref, + n=self._n, m=self.m) + # accept new subsample with probability min(1,L^{hat}_{u_new}(z) - L^{hat}_{u}(z)) + # NOTE: latent variables (z aka theta) same, subsample indices (u) different by one block. + + accept_prob = jnp.clip(jnp.exp(-llu_new + state.ll_u), a_max=1.) + transition = random.bernoulli(rng_key_transition, accept_prob) + u, ll_u = cond(transition, + (u_new, llu_new), identity, + (state.u, state.ll_u), identity) + + ######## UPDATE PARAMETERS ########## + + hmc_subsamplestate = HMCECSState(u=u, hmc_state=state.hmc_state, + z_ref=state.z_ref, + ll_u=ll_u, ll_ref=state.ll_ref, + jac_all=state.jac_all, + hess_all=state.hess_all) + hmc_subsamplestate = tuplemerge(hmc_subsamplestate._asdict(),state._asdict()) + print(model_args[0].shape) + + return self._sample_fn(hmc_subsamplestate, + model_args=self.model_args_sub(u,model_args), + model_kwargs=model_kwargs, + subsample_method=self.subsample_method, + model = self._model, + ll_ref = state.ll_ref, + jac_all =state.jac_all, + z= state.z, + z_ref = state.z_ref, + hess_all = state.hess_all, + n= self._n, + m= self.m) + + else: + return self._sample_fn(state, model_args, model_kwargs) + + + + class NUTS(HMC): """ diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py index d487894a9..a664c93c0 100644 --- a/numpyro/contrib/hmcecs_utils.py +++ b/numpyro/contrib/hmcecs_utils.py @@ -70,7 +70,7 @@ def log_density_hmcecs(model, model_args, model_kwargs, params,prior=False): return log_joint, model_trace def grad_potential(model, model_args, model_kwargs,z, z_ref, jac_all, hess_all, n, m, *args, **kwargs): - """Calculate the gradient of the potential energy function""" + """Calculate the gradient of the potential energy function for the current subsample""" k, = jac_all.shape z_flat, treedef = ravel_pytree(z) zref_flat, _ = ravel_pytree(z_ref) @@ -104,6 +104,15 @@ def grad_potential(model, model_args, model_kwargs,z, z_ref, jac_all, hess_all, return treedef(gradll - jac_sub) +def reducer( accum, d ): + accum.update(d) + return accum + +def tuplemerge( *dictionaries ): + from functools import reduce + merged = reduce( reducer, dictionaries, {} ) + return namedtuple('HMCCombinedState', merged )(**merged) # <==== Gist of the gist + def potential_est(model, model_args, model_kwargs,ll_ref, jac_all, hess_all, z, z_ref, n, m): """Estimate the potential dynamic energy for the HMC ECS implementation. The calculation follows section 7.2.1 in https://jmlr.org/papers/volume18/15-205/15-205.pdf The computation has a complexity of O(1) and it's highly dependant on the quality of the map estimate""" @@ -112,8 +121,10 @@ def potential_est(model, model_args, model_kwargs,ll_ref, jac_all, hess_all, z, z_flat, _ = ravel_pytree(z) zref_flat, _ = ravel_pytree(z_ref) - z_diff = z_flat - zref_flat + z_diff = z_flat - zref_flat + #print(model_args[0].shape) + #print("........................................................................................") ld_fn = lambda args: partial(log_density_hmcecs, model, model_args, model_kwargs,prior=False)(args)[0] jac_sub, _ = ravel_pytree(jax.jacfwd(ld_fn)(z_ref)) @@ -146,6 +157,7 @@ def velocity_verlet_hmcecs(potential_fn, kinetic_fn, grad_potential_fn=None): inverse mass matrix and momentum. :return: a pair of (`init_fn`, `update_fn`). """ + compute_value_grad = value_and_grad(potential_fn) if grad_potential_fn is None \ else lambda z: (potential_fn(z), grad_potential_fn(z)) @@ -160,6 +172,7 @@ def init_fn(z, r, potential_energy=None, z_grad=None): if potential_energy is None or z_grad is None: potential_energy, z_grad = compute_value_grad(z) + return IntegratorState(z, r, potential_energy, z_grad) def update_fn(step_size, inverse_mass_matrix, state): @@ -171,6 +184,7 @@ def update_fn(step_size, inverse_mass_matrix, state): :return: new state for the integrator. """ z, r, _, z_grad = state + r = tree_multimap(lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad) # r(n+1/2) r_grad = grad(kinetic_fn, argnums=1)(inverse_mass_matrix, r) z = tree_multimap(lambda z, r_grad: z + step_size * r_grad, z, r_grad) # z(n+1) diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index dfe0155a6..c8b05fc6f 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -306,6 +306,7 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields): postprocess_fn = self.sampler.postprocess_fn(args, kwargs) else: postprocess_fn = self.postprocess_fn + diagnostics = lambda x: self.sampler.get_diagnostics_str(x[0]) if rng_key.ndim == 1 else '' # noqa: E731 init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,) lower_idx = self._collection_params["lower"] From cc2f1b0dc38f49e70f01930ac7f0f39491b1b209 Mon Sep 17 00:00:00 2001 From: Lys Date: Fri, 25 Sep 2020 18:50:15 +0200 Subject: [PATCH 09/93] FIXED? shapes error new problem stuff being 1D (precision) --- numpyro/contrib/hmcecs.py | 155 ++++++++++++++++++-------------- numpyro/contrib/hmcecs_utils.py | 50 ++++++++--- numpyro/infer/mcmc.py | 7 +- 3 files changed, 132 insertions(+), 80 deletions(-) diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index 99778e549..2a10f3fd1 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -21,7 +21,8 @@ from numpyro.infer.mcmc import MCMCKernel from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model, log_density from numpyro.util import cond, fori_loop, identity -from numpyro.contrib.hmcecs_utils import grad_potential,potential_est,log_density_hmcecs, velocity_verlet_hmcecs, init_near_values,tuplemerge +from numpyro.contrib.hmcecs_utils import grad_potential,potential_est,log_density_hmcecs, \ + velocity_verlet_hmcecs, init_near_values,tuplemerge,model_args_sub,model_kwargs_sub HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob', 'mean_accept_prob', 'diverging', 'adapt_state','rng_key']) HMCECSState = namedtuple("HMCECState",["u","hmc_state","z_ref","ll_ref","jac_all","hess_all","ll_u"]) @@ -96,7 +97,7 @@ def _update_block(rng_key, u, n, m, g): rng_key_block, rng_key_index = random.split(rng_key) # uniformly choose block to update - chosen_block = random.randint(rng_key, shape=(), minval= 0, maxval=g + 1) #TODO: assertions for g values? why minval=0?division by 0 + chosen_block = random.randint(rng_key, shape=(), minval= 0, maxval=g + 1) idxs_new = random.randint(rng_key_index, shape=(m // g,), minval=0, maxval=n) #chose block within the subsample to update @@ -208,8 +209,10 @@ def init_kernel(init_params, jac_all=None, z_ref= None, hess_all=None, + ll_u = None, n = None, m = None, + u= None, rng_key=random.PRNGKey(0), subsample_method=None): """ @@ -263,7 +266,8 @@ def init_kernel(init_params, raise ValueError('Only one of `potential_fn` or `potential_fn_gen` must be provided.') else: if subsample_method == "perturb": - gpe_fn = grad_potential_fn_gen(model, model_args, model_kwargs,ll_ref,jac_all,z, z_ref, hess_all, n, m) + #model, model_args, model_kwargs,z, z_ref, jac_all, hess_all, n, m,u=None + gpe_fn = grad_potential_fn_gen(model, model_args, model_kwargs,z, z_ref,jac_all, hess_all, n, m,u) pe_fn = potential_fn_gen(model, model_args, model_kwargs,ll_ref, jac_all,z,z_ref ,hess_all, n, m) else: kwargs = {} if model_kwargs is None else model_kwargs @@ -271,7 +275,8 @@ def init_kernel(init_params, gpe_fn = grad_potential_fn_gen(*model_args, **kwargs,) if grad_potential_fn_gen: if subsample_method == "perturb": - gpe_fn = grad_potential_fn_gen(model, model_args, model_kwargs, ll_ref, jac_all, z, z_ref, hess_all,n, m) + #model, model_args, model_kwargs,z, z_ref, jac_all, hess_all, n, m,u=None + gpe_fn = grad_potential_fn_gen(model, model_args, model_kwargs, z, z_ref,jac_all, hess_all,n, m,u) else: kwargs = {} if model_kwargs is None else model_kwargs gpe_fn = grad_potential_fn_gen(*model_args, **kwargs) @@ -302,19 +307,30 @@ def init_kernel(init_params, r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum) vv_init, vv_update = velocity_verlet_hmcecs(pe_fn, kinetic_fn,grad_potential_fn=gpe_fn) vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad) + energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r) hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, 0, 0., 0., False, wa_state,rng_key_hmc) + hmc_sub_state = HMCECSState(u=3, hmc_state=hmc_state, z_ref=z_ref, ll_ref=ll_ref, jac_all=jac_all, + hess_all=hess_all, ll_u=ll_u) + # if subsample_method == "perturb": + # hmc_sub_state = HMCECSState(u=3,hmc_state=hmc_state,z_ref=z_ref,ll_ref=ll_ref,jac_all=jac_all,hess_all=hess_all,ll_u=ll_u) + # return device_put(hmc_sub_state) + # else: + # return device_put(hmc_state) + hmc_state = tuplemerge(hmc_sub_state._asdict(),hmc_state._asdict()) + return device_put(hmc_state) def _hmc_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key,subsample_method, - model,ll_ref,jac_all,z,z_ref,hess_all,n,m): + model,ll_ref,jac_all,z,z_ref,hess_all,ll_u,u,n,m): if potential_fn_gen: if grad_potential_fn_gen: if subsample_method == "perturb": - gpe_fn = grad_potential_fn_gen(model, model_args, model_kwargs, ll_ref,jac_all, vv_state.z, z_ref, hess_all, n,m) + #model, model_args, model_kwargs,z, z_ref, jac_all, hess_all, n, m,u=None + gpe_fn = grad_potential_fn_gen(model, model_args, model_kwargs,vv_state.z, z_ref,jac_all, hess_all, n,m,u) pe_fn = potential_fn_gen(model, model_args, model_kwargs, ll_ref,jac_all, vv_state.z, z_ref, hess_all, n,m) else: kwargs = {} if model_kwargs is None else model_kwargs @@ -325,15 +341,16 @@ def _hmc_next(step_size, inverse_mass_matrix, vv_state, pe_fn = potential_fn_gen(*model_args, **model_kwargs) nonlocal vv_update #pe_fn = potential_fn_gen(*model_args, **model_kwargs) - _, vv_update = velocity_verlet_hmcecs(pe_fn, kinetic_fn,gpe_fn) + _, vv_update = velocity_verlet_hmcecs(pe_fn, kinetic_fn,gpe_fn) #TODO:vv_update might be updating wrong num_steps = _get_num_steps(step_size, trajectory_len) - #TODO: the verlet update function is taking the full model_args, instead of model_args_sub - print(model_args[0].shape) - print("####################################################") + vv_state_new = fori_loop(0, num_steps, - lambda i, val: vv_update(step_size, inverse_mass_matrix, val), + lambda i, val: vv_update(step_size, inverse_mass_matrix, val,u), #TODO added u vv_state) + # vv_state_new = fori_loop(0, num_steps, + # lambda i, val: vv_update(step_size, inverse_mass_matrix, val), + # vv_state) energy_old = vv_state.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state.r) energy_new = vv_state_new.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state_new.r) delta_energy = energy_new - energy_old @@ -347,15 +364,16 @@ def _hmc_next(step_size, inverse_mass_matrix, vv_state, return vv_state, energy, num_steps, accept_prob, diverging def _nuts_next(step_size, inverse_mass_matrix, vv_state, - model_args, model_kwargs, rng_key,subsample_method): + model_args, model_kwargs, rng_key,subsample_method, + model=None,ll_ref=None,jac_all=None,z = None,z_ref=None,hess_all=None,ll_u=None,u=None,n=None,m=None): if potential_fn_gen: nonlocal vv_update pe_fn = potential_fn_gen(*model_args, **model_kwargs) if grad_potential_fn_gen: if subsample_method == "perturbed": - model, model_args_sub, model_kwargs, ll_ref, jac_all, z, z_ref, hess_all, n, m = model_args - gpe_fn = grad_potential_fn_gen(model, model_args_sub, model_kwargs,ll_ref,jac_all, z, z_ref, n, - m) + #model, model_args_sub, model_kwargs, ll_ref, jac_all, z, z_ref, hess_all,ll_u,u, n, m = model_args + gpe_fn = grad_potential_fn_gen(model, model_args, model_kwargs,ll_ref,jac_all, z, z_ref, n, + m,u) else: kwargs = {} if model_kwargs is None else model_kwargs gpe_fn = grad_potential_fn_gen(*model_args, **kwargs, ) @@ -377,7 +395,8 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, _next = _nuts_next if algo == 'NUTS' else _hmc_next - def sample_kernel(hmc_state,model,ll_ref,jac_all,z,z_ref,hess_all,n,m,model_args=(),model_kwargs=None,subsample_method=None,): + def sample_kernel(hmc_state,model_args=(),model_kwargs=None,subsample_method=None, + model=None,ll_ref=None,jac_all=None,z=None,z_ref=None,hess_all=None,ll_u=None,u=None,n=None,m=None,): """ Given an existing :data:`~numpyro.infer.mcmc.HMCState`, run HMC with fixed (possibly adapted) step size and return a new :data:`~numpyro.infer.mcmc.HMCState`. @@ -391,19 +410,20 @@ def sample_kernel(hmc_state,model,ll_ref,jac_all,z,z_ref,hess_all,n,m,model_args """ model_kwargs = {} if model_kwargs is None else model_kwargs + #if subsample_method =="perturb": + # hmc_state = hmc_state.hmc_state rng_key, rng_key_momentum, rng_key_transition = random.split(hmc_state.rng_key, 3) r = momentum_generator(hmc_state.z, hmc_state.adapt_state.mass_matrix_sqrt, rng_key_momentum) vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad) - vv_state, energy, num_steps, accept_prob, diverging = _next(hmc_state.adapt_state.step_size, hmc_state.adapt_state.inverse_mass_matrix, vv_state, - model_args, + model_args_sub(u,model_args), model_kwargs, rng_key_transition, subsample_method, - model,ll_ref,jac_all,z,z_ref,hess_all,n,m) + model,ll_ref,jac_all,z,z_ref,hess_all,ll_u,u,n,m) # not update adapt_state after warmup phase adapt_state = cond(hmc_state.i < wa_steps, (hmc_state.i, accept_prob, vv_state, hmc_state.adapt_state), @@ -414,9 +434,17 @@ def sample_kernel(hmc_state,model,ll_ref,jac_all,z,z_ref,hess_all,n,m,model_args itr = hmc_state.i + 1 n = jnp.where(hmc_state.i < wa_steps, itr, itr - wa_steps) mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n - - return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps, + hmcstate = HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps, accept_prob, mean_accept_prob, diverging, adapt_state,rng_key) + hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state, z_ref=z_ref, ll_ref=ll_ref, jac_all=jac_all, + hess_all=hess_all, ll_u=ll_u) + # if subsample_method == "perturbed": + # hmc_sub_state = HMCECSState(u=u,hmc_state=hmc_state,z_ref=z_ref,ll_ref=ll_ref,jac_all=jac_all,hess_all=hess_all,ll_u=ll_u) + # return hmc_sub_state + # else: + # return hmcstate + hmcstate = tuplemerge(hmc_sub_state._asdict(),hmcstate._asdict()) + return hmcstate # Make `init_kernel` and `sample_kernel` visible from the global scope once # `hmc` is called for sphinx doc generation. @@ -548,7 +576,7 @@ def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ self._model, init_strategy=partial(init_near_values, values=self.z_ref), dynamic_args=True, - model_args=self.model_args_sub(self._u, model_args), + model_args=model_args_sub(self._u, model_args), model_kwargs=model_kwargs) @@ -561,17 +589,17 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): assert self.z_ref is not None, "Please provide a (i.e map) estimate for the parameters" #Initialize the subsample state self._algo = "HMC" - # Initialize the potential and its gradient + # Initialize the potential and gradient potential functions self._potential_fn = lambda model, args, kwargs, ll_ref, jac_all,z, z_ref, hess_all, n, m: \ lambda z: potential_est(model=self._model, model_args=model_args, model_kwargs=model_kwargs, ll_ref=self._ll_ref, - jac_all=self._jac_all, z=z, z_ref=z_ref, hess_all=hess_all, n=self._n, m=self.m) + jac_all=self._jac_all, z=z, z_ref=z_ref, hess_all=hess_all, n=self._n, m=self.m,u=self._u) self._grad_potential = lambda model, args, kwargs,ll_ref, jac_all,z, z_ref, hess_all, n, m:\ lambda z: grad_potential(model=self._model, model_args=model_args, model_kwargs=model_kwargs, jac_all=self._jac_all,z=z, z_ref=self.z_ref, hess_all=self._hess_all, - n=self._n, m=self.m) + n=self._n, m=self.m,u=self._u) # Initialize the hmc sampler: sample_fn = sample_kernel self._init_fn, self._sample_fn = hmc(potential_fn_gen=self._potential_fn, kinetic_fn=euclidean_kinetic_energy, @@ -621,21 +649,6 @@ def get_diagnostics_str(self, state): return '{} steps of size {:.2e}. acc. prob={:.2f}'.format(state.num_steps, state.adapt_state.step_size, state.mean_accept_prob) - def model_args_sub(self,u,model_args): - """Subsample observations and features according to u subsample indexes""" - args = [] - for arg in model_args: - if isinstance(arg, jnp.ndarray): - args.append(jnp.take(arg, u, axis=0)) - else: - args.append(arg) - return tuple(args) - def model_kwargs_sub(self,u, kwargs): - """Subsample observations and features""" - for key_arg, val_arg in kwargs.items(): - if key_arg == "observations" or key_arg == "features": - kwargs[key_arg] = jnp.take(val_arg, u, axis=0) - return kwargs def _block_indices(self,size, num_blocks): a = jnp.repeat(jnp.arange(num_blocks - 1), size // num_blocks) @@ -667,7 +680,7 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg trajectory_length=self._trajectory_length, max_tree_depth=self._max_tree_depth, find_heuristic_step_size=self._find_heuristic_step_size, - model_args=self.model_args_sub(self._u,model_args), + model_args=model_args_sub(self._u,model_args), model_kwargs=model_kwargs, subsample_method= self.subsample_method, model=self._model, @@ -675,12 +688,14 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg jac_all=self._jac_all, z_ref=self.z_ref, hess_all = self._hess_all, - n=self._n,m=self.m) + n=self._n,m=self.m, + u = self._u) if rng_key.ndim ==1: - init_state = hmc_init_fn(init_params, rng_key) #HMCState + init_state = hmc_init_fn(init_params, rng_key) #HMCState + HMCECSState + self._ll_u = potential_est(self._model, - self.model_args_sub(self._u, model_args), + model_args_sub(self._u, model_args), model_kwargs, self._ll_ref, self._jac_all, @@ -688,30 +703,31 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg init_state.z, self.z_ref, self._n, - self.m) - hmc_init_sub_fn = lambda init_params, rng_key: HMCECSState(u=self._u, - hmc_state=init_state, - z_ref=self.z_ref, - ll_u=self._ll_u, - jac_all=self._jac_all, - hess_all=self._hess_all, - ll_ref=self._ll_ref) - - init_sub_state = hmc_init_sub_fn(init_params,rng_key) #HMCState - - HMCCombinedState = tuplemerge(init_state._asdict(),init_sub_state._asdict()) - - - return HMCCombinedState + self.m, + u = self._u) + # hmc_init_sub_fn = lambda init_params, rng_key: HMCECSState(u=self._u, + # hmc_state=init_state.hmc_state, + # z_ref=self.z_ref, + # ll_u=self._ll_u, + # jac_all=self._jac_all, + # hess_all=self._hess_all, + # ll_ref=self._ll_ref) + # + # init_sub_state = hmc_init_sub_fn(init_params,rng_key) #HMCState + # + # init_sub_state = tuplemerge(init_state._asdict(),init_sub_state._asdict()) + # print(init_sub_state._fields) + # exit() + return init_state else: #For more than 2 chains # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth, # wa_steps because those variables do not depend on traced args: init_params, rng_key. init_state = vmap(hmc_init_fn)(init_params, rng_key) - self._ll_u = potential_est(self._model, self.model_args_sub(self._u, model_args), model_kwargs,self._ll_ref, + self._ll_u = potential_est(self._model, model_args_sub(self._u, model_args), model_kwargs,self._ll_ref, self._jac_all, self._hess_all, - init_state.z, self.z_ref, self._n, self.m) + init_state.z, self.z_ref, self._n, self.m,self._u) hmc_init_sub_fn = lambda init_params, rng_key: HMCECSState(u=self._u, hmc_state=init_state, z_ref=self.z_ref, ll_u=self._ll_u, jac_all=self._jac_all, @@ -771,21 +787,21 @@ def sample(self, state, model_args, model_kwargs): """ if self.subsample_method == "perturb": + rng_key_subsample, rng_key_transition, rng_key_likelihood, rng_key = random.split( - state.hmc_state.rng_key, 4) + state.rng_key, 4) u_new = _update_block(rng_key_subsample, state.u, self._n, self.m, self.g) - # estimate likelihood of subsample with single block updated llu_new = potential_est(model=self._model, - model_args=self.model_args_sub(u_new,model_args), + model_args=model_args_sub(u_new,model_args), model_kwargs=model_kwargs, ll_ref = state.ll_ref, jac_all=state.jac_all, hess_all= state.hess_all, - z=state.hmc_state.z, + z=state.z, z_ref=state.z_ref, - n=self._n, m=self.m) + n=self._n, m=self.m,u=self._u) # accept new subsample with probability min(1,L^{hat}_{u_new}(z) - L^{hat}_{u}(z)) # NOTE: latent variables (z aka theta) same, subsample indices (u) different by one block. @@ -794,7 +810,7 @@ def sample(self, state, model_args, model_kwargs): u, ll_u = cond(transition, (u_new, llu_new), identity, (state.u, state.ll_u), identity) - + self._u = u ######## UPDATE PARAMETERS ########## hmc_subsamplestate = HMCECSState(u=u, hmc_state=state.hmc_state, @@ -803,10 +819,9 @@ def sample(self, state, model_args, model_kwargs): jac_all=state.jac_all, hess_all=state.hess_all) hmc_subsamplestate = tuplemerge(hmc_subsamplestate._asdict(),state._asdict()) - print(model_args[0].shape) return self._sample_fn(hmc_subsamplestate, - model_args=self.model_args_sub(u,model_args), + model_args=model_args, model_kwargs=model_kwargs, subsample_method=self.subsample_method, model = self._model, @@ -815,6 +830,8 @@ def sample(self, state, model_args, model_kwargs): z= state.z, z_ref = state.z_ref, hess_all = state.hess_all, + ll_u = ll_u, + u= u, n= self._n, m= self.m) diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py index a664c93c0..8284772b7 100644 --- a/numpyro/contrib/hmcecs_utils.py +++ b/numpyro/contrib/hmcecs_utils.py @@ -15,6 +15,34 @@ IntegratorState.__new__.__defaults__ = (None,) * len(IntegratorState._fields) +def model_args_sub(u, model_args): + """Subsample observations and features according to u subsample indexes""" + if isinstance(model_args,dict): + args = {} + for key, val in model_args.items(): + if isinstance(val, jnp.ndarray) and val.shape[0] > len(u): + args[key] = jnp.take(val, u, axis=0) + else: + args[key] = val + return args + + else: + args = [] + for arg in model_args: + if isinstance(arg, jnp.ndarray) and arg.shape[0] > len(u): + args.append(jnp.take(arg, u, axis=0)) + else: + args.append(arg) + return tuple(args) + + +def model_kwargs_sub(u, kwargs): + """Subsample observations and features""" + for key_arg, val_arg in kwargs.items(): + if key_arg == "observations" or key_arg == "features": + kwargs[key_arg] = jnp.take(val_arg, u, axis=0) + return kwargs + def log_density_hmcecs(model, model_args, model_kwargs, params,prior=False): """ (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given @@ -41,12 +69,8 @@ def log_density_hmcecs(model, model_args, model_kwargs, params,prior=False): if intermediates: log_prob = site['fn'].log_prob(value, intermediates) else: - #print(site["name"]) - #print("value shape") - #print(value.shape) log_prob = site['fn'].log_prob(value) #TODO: The shape here is duplicated - #print("Log prob shape") - #print(log_prob.shape) + if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob @@ -69,8 +93,10 @@ def log_density_hmcecs(model, model_args, model_kwargs, params,prior=False): log_joint = log_joint + log_prob return log_joint, model_trace -def grad_potential(model, model_args, model_kwargs,z, z_ref, jac_all, hess_all, n, m, *args, **kwargs): +def grad_potential(model, model_args, model_kwargs,z, z_ref, jac_all, hess_all, n, m,u=None, *args, **kwargs): """Calculate the gradient of the potential energy function for the current subsample""" + if any(arg.shape[0] > m for arg in model_args): + model_args = model_args_sub(u,model_args) k, = jac_all.shape z_flat, treedef = ravel_pytree(z) zref_flat, _ = ravel_pytree(z_ref) @@ -113,9 +139,13 @@ def tuplemerge( *dictionaries ): merged = reduce( reducer, dictionaries, {} ) return namedtuple('HMCCombinedState', merged )(**merged) # <==== Gist of the gist -def potential_est(model, model_args, model_kwargs,ll_ref, jac_all, hess_all, z, z_ref, n, m): +def potential_est(model, model_args, model_kwargs,ll_ref, jac_all, hess_all, z, z_ref, n, m,u=None): """Estimate the potential dynamic energy for the HMC ECS implementation. The calculation follows section 7.2.1 in https://jmlr.org/papers/volume18/15-205/15-205.pdf The computation has a complexity of O(1) and it's highly dependant on the quality of the map estimate""" + + if any(arg.shape[0] > m for arg in model_args): + model_args = model_args_sub(u,model_args) + # Agrees with reference upto constant factor on prior k, = jac_all.shape # number of features z_flat, _ = ravel_pytree(z) @@ -123,8 +153,7 @@ def potential_est(model, model_args, model_kwargs,ll_ref, jac_all, hess_all, z, z_diff = z_flat - zref_flat - #print(model_args[0].shape) - #print("........................................................................................") + ld_fn = lambda args: partial(log_density_hmcecs, model, model_args, model_kwargs,prior=False)(args)[0] jac_sub, _ = ravel_pytree(jax.jacfwd(ld_fn)(z_ref)) @@ -175,7 +204,7 @@ def init_fn(z, r, potential_energy=None, z_grad=None): return IntegratorState(z, r, potential_energy, z_grad) - def update_fn(step_size, inverse_mass_matrix, state): + def update_fn(step_size, inverse_mass_matrix, state,u=None): """ :param float step_size: Size of a single step. :param inverse_mass_matrix: Inverse of mass matrix, which is used to @@ -190,6 +219,7 @@ def update_fn(step_size, inverse_mass_matrix, state): z = tree_multimap(lambda z, r_grad: z + step_size * r_grad, z, r_grad) # z(n+1) potential_energy, z_grad = compute_value_grad(z) r = tree_multimap(lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad) # r(n+1) + #return IntegratorState(z, r, potential_energy, z_grad) return IntegratorState(z, r, potential_energy, z_grad) return init_fn, update_fn diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index c8b05fc6f..b18bcbaa8 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -168,6 +168,7 @@ def _sample_fn_jit_args(state, sampler): def _sample_fn_nojit_args(state, sampler, args, kwargs): # state is a tuple of size 1 - containing HMCState + return sampler.sample(state[0], args, kwargs), @@ -175,7 +176,11 @@ def _collect_fn(collect_fields): @cached_by(_collect_fn, collect_fields) def collect(x): if collect_fields: - return attrgetter(*collect_fields)(x[0]) + # f = getattr(x[0], '_fields', None) + # if any(n == "hmc_state" for n in f): + # return attrgetter(*collect_fields)(x[0].hmc_state) + # else: + return attrgetter(*collect_fields)(x[0]) else: return x[0] From c3de2536ab91ee7d57e1962a81a50c2e673c801c Mon Sep 17 00:00:00 2001 From: Lys Date: Mon, 28 Sep 2020 17:49:54 +0200 Subject: [PATCH 10/93] Sampling working Added dataset --- .gitignore | 2 + examples/logistic_hmcecs.py | 73 +++++++++++++++++++++++++++------ numpyro/contrib/hmcecs.py | 46 ++++++++++++--------- numpyro/contrib/hmcecs_utils.py | 2 +- numpyro/examples/datasets.py | 35 +++++++++++++++- 5 files changed, 122 insertions(+), 36 deletions(-) diff --git a/.gitignore b/.gitignore index 4259586b1..c335c476e 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,5 @@ numpyro/examples/.data # docs docs/build docs/.DS_Store + +examples/HIGGS.csv.gz diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index 91782367a..de0bdeb5c 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -4,8 +4,17 @@ import numpyro import numpyro.distributions as dist from numpyro.infer import NUTS, MCMC, Predictive -from numpyro.contrib.hmcecs import HMC +import sys, os +from jax.config import config +import datetime +sys.path.append('/home/lys/Dropbox/PhD/numpyro/numpyro/contrib/') +sys.path.append('/home/lys/Dropbox/PhD/numpyro/numpyro/examples/') + +from hmcecs import HMC +#from numpyro.contrib.hmcecs import HMC + from sklearn.datasets import load_breast_cancer +from datasets import _load_higgs numpyro.set_platform("cpu") # TODO: import Higgs data! ---> http://archive.ics.uci.edu/ml/machine-learning-databases/00280/ @@ -16,8 +25,8 @@ def model(feats, obs): """ n, m = feats.shape - precision = numpyro.sample('precision', dist.continuous.Uniform(0, 4)) - #precision = 0.5 + #precision = numpyro.sample('precision', dist.continuous.Uniform(1, 4)) + precision = 0.5 theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), precision * jnp.ones(m))) numpyro.sample('obs', dist.Bernoulli(logits=jnp.matmul(feats, theta)), obs=obs) @@ -31,7 +40,7 @@ def infer_nuts(rng_key, feats, obs, samples=5, warmup=0, ): return mcmc.get_samples() -def infer_hmcecs(rng_key, feats, obs, g=2, samples=10, warmup=5, ): +def infer_hmcecs(rng_key, feats, obs, m=50,g=20,samples=10, warmup=5, ): hmcecs_key, map_key = jax.random.split(rng_key) n, _ = feats.shape @@ -43,7 +52,7 @@ def infer_hmcecs(rng_key, feats, obs, g=2, samples=10, warmup=5, ): #Observations = (569,1) #Features = (569,31) print("Running MCMC subsampling") - kernel = HMC(model=model,z_ref=z_map,m=5,g=2,subsample_method="perturb") + kernel = HMC(model=model,z_ref=z_map,m=m,g=g,subsample_method="perturb") mcmc = MCMC(kernel,num_warmup=warmup,num_samples=samples) mcmc.run(rng_key,feats,obs) return mcmc.get_samples() @@ -56,25 +65,63 @@ def breast_cancer_data(): feats = (feats - feats.mean(0)) / feats.std(0) feats = jnp.hstack((feats, jnp.ones((feats.shape[0], 1)))) - return feats[:10], dataset.target[:10] + return feats[:500], dataset.target[:500] def higgs_data(): - return - + observations,features = _load_higgs() + return observations[:1000],features[:1000] + + +def Plot(samples): + import matplotlib.pyplot as plt + import pandas as pd + import seaborn as sns + import time + + for sample in [0,7,15,25]: + plt.figure(sample) + + #samples = pd.DataFrame.from_records(samples,index="theta") + sns.kdeplot(data=samples["theta"][sample]) + plt.xlabel(r"$\theta") + plt.ylabel("Density") + plt.title(r"$\theta$ {} Density plot".format(sample)) + plt.savefig("{}/KDE_plot_theta_{}.png".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss")),sample)) + + +def Folders(folder_name): + """ Folder for all the generated images It will updated everytime!!! Save the previous folder before running again. Creates folder in current directory""" + import os + import shutil + basepath = os.getcwd() + if not basepath: + newpath = folder_name + else: + newpath = basepath + "/%s" % folder_name + + if not os.path.exists(newpath): + try: + original_umask = os.umask(0) + os.makedirs(newpath, 0o777) + finally: + os.umask(original_umask) + else: + shutil.rmtree(newpath) # removes all the subdirectories! + os.makedirs(newpath,0o777) if __name__ == '__main__': rng_key = jax.random.PRNGKey(37) rng_key, feat_key, obs_key = jax.random.split(rng_key, 3) - n = 100 - m = 10 - feats, obs = breast_cancer_data() - from jax.config import config + #feats, obs = breast_cancer_data() + feats,obs = higgs_data() + now = datetime.datetime.now() + Folders("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss"))) config.update('jax_disable_jit', True) - est_posterior = infer_hmcecs(rng_key, feats=feats, obs=obs) + est_posterior = infer_hmcecs(rng_key, feats=feats, obs=obs, m =50,g=20) exit() predictions = Predictive(model, posterior_samples=est_posterior)(rng_key, feats, None)['obs'] diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index 2a10f3fd1..0b2c9523d 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -21,7 +21,10 @@ from numpyro.infer.mcmc import MCMCKernel from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model, log_density from numpyro.util import cond, fori_loop, identity -from numpyro.contrib.hmcecs_utils import grad_potential,potential_est,log_density_hmcecs, \ +import sys +sys.path.append('/home/lys/Dropbox/PhD/numpyro/numpyro/contrib/') + +from hmcecs_utils import grad_potential,potential_est,log_density_hmcecs, \ velocity_verlet_hmcecs, init_near_values,tuplemerge,model_args_sub,model_kwargs_sub HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob', 'mean_accept_prob', 'diverging', 'adapt_state','rng_key']) @@ -100,8 +103,8 @@ def _update_block(rng_key, u, n, m, g): chosen_block = random.randint(rng_key, shape=(), minval= 0, maxval=g + 1) idxs_new = random.randint(rng_key_index, shape=(m // g,), minval=0, maxval=n) #chose block within the subsample to update - u_new = jnp.zeros(m, jnp.dtype(u)) #empty array with size m + for i in range(m): #if index in the subsample // g = chosen block : pick new indexes from the subsample size #else not update: keep the same indexes @@ -368,15 +371,18 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, model=None,ll_ref=None,jac_all=None,z = None,z_ref=None,hess_all=None,ll_u=None,u=None,n=None,m=None): if potential_fn_gen: nonlocal vv_update - pe_fn = potential_fn_gen(*model_args, **model_kwargs) if grad_potential_fn_gen: - if subsample_method == "perturbed": + if subsample_method == "perturb": #model, model_args_sub, model_kwargs, ll_ref, jac_all, z, z_ref, hess_all,ll_u,u, n, m = model_args gpe_fn = grad_potential_fn_gen(model, model_args, model_kwargs,ll_ref,jac_all, z, z_ref, n, m,u) + pe_fn = potential_fn_gen(model, model_args, model_kwargs, ll_ref,jac_all, vv_state.z, z_ref, hess_all, n,m) + else: kwargs = {} if model_kwargs is None else model_kwargs gpe_fn = grad_potential_fn_gen(*model_args, **kwargs, ) + pe_fn = potential_fn_gen(*model_args, **model_kwargs) + else: gpe_fn = None _, vv_update = velocity_verlet_hmcecs(pe_fn, kinetic_fn,gpe_fn) @@ -588,6 +594,7 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): if self.subsample_method is not None: assert self.z_ref is not None, "Please provide a (i.e map) estimate for the parameters" #Initialize the subsample state + self._algo = "HMC" # Initialize the potential and gradient potential functions self._potential_fn = lambda model, args, kwargs, ll_ref, jac_all,z, z_ref, hess_all, n, m: \ @@ -693,7 +700,6 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg if rng_key.ndim ==1: init_state = hmc_init_fn(init_params, rng_key) #HMCState + HMCECSState - self._ll_u = potential_est(self._model, model_args_sub(self._u, model_args), model_kwargs, @@ -705,20 +711,20 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg self._n, self.m, u = self._u) - # hmc_init_sub_fn = lambda init_params, rng_key: HMCECSState(u=self._u, - # hmc_state=init_state.hmc_state, - # z_ref=self.z_ref, - # ll_u=self._ll_u, - # jac_all=self._jac_all, - # hess_all=self._hess_all, - # ll_ref=self._ll_ref) - # - # init_sub_state = hmc_init_sub_fn(init_params,rng_key) #HMCState - # - # init_sub_state = tuplemerge(init_state._asdict(),init_sub_state._asdict()) - # print(init_sub_state._fields) - # exit() - return init_state + + hmc_init_sub_fn = lambda init_params, rng_key: HMCECSState(u=self._u, + hmc_state=init_state.hmc_state, + z_ref=self.z_ref, + ll_u=self._ll_u, + jac_all=self._jac_all, + hess_all=self._hess_all, + ll_ref=self._ll_ref) + + init_sub_state = hmc_init_sub_fn(init_params,rng_key) #HMCState + init_sub_state = tuplemerge(init_state._asdict(),init_sub_state._asdict()) + + + return init_sub_state else: #For more than 2 chains # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some @@ -791,7 +797,7 @@ def sample(self, state, model_args, model_kwargs): rng_key_subsample, rng_key_transition, rng_key_likelihood, rng_key = random.split( state.rng_key, 4) - u_new = _update_block(rng_key_subsample, state.u, self._n, self.m, self.g) + u_new = _update_block(rng_key_subsample, self._u, self._n, self.m, self.g) # estimate likelihood of subsample with single block updated llu_new = potential_est(model=self._model, model_args=model_args_sub(u_new,model_args), diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py index 8284772b7..a6fbba53d 100644 --- a/numpyro/contrib/hmcecs_utils.py +++ b/numpyro/contrib/hmcecs_utils.py @@ -169,7 +169,7 @@ def potential_est(model, model_args, model_kwargs,ll_ref, jac_all, hess_all, z, sigma = n ** 2 / m * jnp.var(diff) - ll_prior, _ = log_density_hmcecs(model, model_args, model_kwargs, z,prior=True) + ll_prior, _ = log_density_hmcecs(model, model_args, model_kwargs, z,prior=True) #TODO: work with hierachical models return (-l_hat + .5 * sigma) - ll_prior diff --git a/numpyro/examples/datasets.py b/numpyro/examples/datasets.py index b7ec550b0..feecccfd3 100644 --- a/numpyro/examples/datasets.py +++ b/numpyro/examples/datasets.py @@ -9,9 +9,9 @@ import struct from urllib.parse import urlparse from urllib.request import urlretrieve - +import warnings import numpy as np - +import pandas as pd from jax import device_put, lax from jax.interpreters.xla import DeviceArray @@ -63,6 +63,9 @@ 'https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/jsb_chorales.pickle', ]) +HIGGS = dset("higgs",[ + "https://archive.ics.uci.edu/ml/machine-learning-databases/00280/HIGGS.csv.gz", +]) def _download(dset): for url in dset.urls: @@ -216,6 +219,32 @@ def _load_jsb_chorales(): processed_dataset[k] = (lengths, _pad_sequence(sequences).astype("int32")) return processed_dataset +def _load_higgs(): + warnings.warn("Downloading 2.6 GB dataset") + _download(HIGGS) + file_path = os.path.join(DATA_DIR, 'HIGGS.csv.gz') + df = pd.read_csv(file_path, header=None) + obs, feats = df.iloc[:, 0], df.iloc[:, 1:] + return obs.to_numpy(), feats.to_numpy() + + #SLOW (no pandas) option + # observations,features = [],[] + # with gzip.open(file_path, mode='rt') as f: + # csv_reader = csv.DictReader( + # f, + # delimiter=',', + # restkey="30", + # fieldnames=['observations'] +['feature_{}'.format(i) for i in range(28)], + # ) + # + # for row in csv_reader: + # observations.append(row["observations"]) + # for i in range(28): + # print(row["feature_{}".format(i)]) + # features.append(row["feature_{}".format(i)]) + # return {"observations": np.stack(observations),"features": np.stack(features)} + + def _load(dset): if dset == BASEBALL: @@ -232,6 +261,8 @@ def _load(dset): return _load_lynxhare() elif dset == JSB_CHORALES: return _load_jsb_chorales() + elif dset == HIGGS: + return _load_higgs() raise ValueError('Dataset - {} not found.'.format(dset.name)) From b17d53d72941a130660c63573d7811ac73c24469 Mon Sep 17 00:00:00 2001 From: Lys Date: Mon, 28 Sep 2020 20:07:54 +0200 Subject: [PATCH 11/93] Seems to be working --- .gitignore | 1 + examples/logistic_hmcecs.py | 8 ++++---- numpyro/contrib/hmcecs.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index c335c476e..2ff5ff5fc 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,4 @@ docs/build docs/.DS_Store examples/HIGGS.csv.gz +examples/PLOTS* diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index de0bdeb5c..501d9e02e 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -40,10 +40,10 @@ def infer_nuts(rng_key, feats, obs, samples=5, warmup=0, ): return mcmc.get_samples() -def infer_hmcecs(rng_key, feats, obs, m=50,g=20,samples=10, warmup=5, ): +def infer_hmcecs(rng_key, feats, obs, m=50,g=20,samples=1000, warmup=500, ): hmcecs_key, map_key = jax.random.split(rng_key) n, _ = feats.shape - + print("Using {} samples".format(str(samples+warmup))) print("Running NUTS for map estimation") @@ -70,7 +70,7 @@ def breast_cancer_data(): def higgs_data(): observations,features = _load_higgs() - return observations[:1000],features[:1000] + return features[:1000],observations[:1000] def Plot(samples): @@ -122,7 +122,7 @@ def Folders(folder_name): Folders("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss"))) config.update('jax_disable_jit', True) est_posterior = infer_hmcecs(rng_key, feats=feats, obs=obs, m =50,g=20) - + Plot(est_posterior) exit() predictions = Predictive(model, posterior_samples=est_posterior)(rng_key, feats, None)['obs'] diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index 0b2c9523d..8632e0e09 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -595,7 +595,7 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): assert self.z_ref is not None, "Please provide a (i.e map) estimate for the parameters" #Initialize the subsample state - self._algo = "HMC" + self._algo = "NUTS" # Initialize the potential and gradient potential functions self._potential_fn = lambda model, args, kwargs, ll_ref, jac_all,z, z_ref, hess_all, n, m: \ lambda z: potential_est(model=self._model, model_args=model_args, model_kwargs=model_kwargs, From 40be6c3a54e07391c20efd6d020ba90a74da84a5 Mon Sep 17 00:00:00 2001 From: Lys Date: Tue, 29 Sep 2020 17:18:04 +0200 Subject: [PATCH 12/93] Added: Plotting and save samples to example --- examples/logistic_hmcecs.py | 28 +++++++++++++++++++++------- numpyro/contrib/hmcecs.py | 8 +++++--- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index 501d9e02e..b87db0c49 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -6,7 +6,8 @@ from numpyro.infer import NUTS, MCMC, Predictive import sys, os from jax.config import config -import datetime +import datetime,time + sys.path.append('/home/lys/Dropbox/PhD/numpyro/numpyro/contrib/') sys.path.append('/home/lys/Dropbox/PhD/numpyro/numpyro/examples/') @@ -40,7 +41,7 @@ def infer_nuts(rng_key, feats, obs, samples=5, warmup=0, ): return mcmc.get_samples() -def infer_hmcecs(rng_key, feats, obs, m=50,g=20,samples=1000, warmup=500, ): +def infer_hmcecs(rng_key, feats, obs, m=50,g=20,samples=10, warmup=5, ): hmcecs_key, map_key = jax.random.split(rng_key) n, _ = feats.shape print("Using {} samples".format(str(samples+warmup))) @@ -52,9 +53,18 @@ def infer_hmcecs(rng_key, feats, obs, m=50,g=20,samples=1000, warmup=500, ): #Observations = (569,1) #Features = (569,31) print("Running MCMC subsampling") - kernel = HMC(model=model,z_ref=z_map,m=m,g=g,subsample_method="perturb") + start = time.time() + kernel = HMC(model=model,z_ref=z_map,m=m,g=g,subsample_method="perturb",algo="NUTS") + mcmc = MCMC(kernel,num_warmup=warmup,num_samples=samples) mcmc.run(rng_key,feats,obs) + stop = time.time() + file_hyperparams = open("PLOTS_{}/Hyperparameters_{}.txt".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss"),now.strftime("%Y_%m_%d_%Hh%Mmin%Ss")), "a") + file_hyperparams.write('MCMC/NUTS elapsed time: {} \n'.format(time.time() - start)) + file_hyperparams.close() + + save_obj(mcmc.get_samples(),"{}/MCMC_Dict_Samples.pkl".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss")))) + return mcmc.get_samples() @@ -70,8 +80,12 @@ def breast_cancer_data(): def higgs_data(): observations,features = _load_higgs() - return features[:1000],observations[:1000] - + return features[:100],observations[:100] +def save_obj(obj, name): + import _pickle as cPickle + import bz2 + with bz2.BZ2File(name, "wb") as f: + cPickle.dump(obj, f) def Plot(samples): import matplotlib.pyplot as plt @@ -115,8 +129,8 @@ def Folders(folder_name): rng_key, feat_key, obs_key = jax.random.split(rng_key, 3) - #feats, obs = breast_cancer_data() - feats,obs = higgs_data() + feats, obs = breast_cancer_data() + #feats,obs = higgs_data() now = datetime.datetime.now() Folders("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss"))) diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index 8632e0e09..b1532b12a 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -525,7 +525,9 @@ def __init__(self, subsample_method = None, m= None, g = None, - z_ref= None + z_ref= None, + algo = "HMC", + covariate_fn = None, #TODO: substitute with default Taylor expansion ): if not (model is None) ^ (potential_fn is None): raise ValueError('Only one of `model` or `potential_fn` must be specified.') @@ -540,7 +542,7 @@ def __init__(self, self._dense_mass = dense_mass self._target_accept_prob = target_accept_prob self._trajectory_length = trajectory_length - self._algo = 'HMC' + self._algo = algo self._max_tree_depth = 10 self._init_strategy = init_strategy self._find_heuristic_step_size = find_heuristic_step_size @@ -560,6 +562,7 @@ def __init__(self, self._postprocess_fn = None self._sample_fn = None self._subsample_fn = None + self.covariates_fn = None def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ref): "Compute the jacobian, hessian and gradient for all the data" @@ -595,7 +598,6 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): assert self.z_ref is not None, "Please provide a (i.e map) estimate for the parameters" #Initialize the subsample state - self._algo = "NUTS" # Initialize the potential and gradient potential functions self._potential_fn = lambda model, args, kwargs, ll_ref, jac_all,z, z_ref, hess_all, n, m: \ lambda z: potential_est(model=self._model, model_args=model_args, model_kwargs=model_kwargs, From f58dbf73c6367e42ae4fcd95fd5bfdac940e763b Mon Sep 17 00:00:00 2001 From: Lys Date: Tue, 29 Sep 2020 17:39:28 +0200 Subject: [PATCH 13/93] ADDED: Assertion errors working with hierarchical priors --- examples/logistic_hmcecs.py | 4 ++-- numpyro/contrib/hmcecs.py | 33 +++++++++++++-------------------- numpyro/contrib/hmcecs_utils.py | 26 ++++++++------------------ 3 files changed, 23 insertions(+), 40 deletions(-) diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index b87db0c49..c0c12877e 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -26,8 +26,8 @@ def model(feats, obs): """ n, m = feats.shape - #precision = numpyro.sample('precision', dist.continuous.Uniform(1, 4)) - precision = 0.5 + precision = numpyro.sample('precision', dist.continuous.Uniform(1, 4)) + #precision = 0.5 theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), precision * jnp.ones(m))) numpyro.sample('obs', dist.Bernoulli(logits=jnp.matmul(feats, theta)), obs=obs) diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index b1532b12a..f58b88a09 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -317,11 +317,6 @@ def init_kernel(init_params, 0, 0., 0., False, wa_state,rng_key_hmc) hmc_sub_state = HMCECSState(u=3, hmc_state=hmc_state, z_ref=z_ref, ll_ref=ll_ref, jac_all=jac_all, hess_all=hess_all, ll_u=ll_u) - # if subsample_method == "perturb": - # hmc_sub_state = HMCECSState(u=3,hmc_state=hmc_state,z_ref=z_ref,ll_ref=ll_ref,jac_all=jac_all,hess_all=hess_all,ll_u=ll_u) - # return device_put(hmc_sub_state) - # else: - # return device_put(hmc_state) hmc_state = tuplemerge(hmc_sub_state._asdict(),hmc_state._asdict()) return device_put(hmc_state) @@ -344,16 +339,14 @@ def _hmc_next(step_size, inverse_mass_matrix, vv_state, pe_fn = potential_fn_gen(*model_args, **model_kwargs) nonlocal vv_update #pe_fn = potential_fn_gen(*model_args, **model_kwargs) - _, vv_update = velocity_verlet_hmcecs(pe_fn, kinetic_fn,gpe_fn) #TODO:vv_update might be updating wrong + _, vv_update = velocity_verlet_hmcecs(pe_fn, kinetic_fn,gpe_fn) num_steps = _get_num_steps(step_size, trajectory_len) vv_state_new = fori_loop(0, num_steps, - lambda i, val: vv_update(step_size, inverse_mass_matrix, val,u), #TODO added u + lambda i, val: vv_update(step_size, inverse_mass_matrix, val), vv_state) - # vv_state_new = fori_loop(0, num_steps, - # lambda i, val: vv_update(step_size, inverse_mass_matrix, val), - # vv_state) + energy_old = vv_state.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state.r) energy_new = vv_state_new.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state_new.r) delta_energy = energy_new - energy_old @@ -373,7 +366,6 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, nonlocal vv_update if grad_potential_fn_gen: if subsample_method == "perturb": - #model, model_args_sub, model_kwargs, ll_ref, jac_all, z, z_ref, hess_all,ll_u,u, n, m = model_args gpe_fn = grad_potential_fn_gen(model, model_args, model_kwargs,ll_ref,jac_all, z, z_ref, n, m,u) pe_fn = potential_fn_gen(model, model_args, model_kwargs, ll_ref,jac_all, vv_state.z, z_ref, hess_all, n,m) @@ -444,11 +436,6 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None,subsample_method=Non accept_prob, mean_accept_prob, diverging, adapt_state,rng_key) hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state, z_ref=z_ref, ll_ref=ll_ref, jac_all=jac_all, hess_all=hess_all, ll_u=ll_u) - # if subsample_method == "perturbed": - # hmc_sub_state = HMCECSState(u=u,hmc_state=hmc_state,z_ref=z_ref,ll_ref=ll_ref,jac_all=jac_all,hess_all=hess_all,ll_u=ll_u) - # return hmc_sub_state - # else: - # return hmcstate hmcstate = tuplemerge(hmc_sub_state._asdict(),hmcstate._asdict()) return hmcstate @@ -616,7 +603,13 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): algo=self._algo) init_params, potential_fn, postprocess_fn, model_trace=self._init_subsample_state(rng_key, model_args, model_kwargs, init_params,self.z_ref) - + if (self.g > self.m) or (self.g < 1): + raise ValueError( + 'Block size (g) = {} needs to = or > than 1 and smaller than the subsample size {}'.format(self.g, + self.m)) + elif (self.m > self._n): + raise ValueError( + 'Subsample size (m) = {} needs to = or < than data size (n) {}'.format(self.m, self._n)) if self._model is not None: init_params, potential_fn, postprocess_fn, model_trace = initialize_model( rng_key, @@ -727,7 +720,7 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg return init_sub_state - else: + else: #TODO: Check that it works for more than 2 chains #For more than 2 chains # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth, @@ -744,10 +737,10 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg init_subsample_state = vmap(hmc_init_sub_fn)(init_params,rng_key) sample_fn = vmap(self._sample_fn, in_axes=(0, None, None)) - subsample = vmap(self._subsample_fn, in_axes=(0,None,None)) + subsample_fn = vmap(self._subsample_fn, in_axes=(0,None,None)) HMCCombinedState = tuplemerge(init_state._asdict,init_subsample_state._asdict()) self._sample_fn = sample_fn - self._subsample_fn = subsample + self._subsample_fn = subsample_fn return HMCCombinedState else: diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py index a6fbba53d..20b07bf8f 100644 --- a/numpyro/contrib/hmcecs_utils.py +++ b/numpyro/contrib/hmcecs_utils.py @@ -17,23 +17,13 @@ def model_args_sub(u, model_args): """Subsample observations and features according to u subsample indexes""" - if isinstance(model_args,dict): - args = {} - for key, val in model_args.items(): - if isinstance(val, jnp.ndarray) and val.shape[0] > len(u): - args[key] = jnp.take(val, u, axis=0) - else: - args[key] = val - return args - - else: - args = [] - for arg in model_args: - if isinstance(arg, jnp.ndarray) and arg.shape[0] > len(u): - args.append(jnp.take(arg, u, axis=0)) - else: - args.append(arg) - return tuple(args) + args = [] + for arg in model_args: + if isinstance(arg, jnp.ndarray) and arg.shape[0] > len(u): + args.append(jnp.take(arg, u, axis=0)) + else: + args.append(arg) + return tuple(args) def model_kwargs_sub(u, kwargs): @@ -204,7 +194,7 @@ def init_fn(z, r, potential_energy=None, z_grad=None): return IntegratorState(z, r, potential_energy, z_grad) - def update_fn(step_size, inverse_mass_matrix, state,u=None): + def update_fn(step_size, inverse_mass_matrix, state): """ :param float step_size: Size of a single step. :param inverse_mass_matrix: Inverse of mass matrix, which is used to From 3a2552385fe81b8756bbf03004534889c7df1ca0 Mon Sep 17 00:00:00 2001 From: Lys Date: Wed, 30 Sep 2020 11:20:03 +0200 Subject: [PATCH 14/93] working on more than 1 chain --- examples/logistic_hmcecs.py | 4 ++-- numpyro/contrib/hmcecs.py | 11 ++++++++--- numpyro/infer/mcmc.py | 1 + 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index c0c12877e..207e52aa4 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -13,7 +13,7 @@ from hmcecs import HMC #from numpyro.contrib.hmcecs import HMC - +#numpyro.set_host_device_count(2) from sklearn.datasets import load_breast_cancer from datasets import _load_higgs numpyro.set_platform("cpu") @@ -56,7 +56,7 @@ def infer_hmcecs(rng_key, feats, obs, m=50,g=20,samples=10, warmup=5, ): start = time.time() kernel = HMC(model=model,z_ref=z_map,m=m,g=g,subsample_method="perturb",algo="NUTS") - mcmc = MCMC(kernel,num_warmup=warmup,num_samples=samples) + mcmc = MCMC(kernel,num_warmup=warmup,num_samples=samples,num_chains=1) mcmc.run(rng_key,feats,obs) stop = time.time() file_hyperparams = open("PLOTS_{}/Hyperparameters_{}.txt".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss"),now.strftime("%Y_%m_%d_%Hh%Mmin%Ss")), "a") diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index f58b88a09..9795a2721 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -665,6 +665,8 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg # vectorized else: rng_key, rng_key_init_model = jnp.swapaxes(vmap(random.split)(rng_key), 0, 1) + + init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params) #should work for all cases if self._potential_fn and init_params is None: @@ -672,6 +674,7 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg ' `potential_fn`.') if self.subsample_method == "perturb": + hmc_init_fn = lambda init_params,rng_key: self._init_fn(init_params=init_params, num_warmup = num_warmup, step_size = self._step_size, @@ -693,7 +696,10 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg n=self._n,m=self.m, u = self._u) + print(rng_key.shape) + if rng_key.ndim ==1: + init_state = hmc_init_fn(init_params, rng_key) #HMCState + HMCECSState self._ll_u = potential_est(self._model, model_args_sub(self._u, model_args), @@ -718,14 +724,13 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg init_sub_state = hmc_init_sub_fn(init_params,rng_key) #HMCState init_sub_state = tuplemerge(init_state._asdict(),init_sub_state._asdict()) - return init_sub_state - else: #TODO: Check that it works for more than 2 chains - #For more than 2 chains + else: #TODO: What is this for? It does not go into it for num_chains>1 # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth, # wa_steps because those variables do not depend on traced args: init_params, rng_key. init_state = vmap(hmc_init_fn)(init_params, rng_key) + self._ll_u = potential_est(self._model, model_args_sub(self._u, model_args), model_kwargs,self._ll_ref, self._jac_all, self._hess_all, init_state.z, self.z_ref, self._n, self.m,self._u) diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index b18bcbaa8..7e0d2751f 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -420,6 +420,7 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): if self.num_chains > 1 and rng_key.ndim == 1: rng_key = random.split(rng_key, self.num_chains) + if self._warmup_state is not None: self._set_collection_params(0, self.num_samples, self.num_samples, "sample") init_state = self._warmup_state._replace(rng_key=rng_key) From a0758573e4e1dd191bcf142aa9a2a2554482b0b8 Mon Sep 17 00:00:00 2001 From: Lys Date: Wed, 30 Sep 2020 17:06:54 +0200 Subject: [PATCH 15/93] ADDED: more plotting --- examples/logistic_hmcecs.py | 51 +++++--- numpyro/contrib/hmcecs.py | 222 ++++++++++++++++++-------------- numpyro/contrib/hmcecs_utils.py | 93 ++++++------- 3 files changed, 191 insertions(+), 175 deletions(-) diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index 207e52aa4..46710851c 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -16,6 +16,7 @@ #numpyro.set_host_device_count(2) from sklearn.datasets import load_breast_cancer from datasets import _load_higgs +import jax.numpy as np_jax numpyro.set_platform("cpu") # TODO: import Higgs data! ---> http://archive.ics.uci.edu/ml/machine-learning-databases/00280/ @@ -33,7 +34,7 @@ def model(feats, obs): numpyro.sample('obs', dist.Bernoulli(logits=jnp.matmul(feats, theta)), obs=obs) -def infer_nuts(rng_key, feats, obs, samples=5, warmup=0, ): +def infer_nuts(rng_key, feats, obs, samples=10, warmup=5, ): kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup=warmup, num_samples=samples) mcmc.run(rng_key, feats, obs) @@ -41,29 +42,34 @@ def infer_nuts(rng_key, feats, obs, samples=5, warmup=0, ): return mcmc.get_samples() -def infer_hmcecs(rng_key, feats, obs, m=50,g=20,samples=10, warmup=5, ): +def infer_hmcecs(rng_key, feats, obs, m=None,g=None,samples=10, warmup=5,algo="NUTS",subsample_method=None ): hmcecs_key, map_key = jax.random.split(rng_key) n, _ = feats.shape print("Using {} samples".format(str(samples+warmup))) print("Running NUTS for map estimation") - z_map = {key: value.mean(0) for key, value in infer_nuts(map_key, feats, obs).items()} - - #Observations = (569,1) - #Features = (569,31) + if subsample_method=="perturb": + z_map = {key: value.mean(0) for key, value in infer_nuts(map_key, feats, obs).items()} + else: + z_map = None print("Running MCMC subsampling") start = time.time() - kernel = HMC(model=model,z_ref=z_map,m=m,g=g,subsample_method="perturb",algo="NUTS") + kernel = HMC(model=model,z_ref=z_map,m=m,g=g,algo=algo,subsample_method=subsample_method) mcmc = MCMC(kernel,num_warmup=warmup,num_samples=samples,num_chains=1) mcmc.run(rng_key,feats,obs) stop = time.time() file_hyperparams = open("PLOTS_{}/Hyperparameters_{}.txt".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss"),now.strftime("%Y_%m_%d_%Hh%Mmin%Ss")), "a") - file_hyperparams.write('MCMC/NUTS elapsed time: {} \n'.format(time.time() - start)) + file_hyperparams.write('MCMC/NUTS elapsed time {}: {} \n'.format(subsample_method,time.time() - start)) + file_hyperparams.write('Effective size {}: {}\n'.format(subsample_method,samples)) + file_hyperparams.write('Warm up size {}: {}\n'.format(subsample_method,warmup)) + file_hyperparams.write('Subsample size (m): {}\n'.format(m)) + file_hyperparams.write('Block size (g): {}\n'.format(g)) + file_hyperparams.close() - save_obj(mcmc.get_samples(),"{}/MCMC_Dict_Samples.pkl".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss")))) + save_obj(mcmc.get_samples(),"{}/MCMC_Dict_Samples_{}.pkl".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss")),subsample_method)) return mcmc.get_samples() @@ -75,19 +81,19 @@ def breast_cancer_data(): feats = (feats - feats.mean(0)) / feats.std(0) feats = jnp.hstack((feats, jnp.ones((feats.shape[0], 1)))) - return feats[:500], dataset.target[:500] + return feats[:50], dataset.target[:50] def higgs_data(): observations,features = _load_higgs() - return features[:100],observations[:100] + return features[:10],observations[:10] def save_obj(obj, name): import _pickle as cPickle import bz2 with bz2.BZ2File(name, "wb") as f: cPickle.dump(obj, f) -def Plot(samples): +def Plot(samples_ECS,samples_NUTS): import matplotlib.pyplot as plt import pandas as pd import seaborn as sns @@ -97,9 +103,12 @@ def Plot(samples): plt.figure(sample) #samples = pd.DataFrame.from_records(samples,index="theta") - sns.kdeplot(data=samples["theta"][sample]) + sns.kdeplot(data=samples_ECS["theta"][sample],color="r",label="ECS") + sns.kdeplot(data=samples_NUTS["theta"][sample],color="b",label="NUTS") + plt.xlabel(r"$\theta") plt.ylabel("Density") + plt.legend() plt.title(r"$\theta$ {} Density plot".format(sample)) plt.savefig("{}/KDE_plot_theta_{}.png".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss")),sample)) @@ -129,16 +138,18 @@ def Folders(folder_name): rng_key, feat_key, obs_key = jax.random.split(rng_key, 3) - feats, obs = breast_cancer_data() - #feats,obs = higgs_data() + #feats, obs = breast_cancer_data() + feats,obs = higgs_data() now = datetime.datetime.now() Folders("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss"))) config.update('jax_disable_jit', True) - est_posterior = infer_hmcecs(rng_key, feats=feats, obs=obs, m =50,g=20) - Plot(est_posterior) + m = int(np_jax.sqrt(obs.shape[0])*2) + g= int(m//3) + est_posterior_ECS = infer_hmcecs(rng_key, feats=feats, obs=obs, m =m,g=g,algo="NUTS",subsample_method="perturb") + est_posterior_NUTS = infer_hmcecs(rng_key, feats=feats, obs=obs, m =m,g=g,algo="NUTS") + + Plot(est_posterior_ECS,est_posterior_NUTS) exit() - predictions = Predictive(model, posterior_samples=est_posterior)(rng_key, feats, None)['obs'] + predictions = Predictive(model, posterior_samples=est_posterior_ECS)(rng_key, feats, None)['obs'] - # for i, y in enumerate(obs): - # print(i, y[0], jnp.sum(predictions[i]) > 50) \ No newline at end of file diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index 9795a2721..a091eff51 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -24,11 +24,15 @@ import sys sys.path.append('/home/lys/Dropbox/PhD/numpyro/numpyro/contrib/') -from hmcecs_utils import grad_potential,potential_est,log_density_hmcecs, \ - velocity_verlet_hmcecs, init_near_values,tuplemerge,model_args_sub,model_kwargs_sub +from hmcecs_utils import potential_est,log_density_hmcecs, \ + velocity_verlet_hmcecs, init_near_values,tuplemerge,\ + model_args_sub,model_kwargs_sub,taylor_proxy,svi_proxy,neural_proxy HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob', 'mean_accept_prob', 'diverging', 'adapt_state','rng_key']) -HMCECSState = namedtuple("HMCECState",["u","hmc_state","z_ref","ll_ref","jac_all","hess_all","ll_u"]) +#HMCECSState = namedtuple("HMCECState",["u","hmc_state","z_ref","ll_ref","jac_all","hess_all","ll_u"]) + +HMCECSState = namedtuple("HMCECState",['u', 'hmc_state', 'z_ref', 'll_u']) + """ A :func:`~collections.namedtuple` consisting of the following fields: @@ -113,7 +117,7 @@ def _update_block(rng_key, u, n, m, g): return u_new -def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, grad_potential_fn_gen=None,algo='NUTS'): +def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, grad_potential_fn_gen=None,covariate_fn=None,algo='NUTS'): r""" Hamiltonian Monte Carlo inference, using either fixed number of steps or the No U-Turn Sampler (NUTS) with adaptive path length. @@ -217,7 +221,8 @@ def init_kernel(init_params, m = None, u= None, rng_key=random.PRNGKey(0), - subsample_method=None): + subsample_method=None, + covariate_fn = None): """ Initializes the HMC sampler. @@ -269,20 +274,18 @@ def init_kernel(init_params, raise ValueError('Only one of `potential_fn` or `potential_fn_gen` must be provided.') else: if subsample_method == "perturb": - #model, model_args, model_kwargs,z, z_ref, jac_all, hess_all, n, m,u=None - gpe_fn = grad_potential_fn_gen(model, model_args, model_kwargs,z, z_ref,jac_all, hess_all, n, m,u) - pe_fn = potential_fn_gen(model, model_args, model_kwargs,ll_ref, jac_all,z,z_ref ,hess_all, n, m) + kwargs = {} if model_kwargs is None else model_kwargs + + proxy,proxy_u = covariate_fn(ll_ref, jac_all, hess_all) + pe_fn = potential_fn_gen(model, model_args,model_kwargs, z, z_ref, n, m, proxy, proxy_u,u) + else: kwargs = {} if model_kwargs is None else model_kwargs pe_fn = potential_fn_gen(*model_args, **kwargs) - gpe_fn = grad_potential_fn_gen(*model_args, **kwargs,) + #gpe_fn = grad_potential_fn_gen(*model_args, **kwargs,) if grad_potential_fn_gen: - if subsample_method == "perturb": - #model, model_args, model_kwargs,z, z_ref, jac_all, hess_all, n, m,u=None - gpe_fn = grad_potential_fn_gen(model, model_args, model_kwargs, z, z_ref,jac_all, hess_all,n, m,u) - else: - kwargs = {} if model_kwargs is None else model_kwargs - gpe_fn = grad_potential_fn_gen(*model_args, **kwargs) + kwargs = {} if model_kwargs is None else model_kwargs + gpe_fn = grad_potential_fn_gen(*model_args, **kwargs) else: gpe_fn = None @@ -308,15 +311,19 @@ def init_kernel(init_params, mass_matrix_size=jnp.size(ravel_pytree(z)[0])) r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum) - vv_init, vv_update = velocity_verlet_hmcecs(pe_fn, kinetic_fn,grad_potential_fn=gpe_fn) + #vv_init, vv_update = velocity_verlet_hmcecs(pe_fn, kinetic_fn,grad_potential_fn=gpe_fn) + vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn) + vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad) energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r) hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, 0, 0., 0., False, wa_state,rng_key_hmc) - hmc_sub_state = HMCECSState(u=3, hmc_state=hmc_state, z_ref=z_ref, ll_ref=ll_ref, jac_all=jac_all, - hess_all=hess_all, ll_u=ll_u) + hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state, + z_ref=z_ref, + ll_u=ll_u) + hmc_state = tuplemerge(hmc_sub_state._asdict(),hmc_state._asdict()) return device_put(hmc_state) @@ -326,27 +333,27 @@ def _hmc_next(step_size, inverse_mass_matrix, vv_state, model,ll_ref,jac_all,z,z_ref,hess_all,ll_u,u,n,m): if potential_fn_gen: if grad_potential_fn_gen: + kwargs = {} if model_kwargs is None else model_kwargs + gpe_fn = grad_potential_fn_gen(*model_args, **kwargs, ) + pe_fn = potential_fn_gen(*model_args, **model_kwargs) + + else: if subsample_method == "perturb": - #model, model_args, model_kwargs,z, z_ref, jac_all, hess_all, n, m,u=None - gpe_fn = grad_potential_fn_gen(model, model_args, model_kwargs,vv_state.z, z_ref,jac_all, hess_all, n,m,u) - pe_fn = potential_fn_gen(model, model_args, model_kwargs, ll_ref,jac_all, vv_state.z, z_ref, hess_all, n,m) - else: + proxy, proxy_u = covariate_fn(ll_ref, jac_all, hess_all) + pe_fn = potential_fn_gen(model, model_args, model_kwargs,vv_state.z, z_ref, n, m, proxy, proxy_u, u) kwargs = {} if model_kwargs is None else model_kwargs - gpe_fn = grad_potential_fn_gen(*model_args, **kwargs,) + gpe_fn = None + else: + gpe_fn = None pe_fn = potential_fn_gen(*model_args, **model_kwargs) - else: - gpe_fn = None - pe_fn = potential_fn_gen(*model_args, **model_kwargs) nonlocal vv_update - #pe_fn = potential_fn_gen(*model_args, **model_kwargs) - _, vv_update = velocity_verlet_hmcecs(pe_fn, kinetic_fn,gpe_fn) + _, vv_update = velocity_verlet(pe_fn, kinetic_fn) num_steps = _get_num_steps(step_size, trajectory_len) vv_state_new = fori_loop(0, num_steps, lambda i, val: vv_update(step_size, inverse_mass_matrix, val), vv_state) - energy_old = vv_state.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state.r) energy_new = vv_state_new.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state_new.r) delta_energy = energy_new - energy_old @@ -365,19 +372,21 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, if potential_fn_gen: nonlocal vv_update if grad_potential_fn_gen: - if subsample_method == "perturb": - gpe_fn = grad_potential_fn_gen(model, model_args, model_kwargs,ll_ref,jac_all, z, z_ref, n, - m,u) - pe_fn = potential_fn_gen(model, model_args, model_kwargs, ll_ref,jac_all, vv_state.z, z_ref, hess_all, n,m) - else: kwargs = {} if model_kwargs is None else model_kwargs gpe_fn = grad_potential_fn_gen(*model_args, **kwargs, ) pe_fn = potential_fn_gen(*model_args, **model_kwargs) else: - gpe_fn = None - _, vv_update = velocity_verlet_hmcecs(pe_fn, kinetic_fn,gpe_fn) + if subsample_method == "perturb": + proxy, proxy_u = covariate_fn(ll_ref, jac_all, hess_all) + pe_fn = potential_fn_gen(model, model_args, model_kwargs, vv_state.z, z_ref, n, m, proxy, + proxy_u, u) + gpe_fn = None + else: + gpe_fn = None + pe_fn = potential_fn_gen(*model_args, **model_kwargs) + _, vv_update = velocity_verlet(pe_fn, kinetic_fn) binary_tree = build_tree(vv_update, kinetic_fn, vv_state, inverse_mass_matrix, step_size, rng_key, @@ -393,8 +402,11 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, _next = _nuts_next if algo == 'NUTS' else _hmc_next - def sample_kernel(hmc_state,model_args=(),model_kwargs=None,subsample_method=None, - model=None,ll_ref=None,jac_all=None,z=None,z_ref=None,hess_all=None,ll_u=None,u=None,n=None,m=None,): + def sample_kernel(hmc_state,model_args=(),model_kwargs=None, + subsample_method=None,covariate_fn=None, + model=None,ll_ref=None,jac_all=None, + z=None,z_ref=None,hess_all=None,ll_u=None, + u=None,n=None,m=None,): """ Given an existing :data:`~numpyro.infer.mcmc.HMCState`, run HMC with fixed (possibly adapted) step size and return a new :data:`~numpyro.infer.mcmc.HMCState`. @@ -408,16 +420,16 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None,subsample_method=Non """ model_kwargs = {} if model_kwargs is None else model_kwargs - #if subsample_method =="perturb": - # hmc_state = hmc_state.hmc_state + rng_key, rng_key_momentum, rng_key_transition = random.split(hmc_state.rng_key, 3) r = momentum_generator(hmc_state.z, hmc_state.adapt_state.mass_matrix_sqrt, rng_key_momentum) vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad) - + if subsample_method =="perturb": + model_args = model_args_sub(u,model_args) vv_state, energy, num_steps, accept_prob, diverging = _next(hmc_state.adapt_state.step_size, hmc_state.adapt_state.inverse_mass_matrix, vv_state, - model_args_sub(u,model_args), + model_args, model_kwargs, rng_key_transition, subsample_method, @@ -434,8 +446,9 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None,subsample_method=Non mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n hmcstate = HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps, accept_prob, mean_accept_prob, diverging, adapt_state,rng_key) - hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state, z_ref=z_ref, ll_ref=ll_ref, jac_all=jac_all, - hess_all=hess_all, ll_u=ll_u) + hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state, + z_ref=z_ref, + ll_u=ll_u) hmcstate = tuplemerge(hmc_sub_state._asdict(),hmcstate._asdict()) return hmcstate @@ -495,6 +508,7 @@ class HMC(MCMCKernel): :param m subsample size :param g block size :param z_ref MAP estimate of the parameters + :param covariate_fn Proxy function to calculate the covariates for the likelihood correction """ def __init__(self, model=None, @@ -514,7 +528,7 @@ def __init__(self, g = None, z_ref= None, algo = "HMC", - covariate_fn = None, #TODO: substitute with default Taylor expansion + covariate_fn = None, ): if not (model is None) ^ (potential_fn is None): raise ValueError('Only one of `model` or `potential_fn` must be specified.') @@ -549,7 +563,8 @@ def __init__(self, self._postprocess_fn = None self._sample_fn = None self._subsample_fn = None - self.covariates_fn = None + self.proxy = "taylor", + self.covariate_fn = covariate_fn def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ref): "Compute the jacobian, hessian and gradient for all the data" @@ -583,24 +598,27 @@ def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ def _init_state(self, rng_key, model_args, model_kwargs, init_params): if self.subsample_method is not None: assert self.z_ref is not None, "Please provide a (i.e map) estimate for the parameters" - #Initialize the subsample state + # Choose the covariate calculation method + if self.proxy == "svi": + self.covariate_fn = lambda ll_ref, jac_all, hess_all:svi_proxy(ll_ref, jac_all, hess_all) + elif self.proxy == "neural": + self.covariate_fn = lambda ll_ref, jac_all, hess_all:neural_proxy(ll_ref, jac_all, hess_all) + else: + warnings.warn("Using default second order Taylor expansion, change by using the proxy flag to {svi,neural}") + self.covariate_fn = lambda ll_ref, jac_all, hess_all:taylor_proxy(ll_ref, jac_all, hess_all) + # Initialize the potential and gradient potential functions - self._potential_fn = lambda model, args, kwargs, ll_ref, jac_all,z, z_ref, hess_all, n, m: \ - lambda z: potential_est(model=self._model, model_args=model_args, model_kwargs=model_kwargs, - ll_ref=self._ll_ref, - jac_all=self._jac_all, z=z, z_ref=z_ref, hess_all=hess_all, n=self._n, m=self.m,u=self._u) - self._grad_potential = lambda model, args, kwargs,ll_ref, jac_all,z, z_ref, hess_all, n, m:\ - lambda z: grad_potential(model=self._model, model_args=model_args, - model_kwargs=model_kwargs, - jac_all=self._jac_all,z=z, - z_ref=self.z_ref, hess_all=self._hess_all, - n=self._n, m=self.m,u=self._u) + + self._potential_fn = lambda model, model_args, model_kwargs,z, z_ref, n, m, proxy, proxy_u,u : lambda z:potential_est(model=model, + model_args=model_args,model_kwargs=model_kwargs,z=z,z_ref=z_ref,n=n,m = m,proxy=proxy,proxy_u=proxy_u,u=u) + + # Initialize the hmc sampler: sample_fn = sample_kernel self._init_fn, self._sample_fn = hmc(potential_fn_gen=self._potential_fn, kinetic_fn=euclidean_kinetic_energy, - grad_potential_fn_gen=self._grad_potential, - algo=self._algo) + algo=self._algo, + covariate_fn=self.covariate_fn) init_params, potential_fn, postprocess_fn, model_trace=self._init_subsample_state(rng_key, model_args, model_kwargs, init_params,self.z_ref) if (self.g > self.m) or (self.g < 1): @@ -672,7 +690,7 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg if self._potential_fn and init_params is None: raise ValueError('Valid value of `init_params` must be provided with' ' `potential_fn`.') - + #TODO: assert the subsample method name is correct if self.subsample_method == "perturb": hmc_init_fn = lambda init_params,rng_key: self._init_fn(init_params=init_params, @@ -693,33 +711,33 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg jac_all=self._jac_all, z_ref=self.z_ref, hess_all = self._hess_all, - n=self._n,m=self.m, - u = self._u) - - print(rng_key.shape) + n=self._n, + m=self.m, + u = self._u, + covariate_fn = self.covariate_fn) if rng_key.ndim ==1: init_state = hmc_init_fn(init_params, rng_key) #HMCState + HMCECSState - self._ll_u = potential_est(self._model, - model_args_sub(self._u, model_args), - model_kwargs, - self._ll_ref, - self._jac_all, - self._hess_all, - init_state.z, - self.z_ref, - self._n, - self.m, - u = self._u) + + self._proxy, self._proxy_u = self.covariate_fn(self._ll_ref, self._jac_all, self._hess_all) + self._ll_u = potential_est(model=self._model, + model_args = model_args_sub(self._u, model_args), + model_kwargs=model_kwargs, + z=init_state.z, + z_ref=self.z_ref, + n=self._n, + m=self.m, + proxy=self._proxy, + proxy_u=self._proxy_u, + u=self._u) + + hmc_init_sub_fn = lambda init_params, rng_key: HMCECSState(u=self._u, hmc_state=init_state.hmc_state, z_ref=self.z_ref, - ll_u=self._ll_u, - jac_all=self._jac_all, - hess_all=self._hess_all, - ll_ref=self._ll_ref) + ll_u=self._ll_u) init_sub_state = hmc_init_sub_fn(init_params,rng_key) #HMCState init_sub_state = tuplemerge(init_state._asdict(),init_sub_state._asdict()) @@ -731,13 +749,19 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg # wa_steps because those variables do not depend on traced args: init_params, rng_key. init_state = vmap(hmc_init_fn)(init_params, rng_key) - self._ll_u = potential_est(self._model, model_args_sub(self._u, model_args), model_kwargs,self._ll_ref, - self._jac_all, self._hess_all, - init_state.z, self.z_ref, self._n, self.m,self._u) + self._proxy, self._proxy_u = self.covariate_fn(self._ll_ref, self._jac_all, self._hess_all) + self._ll_u = potential_est(model=self._model, + model_args=model_args_sub(self._u, model_args), + model_kwargs=model_kwargs, + z=init_state.z, + z_ref=self.z_ref, + n=self._n, + m=self.m, + proxy=self._proxy, + proxy_u=self._proxy_u, + u=self._u) - hmc_init_sub_fn = lambda init_params, rng_key: HMCECSState(u=self._u, hmc_state=init_state, z_ref=self.z_ref, ll_u=self._ll_u, - jac_all=self._jac_all, - hess_all=self._hess_all, ll_ref=self._ll_ref) + hmc_init_sub_fn = lambda init_params, rng_key: HMCECSState(u=self._u, hmc_state=init_state, z_ref=self.z_ref, ll_u=self._ll_u) init_subsample_state = vmap(hmc_init_sub_fn)(init_params,rng_key) @@ -797,17 +821,18 @@ def sample(self, state, model_args, model_kwargs): rng_key_subsample, rng_key_transition, rng_key_likelihood, rng_key = random.split( state.rng_key, 4) - u_new = _update_block(rng_key_subsample, self._u, self._n, self.m, self.g) + u_new = _update_block(rng_key_subsample, state.u, self._n, self.m, self.g) # estimate likelihood of subsample with single block updated + self._proxy, self._proxy_u = self.covariate_fn(self._ll_ref, self._jac_all, self._hess_all) + llu_new = potential_est(model=self._model, model_args=model_args_sub(u_new,model_args), model_kwargs=model_kwargs, - ll_ref = state.ll_ref, - jac_all=state.jac_all, - hess_all= state.hess_all, z=state.z, - z_ref=state.z_ref, - n=self._n, m=self.m,u=self._u) + z_ref=self.z_ref, + proxy = self._proxy, + proxy_u = self._proxy_u, + n=self._n, m=self.m,u=state.u) # accept new subsample with probability min(1,L^{hat}_{u_new}(z) - L^{hat}_{u}(z)) # NOTE: latent variables (z aka theta) same, subsample indices (u) different by one block. @@ -816,14 +841,11 @@ def sample(self, state, model_args, model_kwargs): u, ll_u = cond(transition, (u_new, llu_new), identity, (state.u, state.ll_u), identity) - self._u = u + self._u = u #Just in case , but not necessary ######## UPDATE PARAMETERS ########## hmc_subsamplestate = HMCECSState(u=u, hmc_state=state.hmc_state, - z_ref=state.z_ref, - ll_u=ll_u, ll_ref=state.ll_ref, - jac_all=state.jac_all, - hess_all=state.hess_all) + ll_u=ll_u,z_ref=self.z_ref) hmc_subsamplestate = tuplemerge(hmc_subsamplestate._asdict(),state._asdict()) return self._sample_fn(hmc_subsamplestate, @@ -831,11 +853,11 @@ def sample(self, state, model_args, model_kwargs): model_kwargs=model_kwargs, subsample_method=self.subsample_method, model = self._model, - ll_ref = state.ll_ref, - jac_all =state.jac_all, + ll_ref = self._ll_ref, + jac_all =self._jac_all, z= state.z, - z_ref = state.z_ref, - hess_all = state.hess_all, + z_ref = self.z_ref, + hess_all = self._hess_all, ll_u = ll_u, u= u, n= self._n, diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py index 20b07bf8f..c2229e36e 100644 --- a/numpyro/contrib/hmcecs_utils.py +++ b/numpyro/contrib/hmcecs_utils.py @@ -83,42 +83,6 @@ def log_density_hmcecs(model, model_args, model_kwargs, params,prior=False): log_joint = log_joint + log_prob return log_joint, model_trace -def grad_potential(model, model_args, model_kwargs,z, z_ref, jac_all, hess_all, n, m,u=None, *args, **kwargs): - """Calculate the gradient of the potential energy function for the current subsample""" - if any(arg.shape[0] > m for arg in model_args): - model_args = model_args_sub(u,model_args) - k, = jac_all.shape - z_flat, treedef = ravel_pytree(z) - zref_flat, _ = ravel_pytree(z_ref) - z_diff = z_flat - zref_flat - - ld_fn = lambda args: partial(log_density_hmcecs, model, model_args, model_kwargs,prior=False)(args)[0] - - jac_ref, _ = ravel_pytree(jax.jacfwd(ld_fn)(z_ref)) - hess_ref, _ = ravel_pytree(jax.hessian(ld_fn)(z_ref)) - - jac_ref = jac_ref.reshape(m, k) - hess_ref = hess_ref.reshape(m, k, k) - - grad_sum = jac_all + hess_all.dot(z_diff) - jac_sub, _ = ravel_pytree(jax.jacfwd(ld_fn)(z)) - - ll_sub, _ = log_density_hmcecs(model, model_args, model_kwargs, z,prior=False) # log likelihood for subsample with current theta - ll_ref, _ = log_density_hmcecs(model, model_args, model_kwargs, z_ref,prior=False) # log likelihood for subsample with reference theta - - diff = ll_sub - (ll_ref + jac_ref @ z_diff + .5 * z_diff @ hess_ref @ z_diff.T) - - jac_sub = jac_sub.reshape(jac_ref.shape) - jac_ref - - grad_d_k = jac_sub - z_diff.dot(hess_ref) - - gradll = -(grad_sum + n / m * (jac_sub.sum(0) - hess_ref.sum(0).dot(z_diff))) + n ** 2 / (m ** 2) * ( - diff - diff.mean(0)).T.dot(grad_d_k - grad_d_k.mean(0)) - - ld_fn = lambda args: partial(log_density_hmcecs, model, model_args, model_kwargs,prior=True)(args)[0] - jac_sub, _ = ravel_pytree(jax.jacfwd(ld_fn)(z)) - - return treedef(gradll - jac_sub) def reducer( accum, d ): accum.update(d) @@ -129,39 +93,28 @@ def tuplemerge( *dictionaries ): merged = reduce( reducer, dictionaries, {} ) return namedtuple('HMCCombinedState', merged )(**merged) # <==== Gist of the gist -def potential_est(model, model_args, model_kwargs,ll_ref, jac_all, hess_all, z, z_ref, n, m,u=None): - """Estimate the potential dynamic energy for the HMC ECS implementation. The calculation follows section 7.2.1 in https://jmlr.org/papers/volume18/15-205/15-205.pdf - The computation has a complexity of O(1) and it's highly dependant on the quality of the map estimate""" +def potential_est(model, model_args,model_kwargs, z, z_ref, n, m, proxy, proxy_u,u=None): if any(arg.shape[0] > m for arg in model_args): model_args = model_args_sub(u,model_args) + ll_sub, _ = log_density_hmcecs(model, model_args, model_kwargs, z,prior=False) # log likelihood for subsample with current theta - # Agrees with reference upto constant factor on prior - k, = jac_all.shape # number of features - z_flat, _ = ravel_pytree(z) - zref_flat, _ = ravel_pytree(z_ref) + diff = ll_sub - proxy_u(z, z_ref, model, model_args) + l_hat = proxy(z, z_ref) + n / m * jnp.sum(diff) + sigma = n ** 2 / m * jnp.var(diff) - z_diff = z_flat - zref_flat + ll_prior, _ = log_density_hmcecs(model, model_args, model_kwargs, z,prior=True) + + return (-l_hat + .5 * sigma) - ll_prior - ld_fn = lambda args: partial(log_density_hmcecs, model, model_args, model_kwargs,prior=False)(args)[0] - jac_sub, _ = ravel_pytree(jax.jacfwd(ld_fn)(z_ref)) - hess_sub, _ = ravel_pytree(jax.hessian(ld_fn)(z_ref)) - proxy = jnp.sum(ll_ref) + jac_all.T @ z_diff + .5 * z_diff.T @ hess_all @ z_diff - ll_sub, _ = log_density_hmcecs(model, model_args, model_kwargs, z,prior=False) # log likelihood for subsample with current theta - ll_ref, _ = log_density_hmcecs(model, model_args, model_kwargs, z_ref,prior=False) # log likelihood for subsample with reference theta - diff = ll_sub - (ll_ref + jac_sub.reshape((m, k)) @ z_diff + .5 * z_diff @ hess_sub.reshape((m, k, k)) @ z_diff.T) - l_hat = proxy + n / m * jnp.sum(diff) - sigma = n ** 2 / m * jnp.var(diff) - ll_prior, _ = log_density_hmcecs(model, model_args, model_kwargs, z,prior=True) #TODO: work with hierachical models - return (-l_hat + .5 * sigma) - ll_prior def velocity_verlet_hmcecs(potential_fn, kinetic_fn, grad_potential_fn=None): r""" @@ -233,3 +186,33 @@ def init_near_values(site=None, values={}): except: return init_to_uniform(site) +def taylor_proxy(ll_ref, jac_all, hess_all): + def proxy(z, z_ref): + z_flat, _ = ravel_pytree(z) + zref_flat, _ = ravel_pytree(z_ref) + z_diff = z_flat - zref_flat + return jnp.sum(ll_ref) + jac_all.T @ z_diff + .5 * z_diff.T @ hess_all @ z_diff + + def proxy_u(z, z_ref, model, model_args): + z_flat, _ = ravel_pytree(z) + zref_flat, _ = ravel_pytree(z_ref) + z_diff = z_flat - zref_flat + + ld_fn = lambda args: jnp.sum(partial(log_density_hmcecs, model, model_args, {},prior=False)(args)[0]) + + ll_sub, jac_sub = jax.value_and_grad(ld_fn)(z_ref) + k, = jac_all.shape + hess_sub, _ = ravel_pytree(jax.hessian(ld_fn)(z_ref)) + jac_sub, _ = ravel_pytree(jac_sub) + + return ll_sub + jac_sub @ z_diff + .5 * z_diff @ hess_sub.reshape((k, k)) @ z_diff.T + + return proxy, proxy_u + +def svi_proxy(): + return None + +def neural_proxy(): + return None + + From b40f662f042566033614b5a5c30eead2040c9e66 Mon Sep 17 00:00:00 2001 From: Lys Date: Wed, 30 Sep 2020 18:57:17 +0200 Subject: [PATCH 16/93] ADDED: More tests and proxies --- examples/logistic_hmcecs.py | 66 +++++++++++++++++++++++++++---------- numpyro/diagnostics.py | 1 + numpyro/infer/mcmc.py | 1 + 3 files changed, 51 insertions(+), 17 deletions(-) diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index 46710851c..4b5ffc6a3 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -17,6 +17,13 @@ from sklearn.datasets import load_breast_cancer from datasets import _load_higgs import jax.numpy as np_jax +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +import time +from numpyro.diagnostics import summary +from jax.tree_util import tree_flatten,tree_map + numpyro.set_platform("cpu") # TODO: import Higgs data! ---> http://archive.ics.uci.edu/ml/machine-learning-databases/00280/ @@ -34,39 +41,44 @@ def model(feats, obs): numpyro.sample('obs', dist.Bernoulli(logits=jnp.matmul(feats, theta)), obs=obs) -def infer_nuts(rng_key, feats, obs, samples=10, warmup=5, ): +def infer_nuts(rng_key, feats, obs, samples, warmup ): kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup=warmup, num_samples=samples) mcmc.run(rng_key, feats, obs) - # mcmc.print_summary() - return mcmc.get_samples() + #mcmc.print_summary() + samples = mcmc.get_samples() + samples = tree_map(lambda x: x[None, ...], samples) + r_hat_average = np_jax.sum(summary(samples)["theta"]["r_hat"])/len(summary(samples)["theta"]["r_hat"]) + return mcmc.get_samples(), r_hat_average -def infer_hmcecs(rng_key, feats, obs, m=None,g=None,samples=10, warmup=5,algo="NUTS",subsample_method=None ): + +def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None,algo="NUTS",subsample_method=None ): hmcecs_key, map_key = jax.random.split(rng_key) n, _ = feats.shape - print("Using {} samples".format(str(samples+warmup))) + print("Using {} samples".format(str(n_samples+warmup))) print("Running NUTS for map estimation") if subsample_method=="perturb": - z_map = {key: value.mean(0) for key, value in infer_nuts(map_key, feats, obs).items()} + samples,r_hat_average = infer_nuts(map_key, feats, obs,samples=15,warmup=5) + z_map = {key: value.mean(0) for key, value in samples.items()} else: z_map = None print("Running MCMC subsampling") start = time.time() kernel = HMC(model=model,z_ref=z_map,m=m,g=g,algo=algo,subsample_method=subsample_method) - mcmc = MCMC(kernel,num_warmup=warmup,num_samples=samples,num_chains=1) + mcmc = MCMC(kernel,num_warmup=warmup,num_samples=n_samples,num_chains=1) mcmc.run(rng_key,feats,obs) stop = time.time() file_hyperparams = open("PLOTS_{}/Hyperparameters_{}.txt".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss"),now.strftime("%Y_%m_%d_%Hh%Mmin%Ss")), "a") file_hyperparams.write('MCMC/NUTS elapsed time {}: {} \n'.format(subsample_method,time.time() - start)) - file_hyperparams.write('Effective size {}: {}\n'.format(subsample_method,samples)) + file_hyperparams.write('Effective size {}: {}\n'.format(subsample_method,n_samples)) file_hyperparams.write('Warm up size {}: {}\n'.format(subsample_method,warmup)) file_hyperparams.write('Subsample size (m): {}\n'.format(m)) file_hyperparams.write('Block size (g): {}\n'.format(g)) - + file_hyperparams.write('Data size (n): {}\n'.format(feats.shape[0])) file_hyperparams.close() save_obj(mcmc.get_samples(),"{}/MCMC_Dict_Samples_{}.pkl".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss")),subsample_method)) @@ -81,23 +93,35 @@ def breast_cancer_data(): feats = (feats - feats.mean(0)) / feats.std(0) feats = jnp.hstack((feats, jnp.ones((feats.shape[0], 1)))) - return feats[:50], dataset.target[:50] + return feats[:100], dataset.target[:100] def higgs_data(): observations,features = _load_higgs() - return features[:10],observations[:10] + return features,observations def save_obj(obj, name): import _pickle as cPickle import bz2 with bz2.BZ2File(name, "wb") as f: cPickle.dump(obj, f) +def determine_best_sample_size(rng_key,feats,obs): + """Determine amount of effective sample size for z_map initialization""" + effective_sample_list=[5,10,20,30,50] + r_hat_average_list=[] + for effective_sample in effective_sample_list: + samples, r_hat_average = infer_nuts(rng_key,feats,obs,effective_sample,warmup=6) + r_hat_average_list.append(r_hat_average) + + plt.plot(effective_sample_list,r_hat_average_list) + plt.xlabel(r"Effective sample size") + plt.ylabel(r"$\hat{r}$") + plt.title("Determine best effective sample size for z_map") + plt.savefig("{}/Best_effective_size_z_map.png".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss")))) + + def Plot(samples_ECS,samples_NUTS): - import matplotlib.pyplot as plt - import pandas as pd - import seaborn as sns - import time + for sample in [0,7,15,25]: plt.figure(sample) @@ -113,6 +137,7 @@ def Plot(samples_ECS,samples_NUTS): plt.savefig("{}/KDE_plot_theta_{}.png".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss")),sample)) + def Folders(folder_name): """ Folder for all the generated images It will updated everytime!!! Save the previous folder before running again. Creates folder in current directory""" import os @@ -141,13 +166,20 @@ def Folders(folder_name): #feats, obs = breast_cancer_data() feats,obs = higgs_data() + + + now = datetime.datetime.now() Folders("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss"))) config.update('jax_disable_jit', True) + + + #determine_best_sample_size(rng_key,feats[:100],obs[:100]) + m = int(np_jax.sqrt(obs.shape[0])*2) g= int(m//3) - est_posterior_ECS = infer_hmcecs(rng_key, feats=feats, obs=obs, m =m,g=g,algo="NUTS",subsample_method="perturb") - est_posterior_NUTS = infer_hmcecs(rng_key, feats=feats, obs=obs, m =m,g=g,algo="NUTS") + est_posterior_ECS = infer_hmcecs(rng_key, feats=feats[:100], obs=obs[:100],n_samples=100,warmup=50, m =m,g=g,algo="NUTS",subsample_method="perturb") + est_posterior_NUTS = infer_hmcecs(rng_key, feats=feats[:100], obs=obs[:100], n_samples=100,warmup=50,m =m,g=g,algo="NUTS") Plot(est_posterior_ECS,est_posterior_NUTS) exit() diff --git a/numpyro/diagnostics.py b/numpyro/diagnostics.py index f85ee471d..ab2b9996d 100644 --- a/numpyro/diagnostics.py +++ b/numpyro/diagnostics.py @@ -161,6 +161,7 @@ def effective_sample_size(x): :return: effective sample size of ``x``. :rtype: numpy.ndarray """ + assert x.ndim >= 2 assert x.shape[1] >= 2 diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 7e0d2751f..28067e905 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -500,6 +500,7 @@ def get_extra_fields(self, group_by_chain=False): def print_summary(self, prob=0.9, exclude_deterministic=True): # Exclude deterministic sites by default sites = self._states[self._sample_field] + if isinstance(sites, dict) and exclude_deterministic: sites = {k: v for k, v in self._states[self._sample_field].items() if k in self._last_state.z} From 54bca1211051a67376ff24bbcda8b218cea75645 Mon Sep 17 00:00:00 2001 From: Lys Date: Thu, 1 Oct 2020 14:41:05 +0200 Subject: [PATCH 17/93] Small state fix --- examples/logistic_hmcecs.py | 33 ++++++++++++++++--------------- numpyro/contrib/hmcecs.py | 35 +++++++++++++++------------------ numpyro/contrib/hmcecs_utils.py | 9 ++------- 3 files changed, 35 insertions(+), 42 deletions(-) diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index 4b5ffc6a3..c8d4864c8 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -35,6 +35,8 @@ def model(feats, obs): """ n, m = feats.shape precision = numpyro.sample('precision', dist.continuous.Uniform(1, 4)) + #precision = numpyro.sample('precision', dist.continuous.HalfNormal(1)) + #precision = 0.5 theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), precision * jnp.ones(m))) @@ -59,13 +61,14 @@ def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None, print("Using {} samples".format(str(n_samples+warmup))) - print("Running NUTS for map estimation") if subsample_method=="perturb": - samples,r_hat_average = infer_nuts(map_key, feats, obs,samples=15,warmup=5) + print("Running NUTS for map estimation") + samples,r_hat_average = infer_nuts(map_key, feats, obs,samples=10,warmup=5) z_map = {key: value.mean(0) for key, value in samples.items()} + print("Running MCMC subsampling") + else: z_map = None - print("Running MCMC subsampling") start = time.time() kernel = HMC(model=model,z_ref=z_map,m=m,g=g,algo=algo,subsample_method=subsample_method) @@ -79,6 +82,7 @@ def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None, file_hyperparams.write('Subsample size (m): {}\n'.format(m)) file_hyperparams.write('Block size (g): {}\n'.format(g)) file_hyperparams.write('Data size (n): {}\n'.format(feats.shape[0])) + file_hyperparams.write('...........................................\n') file_hyperparams.close() save_obj(mcmc.get_samples(),"{}/MCMC_Dict_Samples_{}.pkl".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss")),subsample_method)) @@ -93,7 +97,7 @@ def breast_cancer_data(): feats = (feats - feats.mean(0)) / feats.std(0) feats = jnp.hstack((feats, jnp.ones((feats.shape[0], 1)))) - return feats[:100], dataset.target[:100] + return feats, dataset.target def higgs_data(): @@ -105,7 +109,7 @@ def save_obj(obj, name): with bz2.BZ2File(name, "wb") as f: cPickle.dump(obj, f) -def determine_best_sample_size(rng_key,feats,obs): +def Determine_best_sample_size(rng_key,feats,obs): """Determine amount of effective sample size for z_map initialization""" effective_sample_list=[5,10,20,30,50] r_hat_average_list=[] @@ -163,23 +167,20 @@ def Folders(folder_name): rng_key, feat_key, obs_key = jax.random.split(rng_key, 3) - #feats, obs = breast_cancer_data() - feats,obs = higgs_data() - - + feats, obs = breast_cancer_data() + #feats,obs = higgs_data() now = datetime.datetime.now() Folders("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss"))) config.update('jax_disable_jit', True) - - #determine_best_sample_size(rng_key,feats[:100],obs[:100]) - - m = int(np_jax.sqrt(obs.shape[0])*2) - g= int(m//3) - est_posterior_ECS = infer_hmcecs(rng_key, feats=feats[:100], obs=obs[:100],n_samples=100,warmup=50, m =m,g=g,algo="NUTS",subsample_method="perturb") - est_posterior_NUTS = infer_hmcecs(rng_key, feats=feats[:100], obs=obs[:100], n_samples=100,warmup=50,m =m,g=g,algo="NUTS") + #Determine_best_sample_size(rng_key,feats[:100],obs[:100]) + factor = 100 + m = int(np_jax.sqrt(obs[:factor].shape[0])*2) + g= int(m//6) + est_posterior_ECS = infer_hmcecs(rng_key, feats=feats[:factor], obs=obs[:factor],n_samples=10,warmup=5, m =m,g=g,algo="NUTS",subsample_method="perturb") + est_posterior_NUTS = infer_hmcecs(rng_key, feats=feats[:factor], obs=obs[:factor], n_samples=10,warmup=5,m =m,g=g,algo="NUTS") Plot(est_posterior_ECS,est_posterior_NUTS) exit() diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index a091eff51..22b86365b 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -108,12 +108,12 @@ def _update_block(rng_key, u, n, m, g): idxs_new = random.randint(rng_key_index, shape=(m // g,), minval=0, maxval=n) #chose block within the subsample to update u_new = jnp.zeros(m, jnp.dtype(u)) #empty array with size m - for i in range(m): #if index in the subsample // g = chosen block : pick new indexes from the subsample size #else not update: keep the same indexes u_new = ops.index_add(u_new, i, lax.cond(i // g == chosen_block, i, lambda _: idxs_new[i % (m // g)], i, lambda _: u[i])) + return u_new @@ -275,14 +275,11 @@ def init_kernel(init_params, else: if subsample_method == "perturb": kwargs = {} if model_kwargs is None else model_kwargs - proxy,proxy_u = covariate_fn(ll_ref, jac_all, hess_all) pe_fn = potential_fn_gen(model, model_args,model_kwargs, z, z_ref, n, m, proxy, proxy_u,u) - else: kwargs = {} if model_kwargs is None else model_kwargs pe_fn = potential_fn_gen(*model_args, **kwargs) - #gpe_fn = grad_potential_fn_gen(*model_args, **kwargs,) if grad_potential_fn_gen: kwargs = {} if model_kwargs is None else model_kwargs gpe_fn = grad_potential_fn_gen(*model_args, **kwargs) @@ -342,9 +339,7 @@ def _hmc_next(step_size, inverse_mass_matrix, vv_state, proxy, proxy_u = covariate_fn(ll_ref, jac_all, hess_all) pe_fn = potential_fn_gen(model, model_args, model_kwargs,vv_state.z, z_ref, n, m, proxy, proxy_u, u) kwargs = {} if model_kwargs is None else model_kwargs - gpe_fn = None else: - gpe_fn = None pe_fn = potential_fn_gen(*model_args, **model_kwargs) nonlocal vv_update _, vv_update = velocity_verlet(pe_fn, kinetic_fn) @@ -372,19 +367,15 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, if potential_fn_gen: nonlocal vv_update if grad_potential_fn_gen: - kwargs = {} if model_kwargs is None else model_kwargs gpe_fn = grad_potential_fn_gen(*model_args, **kwargs, ) pe_fn = potential_fn_gen(*model_args, **model_kwargs) - else: if subsample_method == "perturb": proxy, proxy_u = covariate_fn(ll_ref, jac_all, hess_all) pe_fn = potential_fn_gen(model, model_args, model_kwargs, vv_state.z, z_ref, n, m, proxy, proxy_u, u) - gpe_fn = None else: - gpe_fn = None pe_fn = potential_fn_gen(*model_args, **model_kwargs) _, vv_update = velocity_verlet(pe_fn, kinetic_fn) @@ -420,12 +411,14 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, """ model_kwargs = {} if model_kwargs is None else model_kwargs - + print(hmc_state._fields) + if subsample_method =="perturb": + model_args = model_args_sub(u,model_args) + #hmc_state = hmc_state.hmc_state #TODO: Probably not necessary since keys are merged rng_key, rng_key_momentum, rng_key_transition = random.split(hmc_state.rng_key, 3) r = momentum_generator(hmc_state.z, hmc_state.adapt_state.mass_matrix_sqrt, rng_key_momentum) vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad) - if subsample_method =="perturb": - model_args = model_args_sub(u,model_args) + vv_state, energy, num_steps, accept_prob, diverging = _next(hmc_state.adapt_state.step_size, hmc_state.adapt_state.inverse_mass_matrix, vv_state, @@ -567,9 +560,9 @@ def __init__(self, self.covariate_fn = covariate_fn def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ref): - "Compute the jacobian, hessian and gradient for all the data" - - rng_key_subsample, rng_key_model, rng_key_hmc_init, rng_key_potential, rng_key,rng_key_init_model = random.split(rng_key, 6) + "Compute the jacobian, hessian and log likelihood for all the data" + rng_key_subsample, rng_key_model, rng_key_hmc_init, rng_key_potential, rng_key_init_model = random.split( + rng_key, 5) self._n = model_args[0].shape[0] self._u = random.randint(rng_key, (self.m,), 0, self._n) @@ -583,7 +576,7 @@ def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ self._hess_all = hess_all.reshape((k, k)) # Initialize the model parameters init_params, potential_fn, postprocess_fn, model_trace = initialize_model( - rng_key, + rng_key_init_model, self._model, init_strategy=partial(init_near_values, values=self.z_ref), dynamic_args=True, @@ -598,6 +591,8 @@ def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ def _init_state(self, rng_key, model_args, model_kwargs, init_params): if self.subsample_method is not None: assert self.z_ref is not None, "Please provide a (i.e map) estimate for the parameters" + + # Choose the covariate calculation method if self.proxy == "svi": self.covariate_fn = lambda ll_ref, jac_all, hess_all:svi_proxy(ll_ref, jac_all, hess_all) @@ -690,7 +685,6 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg if self._potential_fn and init_params is None: raise ValueError('Valid value of `init_params` must be provided with' ' `potential_fn`.') - #TODO: assert the subsample method name is correct if self.subsample_method == "perturb": hmc_init_fn = lambda init_params,rng_key: self._init_fn(init_params=init_params, @@ -711,6 +705,7 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg jac_all=self._jac_all, z_ref=self.z_ref, hess_all = self._hess_all, + ll_u = self._ll_u, n=self._n, m=self.m, u = self._u, @@ -817,11 +812,11 @@ def sample(self, state, model_args, model_kwargs): """ if self.subsample_method == "perturb": - rng_key_subsample, rng_key_transition, rng_key_likelihood, rng_key = random.split( state.rng_key, 4) u_new = _update_block(rng_key_subsample, state.u, self._n, self.m, self.g) + # estimate likelihood of subsample with single block updated self._proxy, self._proxy_u = self.covariate_fn(self._ll_ref, self._jac_all, self._hess_all) @@ -842,10 +837,12 @@ def sample(self, state, model_args, model_kwargs): (u_new, llu_new), identity, (state.u, state.ll_u), identity) self._u = u #Just in case , but not necessary + self._ll_u = ll_u ######## UPDATE PARAMETERS ########## hmc_subsamplestate = HMCECSState(u=u, hmc_state=state.hmc_state, ll_u=ll_u,z_ref=self.z_ref) + hmc_subsamplestate = tuplemerge(hmc_subsamplestate._asdict(),state._asdict()) return self._sample_fn(hmc_subsamplestate, diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py index c2229e36e..ddc8d1a99 100644 --- a/numpyro/contrib/hmcecs_utils.py +++ b/numpyro/contrib/hmcecs_utils.py @@ -90,7 +90,9 @@ def reducer( accum, d ): def tuplemerge( *dictionaries ): from functools import reduce + merged = reduce( reducer, dictionaries, {} ) + return namedtuple('HMCCombinedState', merged )(**merged) # <==== Gist of the gist @@ -109,13 +111,6 @@ def potential_est(model, model_args,model_kwargs, z, z_ref, n, m, proxy, proxy_u return (-l_hat + .5 * sigma) - ll_prior - - - - - - - def velocity_verlet_hmcecs(potential_fn, kinetic_fn, grad_potential_fn=None): r""" Second order symplectic integrator that uses the velocity verlet algorithm From 765b3d6f453e02cf1a349d56835f074892592a0a Mon Sep 17 00:00:00 2001 From: Lys Date: Tue, 6 Oct 2020 14:33:57 +0200 Subject: [PATCH 18/93] Fixed : Proxies and init --- examples/Running_Tests.sh | 25 + examples/autoguide_hmcecs.py | 713 ++++++++++++++++++++++++++++ examples/logistic_hmcecs.py | 166 ++++--- examples/logistic_hmcecs_svi.py | 61 +++ numpyro/contrib/hmcecs.py | 240 +++++----- numpyro/contrib/hmcecs_utils.py | 90 ++-- numpyro/distributions/continuous.py | 2 +- 7 files changed, 1076 insertions(+), 221 deletions(-) create mode 100644 examples/Running_Tests.sh create mode 100644 examples/autoguide_hmcecs.py create mode 100644 examples/logistic_hmcecs_svi.py diff --git a/examples/Running_Tests.sh b/examples/Running_Tests.sh new file mode 100644 index 000000000..3c38ec25e --- /dev/null +++ b/examples/Running_Tests.sh @@ -0,0 +1,25 @@ +#!/bin/sh +#python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo NUTS -algo NUTS -map_init NUTS & +#python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo NUTS -algo NUTS -map_init HMC & +#python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo NUTS -algo NUTS -map_init SVI & #Slow, wrong number of epochs,repeat + +echo NUTS,HMC,NUTS +python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo NUTS -algo HMC -map_init NUTS & +echo NUTS,HMC,HMC +#python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo NUTS -algo HMC -map_init HMC & +echo NUTS,HMC,SVI +#python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo NUTS -algo HMC -map_init SVI & + +echo HMC,NUTS,NUTS +python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo HMC -algo NUTS -map_init NUTS & +echo HMC,NUTS,HMC +python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo HMC -algo NUTS -map_init HMC & +echo HMC,NUTS,SVI +python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo HMC -algo NUTS -map_init SVI & + +echo HMC,HMC,NUTS +python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo HMC -algo HMC -map_init NUTS & +echo HMC,HMC,HMC +python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo HMC -algo HMC -map_init HMC & +echo HMC,HMC,SVI +python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo HMC -algo HMC -map_init SVI & diff --git a/examples/autoguide_hmcecs.py b/examples/autoguide_hmcecs.py new file mode 100644 index 000000000..518badd91 --- /dev/null +++ b/examples/autoguide_hmcecs.py @@ -0,0 +1,713 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from pyro.infer.autoguide +from abc import ABC, abstractmethod +import warnings + +from jax import hessian, lax, random, tree_map +from jax.experimental import stax +from jax.flatten_util import ravel_pytree +import jax.numpy as jnp + +import numpyro +from numpyro import handlers +from numpyro.nn.auto_reg_nn import AutoregressiveNN +from numpyro.nn.block_neural_arn import BlockNeuralAutoregressiveNN +import numpyro.distributions as dist +from numpyro.distributions import constraints +from numpyro.distributions.flows import BlockNeuralAutoregressiveTransform, InverseAutoregressiveTransform +from numpyro.distributions.transforms import ( + AffineTransform, + ComposeTransform, + LowerCholeskyAffine, + PermuteTransform, + UnpackTransform, + biject_to +) +from numpyro.distributions.util import cholesky_of_inverse, sum_rightmost +from numpyro.infer.elbo import ELBO +from numpyro.infer.util import initialize_model, init_to_uniform, find_valid_initial_params +from numpyro.util import not_jax_tracer +from contextlib import ExitStack + +__all__ = [ + 'AutoContinuous', + 'AutoGuide', + 'AutoDiagonalNormal', + 'AutoLaplaceApproximation', + 'AutoLowRankMultivariateNormal', + 'AutoMultivariateNormal', + 'AutoBNAFNormal', + 'AutoIAFNormal', + 'AutoDelta' +] + +class ReinitGuide(ABC): + @abstractmethod + def init_params(self): + raise NotImplementedError + + @abstractmethod + def find_params(self, rng_keys, *args, **kwargs): + raise NotImplementedError +class AutoGuide(ABC): + """ + Base class for automatic guides. + + Derived classes must implement the :meth:`__call__` method. + + :param callable model: a pyro model + :param str prefix: a prefix that will be prefixed to all param internal sites + """ + + def __init__(self, model, prefix='auto', create_plates=None): + assert isinstance(prefix, str) + self.model = model + self.prefix = prefix + self.prototype_trace = None + self._prototype_frames = {} + self.create_plates = create_plates + + @abstractmethod + def __call__(self, *args, **kwargs): + """ + A guide with the same ``*args, **kwargs`` as the base ``model``. + + :return: A dict mapping sample site name to sampled value. + :rtype: dict + """ + raise NotImplementedError + + @abstractmethod + def sample_posterior(self, rng_key, params, *args, **kwargs): + """ + Generate samples from the approximate posterior over the latent + sites in the model. + + :param jax.random.PRNGKey rng_key: PRNG seed. + :param params: Current parameters of model and autoguide. + :param sample_shape: (keyword argument) shape of samples to be drawn. + :return: batch of samples from the approximate posterior. + """ + raise NotImplementedError + + @abstractmethod + def _sample_latent(self, *args, **kwargs): + """ + Samples an encoded latent given the same ``*args, **kwargs`` as the + base ``model``. + """ + raise NotImplementedError + + def _setup_prototype(self, *args, **kwargs): + # run the model so we can inspect its structure + rng_key = random.PRNGKey(0) + #rng_key = numpyro.rng_key("_{}_rng_key_setup".format(self.prefix)) + model = handlers.seed(self.model, rng_key) + self.prototype_trace = handlers.block(handlers.trace(model).get_trace)(*args, **kwargs) + self._args = args + self._kwargs = kwargs + for _, site in self.prototype_trace.items(): + if site['type'] != 'sample' or site['is_observed']: + continue + for frame in site['cond_indep_stack']: + if frame.vectorized: + self._prototype_frames[frame.name] = frame + else: + raise NotImplementedError("AutoGuide does not support sequential numpyro.plate") + + def _create_plates(self, *args, **kwargs): + if self.create_plates is None: + self.plates = {} + else: + plates = self.create_plates(*args, **kwargs) + if isinstance(plates, numpyro.plate): + plates = [plates] + assert all(isinstance(p, numpyro.plate) for p in plates), \ + "create_plates() returned a non-plate" + self.plates = {p.name: p for p in plates} + for name, frame in sorted(self._prototype_frames.items()): + if name not in self.plates: + self.plates[name] = numpyro.plate(name, frame.size, dim=frame.dim) + return self.plates + + +class AutoContinuous(AutoGuide): + """ + Base class for implementations of continuous-valued Automatic + Differentiation Variational Inference [1]. + + Each derived class implements its own :meth:`_get_posterior` method. + + Assumes model structure and latent dimension are fixed, and all latent + variables are continuous. + + **Reference:** + + 1. *Automatic Differentiation Variational Inference*, + Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. + Blei + + :param callable model: A NumPyro model. + :param str prefix: a prefix that will be prefixed to all param internal sites. + :param callable init_strategy: A per-site initialization function. + See :ref:`init_strategy` section for available functions. + """ + + def __init__(self, model, prefix="auto", init_strategy=init_to_uniform): + self.init_strategy = init_strategy + super(AutoContinuous, self).__init__(model, prefix=prefix) + + def _setup_prototype(self, *args, **kwargs): + rng_key = numpyro.rng_key("_{}_rng_key_setup".format(self.prefix)) + with handlers.block(): + init_params, _, self._postprocess_fn, self.prototype_trace = initialize_model( + rng_key, self.model, + init_strategy=self.init_strategy, + dynamic_args=False, + model_args=args, + model_kwargs=kwargs) + + self._init_latent, unpack_latent = ravel_pytree(init_params[0]) + # this is to match the behavior of Pyro, where we can apply + # unpack_latent for a batch of samples + self._unpack_latent = UnpackTransform(unpack_latent) + self.latent_dim = jnp.size(self._init_latent) + if self.latent_dim == 0: + raise RuntimeError('{} found no latent variables; Use an empty guide instead' + .format(type(self).__name__)) + + @abstractmethod + def _get_posterior(self): + raise NotImplementedError + + def _sample_latent(self, *args, **kwargs): + sample_shape = kwargs.pop('sample_shape', ()) + posterior = self._get_posterior() + return numpyro.sample("_{}_latent".format(self.prefix), posterior, sample_shape=sample_shape) + + def __call__(self, *args, **kwargs): + """ + An automatic guide with the same ``*args, **kwargs`` as the base ``model``. + + :return: A dict mapping sample site name to sampled value. + :rtype: dict + """ + if self.prototype_trace is None: + # run model to inspect the model structure + self._setup_prototype(*args, **kwargs) + + latent = self._sample_latent(*args, **kwargs) + + # unpack continuous latent samples + result = {} + + for name, unconstrained_value in self._unpack_latent(latent).items(): + site = self.prototype_trace[name] + transform = biject_to(site['fn'].support) + value = transform(unconstrained_value) + log_density = - transform.log_abs_det_jacobian(unconstrained_value, value) + event_ndim = len(site['fn'].event_shape) + log_density = sum_rightmost(log_density, + jnp.ndim(log_density) - jnp.ndim(value) + event_ndim) + delta_dist = dist.Delta(value, log_density=log_density, event_dim=event_ndim) + result[name] = numpyro.sample(name, delta_dist) + + return result + + def _unpack_and_constrain(self, latent_sample, params): + def unpack_single_latent(latent): + unpacked_samples = self._unpack_latent(latent) + # add param sites in model + unpacked_samples.update({k: v for k, v in params.items() if k in self.prototype_trace + and v['type'] == 'param'}) + return self._postprocess_fn(unpacked_samples) + + sample_shape = jnp.shape(latent_sample)[:-1] + if sample_shape: + latent_sample = jnp.reshape(latent_sample, (-1, jnp.shape(latent_sample)[-1])) + unpacked_samples = lax.map(unpack_single_latent, latent_sample) + return tree_map(lambda x: jnp.reshape(x, sample_shape + jnp.shape(x)[1:]), + unpacked_samples) + else: + return unpack_single_latent(latent_sample) + + def get_base_dist(self): + """ + Returns the base distribution of the posterior when reparameterized + as a :class:`~numpyro.distributions.distribution.TransformedDistribution`. This + should not depend on the model's `*args, **kwargs`. + """ + raise NotImplementedError + + def get_transform(self, params): + """ + Returns the transformation learned by the guide to generate samples from the unconstrained + (approximate) posterior. + + :param dict params: Current parameters of model and autoguide. + The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params` + method from :class:`~numpyro.infer.svi.SVI`. + :return: the transform of posterior distribution + :rtype: :class:`~numpyro.distributions.transforms.Transform` + """ + posterior = handlers.substitute(self._get_posterior, params)() + assert isinstance(posterior, dist.TransformedDistribution), \ + "posterior is not a transformed distribution" + if len(posterior.transforms) > 0: + return ComposeTransform(posterior.transforms) + else: + return posterior.transforms[0] + + def get_posterior(self, params): + """ + Returns the posterior distribution. + + :param dict params: Current parameters of model and autoguide. + The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params` + method from :class:`~numpyro.infer.svi.SVI`. + """ + base_dist = self.get_base_dist() + transform = self.get_transform(params) + return dist.TransformedDistribution(base_dist, transform) + + def sample_posterior(self, rng_key, params, sample_shape=()): + """ + Get samples from the learned posterior. + + :param jax.random.PRNGKey rng_key: random key to be used draw samples. + :param dict params: Current parameters of model and autoguide. + The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params` + method from :class:`~numpyro.infer.svi.SVI`. + :param tuple sample_shape: batch shape of each latent sample, defaults to (). + :return: a dict containing samples drawn the this guide. + :rtype: dict + """ + latent_sample = handlers.substitute( + handlers.seed(self._sample_latent, rng_key), params)(sample_shape=sample_shape) + return self._unpack_and_constrain(latent_sample, params) + + def median(self, params): + """ + Returns the posterior median value of each latent variable. + + :param dict params: A dict containing parameter values. + The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params` + method from :class:`~numpyro.infer.svi.SVI`. + :return: A dict mapping sample site name to median tensor. + :rtype: dict + """ + raise NotImplementedError + + def quantiles(self, params, quantiles): + """ + Returns posterior quantiles each latent variable. Example:: + + print(guide.quantiles(opt_state, [0.05, 0.5, 0.95])) + + :param dict params: A dict containing parameter values. + The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params` + method from :class:`~numpyro.infer.svi.SVI`. + :param list quantiles: A list of requested quantiles between 0 and 1. + :return: A dict mapping sample site name to a list of quantile values. + :rtype: dict + """ + raise NotImplementedError + + +class AutoDiagonalNormal(AutoContinuous): + """ + This implementation of :class:`AutoContinuous` uses a Normal distribution + with a diagonal covariance matrix to construct a guide over the entire + latent space. The guide does not depend on the model's ``*args, **kwargs``. + + Usage:: + + guide = AutoDiagonalNormal(model, ...) + svi = SVI(model, guide, ...) + """ + + def __init__(self, model, prefix="auto", init_strategy=init_to_uniform, init_scale=0.1): + if init_scale <= 0: + raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) + self._init_scale = init_scale + super().__init__(model, prefix, init_strategy) + + def _get_posterior(self): + loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent) + scale = numpyro.param('{}_scale'.format(self.prefix), + jnp.full(self.latent_dim, self._init_scale), + constraint=constraints.positive) + return dist.Normal(loc, scale) + + def get_base_dist(self): + return dist.Normal(jnp.zeros(self.latent_dim), 1).to_event(1) + + def get_transform(self, params): + loc = params['{}_loc'.format(self.prefix)] + scale = params['{}_scale'.format(self.prefix)] + return AffineTransform(loc, scale, domain=constraints.real_vector) + + def get_posterior(self, params): + """ + Returns a diagonal Normal posterior distribution. + """ + transform = self.get_transform(params) + return dist.Normal(transform.loc, transform.scale) + + def median(self, params): + loc = params['{}_loc'.format(self.prefix)] + return self._unpack_and_constrain(loc, params) + + def quantiles(self, params, quantiles): + quantiles = jnp.array(quantiles)[..., None] + latent = self.get_posterior(params).icdf(quantiles) + return self._unpack_and_constrain(latent, params) + + +class AutoMultivariateNormal(AutoContinuous): + """ + This implementation of :class:`AutoContinuous` uses a MultivariateNormal + distribution to construct a guide over the entire latent space. + The guide does not depend on the model's ``*args, **kwargs``. + + Usage:: + + guide = AutoMultivariateNormal(model, ...) + svi = SVI(model, guide, ...) + """ + + def __init__(self, model, prefix="auto", init_strategy=init_to_uniform, init_scale=0.1): + if init_scale <= 0: + raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) + self._init_scale = init_scale + super().__init__(model, prefix, init_strategy) + + def _get_posterior(self): + loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent) + scale_tril = numpyro.param('{}_scale_tril'.format(self.prefix), + jnp.identity(self.latent_dim) * self._init_scale, + constraint=constraints.lower_cholesky) + return dist.MultivariateNormal(loc, scale_tril=scale_tril) + + def get_base_dist(self): + return dist.Normal(jnp.zeros(self.latent_dim), 1).to_event(1) + + def get_transform(self, params): + loc = params['{}_loc'.format(self.prefix)] + scale_tril = params['{}_scale_tril'.format(self.prefix)] + return LowerCholeskyAffine(loc, scale_tril) #TODO: Changed MultivariateAffineTransform to LowerCholeskyAffine + + def get_posterior(self, params): + """ + Returns a multivariate Normal posterior distribution. + """ + transform = self.get_transform(params) + return dist.MultivariateNormal(transform.loc, transform.scale_tril) + + def median(self, params): + loc = params['{}_loc'.format(self.prefix)] + return self._unpack_and_constrain(loc, params) + + def quantiles(self, params, quantiles): + transform = self.get_transform(params) + quantiles = jnp.array(quantiles)[..., None] + latent = dist.Normal(transform.loc, jnp.diagonal(transform.scale_tril)).icdf(quantiles) + return self._unpack_and_constrain(latent, params) + + +class AutoLowRankMultivariateNormal(AutoContinuous): + """ + This implementation of :class:`AutoContinuous` uses a LowRankMultivariateNormal + distribution to construct a guide over the entire latent space. + The guide does not depend on the model's ``*args, **kwargs``. + + Usage:: + + guide = AutoLowRankMultivariateNormal(model, rank=2, ...) + svi = SVI(model, guide, ...) + """ + + def __init__(self, model, prefix="auto", init_strategy=init_to_uniform, init_scale=0.1, rank=None): + if init_scale <= 0: + raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) + self._init_scale = init_scale + self.rank = rank + super(AutoLowRankMultivariateNormal, self).__init__( + model, prefix=prefix, init_strategy=init_strategy) + + def _get_posterior(self, *args, **kwargs): + rank = int(round(self.latent_dim ** 0.5)) if self.rank is None else self.rank + loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent) + cov_factor = numpyro.param('{}_cov_factor'.format(self.prefix), jnp.zeros((self.latent_dim, rank))) + scale = numpyro.param('{}_scale'.format(self.prefix), + jnp.full(self.latent_dim, self._init_scale), + constraint=constraints.positive) + cov_diag = scale * scale + cov_factor = cov_factor * scale[..., None] + return dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag) + + def get_base_dist(self): + return dist.Normal(jnp.zeros(self.latent_dim), 1).to_event(1) + + def get_transform(self, params): + posterior = self.get_posterior(params) + return LowerCholeskyAffine(posterior.loc, posterior.scale_tril) + + def get_posterior(self, params): + """ + Returns a lowrank multivariate Normal posterior distribution. + """ + loc = params['{}_loc'.format(self.prefix)] + cov_factor = params['{}_cov_factor'.format(self.prefix)] + scale = params['{}_scale'.format(self.prefix)] + cov_diag = scale * scale + cov_factor = cov_factor * scale[..., None] + return dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag) + + def median(self, params): + loc = params['{}_loc'.format(self.prefix)] + return self._unpack_and_constrain(loc, params) + + def quantiles(self, params, quantiles): + transform = self.get_transform(params) + quantiles = jnp.array(quantiles)[..., None] + latent = dist.Normal(transform.loc, jnp.diagonal(transform.scale_tril)).icdf(quantiles) + return self._unpack_and_constrain(latent, params) + + +class AutoLaplaceApproximation(AutoContinuous): + r""" + Laplace approximation (quadratic approximation) approximates the posterior + :math:`\log p(z | x)` by a multivariate normal distribution in the + unconstrained space. Under the hood, it uses Delta distributions to + construct a MAP guide over the entire (unconstrained) latent space. Its + covariance is given by the inverse of the hessian of :math:`-\log p(x, z)` + at the MAP point of `z`. + + Usage:: + + guide = AutoLaplaceApproximation(model, ...) + svi = SVI(model, guide, ...) + """ + + def _setup_prototype(self, *args, **kwargs): + super(AutoLaplaceApproximation, self)._setup_prototype(*args, **kwargs) + + def loss_fn(params): + # we are doing maximum likelihood, so only require `num_particles=1` and an arbitrary rng_key. + return ELBO().loss(random.PRNGKey(0), params, self.model, self, *args, **kwargs) + + self._loss_fn = loss_fn + + def _get_posterior(self, *args, **kwargs): + # sample from Delta guide + loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent) + return dist.Delta(loc, event_dim=1) + + def get_base_dist(self): + return dist.Normal(jnp.zeros(self.latent_dim), 1).to_event(1) + + def get_transform(self, params): + def loss_fn(z): + params1 = params.copy() + params1['{}_loc'.format(self.prefix)] = z + return self._loss_fn(params1) + + loc = params['{}_loc'.format(self.prefix)] + precision = hessian(loss_fn)(loc) + scale_tril = cholesky_of_inverse(precision) + if not_jax_tracer(scale_tril): + if jnp.any(jnp.isnan(scale_tril)): + warnings.warn("Hessian of log posterior at the MAP point is singular. Posterior" + " samples from AutoLaplaceApproxmiation will be constant (equal to" + " the MAP point).") + scale_tril = jnp.where(jnp.isnan(scale_tril), 0., scale_tril) + return LowerCholeskyAffine(loc, scale_tril) + + def get_posterior(self, params): + """ + Returns a multivariate Normal posterior distribution. + """ + transform = self.get_transform(params) + return dist.MultivariateNormal(transform.loc, scale_tril=transform.scale_tril) + + def sample_posterior(self, rng_key, params, sample_shape=()): + latent_sample = self.get_posterior(params).sample(rng_key, sample_shape) + return self._unpack_and_constrain(latent_sample, params) + + def median(self, params): + loc = params['{}_loc'.format(self.prefix)] + return self._unpack_and_constrain(loc, params) + + def quantiles(self, params, quantiles): + transform = self.get_transform(params) + quantiles = jnp.array(quantiles)[..., None] + latent = dist.Normal(transform.loc, jnp.diagonal(transform.scale_tril)).icdf(quantiles) + return self._unpack_and_constrain(latent, params) + + +class AutoIAFNormal(AutoContinuous): + """ + This implementation of :class:`AutoContinuous` uses a Diagonal Normal + distribution transformed via a + :class:`~numpyro.distributions.flows.InverseAutoregressiveTransform` + to construct a guide over the entire latent space. The guide does not + depend on the model's ``*args, **kwargs``. + + Usage:: + + guide = AutoIAFNormal(model, hidden_dims=[20], skip_connections=True, ...) + svi = SVI(model, guide, ...) + + :param callable model: a generative model. + :param str prefix: a prefix that will be prefixed to all param internal sites. + :param callable init_strategy: A per-site initialization function. + :param int num_flows: the number of flows to be used, defaults to 3. + :param list hidden_dims: the dimensionality of the hidden units per layer. + Defaults to ``[latent_dim, latent_dim]``. + :param bool skip_connections: whether to add skip connections from the input to the + output of each flow. Defaults to False. + :param callable nonlinearity: the nonlinearity to use in the feedforward network. + Defaults to :func:`jax.experimental.stax.Elu`. + """ + + def __init__(self, model, prefix="auto", init_strategy=init_to_uniform, + num_flows=3, hidden_dims=None, skip_connections=False, nonlinearity=stax.Elu): + self.num_flows = num_flows + # 2-layer, stax.Elu, skip_connections=False by default following the experiments in + # IAF paper (https://arxiv.org/abs/1606.04934) + # and Neutra paper (https://arxiv.org/abs/1903.03704) + self._hidden_dims = hidden_dims + self._skip_connections = skip_connections + self._nonlinearity = nonlinearity + super(AutoIAFNormal, self).__init__(model, prefix=prefix, init_strategy=init_strategy) + + def _get_posterior(self): + if self.latent_dim == 1: + raise ValueError('latent dim = 1. Consider using AutoDiagonalNormal instead') + hidden_dims = [self.latent_dim, self.latent_dim] if self._hidden_dims is None else self._hidden_dims + flows = [] + for i in range(self.num_flows): + if i > 0: + flows.append(PermuteTransform(jnp.arange(self.latent_dim)[::-1])) + arn = AutoregressiveNN(self.latent_dim, hidden_dims, + permutation=jnp.arange(self.latent_dim), + skip_connections=self._skip_connections, + nonlinearity=self._nonlinearity) + arnn = numpyro.module('{}_arn__{}'.format(self.prefix, i), arn, (self.latent_dim,)) + flows.append(InverseAutoregressiveTransform(arnn)) + return dist.TransformedDistribution(self.get_base_dist(), flows) + + def get_base_dist(self): + return dist.Normal(jnp.zeros(self.latent_dim), 1).to_event(1) + + +class AutoBNAFNormal(AutoContinuous): + """ + This implementation of :class:`AutoContinuous` uses a Diagonal Normal + distribution transformed via a + :class:`~numpyro.distributions.flows.BlockNeuralAutoregressiveTransform` + to construct a guide over the entire latent space. The guide does not + depend on the model's ``*args, **kwargs``. + + Usage:: + + guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[50, 50], ...) + svi = SVI(model, guide, ...) + + **References** + + 1. *Block Neural Autoregressive Flow*, + Nicola De Cao, Ivan Titov, Wilker Aziz + + :param callable model: a generative model. + :param str prefix: a prefix that will be prefixed to all param internal sites. + :param callable init_strategy: A per-site initialization function. + :param int num_flows: the number of flows to be used, defaults to 3. + :param list hidden_factors: Hidden layer i has ``hidden_factors[i]`` hidden units per + input dimension. This corresponds to both :math:`a` and :math:`b` in reference [1]. + The elements of hidden_factors must be integers. + """ + + def __init__(self, model, prefix="auto", init_strategy=init_to_uniform, num_flows=1, + hidden_factors=[8, 8]): + self.num_flows = num_flows + self._hidden_factors = hidden_factors + super(AutoBNAFNormal, self).__init__(model, prefix=prefix, init_strategy=init_strategy) + + def _get_posterior(self): + if self.latent_dim == 1: + raise ValueError('latent dim = 1. Consider using AutoDiagonalNormal instead') + flows = [] + for i in range(self.num_flows): + if i > 0: + flows.append(PermuteTransform(jnp.arange(self.latent_dim)[::-1])) + residual = "gated" if i < (self.num_flows - 1) else None + arn = BlockNeuralAutoregressiveNN(self.latent_dim, self._hidden_factors, residual) + arnn = numpyro.module('{}_arn__{}'.format(self.prefix, i), arn, (self.latent_dim,)) + flows.append(BlockNeuralAutoregressiveTransform(arnn)) + return dist.TransformedDistribution(self.get_base_dist(), flows) + + def get_base_dist(self): + return dist.Normal(jnp.zeros(self.latent_dim), 1).to_event(1) + + +class AutoDelta(AutoGuide, ReinitGuide): + def __init__(self, model, *, prefix='auto', init_strategy=init_to_uniform(), create_plates=None): + self.init_strategy = init_strategy + self._param_map = None + self._init_params = None + super(AutoDelta, self).__init__(model, prefix=prefix, create_plates=create_plates) + + def init_params(self): + return self._init_params + + def __call__(self, *args, **kwargs): + if self.prototype_trace is None: + self._setup_prototype(*args, **kwargs) + plates = self._create_plates(*args, **kwargs) + result = {} + for name, site in self.prototype_trace.items(): + if site['type'] != 'sample' or site['is_observed']: + continue + with ExitStack() as stack: + for frame in site['cond_indep_stack']: + stack.enter_context(plates[frame.name]) + if site['intermediates']: + event_dim = len(site['fn'].base_dist.event_shape) + else: + event_dim = len(site['fn'].event_shape) + param_name, param_val, constraint = self._param_map[name] + val_param = numpyro.param(param_name, param_val, constraint=constraint) + result[name] = numpyro.sample(name, dist.Delta(val_param, event_dim=event_dim)) + return result + + def _sample_latent(self, *args, **kwargs): + raise NotImplementedError + + def sample_posterior(self, rng_key, *args, **kwargs): + raise NotImplementedError + + def find_params(self, rng_keys, *args, **kwargs): + params = {site['name']: site['value'] for site in self.prototype_trace.values() + if site['type'] == 'sample' and not site['is_observed']} + (init_params, _, _), _ = handlers.block(find_valid_initial_params)(rng_keys, self.model, + init_strategy=self.init_strategy, + model_args=args, + model_kwargs=kwargs, + prototype_params=params) + for name, site in self.prototype_trace.items(): + if site['type'] == 'sample' and not site['is_observed']: + param_name = "{}_{}".format(self.prefix, name) + param_val = biject_to(site['fn'].support)(init_params[name]) + params[name] = (param_name, param_val, site['fn'].support) + self._param_map = params + self._init_params = {param: (val, constr) for param, val, constr in self._param_map.values()} + + def _setup_prototype(self, *args, **kwargs): + super(AutoDelta, self)._setup_prototype(*args, **kwargs) + #rng_key = numpyro.rng_key("_{}_rng_key_init".format(self.prefix)) + rng_key = random.PRNGKey(1) + self.find_params(rng_key, *args, **kwargs) \ No newline at end of file diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index c8d4864c8..767609d2d 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -1,4 +1,5 @@ """ Logistic regression model as implemetned in https://arxiv.org/pdf/1708.00955.pdf with Higgs Dataset """ +#!/usr/bin/env python import jax import jax.numpy as jnp import numpyro @@ -7,15 +8,18 @@ import sys, os from jax.config import config import datetime,time +import argparse sys.path.append('/home/lys/Dropbox/PhD/numpyro/numpyro/contrib/') sys.path.append('/home/lys/Dropbox/PhD/numpyro/numpyro/examples/') from hmcecs import HMC #from numpyro.contrib.hmcecs import HMC -#numpyro.set_host_device_count(2) + from sklearn.datasets import load_breast_cancer -from datasets import _load_higgs +#from datasets import _load_higgs +from numpyro.examples.datasets import _load_higgs +from logistic_hmcecs_svi import svi_map import jax.numpy as np_jax import matplotlib.pyplot as plt import pandas as pd @@ -26,25 +30,35 @@ numpyro.set_platform("cpu") -# TODO: import Higgs data! ---> http://archive.ics.uci.edu/ml/machine-learning-databases/00280/ -# https://towardsdatascience.com/identifying-higgs-bosons-from-background-noise-pyspark-d7983234207e +def breast_cancer_data(): + dataset = load_breast_cancer() + feats = dataset.data + feats = (feats - feats.mean(0)) / feats.std(0) + feats = jnp.hstack((feats, jnp.ones((feats.shape[0], 1)))) + + return feats, dataset.target + + +def higgs_data(): + observations,features = _load_higgs() + return features,observations +def save_obj(obj, name): + import _pickle as cPickle + import bz2 + with bz2.BZ2File(name, "wb") as f: + cPickle.dump(obj, f) def model(feats, obs): """ Logistic regression model """ n, m = feats.shape - precision = numpyro.sample('precision', dist.continuous.Uniform(1, 4)) - #precision = numpyro.sample('precision', dist.continuous.HalfNormal(1)) - - #precision = 0.5 - theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), precision * jnp.ones(m))) + theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), 2 * jnp.ones(m))) numpyro.sample('obs', dist.Bernoulli(logits=jnp.matmul(feats, theta)), obs=obs) - def infer_nuts(rng_key, feats, obs, samples, warmup ): - kernel = NUTS(model=model) + kernel = NUTS(model=model,target_accept_prob=0.8) mcmc = MCMC(kernel, num_warmup=warmup, num_samples=samples) mcmc.run(rng_key, feats, obs) #mcmc.print_summary() @@ -55,27 +69,55 @@ def infer_nuts(rng_key, feats, obs, samples, warmup ): return mcmc.get_samples(), r_hat_average -def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None,algo="NUTS",subsample_method=None ): + + +def infer_hmc(rng_key, feats, obs, samples, warmup ): + kernel = HMC(model=model,target_accept_prob=0.8) + mcmc = MCMC(kernel, num_warmup=warmup, num_samples=samples) + mcmc.run(rng_key, feats, obs) + #mcmc.print_summary() + samples = mcmc.get_samples() + samples = tree_map(lambda x: x[None, ...], samples) + r_hat_average = np_jax.sum(summary(samples)["theta"]["r_hat"])/len(summary(samples)["theta"]["r_hat"]) + + return mcmc.get_samples(), r_hat_average + + + + + +def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None,algo="NUTS",subsample_method=None,map_method=None,num_epochs=None ): hmcecs_key, map_key = jax.random.split(rng_key) n, _ = feats.shape print("Using {} samples".format(str(n_samples+warmup))) - if subsample_method=="perturb": - print("Running NUTS for map estimation") - samples,r_hat_average = infer_nuts(map_key, feats, obs,samples=10,warmup=5) - z_map = {key: value.mean(0) for key, value in samples.items()} + if map_method == "NUTS": + print("Running NUTS for map estimation") + samples,r_hat_average = infer_nuts(map_key, feats, obs,samples=100,warmup=50) + z_map = {key: value.mean(0) for key, value in samples.items()} + if map_method == "HMC": + print("Running HMC for map estimation") + samples, r_hat_average = infer_hmc(map_key, feats, obs, samples=100, warmup=50) + z_map = {key: value.mean(0) for key, value in samples.items()} + + if map_method == "SVI": + print("Running SVI for map estimation") + z_map = svi_map(model, map_key, feats=feats, obs=obs,num_epochs=num_epochs,batch_size = m) + z_map = {k[5:]: v for k, v in z_map.items()} #highlight: [5:] is to skip the "auto" part + save_obj(z_map,"{}/MAP_Dict_Samples_{}.pkl".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")), map_method)) print("Running MCMC subsampling") else: z_map = None + start = time.time() - kernel = HMC(model=model,z_ref=z_map,m=m,g=g,algo=algo,subsample_method=subsample_method) + kernel = HMC(model=model,z_ref=z_map,m=m,g=g,algo=algo,subsample_method=subsample_method,target_accept_prob=0.8) mcmc = MCMC(kernel,num_warmup=warmup,num_samples=n_samples,num_chains=1) mcmc.run(rng_key,feats,obs) stop = time.time() - file_hyperparams = open("PLOTS_{}/Hyperparameters_{}.txt".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss"),now.strftime("%Y_%m_%d_%Hh%Mmin%Ss")), "a") + file_hyperparams = open("PLOTS_{}/Hyperparameters_{}.txt".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms"),now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")), "a") file_hyperparams.write('MCMC/NUTS elapsed time {}: {} \n'.format(subsample_method,time.time() - start)) file_hyperparams.write('Effective size {}: {}\n'.format(subsample_method,n_samples)) file_hyperparams.write('Warm up size {}: {}\n'.format(subsample_method,warmup)) @@ -85,30 +127,12 @@ def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None, file_hyperparams.write('...........................................\n') file_hyperparams.close() - save_obj(mcmc.get_samples(),"{}/MCMC_Dict_Samples_{}.pkl".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss")),subsample_method)) + save_obj(mcmc.get_samples(),"{}/MCMC_Dict_Samples_{}.pkl".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")),subsample_method)) return mcmc.get_samples() -def breast_cancer_data(): - dataset = load_breast_cancer() - feats = dataset.data - feats = (feats - feats.mean(0)) / feats.std(0) - feats = jnp.hstack((feats, jnp.ones((feats.shape[0], 1)))) - - return feats, dataset.target - - -def higgs_data(): - observations,features = _load_higgs() - return features,observations -def save_obj(obj, name): - import _pickle as cPickle - import bz2 - with bz2.BZ2File(name, "wb") as f: - cPickle.dump(obj, f) - def Determine_best_sample_size(rng_key,feats,obs): """Determine amount of effective sample size for z_map initialization""" effective_sample_list=[5,10,20,30,50] @@ -121,24 +145,24 @@ def Determine_best_sample_size(rng_key,feats,obs): plt.xlabel(r"Effective sample size") plt.ylabel(r"$\hat{r}$") plt.title("Determine best effective sample size for z_map") - plt.savefig("{}/Best_effective_size_z_map.png".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss")))) + plt.savefig("{}/Best_effective_size_z_map.png".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")))) -def Plot(samples_ECS,samples_NUTS): +def Plot(samples_ECS,samples_NUTS,ecs_algo,algo): for sample in [0,7,15,25]: plt.figure(sample) #samples = pd.DataFrame.from_records(samples,index="theta") - sns.kdeplot(data=samples_ECS["theta"][sample],color="r",label="ECS") - sns.kdeplot(data=samples_NUTS["theta"][sample],color="b",label="NUTS") + sns.kdeplot(data=samples_ECS["theta"][sample],color="r",label="ECS-{}".format(ecs_algo)) + sns.kdeplot(data=samples_NUTS["theta"][sample],color="b",label="{}".format(algo)) plt.xlabel(r"$\theta") plt.ylabel("Density") plt.legend() plt.title(r"$\theta$ {} Density plot".format(sample)) - plt.savefig("{}/KDE_plot_theta_{}.png".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss")),sample)) + plt.savefig("{}/KDE_plot_theta_{}.png".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")),sample)) @@ -162,27 +186,61 @@ def Folders(folder_name): shutil.rmtree(newpath) # removes all the subdirectories! os.makedirs(newpath,0o777) +def Tests(map_method,ecs_algo,algo,n_samples,n_warmup,epochs): + factor = 100 + m = int(np_jax.sqrt(obs[:factor].shape[0])*2) + g= 5 + est_posterior_ECS = infer_hmcecs(rng_key, feats=feats[:factor], obs=obs[:factor], + n_samples=n_samples, + warmup=n_warmup, + m =m,g=g, + algo=ecs_algo, + subsample_method="perturb", + map_method = map_method, + num_epochs=epochs) + est_posterior_NUTS = infer_hmcecs(rng_key, feats=feats[:factor], obs=obs[:factor], n_samples=n_samples,warmup=n_warmup,m =m,g=g,algo=algo) + + Plot(est_posterior_ECS,est_posterior_NUTS,ecs_algo,algo) + if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('-num_samples', nargs='?', default=100, type=int) + parser.add_argument('-num_warmup', nargs='?', default=50, type=int) + parser.add_argument('-ecs_algo', nargs='?', default="NUTS", type=str) + parser.add_argument('-algo', nargs='?', default="HMC", type=str) + parser.add_argument('-map_init', nargs='?', default="NUTS", type=str) + parser.add_argument("-epochs",default=100,type=int) + args = parser.parse_args() + + rng_key = jax.random.PRNGKey(37) + rng_key, feat_key, obs_key = jax.random.split(rng_key, 3) + now = datetime.datetime.now() + Folders("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms"))) + file_hyperparams = open("PLOTS_{}/Hyperparameters_{}.txt".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms"), + now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")), "a") + file_hyperparams.write('ECS algo : {} \n'.format(args.ecs_algo)) + file_hyperparams.write('algo : {} \n'.format(args.algo)) + file_hyperparams.write('MAP init : {} \n'.format(args.map_init)) + file_hyperparams.write('SVI epochs : {} \n'.format(args.epochs)) - feats, obs = breast_cancer_data() - #feats,obs = higgs_data() + higgs = True + if higgs: + feats,obs = higgs_data() + file_hyperparams.write('Dataset : HIGGS \n') - now = datetime.datetime.now() - Folders("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss"))) + else: + feats, obs = breast_cancer_data() + file_hyperparams.write('Dataset : BREAST CANCER DATA \n') + + file_hyperparams.close() config.update('jax_disable_jit', True) #Determine_best_sample_size(rng_key,feats[:100],obs[:100]) - factor = 100 - m = int(np_jax.sqrt(obs[:factor].shape[0])*2) - g= int(m//6) - est_posterior_ECS = infer_hmcecs(rng_key, feats=feats[:factor], obs=obs[:factor],n_samples=10,warmup=5, m =m,g=g,algo="NUTS",subsample_method="perturb") - est_posterior_NUTS = infer_hmcecs(rng_key, feats=feats[:factor], obs=obs[:factor], n_samples=10,warmup=5,m =m,g=g,algo="NUTS") + Tests(args.map_init,args.ecs_algo,args.algo,args.num_samples,args.num_warmup,args.epochs) - Plot(est_posterior_ECS,est_posterior_NUTS) - exit() - predictions = Predictive(model, posterior_samples=est_posterior_ECS)(rng_key, feats, None)['obs'] diff --git a/examples/logistic_hmcecs_svi.py b/examples/logistic_hmcecs_svi.py new file mode 100644 index 000000000..ebe7ad35e --- /dev/null +++ b/examples/logistic_hmcecs_svi.py @@ -0,0 +1,61 @@ +import jax.numpy as np_jax +import numpy as np +from jax import lax +def load_dataset(observations,features, batch_size=None, shuffle=True): + + arrays = (observations,features) + num_records = observations.shape[0] + idxs = np_jax.arange(num_records) + if not batch_size: + batch_size = num_records + + def init(): + return num_records // batch_size, np.random.permutation(idxs) if shuffle else idxs + + def get_batch(i=0, idxs=idxs): + ret_idx = lax.dynamic_slice_in_dim(idxs, i * batch_size, batch_size) + batch_data = np_jax.take(arrays[0], ret_idx, axis=0) + batch_matrix =np_jax.take(np_jax.take(arrays[1], ret_idx, axis=0),ret_idx,axis=1) + return (batch_data,batch_matrix) + return init, get_batch +def svi_map(model, rng_key, feats,obs,num_epochs,batch_size): + """ + MLE in numpy: https://medium.com/@rrfd/what-is-maximum-likelihood-estimation-examples-in-python-791153818030i + Cost function: -log (likelihood(parameters|data) + Calculate pdf of the parameter|data under the distribution + """ + from jax import random, jit + from numpyro import optim + from numpyro.infer.elbo import RenyiELBO + from numpyro.infer.svi import SVI + from numpyro.util import fori_loop + import time + import numpyro + numpyro.set_platform("gpu") + + from autoguide_hmcecs import AutoDelta + n, _ = feats.shape + guide = AutoDelta(model) + loss = RenyiELBO(alpha=2, num_particles=1) + svi = SVI(model, guide, optim.Adam(0.001), loss=loss) + svi_state = svi.init( rng_key,feats,obs) + train_init, train_fetch = load_dataset(obs,feats, batch_size=batch_size) + num_train, train_idx = train_init() + + @jit + def epoch_train(svi_state): + def body_fn(i, val): + batch_obs = train_fetch(i, train_idx)[0] + batch_feats = train_fetch(i, train_idx)[1] + loss_sum, svi_state = val + svi_state, loss = svi.update(svi_state, feats,obs) + loss_sum += loss + return loss_sum, svi_state + + return fori_loop(0, n, body_fn, (0., svi_state)) + + for i in range(num_epochs): + t_start = time.time() + train_loss, svi_state = epoch_train(svi_state) + print("Epoch {}: loss = {} ({:.2f} s.)".format(i, train_loss, time.time() - t_start)) + return svi.get_params(svi_state) \ No newline at end of file diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index 22b86365b..41edf5007 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -24,9 +24,9 @@ import sys sys.path.append('/home/lys/Dropbox/PhD/numpyro/numpyro/contrib/') -from hmcecs_utils import potential_est,log_density_hmcecs, \ - velocity_verlet_hmcecs, init_near_values,tuplemerge,\ - model_args_sub,model_kwargs_sub,taylor_proxy,svi_proxy,neural_proxy +from hmcecs_utils import potential_est, init_near_values,tuplemerge,\ + model_args_sub,model_kwargs_sub,taylor_proxy,svi_proxy,neural_proxy,log_density_obs_hmcecs,log_density_prior_hmcecs + HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob', 'mean_accept_prob', 'diverging', 'adapt_state','rng_key']) #HMCECSState = namedtuple("HMCECState",["u","hmc_state","z_ref","ll_ref","jac_all","hess_all","ll_u"]) @@ -117,7 +117,7 @@ def _update_block(rng_key, u, n, m, g): return u_new -def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, grad_potential_fn_gen=None,covariate_fn=None,algo='NUTS'): +def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, grad_potential_fn_gen=None,algo='NUTS'): r""" Hamiltonian Monte Carlo inference, using either fixed number of steps or the No U-Turn Sampler (NUTS) with adaptive path length. @@ -222,7 +222,8 @@ def init_kernel(init_params, u= None, rng_key=random.PRNGKey(0), subsample_method=None, - covariate_fn = None): + proxy_fn=None, + proxy_u_fn = None): """ Initializes the HMC sampler. @@ -275,8 +276,8 @@ def init_kernel(init_params, else: if subsample_method == "perturb": kwargs = {} if model_kwargs is None else model_kwargs - proxy,proxy_u = covariate_fn(ll_ref, jac_all, hess_all) - pe_fn = potential_fn_gen(model, model_args,model_kwargs, z, z_ref, n, m, proxy, proxy_u,u) + + pe_fn = potential_fn_gen(model, model_args,model_kwargs, z, z_ref, n, m, proxy_fn, proxy_u_fn,u) else: kwargs = {} if model_kwargs is None else model_kwargs pe_fn = potential_fn_gen(*model_args, **kwargs) @@ -323,11 +324,14 @@ def init_kernel(init_params, hmc_state = tuplemerge(hmc_sub_state._asdict(),hmc_state._asdict()) + return device_put(hmc_state) def _hmc_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key,subsample_method, - model,ll_ref,jac_all,z,z_ref,hess_all,ll_u,u,n,m): + proxy_fn = None, proxy_u_fn = None, + model = None, ll_ref = None, jac_all = None, z = None, z_ref = None, hess_all = None, + ll_u = None, u = None, n = None, m = None): if potential_fn_gen: if grad_potential_fn_gen: kwargs = {} if model_kwargs is None else model_kwargs @@ -336,8 +340,7 @@ def _hmc_next(step_size, inverse_mass_matrix, vv_state, else: if subsample_method == "perturb": - proxy, proxy_u = covariate_fn(ll_ref, jac_all, hess_all) - pe_fn = potential_fn_gen(model, model_args, model_kwargs,vv_state.z, z_ref, n, m, proxy, proxy_u, u) + pe_fn = potential_fn_gen(model, model_args, model_kwargs,vv_state.z, z_ref, n, m, proxy_fn, proxy_u_fn, u) kwargs = {} if model_kwargs is None else model_kwargs else: pe_fn = potential_fn_gen(*model_args, **model_kwargs) @@ -359,10 +362,12 @@ def _hmc_next(step_size, inverse_mass_matrix, vv_state, vv_state, energy = cond(transition, (vv_state_new, energy_new), identity, (vv_state, energy_old), identity) + return vv_state, energy, num_steps, accept_prob, diverging def _nuts_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key,subsample_method, + proxy_fn=None,proxy_u_fn=None, model=None,ll_ref=None,jac_all=None,z = None,z_ref=None,hess_all=None,ll_u=None,u=None,n=None,m=None): if potential_fn_gen: nonlocal vv_update @@ -372,9 +377,8 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, pe_fn = potential_fn_gen(*model_args, **model_kwargs) else: if subsample_method == "perturb": - proxy, proxy_u = covariate_fn(ll_ref, jac_all, hess_all) - pe_fn = potential_fn_gen(model, model_args, model_kwargs, vv_state.z, z_ref, n, m, proxy, - proxy_u, u) + pe_fn = potential_fn_gen(model, model_args, model_kwargs, vv_state.z, z_ref, n, m, proxy_fn, + proxy_u_fn, u) else: pe_fn = potential_fn_gen(*model_args, **model_kwargs) _, vv_update = velocity_verlet(pe_fn, kinetic_fn) @@ -394,7 +398,7 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, _next = _nuts_next if algo == 'NUTS' else _hmc_next def sample_kernel(hmc_state,model_args=(),model_kwargs=None, - subsample_method=None,covariate_fn=None, + subsample_method=None,proxy_fn=None,proxy_u_fn=None, model=None,ll_ref=None,jac_all=None, z=None,z_ref=None,hess_all=None,ll_u=None, u=None,n=None,m=None,): @@ -411,10 +415,8 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, """ model_kwargs = {} if model_kwargs is None else model_kwargs - print(hmc_state._fields) if subsample_method =="perturb": model_args = model_args_sub(u,model_args) - #hmc_state = hmc_state.hmc_state #TODO: Probably not necessary since keys are merged rng_key, rng_key_momentum, rng_key_transition = random.split(hmc_state.rng_key, 3) r = momentum_generator(hmc_state.z, hmc_state.adapt_state.mass_matrix_sqrt, rng_key_momentum) vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad) @@ -426,6 +428,8 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, model_kwargs, rng_key_transition, subsample_method, + proxy_fn, + proxy_u_fn, model,ll_ref,jac_all,z,z_ref,hess_all,ll_u,u,n,m) # not update adapt_state after warmup phase adapt_state = cond(hmc_state.i < wa_steps, @@ -520,8 +524,7 @@ def __init__(self, m= None, g = None, z_ref= None, - algo = "HMC", - covariate_fn = None, + algo = "HMC" ): if not (model is None) ^ (potential_fn is None): raise ValueError('Only one of `model` or `potential_fn` must be specified.') @@ -556,42 +559,31 @@ def __init__(self, self._postprocess_fn = None self._sample_fn = None self._subsample_fn = None - self.proxy = "taylor", - self.covariate_fn = covariate_fn + self.proxy = "taylor" + self._proxy_fn = None + self._proxy_u_fn = None def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ref): "Compute the jacobian, hessian and log likelihood for all the data" - rng_key_subsample, rng_key_model, rng_key_hmc_init, rng_key_potential, rng_key_init_model = random.split( - rng_key, 5) + rng_key_subsample, rng_key_model, rng_key_hmc_init, rng_key_potential, rng_key = random.split(rng_key, 5) + self._n = model_args[0].shape[0] self._u = random.randint(rng_key, (self.m,), 0, self._n) - - ld_fn = lambda args: jnp.sum(partial(log_density_hmcecs, self._model, model_args, {},prior=False)(args)[0]) - - self._ll_ref = ld_fn(z_ref) - self._jac_all, _ = ravel_pytree(jacfwd(ld_fn)(z_ref)) - hess_all, _ = ravel_pytree(hessian(ld_fn)(z_ref)) - - k, = self._jac_all.shape - self._hess_all = hess_all.reshape((k, k)) - # Initialize the model parameters - init_params, potential_fn, postprocess_fn, model_trace = initialize_model( - rng_key_init_model, - self._model, - init_strategy=partial(init_near_values, values=self.z_ref), - dynamic_args=True, - model_args=model_args_sub(self._u, model_args), - model_kwargs=model_kwargs) - - - - return init_params, potential_fn, postprocess_fn, model_trace + if self.proxy == "taylor": + ld_fn = lambda args: jnp.sum(partial(log_density_obs_hmcecs, self._model, model_args, model_kwargs)(args)[0]) + self._jac_all, _ = ravel_pytree(jacfwd(ld_fn)(z_ref)) + hess_all, _ = ravel_pytree(hessian(ld_fn)(z_ref)) + k, = self._jac_all.shape + self._hess_all = hess_all.reshape((k, k)) + ld_fn = lambda args: partial(log_density_obs_hmcecs,self._model,model_args,model_kwargs)(args)[0] + self._ll_ref = ld_fn(z_ref) def _init_state(self, rng_key, model_args, model_kwargs, init_params): if self.subsample_method is not None: assert self.z_ref is not None, "Please provide a (i.e map) estimate for the parameters" + self._init_subsample_state(rng_key, model_args, model_kwargs, init_params,self.z_ref) # Choose the covariate calculation method if self.proxy == "svi": @@ -600,50 +592,63 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): self.covariate_fn = lambda ll_ref, jac_all, hess_all:neural_proxy(ll_ref, jac_all, hess_all) else: warnings.warn("Using default second order Taylor expansion, change by using the proxy flag to {svi,neural}") - self.covariate_fn = lambda ll_ref, jac_all, hess_all:taylor_proxy(ll_ref, jac_all, hess_all) - + self._proxy_fn, self._proxy_u_fn = taylor_proxy(self._ll_ref, self._jac_all, self._hess_all) # Initialize the potential and gradient potential functions - self._potential_fn = lambda model, model_args, model_kwargs,z, z_ref, n, m, proxy, proxy_u,u : lambda z:potential_est(model=model, - model_args=model_args,model_kwargs=model_kwargs,z=z,z_ref=z_ref,n=n,m = m,proxy=proxy,proxy_u=proxy_u,u=u) + self._potential_fn = lambda model, model_args, model_kwargs,z, z_ref, n, m, proxy_fn, proxy_u_fn,u : lambda z:potential_est(model=model, + model_args=model_args,model_kwargs=model_kwargs,z=z,z_ref=z_ref,n=n,m = m,proxy_fn=proxy_fn,proxy_u_fn=proxy_u_fn,u=u) # Initialize the hmc sampler: sample_fn = sample_kernel self._init_fn, self._sample_fn = hmc(potential_fn_gen=self._potential_fn, kinetic_fn=euclidean_kinetic_energy, - algo=self._algo, - covariate_fn=self.covariate_fn) + algo=self._algo) + + + self._init_strategy = partial(init_near_values, values=self.z_ref) + # Initialize the model parameters + rng_key_init_model, rng_key = random.split(rng_key) - init_params, potential_fn, postprocess_fn, model_trace=self._init_subsample_state(rng_key, model_args, model_kwargs, init_params,self.z_ref) - if (self.g > self.m) or (self.g < 1): - raise ValueError( - 'Block size (g) = {} needs to = or > than 1 and smaller than the subsample size {}'.format(self.g, - self.m)) - elif (self.m > self._n): - raise ValueError( - 'Subsample size (m) = {} needs to = or < than data size (n) {}'.format(self.m, self._n)) - if self._model is not None: init_params, potential_fn, postprocess_fn, model_trace = initialize_model( - rng_key, + rng_key_init_model, self._model, + init_strategy=self._init_strategy, dynamic_args=True, - model_args=model_args, + model_args=model_args_sub(self._u, model_args), model_kwargs=model_kwargs) - if any(v['type'] == 'param' for v in model_trace.values()): - warnings.warn("'param' sites will be treated as constants during inference. To define " - "an improper variable, please use a 'sample' site with log probability " - "masked out. For example, `sample('x', dist.LogNormal(0, 1).mask(False)` " - "means that `x` has improper distribution over the positive domain.") - if self._init_fn is None: - self._init_fn, self._sample_fn = hmc(potential_fn_gen=potential_fn, + + + if (self.g > self.m) or (self.g < 1): + raise ValueError( + 'Block size (g) = {} needs to = or > than 1 and smaller than the subsample size {}'.format(self.g, + self.m)) + elif (self.m > self._n): + raise ValueError( + 'Subsample size (m) = {} needs to = or < than data size (n) {}'.format(self.m, self._n)) + else: + if self._model is not None: + init_params, potential_fn, postprocess_fn, model_trace = initialize_model( + rng_key, + self._model, + dynamic_args=True, + model_args=model_args, + model_kwargs=model_kwargs) + + if any(v['type'] == 'param' for v in model_trace.values()): + warnings.warn("'param' sites will be treated as constants during inference. To define " + "an improper variable, please use a 'sample' site with log probability " + "masked out. For example, `sample('x', dist.LogNormal(0, 1).mask(False)` " + "means that `x` has improper distribution over the positive domain.") + if self._init_fn is None: + self._init_fn, self._sample_fn = hmc(potential_fn_gen=potential_fn, + kinetic_fn=self._kinetic_fn, + algo=self._algo) + self._postprocess_fn = postprocess_fn + elif self._init_fn is None: + self._init_fn, self._sample_fn = hmc(potential_fn=self._potential_fn, kinetic_fn=self._kinetic_fn, algo=self._algo) - self._postprocess_fn = postprocess_fn - elif self._init_fn is None: - self._init_fn, self._sample_fn = hmc(potential_fn=self._potential_fn, - kinetic_fn=self._kinetic_fn, - algo=self._algo) return init_params @@ -681,41 +686,41 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params) #should work for all cases - if self._potential_fn and init_params is None: raise ValueError('Valid value of `init_params` must be provided with' ' `potential_fn`.') if self.subsample_method == "perturb": hmc_init_fn = lambda init_params,rng_key: self._init_fn(init_params=init_params, - num_warmup = num_warmup, - step_size = self._step_size, - adapt_step_size = self._adapt_step_size, - adapt_mass_matrix = self._adapt_mass_matrix, - dense_mass = self._dense_mass, - target_accept_prob = self._target_accept_prob, - trajectory_length=self._trajectory_length, - max_tree_depth=self._max_tree_depth, - find_heuristic_step_size=self._find_heuristic_step_size, - model_args=model_args_sub(self._u,model_args), - model_kwargs=model_kwargs, - subsample_method= self.subsample_method, - model=self._model, - ll_ref =self._ll_ref, - jac_all=self._jac_all, - z_ref=self.z_ref, - hess_all = self._hess_all, - ll_u = self._ll_u, - n=self._n, - m=self.m, - u = self._u, - covariate_fn = self.covariate_fn) + num_warmup = num_warmup, + step_size = self._step_size, + adapt_step_size = self._adapt_step_size, + adapt_mass_matrix = self._adapt_mass_matrix, + dense_mass = self._dense_mass, + target_accept_prob = self._target_accept_prob, + trajectory_length=self._trajectory_length, + max_tree_depth=self._max_tree_depth, + find_heuristic_step_size=self._find_heuristic_step_size, + model_args=model_args_sub(self._u,model_args), + model_kwargs=model_kwargs, + subsample_method= self.subsample_method, + model=self._model, + ll_ref =self._ll_ref, + jac_all=self._jac_all, + z_ref=self.z_ref, + hess_all = self._hess_all, + ll_u = self._ll_u, + n=self._n, + m=self.m, + u = self._u, + proxy_fn = self._proxy_fn, + proxy_u_fn = self._proxy_u_fn) if rng_key.ndim ==1: - - init_state = hmc_init_fn(init_params, rng_key) #HMCState + HMCECSState - - self._proxy, self._proxy_u = self.covariate_fn(self._ll_ref, self._jac_all, self._hess_all) + rng_key_hmc_init = jnp.array([1000966916, 171341646]) + init_state = hmc_init_fn(init_params, rng_key_hmc_init) #HMCState + HMCECSState + if self.proxy == "taylor": + self._proxy_fn, self._proxy_u_fn = taylor_proxy(self._ll_ref, self._jac_all, self._hess_all) self._ll_u = potential_est(model=self._model, model_args = model_args_sub(self._u, model_args), model_kwargs=model_kwargs, @@ -723,19 +728,15 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg z_ref=self.z_ref, n=self._n, m=self.m, - proxy=self._proxy, - proxy_u=self._proxy_u, + proxy_fn=self._proxy_fn, + proxy_u_fn=self._proxy_u_fn, u=self._u) - - - hmc_init_sub_fn = lambda init_params, rng_key: HMCECSState(u=self._u, - hmc_state=init_state.hmc_state, - z_ref=self.z_ref, - ll_u=self._ll_u) - - init_sub_state = hmc_init_sub_fn(init_params,rng_key) #HMCState - init_sub_state = tuplemerge(init_state._asdict(),init_sub_state._asdict()) + hmc_init_sub_state = HMCECSState(u=self._u, + hmc_state=init_state.hmc_state, + z_ref=self.z_ref, + ll_u=self._ll_u) + init_sub_state = tuplemerge(init_state._asdict(),hmc_init_sub_state._asdict()) return init_sub_state else: #TODO: What is this for? It does not go into it for num_chains>1 @@ -744,7 +745,6 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg # wa_steps because those variables do not depend on traced args: init_params, rng_key. init_state = vmap(hmc_init_fn)(init_params, rng_key) - self._proxy, self._proxy_u = self.covariate_fn(self._ll_ref, self._jac_all, self._hess_all) self._ll_u = potential_est(model=self._model, model_args=model_args_sub(self._u, model_args), model_kwargs=model_kwargs, @@ -752,8 +752,8 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg z_ref=self.z_ref, n=self._n, m=self.m, - proxy=self._proxy, - proxy_u=self._proxy_u, + proxy_fn=self._proxy_fn, + proxy_u_fn=self._proxy_u_fn, u=self._u) hmc_init_sub_fn = lambda init_params, rng_key: HMCECSState(u=self._u, hmc_state=init_state, z_ref=self.z_ref, ll_u=self._ll_u) @@ -816,28 +816,24 @@ def sample(self, state, model_args, model_kwargs): state.rng_key, 4) u_new = _update_block(rng_key_subsample, state.u, self._n, self.m, self.g) - # estimate likelihood of subsample with single block updated - self._proxy, self._proxy_u = self.covariate_fn(self._ll_ref, self._jac_all, self._hess_all) - llu_new = potential_est(model=self._model, model_args=model_args_sub(u_new,model_args), model_kwargs=model_kwargs, z=state.z, z_ref=self.z_ref, - proxy = self._proxy, - proxy_u = self._proxy_u, + proxy_fn = self._proxy_fn, + proxy_u_fn = self._proxy_u_fn, n=self._n, m=self.m,u=state.u) # accept new subsample with probability min(1,L^{hat}_{u_new}(z) - L^{hat}_{u}(z)) # NOTE: latent variables (z aka theta) same, subsample indices (u) different by one block. - accept_prob = jnp.clip(jnp.exp(-llu_new + state.ll_u), a_max=1.) transition = random.bernoulli(rng_key_transition, accept_prob) u, ll_u = cond(transition, (u_new, llu_new), identity, (state.u, state.ll_u), identity) - self._u = u #Just in case , but not necessary - self._ll_u = ll_u + print("ll_u") + print(ll_u) ######## UPDATE PARAMETERS ########## hmc_subsamplestate = HMCECSState(u=u, hmc_state=state.hmc_state, @@ -849,6 +845,8 @@ def sample(self, state, model_args, model_kwargs): model_args=model_args, model_kwargs=model_kwargs, subsample_method=self.subsample_method, + proxy_fn = self._proxy_fn, + proxy_u_fn = self._proxy_u_fn, model = self._model, ll_ref = self._ll_ref, jac_all =self._jac_all, diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py index ddc8d1a99..e72b1d8ea 100644 --- a/numpyro/contrib/hmcecs_utils.py +++ b/numpyro/contrib/hmcecs_utils.py @@ -32,12 +32,28 @@ def model_kwargs_sub(u, kwargs): if key_arg == "observations" or key_arg == "features": kwargs[key_arg] = jnp.take(val_arg, u, axis=0) return kwargs - -def log_density_hmcecs(model, model_args, model_kwargs, params,prior=False): +def log_density_obs_hmcecs(model, model_args, model_kwargs, params): + model = substitute(model, data=params) + model_trace = trace(model).get_trace(*model_args, **model_kwargs) + log_joint = jnp.array(0.) + for site in model_trace.values(): + if site['type'] == 'sample' and site['is_observed'] and not isinstance(site['fn'], dist.PRNGIdentity): + value = site['value'] + intermediates = site['intermediates'] + scale = site['scale'] + if intermediates: + log_prob = site['fn'].log_prob(value, intermediates) + else: + log_prob = site['fn'].log_prob(value) + if (scale is not None) and (not is_identically_one(scale)): + log_prob = scale * log_prob + log_joint += log_prob + + return log_prob, model_trace +def log_density_prior_hmcecs(model, model_args, model_kwargs, params): """ (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given - latent values ``params``. If prior == False, the log probability of the prior probability - over the parameters is not computed, solely the log probability of the observations + latent values ``params``. :param model: Python callable containing NumPyro primitives. :param tuple model_args: args provided to the model. @@ -46,42 +62,25 @@ def log_density_hmcecs(model, model_args, model_kwargs, params,prior=False): name. :return: log of joint density and a corresponding model trace """ - model = substitute(model, data=params) model_trace = trace(model).get_trace(*model_args, **model_kwargs) log_joint = jnp.array(0.) - if not prior: - for site in model_trace.values(): - if site['type'] == 'sample' and site['is_observed'] and not isinstance(site['fn'], dist.PRNGIdentity): - value = site['value'] - intermediates = site['intermediates'] - scale = site['scale'] - if intermediates: - log_prob = site['fn'].log_prob(value, intermediates) - else: - log_prob = site['fn'].log_prob(value) #TODO: The shape here is duplicated - - if (scale is not None) and (not is_identically_one(scale)): - log_prob = scale * log_prob - - return log_prob, model_trace - else: - for site in model_trace.values(): - if site['type'] == 'sample' and not isinstance(site['fn'], dist.PRNGIdentity) and not site['is_observed']: #Prior prob - value = site['value'] - intermediates = site['intermediates'] - scale = site['scale'] - if intermediates: - log_prob = site['fn'].log_prob(value, intermediates) - else: - log_prob = site['fn'].log_prob(value) - - if (scale is not None) and (not is_identically_one(scale)): - log_prob = scale * log_prob - - log_prob = jnp.sum(log_prob) - log_joint = log_joint + log_prob - return log_joint, model_trace + for site in model_trace.values(): + if site['type'] == 'sample' and not isinstance(site['fn'], dist.PRNGIdentity) and not site['is_observed']: + value = site['value'] + intermediates = site['intermediates'] + scale = site['scale'] + if intermediates: + log_prob = site['fn'].log_prob(value, intermediates) + else: + log_prob = site['fn'].log_prob(value) + + if (scale is not None) and (not is_identically_one(scale)): + log_prob = scale * log_prob + + log_prob = jnp.sum(log_prob) + log_joint = log_joint + log_prob + return log_joint, model_trace def reducer( accum, d ): @@ -96,21 +95,22 @@ def tuplemerge( *dictionaries ): return namedtuple('HMCCombinedState', merged )(**merged) # <==== Gist of the gist -def potential_est(model, model_args,model_kwargs, z, z_ref, n, m, proxy, proxy_u,u=None): - if any(arg.shape[0] > m for arg in model_args): - model_args = model_args_sub(u,model_args) - ll_sub, _ = log_density_hmcecs(model, model_args, model_kwargs, z,prior=False) # log likelihood for subsample with current theta +def potential_est(model, model_args,model_kwargs, z, z_ref, n, m, proxy_fn, proxy_u_fn,u=None): + #if any(arg.shape[0] > m for arg in model_args): + # model_args = model_args_sub(u,model_args) + ll_sub, _ = log_density_obs_hmcecs(model, model_args, model_kwargs, z) # log likelihood for subsample with current theta - diff = ll_sub - proxy_u(z, z_ref, model, model_args) - l_hat = proxy(z, z_ref) + n / m * jnp.sum(diff) + diff = ll_sub - proxy_u_fn(z, z_ref, model, model_args) + l_hat = proxy_fn(z, z_ref) + n / m * jnp.sum(diff) sigma = n ** 2 / m * jnp.var(diff) - ll_prior, _ = log_density_hmcecs(model, model_args, model_kwargs, z,prior=True) + ll_prior, _ = log_density_prior_hmcecs(model, model_args, model_kwargs, z) return (-l_hat + .5 * sigma) - ll_prior + def velocity_verlet_hmcecs(potential_fn, kinetic_fn, grad_potential_fn=None): r""" Second order symplectic integrator that uses the velocity verlet algorithm @@ -193,7 +193,7 @@ def proxy_u(z, z_ref, model, model_args): zref_flat, _ = ravel_pytree(z_ref) z_diff = z_flat - zref_flat - ld_fn = lambda args: jnp.sum(partial(log_density_hmcecs, model, model_args, {},prior=False)(args)[0]) + ld_fn = lambda args: jnp.sum(partial(log_density_obs_hmcecs, model, model_args, {})(args)[0]) ll_sub, jac_sub = jax.value_and_grad(ld_fn)(z_ref) k, = jac_all.shape diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 283b34d87..6553e7586 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -918,7 +918,7 @@ def sample(self, key, sample_shape=()): @validate_sample def log_prob(self, value): - normalize_term = jnp.log(jnp.sqrt(2 * jnp.pi) * self.scale) + normalize_term = jnp.log(jnp.sqrt(2 * jnp.pi) * self.scale) #TODO:Added jnp.abs value_scaled = (value - self.loc) / self.scale return -0.5 * value_scaled ** 2 - normalize_term From 1fc67401ee098f535c765c955abb7cabd92c0380 Mon Sep 17 00:00:00 2001 From: Lys Date: Wed, 7 Oct 2020 20:31:52 +0200 Subject: [PATCH 19/93] Working examples --- examples/logistic_hmcecs.py | 12 ++++++------ numpyro/contrib/hmcecs.py | 3 +-- numpyro/contrib/hmcecs_utils.py | 2 +- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index 767609d2d..5a368bae2 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -94,11 +94,11 @@ def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None, if subsample_method=="perturb": if map_method == "NUTS": print("Running NUTS for map estimation") - samples,r_hat_average = infer_nuts(map_key, feats, obs,samples=100,warmup=50) + samples,r_hat_average = infer_nuts(map_key, feats[:1000], obs[:1000],samples=500,warmup=250) z_map = {key: value.mean(0) for key, value in samples.items()} if map_method == "HMC": print("Running HMC for map estimation") - samples, r_hat_average = infer_hmc(map_key, feats, obs, samples=100, warmup=50) + samples, r_hat_average = infer_hmc(map_key, feats[:1000], obs[:1000], samples=50, warmup=250) z_map = {key: value.mean(0) for key, value in samples.items()} if map_method == "SVI": @@ -187,10 +187,10 @@ def Folders(folder_name): os.makedirs(newpath,0o777) def Tests(map_method,ecs_algo,algo,n_samples,n_warmup,epochs): - factor = 100 - m = int(np_jax.sqrt(obs[:factor].shape[0])*2) + factor_NUTS = 1000 + m = int(np_jax.sqrt(obs.shape[0])*2) g= 5 - est_posterior_ECS = infer_hmcecs(rng_key, feats=feats[:factor], obs=obs[:factor], + est_posterior_ECS = infer_hmcecs(rng_key, feats=feats, obs=obs, n_samples=n_samples, warmup=n_warmup, m =m,g=g, @@ -198,7 +198,7 @@ def Tests(map_method,ecs_algo,algo,n_samples,n_warmup,epochs): subsample_method="perturb", map_method = map_method, num_epochs=epochs) - est_posterior_NUTS = infer_hmcecs(rng_key, feats=feats[:factor], obs=obs[:factor], n_samples=n_samples,warmup=n_warmup,m =m,g=g,algo=algo) + est_posterior_NUTS = infer_hmcecs(rng_key, feats=feats[:factor_NUTS], obs=obs[:factor_NUTS], n_samples=n_samples,warmup=n_warmup,m =m,g=g,algo=algo) Plot(est_posterior_ECS,est_posterior_NUTS,ecs_algo,algo) diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index 41edf5007..0b04226d0 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -832,8 +832,7 @@ def sample(self, state, model_args, model_kwargs): u, ll_u = cond(transition, (u_new, llu_new), identity, (state.u, state.ll_u), identity) - print("ll_u") - print(ll_u) + ######## UPDATE PARAMETERS ########## hmc_subsamplestate = HMCECSState(u=u, hmc_state=state.hmc_state, diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py index e72b1d8ea..fc6f046be 100644 --- a/numpyro/contrib/hmcecs_utils.py +++ b/numpyro/contrib/hmcecs_utils.py @@ -47,7 +47,7 @@ def log_density_obs_hmcecs(model, model_args, model_kwargs, params): log_prob = site['fn'].log_prob(value) if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob - log_joint += log_prob + log_joint += log_prob #TODO: log_joint += jnp.sum(log_prob) ?---> gives a single number return log_prob, model_trace def log_density_prior_hmcecs(model, model_args, model_kwargs, params): From 3993afdee634193f0430445d8bc18ef437718d2d Mon Sep 17 00:00:00 2001 From: Lys Date: Mon, 19 Oct 2020 20:31:56 +0200 Subject: [PATCH 20/93] Maybe working --- examples/autoguide_hmcecs.py | 23 ++++- examples/logistic_hmcecs.py | 157 +++++++++++++++++++++++--------- examples/logistic_hmcecs_svi.py | 7 +- numpyro/contrib/hmcecs.py | 134 ++++++++++++++------------- numpyro/contrib/hmcecs_utils.py | 38 +++++--- numpyro/distributions/kl.py | 4 + numpyro/distributions/util.py | 2 +- 7 files changed, 243 insertions(+), 122 deletions(-) diff --git a/examples/autoguide_hmcecs.py b/examples/autoguide_hmcecs.py index 518badd91..cf3805ae8 100644 --- a/examples/autoguide_hmcecs.py +++ b/examples/autoguide_hmcecs.py @@ -160,7 +160,8 @@ def __init__(self, model, prefix="auto", init_strategy=init_to_uniform): super(AutoContinuous, self).__init__(model, prefix=prefix) def _setup_prototype(self, *args, **kwargs): - rng_key = numpyro.rng_key("_{}_rng_key_setup".format(self.prefix)) + rng_key = random.PRNGKey(0) + #rng_key = numpyro.rng_key("_{}_rng_key_setup".format(self.prefix)) with handlers.block(): init_params, _, self._postprocess_fn, self.prototype_trace = initialize_model( rng_key, self.model, @@ -187,6 +188,26 @@ def _sample_latent(self, *args, **kwargs): posterior = self._get_posterior() return numpyro.sample("_{}_latent".format(self.prefix), posterior, sample_shape=sample_shape) + def expectation(self, latent): + """Computes the expectation/probabilities of the parameters of the guide. The expectation over the variance over the latent space is bounded + using the reparametrization trick""" + if self.prototype_trace is None: + raise ValueError() # TODO: fix value error + + result = {} + for name, unconstrained_value in latent.items(): + site = self.prototype_trace[name] + transform = biject_to(site['fn'].support) + + value = transform(unconstrained_value) + log_density = - transform.log_abs_det_jacobian(unconstrained_value, value) + event_ndim = len(site['fn'].event_shape) + log_density = sum_rightmost(log_density, + jnp.ndim(log_density) - jnp.ndim(value) + event_ndim) + prob = jnp.exp(log_density) + result[name] = prob * value + return result + def __call__(self, *args, **kwargs): """ An automatic guide with the same ``*args, **kwargs`` as the base ``model``. diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index 5a368bae2..3f1bd5670 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -9,6 +9,9 @@ from jax.config import config import datetime,time import argparse +import numpy as np +from numpyro.distributions.kl import kl_divergence +from matplotlib.pyplot import cm sys.path.append('/home/lys/Dropbox/PhD/numpyro/numpyro/contrib/') sys.path.append('/home/lys/Dropbox/PhD/numpyro/numpyro/examples/') @@ -47,6 +50,13 @@ def save_obj(obj, name): import bz2 with bz2.BZ2File(name, "wb") as f: cPickle.dump(obj, f) +def load_obj(name): + import _pickle as cPickle + import bz2 + data = bz2.BZ2File(name, "rb") + data = cPickle.load(data) + + return data def model(feats, obs): """ Logistic regression model @@ -86,38 +96,60 @@ def infer_hmc(rng_key, feats, obs, samples, warmup ): -def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None,algo="NUTS",subsample_method=None,map_method=None,num_epochs=None ): +def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None,algo="NUTS",subsample_method=None,map_method=None,proxy="taylor",num_epochs=None ): hmcecs_key, map_key = jax.random.split(rng_key) n, _ = feats.shape - print("Using {} samples".format(str(n_samples+warmup))) - - if subsample_method=="perturb": + file_hyperparams = open("PLOTS_{}/Hyperparameters_{}.txt".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms"),now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")), "a") + if subsample_method=="perturb" and proxy== "taylor": + map_samples = 100 + map_warmup = 50 + factor_NUTS = 1000 if map_method == "NUTS": - print("Running NUTS for map estimation") - samples,r_hat_average = infer_nuts(map_key, feats[:1000], obs[:1000],samples=500,warmup=250) - z_map = {key: value.mean(0) for key, value in samples.items()} + print("Running NUTS for map estimation {} + {} samples".format(map_samples,map_warmup)) + file_hyperparams.write('MAP samples : {} \n'.format(map_samples)) + file_hyperparams.write('MAP warmup : {} \n'.format(map_warmup)) + samples,r_hat_average = infer_nuts(map_key, feats[:factor_NUTS], obs[:factor_NUTS],samples=map_samples,warmup=map_warmup) + z_ref = {key: value.mean(0) for key, value in samples.items()} if map_method == "HMC": print("Running HMC for map estimation") - samples, r_hat_average = infer_hmc(map_key, feats[:1000], obs[:1000], samples=50, warmup=250) - z_map = {key: value.mean(0) for key, value in samples.items()} + file_hyperparams.write('MAP samples : {} \n'.format(map_samples)) + file_hyperparams.write('MAP warmup : {} \n'.format(map_warmup)) + samples, r_hat_average = infer_hmc(map_key, feats[:factor_NUTS], obs[:factor_NUTS], samples=map_samples, warmup=map_warmup) + z_ref = {key: value.mean(0) for key, value in samples.items()} if map_method == "SVI": print("Running SVI for map estimation") - z_map = svi_map(model, map_key, feats=feats, obs=obs,num_epochs=num_epochs,batch_size = m) - z_map = {k[5:]: v for k, v in z_map.items()} #highlight: [5:] is to skip the "auto" part - save_obj(z_map,"{}/MAP_Dict_Samples_{}.pkl".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")), map_method)) - print("Running MCMC subsampling") + file_hyperparams.write('SVI epochs : {} \n'.format(num_epochs)) + z_ref = svi_map(model, map_key, feats=feats, obs=obs,num_epochs=num_epochs,batch_size = m) + z_ref = {k[5:]: v for k, v in z_ref.items()} #highlight: [5:] is to skip the "auto" part + svi = None + save_obj(z_ref,"{}/MAP_Dict_Samples_MAP_{}.pkl".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")), map_method)) + print("Running MCMC subsampling with Taylor proxy") + elif subsample_method =="perturb" and proxy=="svi": + factor_SVI = obs.shape[0] + batch_size = int(factor_SVI//10) + print("Running SVI for map estimation") + file_hyperparams.write('SVI epochs : {} \n'.format(num_epochs)) + map_key, post_key = jax.random.split(map_key) + z_ref, svi, svi_state = svi_map(model, map_key, feats=feats[:factor_SVI], obs=obs[:factor_SVI], + num_epochs=num_epochs, batch_size=batch_size) + z_ref = svi.guide.sample_posterior(post_key, svi.get_params(svi_state), (100,)) + z_ref = {name: value.mean(0) for name, value in z_ref.items()} #highlight: AutoDiagonalNormal does not have auto_ in front of the parmeters + + save_obj(z_ref,"{}/MAP_Dict_Samples_Proxy_{}.pkl".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")), + proxy)) + print("Running MCMC subsampling with SVI proxy") else: - z_map = None + z_ref = None + svi = None start = time.time() - kernel = HMC(model=model,z_ref=z_map,m=m,g=g,algo=algo,subsample_method=subsample_method,target_accept_prob=0.8) + kernel = HMC(model=model,z_ref=z_ref,m=m,g=g,algo=algo,subsample_method=subsample_method,proxy=proxy,svi_fn=svi,target_accept_prob=0.8) mcmc = MCMC(kernel,num_warmup=warmup,num_samples=n_samples,num_chains=1) mcmc.run(rng_key,feats,obs) stop = time.time() - file_hyperparams = open("PLOTS_{}/Hyperparameters_{}.txt".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms"),now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")), "a") file_hyperparams.write('MCMC/NUTS elapsed time {}: {} \n'.format(subsample_method,time.time() - start)) file_hyperparams.write('Effective size {}: {}\n'.format(subsample_method,n_samples)) file_hyperparams.write('Warm up size {}: {}\n'.format(subsample_method,warmup)) @@ -127,7 +159,7 @@ def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None, file_hyperparams.write('...........................................\n') file_hyperparams.close() - save_obj(mcmc.get_samples(),"{}/MCMC_Dict_Samples_{}.pkl".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")),subsample_method)) + save_obj(mcmc.get_samples(),"{}/MCMC_Dict_Samples_{}_m_{}.pkl".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")),subsample_method,m)) return mcmc.get_samples() @@ -146,26 +178,20 @@ def Determine_best_sample_size(rng_key,feats,obs): plt.ylabel(r"$\hat{r}$") plt.title("Determine best effective sample size for z_map") plt.savefig("{}/Best_effective_size_z_map.png".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")))) - - -def Plot(samples_ECS,samples_NUTS,ecs_algo,algo): - - +def Plot(samples_ECS,samples_NUTS,ecs_algo,algo,proxy,m,kl=None): for sample in [0,7,15,25]: - plt.figure(sample) - + plt.figure(sample + m +3) #samples = pd.DataFrame.from_records(samples,index="theta") - sns.kdeplot(data=samples_ECS["theta"][sample],color="r",label="ECS-{}".format(ecs_algo)) + sns.kdeplot(data=samples_ECS["theta"][sample],color="r",label="ECS-{}-{} proxy".format(ecs_algo,proxy)) sns.kdeplot(data=samples_NUTS["theta"][sample],color="b",label="{}".format(algo)) - + #if kl != None: + # sns.kdeplot(data=kl, color="g", label="KL; m: {}".format(m)) plt.xlabel(r"$\theta") plt.ylabel("Density") plt.legend() - plt.title(r"$\theta$ {} Density plot".format(sample)) - plt.savefig("{}/KDE_plot_theta_{}.png".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")),sample)) - - - + plt.title(r"$\theta$ {} m: {} Density plot".format(sample,str(m))) + plt.savefig("{}/KDE_plot_theta_{}_m_{}.png".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")),sample,str(m))) + plt.clf() def Folders(folder_name): """ Folder for all the generated images It will updated everytime!!! Save the previous folder before running again. Creates folder in current directory""" import os @@ -185,9 +211,29 @@ def Folders(folder_name): else: shutil.rmtree(newpath) # removes all the subdirectories! os.makedirs(newpath,0o777) - -def Tests(map_method,ecs_algo,algo,n_samples,n_warmup,epochs): +def Plot_KL(map_method,ecs_algo,algo,n_samples,n_warmup,epochs): + factor_ECS=obs.shape[0] + m = [int(np_jax.sqrt(obs[:factor_ECS].shape[0])),2*int(np_jax.sqrt(obs[:factor_ECS].shape[0])),4*int(np_jax.sqrt(obs[:factor_ECS].shape[0])),8*int(np_jax.sqrt(obs[:factor_ECS].shape[0]))] + g = 5 factor_NUTS = 1000 + colors = cm.rainbow(np.linspace(0, 1, len(m))) + est_posterior_NUTS = infer_hmcecs(rng_key, feats=feats[:factor_NUTS], obs=obs[:factor_NUTS], + n_samples=n_samples, warmup=n_warmup, m="all", g=g, algo=algo) + for m_val, color in zip(m,colors): + est_posterior_ECS = infer_hmcecs(rng_key, feats=feats[:factor_ECS], obs=obs[:factor_ECS], + n_samples=n_samples, + warmup=n_warmup, + m=m_val, g=g, + algo=ecs_algo, + subsample_method="perturb", + map_method=map_method, + num_epochs=epochs) + + p = dist.Normal(est_posterior_ECS["theta"]) + q = dist.Normal(est_posterior_NUTS["theta"]) + kl = kl_divergence(p, q) + Plot(est_posterior_ECS, est_posterior_NUTS, ecs_algo, algo,m_val,kl=kl ) +def Tests(map_method,ecs_algo,algo,n_samples,n_warmup,epochs,proxy): m = int(np_jax.sqrt(obs.shape[0])*2) g= 5 est_posterior_ECS = infer_hmcecs(rng_key, feats=feats, obs=obs, @@ -196,27 +242,31 @@ def Tests(map_method,ecs_algo,algo,n_samples,n_warmup,epochs): m =m,g=g, algo=ecs_algo, subsample_method="perturb", + proxy=proxy, map_method = map_method, num_epochs=epochs) - est_posterior_NUTS = infer_hmcecs(rng_key, feats=feats[:factor_NUTS], obs=obs[:factor_NUTS], n_samples=n_samples,warmup=n_warmup,m =m,g=g,algo=algo) + est_posterior_NUTS = infer_hmcecs(rng_key, feats=feats, obs=obs, n_samples=n_samples,warmup=n_warmup,m =m,g=g,algo=algo) - Plot(est_posterior_ECS,est_posterior_NUTS,ecs_algo,algo) + Plot(est_posterior_ECS,est_posterior_NUTS,ecs_algo,algo,proxy,m) if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-num_samples', nargs='?', default=100, type=int) - parser.add_argument('-num_warmup', nargs='?', default=50, type=int) + parser.add_argument('-num_samples', nargs='?', default=50, type=int) + parser.add_argument('-num_warmup', nargs='?', default=5, type=int) parser.add_argument('-ecs_algo', nargs='?', default="NUTS", type=str) + parser.add_argument('-ecs_proxy', nargs='?', default="svi", type=str) parser.add_argument('-algo', nargs='?', default="HMC", type=str) parser.add_argument('-map_init', nargs='?', default="NUTS", type=str) - parser.add_argument("-epochs",default=100,type=int) + parser.add_argument("-epochs",default=2,type=int) args = parser.parse_args() rng_key = jax.random.PRNGKey(37) rng_key, feat_key, obs_key = jax.random.split(rng_key, 3) + if args.ecs_proxy == "svi": + args.map_init = "SVI" now = datetime.datetime.now() Folders("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms"))) @@ -225,10 +275,8 @@ def Tests(map_method,ecs_algo,algo,n_samples,n_warmup,epochs): file_hyperparams.write('ECS algo : {} \n'.format(args.ecs_algo)) file_hyperparams.write('algo : {} \n'.format(args.algo)) file_hyperparams.write('MAP init : {} \n'.format(args.map_init)) - file_hyperparams.write('SVI epochs : {} \n'.format(args.epochs)) - - higgs = True + higgs = False if higgs: feats,obs = higgs_data() file_hyperparams.write('Dataset : HIGGS \n') @@ -241,6 +289,31 @@ def Tests(map_method,ecs_algo,algo,n_samples,n_warmup,epochs): config.update('jax_disable_jit', True) #Determine_best_sample_size(rng_key,feats[:100],obs[:100]) - Tests(args.map_init,args.ecs_algo,args.algo,args.num_samples,args.num_warmup,args.epochs) + Tests(args.map_init,args.ecs_algo,args.algo,args.num_samples,args.num_warmup,args.epochs,args.ecs_proxy) + #Plot_KL(args.map_init,args.ecs_algo,args.algo,args.num_samples,args.num_warmup,args.epochs) + + + exit() + samples_ECS_3316 = load_obj("/home/lys/Dropbox/PhD/numpyro/examples/PLOTS_2020_10_09_11h41min24s333577ms_DONOTREMOVE/MCMC_Dict_Samples_perturb_m_3316.pkl") + samples_ECS_6632 = load_obj("/home/lys/Dropbox/PhD/numpyro/examples/PLOTS_2020_10_09_11h41min24s333577ms_DONOTREMOVE/MCMC_Dict_Samples_perturb_m_6632.pkl") + samples_ECS_132264 = load_obj("/home/lys/Dropbox/PhD/numpyro/examples/PLOTS_2020_10_09_11h41min24s333577ms_DONOTREMOVE/MCMC_Dict_Samples_perturb_m_13264.pkl") + + samples_HMC = load_obj("/home/lys/Dropbox/PhD/numpyro/examples/PLOTS_2020_10_09_11h41min24s333577ms_DONOTREMOVE/MCMC_Dict_Samples_None_m_all.pkl") + + p = dist.Normal(samples_ECS_3316["theta"]) + q = dist.Normal(samples_HMC["theta"]) + kl = kl_divergence(p, q) + print(np_jax.average(kl)) + Plot(samples_ECS_3316, samples_HMC, args.ecs_algo, args.algo, 3316,args.ecs_proxy ,kl=kl) + + # samples = pd.DataFrame.from_records(samples,index="theta") + # sns.kdeplot(data=kl, color=color, label="m : ".format(m_val)) + # plt.figure(m_val) + # plt.xlabel(r"$\theta") + # plt.ylabel("Density") + # plt.legend() + # plt.title(r"$\theta$ KL-divergence") + # plt.savefig("{}/KL_divergence_m_{}.png".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")), + # str(m_val))) diff --git a/examples/logistic_hmcecs_svi.py b/examples/logistic_hmcecs_svi.py index ebe7ad35e..b95f47d85 100644 --- a/examples/logistic_hmcecs_svi.py +++ b/examples/logistic_hmcecs_svi.py @@ -33,9 +33,10 @@ def svi_map(model, rng_key, feats,obs,num_epochs,batch_size): import numpyro numpyro.set_platform("gpu") - from autoguide_hmcecs import AutoDelta + from autoguide_hmcecs import AutoDelta, AutoDiagonalNormal n, _ = feats.shape - guide = AutoDelta(model) + #guide = AutoDelta(model) + guide = AutoDiagonalNormal(model) loss = RenyiELBO(alpha=2, num_particles=1) svi = SVI(model, guide, optim.Adam(0.001), loss=loss) svi_state = svi.init( rng_key,feats,obs) @@ -58,4 +59,4 @@ def body_fn(i, val): t_start = time.time() train_loss, svi_state = epoch_train(svi_state) print("Epoch {}: loss = {} ({:.2f} s.)".format(i, train_loss, time.time() - t_start)) - return svi.get_params(svi_state) \ No newline at end of file + return svi.get_params(svi_state), svi, svi_state \ No newline at end of file diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index 0b04226d0..d8fc290db 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -31,7 +31,7 @@ 'mean_accept_prob', 'diverging', 'adapt_state','rng_key']) #HMCECSState = namedtuple("HMCECState",["u","hmc_state","z_ref","ll_ref","jac_all","hess_all","ll_u"]) -HMCECSState = namedtuple("HMCECState",['u', 'hmc_state', 'z_ref', 'll_u']) +HMCECSState = namedtuple("HMCECState",['u', 'hmc_state', 'll_u']) """ A :func:`~collections.namedtuple` consisting of the following fields: @@ -276,8 +276,8 @@ def init_kernel(init_params, else: if subsample_method == "perturb": kwargs = {} if model_kwargs is None else model_kwargs - - pe_fn = potential_fn_gen(model, model_args,model_kwargs, z, z_ref, n, m, proxy_fn, proxy_u_fn,u) + pe_fn = potential_fn_gen(model=model, model_args=model_args, model_kwargs=kwargs, z=z, n=n, m=m, + proxy_fn=proxy_fn, proxy_u_fn=proxy_u_fn) else: kwargs = {} if model_kwargs is None else model_kwargs pe_fn = potential_fn_gen(*model_args, **kwargs) @@ -318,9 +318,7 @@ def init_kernel(init_params, hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, 0, 0., 0., False, wa_state,rng_key_hmc) - hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state, - z_ref=z_ref, - ll_u=ll_u) + hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state,ll_u=ll_u) hmc_state = tuplemerge(hmc_sub_state._asdict(),hmc_state._asdict()) @@ -340,7 +338,15 @@ def _hmc_next(step_size, inverse_mass_matrix, vv_state, else: if subsample_method == "perturb": - pe_fn = potential_fn_gen(model, model_args, model_kwargs,vv_state.z, z_ref, n, m, proxy_fn, proxy_u_fn, u) + #pe_fn = potential_fn_gen(model, model_args, model_kwargs,vv_state.z, z_ref, n, m, proxy_fn, proxy_u_fn, u) + pe_fn = potential_fn_gen(model=model, + model_args=model_args, + model_kwargs=model_kwargs, + z=vv_state.z, + n=n, + m=m, + proxy_fn=proxy_fn, + proxy_u_fn=proxy_u_fn) kwargs = {} if model_kwargs is None else model_kwargs else: pe_fn = potential_fn_gen(*model_args, **model_kwargs) @@ -377,8 +383,15 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, pe_fn = potential_fn_gen(*model_args, **model_kwargs) else: if subsample_method == "perturb": - pe_fn = potential_fn_gen(model, model_args, model_kwargs, vv_state.z, z_ref, n, m, proxy_fn, - proxy_u_fn, u) + #pe_fn = potential_fn_gen(model, model_args, model_kwargs, vv_state.z, z_ref, n, m, proxy_fn,proxy_u_fn, u) + pe_fn = potential_fn_gen(model=model, + model_args=model_args, + model_kwargs=model_kwargs, + z=vv_state.z, + n=n, + m=m, + proxy_fn=proxy_fn, + proxy_u_fn=proxy_u_fn) else: pe_fn = potential_fn_gen(*model_args, **model_kwargs) _, vv_update = velocity_verlet(pe_fn, kinetic_fn) @@ -401,7 +414,7 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, subsample_method=None,proxy_fn=None,proxy_u_fn=None, model=None,ll_ref=None,jac_all=None, z=None,z_ref=None,hess_all=None,ll_u=None, - u=None,n=None,m=None,): + u=None,n=None,m=None,): #TODO: Remove so many args """ Given an existing :data:`~numpyro.infer.mcmc.HMCState`, run HMC with fixed (possibly adapted) step size and return a new :data:`~numpyro.infer.mcmc.HMCState`. @@ -444,7 +457,6 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, hmcstate = HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps, accept_prob, mean_accept_prob, diverging, adapt_state,rng_key) hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state, - z_ref=z_ref, ll_u=ll_u) hmcstate = tuplemerge(hmc_sub_state._asdict(),hmcstate._asdict()) return hmcstate @@ -521,6 +533,8 @@ def __init__(self, init_strategy=init_to_uniform, find_heuristic_step_size=False, subsample_method = None, + proxy="taylor", + svi_fn=None, m= None, g = None, z_ref= None, @@ -559,47 +573,43 @@ def __init__(self, self._postprocess_fn = None self._sample_fn = None self._subsample_fn = None - self.proxy = "taylor" + self.proxy = proxy + self.svi_fn = svi_fn self._proxy_fn = None self._proxy_u_fn = None def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ref): - "Compute the jacobian, hessian and log likelihood for all the data" + "Compute the jacobian, hessian and log likelihood for all the data. Used with taylor expansion proxy" rng_key_subsample, rng_key_model, rng_key_hmc_init, rng_key_potential, rng_key = random.split(rng_key, 5) - self._n = model_args[0].shape[0] - self._u = random.randint(rng_key, (self.m,), 0, self._n) - if self.proxy == "taylor": - ld_fn = lambda args: jnp.sum(partial(log_density_obs_hmcecs, self._model, model_args, model_kwargs)(args)[0]) - self._jac_all, _ = ravel_pytree(jacfwd(ld_fn)(z_ref)) - hess_all, _ = ravel_pytree(hessian(ld_fn)(z_ref)) - k, = self._jac_all.shape - self._hess_all = hess_all.reshape((k, k)) - ld_fn = lambda args: partial(log_density_obs_hmcecs,self._model,model_args,model_kwargs)(args)[0] - self._ll_ref = ld_fn(z_ref) + ld_fn = lambda args: jnp.sum(partial(log_density_obs_hmcecs, self._model, model_args, model_kwargs)(args)[0]) + self._jac_all, _ = ravel_pytree(jacfwd(ld_fn)(z_ref)) + hess_all, _ = ravel_pytree(hessian(ld_fn)(z_ref)) + k, = self._jac_all.shape + self._hess_all = hess_all.reshape((k, k)) + ld_fn = lambda args: partial(log_density_obs_hmcecs,self._model,model_args,model_kwargs)(args)[0] + self._ll_ref = ld_fn(z_ref) def _init_state(self, rng_key, model_args, model_kwargs, init_params): if self.subsample_method is not None: assert self.z_ref is not None, "Please provide a (i.e map) estimate for the parameters" - - self._init_subsample_state(rng_key, model_args, model_kwargs, init_params,self.z_ref) - + self._n = model_args[0].shape[0] + self._u = random.randint(rng_key, (self.m,), 0, self._n) # Choose the covariate calculation method if self.proxy == "svi": - self.covariate_fn = lambda ll_ref, jac_all, hess_all:svi_proxy(ll_ref, jac_all, hess_all) + self._proxy_fn,self._proxy_u_fn = svi_proxy(self.svi_fn,model_args,model_kwargs) + elif self.proxy == "neural": - self.covariate_fn = lambda ll_ref, jac_all, hess_all:neural_proxy(ll_ref, jac_all, hess_all) - else: + raise ValueError("Not implemented") + elif self.proxy == "taylor": warnings.warn("Using default second order Taylor expansion, change by using the proxy flag to {svi,neural}") - self._proxy_fn, self._proxy_u_fn = taylor_proxy(self._ll_ref, self._jac_all, self._hess_all) - + self._init_subsample_state(rng_key, model_args, model_kwargs, init_params, self.z_ref) + self._proxy_fn,self._proxy_u_fn = taylor_proxy(self.z_ref, self._model, self._ll_ref, self._jac_all, self._hess_all) # Initialize the potential and gradient potential functions - self._potential_fn = lambda model, model_args, model_kwargs,z, z_ref, n, m, proxy_fn, proxy_u_fn,u : lambda z:potential_est(model=model, - model_args=model_args,model_kwargs=model_kwargs,z=z,z_ref=z_ref,n=n,m = m,proxy_fn=proxy_fn,proxy_u_fn=proxy_u_fn,u=u) - - + self._potential_fn = lambda model, model_args, model_kwargs, z, n, m, proxy_fn, proxy_u_fn : lambda z:potential_est(model=model, + model_args=model_args, model_kwargs=model_kwargs, z=z, n=n, m=m, proxy_fn=proxy_fn, proxy_u_fn=proxy_u_fn) # Initialize the hmc sampler: sample_fn = sample_kernel self._init_fn, self._sample_fn = hmc(potential_fn_gen=self._potential_fn, kinetic_fn=euclidean_kinetic_energy, @@ -618,7 +628,6 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): model_args=model_args_sub(self._u, model_args), model_kwargs=model_kwargs) - if (self.g > self.m) or (self.g < 1): raise ValueError( 'Block size (g) = {} needs to = or > than 1 and smaller than the subsample size {}'.format(self.g, @@ -717,24 +726,25 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg proxy_u_fn = self._proxy_u_fn) if rng_key.ndim ==1: - rng_key_hmc_init = jnp.array([1000966916, 171341646]) + #rng_key_hmc_init = jnp.array([1000966916, 171341646]) + rng_key_hmc_init,_ = random.split(rng_key) + init_state = hmc_init_fn(init_params, rng_key_hmc_init) #HMCState + HMCECSState if self.proxy == "taylor": - self._proxy_fn, self._proxy_u_fn = taylor_proxy(self._ll_ref, self._jac_all, self._hess_all) + self._proxy_fn,self._proxy_u_fn = taylor_proxy(self.z_ref, self._model, self._ll_ref, self._jac_all, self._hess_all) + elif self.proxy == "svi": + self._proxy_fn, self._proxy_u_fn = svi_proxy(self.svi_fn, model_args, model_kwargs) + self._ll_u = potential_est(model=self._model, - model_args = model_args_sub(self._u, model_args), + model_args=model_args_sub(self._u, model_args), model_kwargs=model_kwargs, z=init_state.z, - z_ref=self.z_ref, n=self._n, m=self.m, proxy_fn=self._proxy_fn, - proxy_u_fn=self._proxy_u_fn, - u=self._u) - + proxy_u_fn=self._proxy_u_fn) hmc_init_sub_state = HMCECSState(u=self._u, hmc_state=init_state.hmc_state, - z_ref=self.z_ref, ll_u=self._ll_u) init_sub_state = tuplemerge(init_state._asdict(),hmc_init_sub_state._asdict()) @@ -744,27 +754,22 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth, # wa_steps because those variables do not depend on traced args: init_params, rng_key. init_state = vmap(hmc_init_fn)(init_params, rng_key) - self._ll_u = potential_est(model=self._model, model_args=model_args_sub(self._u, model_args), model_kwargs=model_kwargs, z=init_state.z, - z_ref=self.z_ref, n=self._n, m=self.m, proxy_fn=self._proxy_fn, - proxy_u_fn=self._proxy_u_fn, - u=self._u) + proxy_u_fn=self._proxy_u_fn) - hmc_init_sub_fn = lambda init_params, rng_key: HMCECSState(u=self._u, hmc_state=init_state, z_ref=self.z_ref, ll_u=self._ll_u) + hmc_init_sub_fn = lambda init_params, rng_key: HMCECSState(u=self._u, hmc_state=init_state, ll_u=self._ll_u) init_subsample_state = vmap(hmc_init_sub_fn)(init_params,rng_key) sample_fn = vmap(self._sample_fn, in_axes=(0, None, None)) - subsample_fn = vmap(self._subsample_fn, in_axes=(0,None,None)) HMCCombinedState = tuplemerge(init_state._asdict,init_subsample_state._asdict()) self._sample_fn = sample_fn - self._subsample_fn = subsample_fn return HMCCombinedState else: @@ -817,14 +822,22 @@ def sample(self, state, model_args, model_kwargs): u_new = _update_block(rng_key_subsample, state.u, self._n, self.m, self.g) # estimate likelihood of subsample with single block updated + # llu_new = potential_est(model=self._model, + # model_args=model_args_sub(u_new,model_args), + # model_kwargs=model_kwargs, + # z=state.z, + # z_ref=self.z_ref, + # proxy_fn = self._proxy_fn, + # proxy_u_fn = self._proxy_u_fn, + # n=self._n, m=self.m,u=state.u) llu_new = potential_est(model=self._model, - model_args=model_args_sub(u_new,model_args), - model_kwargs=model_kwargs, - z=state.z, - z_ref=self.z_ref, - proxy_fn = self._proxy_fn, - proxy_u_fn = self._proxy_u_fn, - n=self._n, m=self.m,u=state.u) + model_args=model_args_sub(u_new,model_args), + model_kwargs=model_kwargs, + z=state.z, + n=self._n, + m=self.m, + proxy_fn=self._proxy_fn, + proxy_u_fn=self._proxy_u_fn) # accept new subsample with probability min(1,L^{hat}_{u_new}(z) - L^{hat}_{u}(z)) # NOTE: latent variables (z aka theta) same, subsample indices (u) different by one block. accept_prob = jnp.clip(jnp.exp(-llu_new + state.ll_u), a_max=1.) @@ -835,8 +848,7 @@ def sample(self, state, model_args, model_kwargs): ######## UPDATE PARAMETERS ########## - hmc_subsamplestate = HMCECSState(u=u, hmc_state=state.hmc_state, - ll_u=ll_u,z_ref=self.z_ref) + hmc_subsamplestate = HMCECSState(u=u, hmc_state=state.hmc_state,ll_u=ll_u) hmc_subsamplestate = tuplemerge(hmc_subsamplestate._asdict(),state._asdict()) @@ -850,7 +862,7 @@ def sample(self, state, model_args, model_kwargs): ll_ref = self._ll_ref, jac_all =self._jac_all, z= state.z, - z_ref = self.z_ref, + z_ref = self.z_ref, #TODO: Not necessary , remove(z_ref, hess_all, jac_all,ll_ref) hess_all = self._hess_all, ll_u = ll_u, u= u, diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py index fc6f046be..c3b0c6c3d 100644 --- a/numpyro/contrib/hmcecs_utils.py +++ b/numpyro/contrib/hmcecs_utils.py @@ -47,9 +47,10 @@ def log_density_obs_hmcecs(model, model_args, model_kwargs, params): log_prob = site['fn'].log_prob(value) if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob - log_joint += log_prob #TODO: log_joint += jnp.sum(log_prob) ?---> gives a single number + #log_joint += log_prob #TODO: log_joint += jnp.sum(log_prob) ?---> gives a single number + log_joint = log_joint + jnp.sum(log_prob) - return log_prob, model_trace + return log_joint, model_trace def log_density_prior_hmcecs(model, model_args, model_kwargs, params): """ (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given @@ -95,13 +96,11 @@ def tuplemerge( *dictionaries ): return namedtuple('HMCCombinedState', merged )(**merged) # <==== Gist of the gist -def potential_est(model, model_args,model_kwargs, z, z_ref, n, m, proxy_fn, proxy_u_fn,u=None): - #if any(arg.shape[0] > m for arg in model_args): - # model_args = model_args_sub(u,model_args) - ll_sub, _ = log_density_obs_hmcecs(model, model_args, model_kwargs, z) # log likelihood for subsample with current theta +def potential_est(model, model_args, model_kwargs, z, n, m, proxy_fn, proxy_u_fn): + ll_sub, _ = log_density_obs_hmcecs(model, model_args, {}, z) # log likelihood for subsample with current theta - diff = ll_sub - proxy_u_fn(z, z_ref, model, model_args) - l_hat = proxy_fn(z, z_ref) + n / m * jnp.sum(diff) + diff = ll_sub - proxy_u_fn(z=z, model_args=model_args, model_kwargs=model_kwargs) + l_hat = proxy_fn(z) + n / m * diff sigma = n ** 2 / m * jnp.var(diff) @@ -181,19 +180,19 @@ def init_near_values(site=None, values={}): except: return init_to_uniform(site) -def taylor_proxy(ll_ref, jac_all, hess_all): - def proxy(z, z_ref): +def taylor_proxy(z_ref, model, ll_ref, jac_all, hess_all): + def proxy(z, *args, **kwargs): z_flat, _ = ravel_pytree(z) zref_flat, _ = ravel_pytree(z_ref) z_diff = z_flat - zref_flat return jnp.sum(ll_ref) + jac_all.T @ z_diff + .5 * z_diff.T @ hess_all @ z_diff - def proxy_u(z, z_ref, model, model_args): + def proxy_u(z, model_args, model_kwargs, *args, **kwargs): z_flat, _ = ravel_pytree(z) zref_flat, _ = ravel_pytree(z_ref) z_diff = z_flat - zref_flat - ld_fn = lambda args: jnp.sum(partial(log_density_obs_hmcecs, model, model_args, {})(args)[0]) + ld_fn = lambda args: jnp.sum(partial(log_density_obs_hmcecs, model, model_args, model_kwargs)(args)[0]) ll_sub, jac_sub = jax.value_and_grad(ld_fn)(z_ref) k, = jac_all.shape @@ -204,8 +203,19 @@ def proxy_u(z, z_ref, model, model_args): return proxy, proxy_u -def svi_proxy(): - return None + +def svi_proxy(svi, model_args, model_kwargs): + def proxy(z, *args, **kwargs): + z_ref = svi.guide.expectation(z) + ll, _ = log_density_obs_hmcecs(svi.model, model_args, model_kwargs, z_ref) + return ll + + def proxy_u(z, model_args, model_kwargs, *args, **kwargs): + z_ref = svi.guide.expectation(z) + ll, _ = log_density_prior_hmcecs(svi.model, model_args, model_kwargs, z_ref) + return ll + + return proxy, proxy_u def neural_proxy(): return None diff --git a/numpyro/distributions/kl.py b/numpyro/distributions/kl.py index d54d361d6..e888ca42d 100644 --- a/numpyro/distributions/kl.py +++ b/numpyro/distributions/kl.py @@ -177,4 +177,8 @@ def _kl_masked_masked(p, q): def _kl_normal_normal(p, q): var_ratio = jnp.square(p.scale / q.scale) t1 = jnp.square((p.loc - q.loc) / q.scale) +<<<<<<< HEAD return 0.5 * (var_ratio + t1 - 1 - jnp.log(var_ratio)) +======= + return 0.5 * (var_ratio + t1 - 1 - jnp.log(var_ratio)) +>>>>>>> Maybe working diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index fb7d4f2b4..65e6e1741 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -535,4 +535,4 @@ def wrapper(self, *args, **kwargs): log_prob = jnp.where(mask, log_prob, -jnp.inf) return log_prob - return wrapper + return wrapper \ No newline at end of file From 9dd7025feeb5e8800033756b14e2796feb00454f Mon Sep 17 00:00:00 2001 From: Lys Date: Mon, 2 Nov 2020 15:29:58 +0100 Subject: [PATCH 21/93] Started adding Block-Poisson --- examples/logistic_hmcecs.py | 44 ++++++--- examples/logistic_hmcecs_svi.py | 7 +- numpyro/contrib/hmcecs.py | 167 +++++++++++++++++++++----------- numpyro/contrib/hmcecs_utils.py | 37 +++++++ 4 files changed, 178 insertions(+), 77 deletions(-) diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index 3f1bd5670..d8906b0c7 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -101,9 +101,9 @@ def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None, n, _ = feats.shape file_hyperparams = open("PLOTS_{}/Hyperparameters_{}.txt".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms"),now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")), "a") if subsample_method=="perturb" and proxy== "taylor": - map_samples = 100 - map_warmup = 50 - factor_NUTS = 1000 + map_samples = 5 + map_warmup = 20 + factor_NUTS = 100 if map_method == "NUTS": print("Running NUTS for map estimation {} + {} samples".format(map_samples,map_warmup)) file_hyperparams.write('MAP samples : {} \n'.format(map_samples)) @@ -116,7 +116,6 @@ def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None, file_hyperparams.write('MAP warmup : {} \n'.format(map_warmup)) samples, r_hat_average = infer_hmc(map_key, feats[:factor_NUTS], obs[:factor_NUTS], samples=map_samples, warmup=map_warmup) z_ref = {key: value.mean(0) for key, value in samples.items()} - if map_method == "SVI": print("Running SVI for map estimation") file_hyperparams.write('SVI epochs : {} \n'.format(num_epochs)) @@ -127,8 +126,8 @@ def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None, print("Running MCMC subsampling with Taylor proxy") elif subsample_method =="perturb" and proxy=="svi": factor_SVI = obs.shape[0] - batch_size = int(factor_SVI//10) - print("Running SVI for map estimation") + batch_size = 32 #int(factor_SVI//10) + print("Running SVI for map estimation with svi proxy") file_hyperparams.write('SVI epochs : {} \n'.format(num_epochs)) map_key, post_key = jax.random.split(map_key) z_ref, svi, svi_state = svi_map(model, map_key, feats=feats[:factor_SVI], obs=obs[:factor_SVI], @@ -148,6 +147,7 @@ def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None, kernel = HMC(model=model,z_ref=z_ref,m=m,g=g,algo=algo,subsample_method=subsample_method,proxy=proxy,svi_fn=svi,target_accept_prob=0.8) mcmc = MCMC(kernel,num_warmup=warmup,num_samples=n_samples,num_chains=1) + print(feats.shape) mcmc.run(rng_key,feats,obs) stop = time.time() file_hyperparams.write('MCMC/NUTS elapsed time {}: {} \n'.format(subsample_method,time.time() - start)) @@ -211,12 +211,13 @@ def Folders(folder_name): else: shutil.rmtree(newpath) # removes all the subdirectories! os.makedirs(newpath,0o777) -def Plot_KL(map_method,ecs_algo,algo,n_samples,n_warmup,epochs): - factor_ECS=obs.shape[0] +def Plot_KL(map_method,ecs_algo,algo,proxy,n_samples,n_warmup,epochs): + factor_ECS= 1000 #obs.shape[0] m = [int(np_jax.sqrt(obs[:factor_ECS].shape[0])),2*int(np_jax.sqrt(obs[:factor_ECS].shape[0])),4*int(np_jax.sqrt(obs[:factor_ECS].shape[0])),8*int(np_jax.sqrt(obs[:factor_ECS].shape[0]))] g = 5 - factor_NUTS = 1000 + factor_NUTS = 100 colors = cm.rainbow(np.linspace(0, 1, len(m))) + print("Running standard NUTS") est_posterior_NUTS = infer_hmcecs(rng_key, feats=feats[:factor_NUTS], obs=obs[:factor_NUTS], n_samples=n_samples, warmup=n_warmup, m="all", g=g, algo=algo) for m_val, color in zip(m,colors): @@ -226,13 +227,23 @@ def Plot_KL(map_method,ecs_algo,algo,n_samples,n_warmup,epochs): m=m_val, g=g, algo=ecs_algo, subsample_method="perturb", + proxy=proxy, map_method=map_method, num_epochs=epochs) p = dist.Normal(est_posterior_ECS["theta"]) q = dist.Normal(est_posterior_NUTS["theta"]) kl = kl_divergence(p, q) - Plot(est_posterior_ECS, est_posterior_NUTS, ecs_algo, algo,m_val,kl=kl ) + + Plot(samples_ECS=est_posterior_ECS, + samples_NUTS=est_posterior_NUTS, + ecs_algo= ecs_algo, + algo=algo, + proxy= proxy, + m = m_val, + kl=kl) + exit() + def Tests(map_method,ecs_algo,algo,n_samples,n_warmup,epochs,proxy): m = int(np_jax.sqrt(obs.shape[0])*2) g= 5 @@ -252,13 +263,13 @@ def Tests(map_method,ecs_algo,algo,n_samples,n_warmup,epochs,proxy): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-num_samples', nargs='?', default=50, type=int) - parser.add_argument('-num_warmup', nargs='?', default=5, type=int) + parser.add_argument('-num_samples', nargs='?', default=5,type=int) + parser.add_argument('-num_warmup', nargs='?', default=20, type=int) parser.add_argument('-ecs_algo', nargs='?', default="NUTS", type=str) - parser.add_argument('-ecs_proxy', nargs='?', default="svi", type=str) + parser.add_argument('-ecs_proxy', nargs='?', default="taylor", type=str) parser.add_argument('-algo', nargs='?', default="HMC", type=str) parser.add_argument('-map_init', nargs='?', default="NUTS", type=str) - parser.add_argument("-epochs",default=2,type=int) + parser.add_argument("-epochs",default=10,type=int) args = parser.parse_args() @@ -274,6 +285,7 @@ def Tests(map_method,ecs_algo,algo,n_samples,n_warmup,epochs,proxy): now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")), "a") file_hyperparams.write('ECS algo : {} \n'.format(args.ecs_algo)) file_hyperparams.write('algo : {} \n'.format(args.algo)) + file_hyperparams.write('ECS proxy : {} \n'.format(args.ecs_proxy)) file_hyperparams.write('MAP init : {} \n'.format(args.map_init)) higgs = False @@ -289,8 +301,8 @@ def Tests(map_method,ecs_algo,algo,n_samples,n_warmup,epochs,proxy): config.update('jax_disable_jit', True) #Determine_best_sample_size(rng_key,feats[:100],obs[:100]) - Tests(args.map_init,args.ecs_algo,args.algo,args.num_samples,args.num_warmup,args.epochs,args.ecs_proxy) - #Plot_KL(args.map_init,args.ecs_algo,args.algo,args.num_samples,args.num_warmup,args.epochs) + #Tests(args.map_init,args.ecs_algo,args.algo,args.num_samples,args.num_warmup,args.epochs,args.ecs_proxy) + Plot_KL(args.map_init,args.ecs_algo,args.algo,args.ecs_proxy,args.num_samples,args.num_warmup,args.epochs) exit() diff --git a/examples/logistic_hmcecs_svi.py b/examples/logistic_hmcecs_svi.py index b95f47d85..3e67ab226 100644 --- a/examples/logistic_hmcecs_svi.py +++ b/examples/logistic_hmcecs_svi.py @@ -26,7 +26,7 @@ def svi_map(model, rng_key, feats,obs,num_epochs,batch_size): """ from jax import random, jit from numpyro import optim - from numpyro.infer.elbo import RenyiELBO + from numpyro.infer.elbo import RenyiELBO, ELBO from numpyro.infer.svi import SVI from numpyro.util import fori_loop import time @@ -37,8 +37,9 @@ def svi_map(model, rng_key, feats,obs,num_epochs,batch_size): n, _ = feats.shape #guide = AutoDelta(model) guide = AutoDiagonalNormal(model) - loss = RenyiELBO(alpha=2, num_particles=1) - svi = SVI(model, guide, optim.Adam(0.001), loss=loss) + #loss = RenyiELBO(alpha=2, num_particles=1) + loss = ELBO() + svi = SVI(model, guide, optim.Adam(0.0003), loss=loss) svi_state = svi.init( rng_key,feats,obs) train_init, train_fetch = load_dataset(obs,feats, batch_size=batch_size) num_train, train_idx = train_init() diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index d8fc290db..4c653b899 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -23,9 +23,9 @@ from numpyro.util import cond, fori_loop, identity import sys sys.path.append('/home/lys/Dropbox/PhD/numpyro/numpyro/contrib/') - +import numpyro.distributions as dist from hmcecs_utils import potential_est, init_near_values,tuplemerge,\ - model_args_sub,model_kwargs_sub,taylor_proxy,svi_proxy,neural_proxy,log_density_obs_hmcecs,log_density_prior_hmcecs + model_args_sub,model_kwargs_sub,taylor_proxy,svi_proxy,neural_proxy,log_density_obs_hmcecs,log_density_prior_hmcecs,signed_estimator HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob', 'mean_accept_prob', 'diverging', 'adapt_state','rng_key']) @@ -94,7 +94,7 @@ def _update_block(rng_key, u, n, m, g): """Returns the indexes from the subsample that will be updated, there is replacement. The number of indexes to be updated depend on the block size, higher block size more correlation among elements in the subsample. :param rng_key - :param u subsample + :param u subsample indexes :param n total number of data :param m subsample size :param g block size: subsample subdivision""" @@ -102,20 +102,48 @@ def _update_block(rng_key, u, n, m, g): if (g > m) or (g < 1): raise ValueError('Block size (g) = {} needs to = or > than 1 and smaller than the subsample size {}'.format(g,m)) rng_key_block, rng_key_index = random.split(rng_key) - # uniformly choose block to update chosen_block = random.randint(rng_key, shape=(), minval= 0, maxval=g + 1) - - idxs_new = random.randint(rng_key_index, shape=(m // g,), minval=0, maxval=n) #chose block within the subsample to update + idxs_new = random.randint(rng_key_index, shape=(m // g,), minval=0, maxval=n) #choose block within the subsample to update u_new = jnp.zeros(m, jnp.dtype(u)) #empty array with size m for i in range(m): #if index in the subsample // g = chosen block : pick new indexes from the subsample size #else not update: keep the same indexes u_new = ops.index_add(u_new, i, lax.cond(i // g == chosen_block, i, lambda _: idxs_new[i % (m // g)], i, lambda _: u[i])) - return u_new +def _sample_u_poisson(rng_key, m, l): + """ Initialize subsamples u + :param m: subsample size + :param l: length of the current subsample block + :param g: number of blocks + """ + pois_key, sub_key = random.split(rng_key) + block_lengths = dist.discrete.Poisson(1).sample(pois_key, (l,)) + u = random.randint(sub_key, (jnp.sum(block_lengths), m), 0, m) + return jnp.split(u, jnp.cumsum(block_lengths), axis=0) + +@partial(jit, static_argnums=(2, 3, 4)) +def _update_block_poisson(rng_key, u, m, l, g): + """ Update block of u + :param rng_key + :param u: current subsample indexes + :param m: + :param l: + :param g: + """ + if (g > m) or (g < 1): + raise ValueError('Block size (g) = {} needs to = or > than 1 and smaller than the subsample size {}'.format(g,m)) + u = u.copy() + block_key, sample_key = random.split(rng_key) + num_updates = int(round(l / g, 0)) + chosen_blocks = random.randint(block_key, (num_updates,), 0, l) + new_blocks = _sample_u_poisson(sample_key, m, num_updates) + for i, block in enumerate(chosen_blocks): + u[block] = new_blocks[i] + return u + def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, grad_potential_fn_gen=None,algo='NUTS'): r""" @@ -212,10 +240,10 @@ def init_kernel(init_params, model_args=(), model_kwargs=None, model = None, - ll_ref=None, - jac_all=None, - z_ref= None, - hess_all=None, + #ll_ref=None, + #jac_all=None, + #z_ref= None, + #hess_all=None, ll_u = None, n = None, m = None, @@ -327,9 +355,18 @@ def init_kernel(init_params, def _hmc_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key,subsample_method, - proxy_fn = None, proxy_u_fn = None, - model = None, ll_ref = None, jac_all = None, z = None, z_ref = None, hess_all = None, - ll_u = None, u = None, n = None, m = None): + proxy_fn = None, + proxy_u_fn = None, + model = None, + #ll_ref = None, + #jac_all = None, + #z = None, + #z_ref = None, + #hess_all = None, + #ll_u = None, + u = None, + n = None, + m = None): if potential_fn_gen: if grad_potential_fn_gen: kwargs = {} if model_kwargs is None else model_kwargs @@ -374,7 +411,9 @@ def _hmc_next(step_size, inverse_mass_matrix, vv_state, def _nuts_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key,subsample_method, proxy_fn=None,proxy_u_fn=None, - model=None,ll_ref=None,jac_all=None,z = None,z_ref=None,hess_all=None,ll_u=None,u=None,n=None,m=None): + model=None, + #ll_ref=None,jac_all=None,z = None,z_ref=None,hess_all=None,ll_u=None,u=None, + n=None,m=None): if potential_fn_gen: nonlocal vv_update if grad_potential_fn_gen: @@ -411,9 +450,16 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, _next = _nuts_next if algo == 'NUTS' else _hmc_next def sample_kernel(hmc_state,model_args=(),model_kwargs=None, - subsample_method=None,proxy_fn=None,proxy_u_fn=None, - model=None,ll_ref=None,jac_all=None, - z=None,z_ref=None,hess_all=None,ll_u=None, + subsample_method=None, + proxy_fn=None, + proxy_u_fn=None, + model=None, + #ll_ref=None, + #jac_all=None, + #z=None, + #z_ref=None, + #hess_all=None, + ll_u=None, u=None,n=None,m=None,): #TODO: Remove so many args """ Given an existing :data:`~numpyro.infer.mcmc.HMCState`, run HMC with fixed (possibly adapted) @@ -443,7 +489,9 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, subsample_method, proxy_fn, proxy_u_fn, - model,ll_ref,jac_all,z,z_ref,hess_all,ll_u,u,n,m) + model, + #ll_ref,jac_all,z,z_ref,hess_all,ll_u,u, + n,m) # not update adapt_state after warmup phase adapt_state = cond(hmc_state.i < wa_steps, (hmc_state.i, accept_prob, vv_state, hmc_state.adapt_state), @@ -456,8 +504,7 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n hmcstate = HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps, accept_prob, mean_accept_prob, diverging, adapt_state,rng_key) - hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state, - ll_u=ll_u) + hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state,ll_u=ll_u) hmcstate = tuplemerge(hmc_sub_state._asdict(),hmcstate._asdict()) return hmcstate @@ -534,6 +581,7 @@ def __init__(self, find_heuristic_step_size=False, subsample_method = None, proxy="taylor", + estimator =None,#poisson or not svi_fn=None, m= None, g = None, @@ -568,6 +616,8 @@ def __init__(self, self._hess_all = None self._ll_u = None self._u = None + self._neg_ll = None + self._sign = None # Set on first call to init self._init_fn = None self._postprocess_fn = None @@ -577,6 +627,7 @@ def __init__(self, self.svi_fn = svi_fn self._proxy_fn = None self._proxy_u_fn = None + self.estimator = None def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ref): "Compute the jacobian, hessian and log likelihood for all the data. Used with taylor expansion proxy" @@ -599,11 +650,8 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): # Choose the covariate calculation method if self.proxy == "svi": self._proxy_fn,self._proxy_u_fn = svi_proxy(self.svi_fn,model_args,model_kwargs) - - elif self.proxy == "neural": - raise ValueError("Not implemented") elif self.proxy == "taylor": - warnings.warn("Using default second order Taylor expansion, change by using the proxy flag to {svi,neural}") + warnings.warn("Using default second order Taylor expansion, change by using the proxy flag to {svi}") self._init_subsample_state(rng_key, model_args, model_kwargs, init_params, self.z_ref) self._proxy_fn,self._proxy_u_fn = taylor_proxy(self.z_ref, self._model, self._ll_ref, self._jac_all, self._hess_all) # Initialize the potential and gradient potential functions @@ -635,6 +683,7 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): elif (self.m > self._n): raise ValueError( 'Subsample size (m) = {} needs to = or < than data size (n) {}'.format(self.m, self._n)) + else: if self._model is not None: init_params, potential_fn, postprocess_fn, model_trace = initialize_model( @@ -699,7 +748,6 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg raise ValueError('Valid value of `init_params` must be provided with' ' `potential_fn`.') if self.subsample_method == "perturb": - hmc_init_fn = lambda init_params,rng_key: self._init_fn(init_params=init_params, num_warmup = num_warmup, step_size = self._step_size, @@ -714,10 +762,10 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg model_kwargs=model_kwargs, subsample_method= self.subsample_method, model=self._model, - ll_ref =self._ll_ref, - jac_all=self._jac_all, - z_ref=self.z_ref, - hess_all = self._hess_all, + #ll_ref =self._ll_ref, + #jac_all=self._jac_all, + #z_ref=self.z_ref, + #hess_all = self._hess_all, ll_u = self._ll_u, n=self._n, m=self.m, @@ -734,22 +782,26 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg self._proxy_fn,self._proxy_u_fn = taylor_proxy(self.z_ref, self._model, self._ll_ref, self._jac_all, self._hess_all) elif self.proxy == "svi": self._proxy_fn, self._proxy_u_fn = svi_proxy(self.svi_fn, model_args, model_kwargs) - - self._ll_u = potential_est(model=self._model, - model_args=model_args_sub(self._u, model_args), - model_kwargs=model_kwargs, - z=init_state.z, - n=self._n, - m=self.m, - proxy_fn=self._proxy_fn, - proxy_u_fn=self._proxy_u_fn) - hmc_init_sub_state = HMCECSState(u=self._u, - hmc_state=init_state.hmc_state, - ll_u=self._ll_u) - init_sub_state = tuplemerge(init_state._asdict(),hmc_init_sub_state._asdict()) - - return init_sub_state + if self.estimator == "poisson": + print("Poisson, working on it") + self._neg_ll, self._sign = signed_estimator(self._model, model_args, model_kwargs, init_state.z, self.a, self.l, self._proxy_fn, self._proxy_u_fn) + else: + self._ll_u = potential_est(model=self._model, + model_args=model_args_sub(self._u, model_args), + model_kwargs=model_kwargs, + z=init_state.z, + n=self._n, + m=self.m, + proxy_fn=self._proxy_fn, + proxy_u_fn=self._proxy_u_fn) + hmc_init_sub_state = HMCECSState(u=self._u, + hmc_state=init_state.hmc_state, + ll_u=self._ll_u) + init_sub_state = tuplemerge(init_state._asdict(),hmc_init_sub_state._asdict()) + + return init_sub_state else: #TODO: What is this for? It does not go into it for num_chains>1 + raise ValueError("Not implemented for n_chains > 1") # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth, # wa_steps because those variables do not depend on traced args: init_params, rng_key. @@ -822,14 +874,7 @@ def sample(self, state, model_args, model_kwargs): u_new = _update_block(rng_key_subsample, state.u, self._n, self.m, self.g) # estimate likelihood of subsample with single block updated - # llu_new = potential_est(model=self._model, - # model_args=model_args_sub(u_new,model_args), - # model_kwargs=model_kwargs, - # z=state.z, - # z_ref=self.z_ref, - # proxy_fn = self._proxy_fn, - # proxy_u_fn = self._proxy_u_fn, - # n=self._n, m=self.m,u=state.u) + llu_new = potential_est(model=self._model, model_args=model_args_sub(u_new,model_args), model_kwargs=model_kwargs, @@ -845,6 +890,7 @@ def sample(self, state, model_args, model_kwargs): u, ll_u = cond(transition, (u_new, llu_new), identity, (state.u, state.ll_u), identity) + print(u) ######## UPDATE PARAMETERS ########## @@ -859,16 +905,21 @@ def sample(self, state, model_args, model_kwargs): proxy_fn = self._proxy_fn, proxy_u_fn = self._proxy_u_fn, model = self._model, - ll_ref = self._ll_ref, - jac_all =self._jac_all, - z= state.z, - z_ref = self.z_ref, #TODO: Not necessary , remove(z_ref, hess_all, jac_all,ll_ref) - hess_all = self._hess_all, + #ll_ref = self._ll_ref, + #jac_all =self._jac_all, + #z= state.z, + #z_ref = self.z_ref, #TODO: Not necessary , remove(z_ref, hess_all, jac_all,ll_ref) + #hess_all = self._hess_all, ll_u = ll_u, u= u, n= self._n, m= self.m) + + + + + else: return self._sample_fn(state, model_args, model_kwargs) diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py index c3b0c6c3d..6988d7c81 100644 --- a/numpyro/contrib/hmcecs_utils.py +++ b/numpyro/contrib/hmcecs_utils.py @@ -181,6 +181,12 @@ def init_near_values(site=None, values={}): return init_to_uniform(site) def taylor_proxy(z_ref, model, ll_ref, jac_all, hess_all): + """Corrects the subsample likelihood using covariates the taylor expansion + :param z_ref = reference estimate (e.g MAP) of the model's parameters + :param model = model likelihood + :param ll_ref = reference loglikelihood + :param jac_all= Jacobian vector of the entire dataset + :param hess_all = Hessian matrix of the entire dataset""" def proxy(z, *args, **kwargs): z_flat, _ = ravel_pytree(z) zref_flat, _ = ravel_pytree(z_ref) @@ -221,3 +227,34 @@ def neural_proxy(): return None +def signed_estimator(model, model_args, model_kwargs, z, a, l, proxy, proxy_u): + """ + + :param model: + :param model_args: + :param model_kwargs: + :param z: + :param a: + :param l: Length of the block of data to be updated within the subsample + :param proxy: + :param proxy_u: + :return: + """ + xis = 0. + sign = 1. + + for args in model_args: + ll_sub, _ = log_density_obs_hmcecs(model, args, {}, z) # log likelihood for subsample with current theta + xi = (jnp.exp(ll_sub - proxy_u(z=z, model_args=args, model_kwargs=model_kwargs)) - a) / l + sign *= jnp.prod(jnp.sign(xi)) + xis += jnp.sum(jnp.abs(xi), axis=0) + + lhat = proxy(z) + (a + l) / l + xis + ll_prior, _ = log_density_prior_hmcecs(model, model_args, model_kwargs, z) + + neg_ll = - lhat - ll_prior + + return neg_ll, sign + + + From bdcd3524d495a7dd15428ce8ac621fa9e7e8a778 Mon Sep 17 00:00:00 2001 From: Lys Date: Mon, 2 Nov 2020 15:59:42 +0100 Subject: [PATCH 22/93] small stuff --- examples/logistic_hmcecs.py | 1 - numpyro/contrib/hmcecs.py | 1 - 2 files changed, 2 deletions(-) diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index d8906b0c7..4db69a292 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -147,7 +147,6 @@ def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None, kernel = HMC(model=model,z_ref=z_ref,m=m,g=g,algo=algo,subsample_method=subsample_method,proxy=proxy,svi_fn=svi,target_accept_prob=0.8) mcmc = MCMC(kernel,num_warmup=warmup,num_samples=n_samples,num_chains=1) - print(feats.shape) mcmc.run(rng_key,feats,obs) stop = time.time() file_hyperparams.write('MCMC/NUTS elapsed time {}: {} \n'.format(subsample_method,time.time() - start)) diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index 4c653b899..3a3e114eb 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -890,7 +890,6 @@ def sample(self, state, model_args, model_kwargs): u, ll_u = cond(transition, (u_new, llu_new), identity, (state.u, state.ll_u), identity) - print(u) ######## UPDATE PARAMETERS ########## From 742162cfd12b6c09c8eb158a0b56856a57bdff8a Mon Sep 17 00:00:00 2001 From: Lys Date: Mon, 2 Nov 2020 19:30:08 +0100 Subject: [PATCH 23/93] Started adding poisson stuff Lots of things to do --- examples/logistic_hmcecs.py | 76 +++++++++++++++++++---------- numpyro/contrib/hmcecs.py | 85 +++++++++++++++++++++------------ numpyro/contrib/hmcecs_utils.py | 9 ++-- numpyro/distributions/kl.py | 4 -- 4 files changed, 110 insertions(+), 64 deletions(-) diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index 4db69a292..67ce31fd0 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -12,7 +12,7 @@ import numpy as np from numpyro.distributions.kl import kl_divergence from matplotlib.pyplot import cm - +#remember to export the path of the project sys.path.append('/home/lys/Dropbox/PhD/numpyro/numpyro/contrib/') sys.path.append('/home/lys/Dropbox/PhD/numpyro/numpyro/examples/') @@ -20,8 +20,8 @@ #from numpyro.contrib.hmcecs import HMC from sklearn.datasets import load_breast_cancer -#from datasets import _load_higgs -from numpyro.examples.datasets import _load_higgs +from datasets import _load_higgs +#from numpyro.examples.datasets import _load_higgs from logistic_hmcecs_svi import svi_map import jax.numpy as np_jax import matplotlib.pyplot as plt @@ -96,14 +96,15 @@ def infer_hmc(rng_key, feats, obs, samples, warmup ): -def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None,algo="NUTS",subsample_method=None,map_method=None,proxy="taylor",num_epochs=None ): +def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None,algo="NUTS",subsample_method=None,map_method=None,proxy="taylor",estimator=None,num_epochs=None ): hmcecs_key, map_key = jax.random.split(rng_key) n, _ = feats.shape file_hyperparams = open("PLOTS_{}/Hyperparameters_{}.txt".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms"),now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")), "a") + if subsample_method=="perturb" and proxy== "taylor": - map_samples = 5 - map_warmup = 20 - factor_NUTS = 100 + map_samples = 10 + map_warmup = 5 + factor_NUTS = 50 if map_method == "NUTS": print("Running NUTS for map estimation {} + {} samples".format(map_samples,map_warmup)) file_hyperparams.write('MAP samples : {} \n'.format(map_samples)) @@ -144,7 +145,7 @@ def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None, svi = None start = time.time() - kernel = HMC(model=model,z_ref=z_ref,m=m,g=g,algo=algo,subsample_method=subsample_method,proxy=proxy,svi_fn=svi,target_accept_prob=0.8) + kernel = HMC(model=model,z_ref=z_ref,m=m,g=g,algo=algo,subsample_method=subsample_method,proxy=proxy,svi_fn=svi,estimator = estimator,target_accept_prob=0.8) mcmc = MCMC(kernel,num_warmup=warmup,num_samples=n_samples,num_chains=1) mcmc.run(rng_key,feats,obs) @@ -177,11 +178,16 @@ def Determine_best_sample_size(rng_key,feats,obs): plt.ylabel(r"$\hat{r}$") plt.title("Determine best effective sample size for z_map") plt.savefig("{}/Best_effective_size_z_map.png".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")))) -def Plot(samples_ECS,samples_NUTS,ecs_algo,algo,proxy,m,kl=None): +def Plot(samples_ECS,samples_NUTS,ecs_algo,algo,proxy,estimator,m,kl=None): + if estimator : + label = "ECS-{}-{} proxy-{} estimator".format(ecs_algo, proxy, estimator) + else: + label = "ECS-{}-{} proxy".format(ecs_algo, proxy) for sample in [0,7,15,25]: plt.figure(sample + m +3) #samples = pd.DataFrame.from_records(samples,index="theta") - sns.kdeplot(data=samples_ECS["theta"][sample],color="r",label="ECS-{}-{} proxy".format(ecs_algo,proxy)) + + sns.kdeplot(data=samples_ECS["theta"][sample],color="r",label=label) sns.kdeplot(data=samples_NUTS["theta"][sample],color="b",label="{}".format(algo)) #if kl != None: # sns.kdeplot(data=kl, color="g", label="KL; m: {}".format(m)) @@ -210,15 +216,17 @@ def Folders(folder_name): else: shutil.rmtree(newpath) # removes all the subdirectories! os.makedirs(newpath,0o777) -def Plot_KL(map_method,ecs_algo,algo,proxy,n_samples,n_warmup,epochs): - factor_ECS= 1000 #obs.shape[0] +def Plot_KL(map_method,ecs_algo,algo,proxy,estimator,n_samples,n_warmup,epochs): + factor_ECS= 50 #obs.shape[0] m = [int(np_jax.sqrt(obs[:factor_ECS].shape[0])),2*int(np_jax.sqrt(obs[:factor_ECS].shape[0])),4*int(np_jax.sqrt(obs[:factor_ECS].shape[0])),8*int(np_jax.sqrt(obs[:factor_ECS].shape[0]))] g = 5 - factor_NUTS = 100 + factor_NUTS = 50 colors = cm.rainbow(np.linspace(0, 1, len(m))) - print("Running standard NUTS") - est_posterior_NUTS = infer_hmcecs(rng_key, feats=feats[:factor_NUTS], obs=obs[:factor_NUTS], - n_samples=n_samples, warmup=n_warmup, m="all", g=g, algo=algo) + run_test = False + if run_test: + print("Running standard NUTS") + est_posterior_NUTS = infer_hmcecs(rng_key, feats=feats[:factor_NUTS], obs=obs[:factor_NUTS], + n_samples=n_samples, warmup=n_warmup, m="all", g=g, algo=algo) for m_val, color in zip(m,colors): est_posterior_ECS = infer_hmcecs(rng_key, feats=feats[:factor_ECS], obs=obs[:factor_ECS], n_samples=n_samples, @@ -227,6 +235,7 @@ def Plot_KL(map_method,ecs_algo,algo,proxy,n_samples,n_warmup,epochs): algo=ecs_algo, subsample_method="perturb", proxy=proxy, + estimator=estimator, map_method=map_method, num_epochs=epochs) @@ -239,34 +248,45 @@ def Plot_KL(map_method,ecs_algo,algo,proxy,n_samples,n_warmup,epochs): ecs_algo= ecs_algo, algo=algo, proxy= proxy, + estimator = estimator, m = m_val, kl=kl) exit() -def Tests(map_method,ecs_algo,algo,n_samples,n_warmup,epochs,proxy): +def Tests(map_method,ecs_algo,algo,estimator,n_samples,n_warmup,epochs,proxy): m = int(np_jax.sqrt(obs.shape[0])*2) g= 5 est_posterior_ECS = infer_hmcecs(rng_key, feats=feats, obs=obs, n_samples=n_samples, warmup=n_warmup, - m =m,g=g, + m =m, + g=g, algo=ecs_algo, subsample_method="perturb", proxy=proxy, + estimator = estimator, map_method = map_method, num_epochs=epochs) - est_posterior_NUTS = infer_hmcecs(rng_key, feats=feats, obs=obs, n_samples=n_samples,warmup=n_warmup,m =m,g=g,algo=algo) + est_posterior_NUTS = infer_hmcecs(rng_key, + feats=feats, + obs=obs, + n_samples=n_samples, + warmup=n_warmup, + m =m, + g=g, + algo=algo) - Plot(est_posterior_ECS,est_posterior_NUTS,ecs_algo,algo,proxy,m) + Plot(est_posterior_ECS,est_posterior_NUTS,ecs_algo,algo,proxy,estimator,m) if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-num_samples', nargs='?', default=5,type=int) - parser.add_argument('-num_warmup', nargs='?', default=20, type=int) + parser.add_argument('-num_samples', nargs='?', default=10,type=int) + parser.add_argument('-num_warmup', nargs='?', default=5, type=int) parser.add_argument('-ecs_algo', nargs='?', default="NUTS", type=str) parser.add_argument('-ecs_proxy', nargs='?', default="taylor", type=str) parser.add_argument('-algo', nargs='?', default="HMC", type=str) + parser.add_argument('-estimator', nargs='?', default="poisson", type=str) parser.add_argument('-map_init', nargs='?', default="NUTS", type=str) parser.add_argument("-epochs",default=10,type=int) args = parser.parse_args() @@ -301,7 +321,7 @@ def Tests(map_method,ecs_algo,algo,n_samples,n_warmup,epochs,proxy): #Determine_best_sample_size(rng_key,feats[:100],obs[:100]) #Tests(args.map_init,args.ecs_algo,args.algo,args.num_samples,args.num_warmup,args.epochs,args.ecs_proxy) - Plot_KL(args.map_init,args.ecs_algo,args.algo,args.ecs_proxy,args.num_samples,args.num_warmup,args.epochs) + Plot_KL(args.map_init,args.ecs_algo,args.algo,args.ecs_proxy,args.estimator,args.num_samples,args.num_warmup,args.epochs) exit() @@ -314,8 +334,14 @@ def Tests(map_method,ecs_algo,algo,n_samples,n_warmup,epochs,proxy): p = dist.Normal(samples_ECS_3316["theta"]) q = dist.Normal(samples_HMC["theta"]) kl = kl_divergence(p, q) - print(np_jax.average(kl)) - Plot(samples_ECS_3316, samples_HMC, args.ecs_algo, args.algo, 3316,args.ecs_proxy ,kl=kl) + Plot(samples_ECS=samples_ECS_3316, + samples_NUTS=samples_HMC, + ecs_algo= args.ecs_algo, + algo=args.algo, + proxy= args.proxy, + estimator = args.estimator, + m = 3316, + kl=kl) # samples = pd.DataFrame.from_records(samples,index="theta") # sns.kdeplot(data=kl, color=color, label="m : ".format(m_val)) diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index 3a3e114eb..5be4bfc27 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -240,10 +240,10 @@ def init_kernel(init_params, model_args=(), model_kwargs=None, model = None, - #ll_ref=None, - #jac_all=None, - #z_ref= None, - #hess_all=None, + ll_ref=None, + jac_all=None, + z_ref= None, + hess_all=None, ll_u = None, n = None, m = None, @@ -358,12 +358,12 @@ def _hmc_next(step_size, inverse_mass_matrix, vv_state, proxy_fn = None, proxy_u_fn = None, model = None, - #ll_ref = None, - #jac_all = None, - #z = None, - #z_ref = None, - #hess_all = None, - #ll_u = None, + ll_ref = None, + jac_all = None, + z = None, + z_ref = None, + hess_all = None, + ll_u = None, u = None, n = None, m = None): @@ -412,7 +412,7 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key,subsample_method, proxy_fn=None,proxy_u_fn=None, model=None, - #ll_ref=None,jac_all=None,z = None,z_ref=None,hess_all=None,ll_u=None,u=None, + ll_ref=None,jac_all=None,z = None,z_ref=None,hess_all=None,ll_u=None,u=None, n=None,m=None): if potential_fn_gen: nonlocal vv_update @@ -454,11 +454,11 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, proxy_fn=None, proxy_u_fn=None, model=None, - #ll_ref=None, - #jac_all=None, - #z=None, - #z_ref=None, - #hess_all=None, + ll_ref=None, + jac_all=None, + z=None, + z_ref=None, + hess_all=None, ll_u=None, u=None,n=None,m=None,): #TODO: Remove so many args """ @@ -490,7 +490,7 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, proxy_fn, proxy_u_fn, model, - #ll_ref,jac_all,z,z_ref,hess_all,ll_u,u, + ll_ref,jac_all,z,z_ref,hess_all,ll_u,u, n,m) # not update adapt_state after warmup phase adapt_state = cond(hmc_state.i < wa_steps, @@ -618,6 +618,8 @@ def __init__(self, self._u = None self._neg_ll = None self._sign = None + self._l = 1 #TODO: What to initialize this to? + self._a = 1 # Set on first call to init self._init_fn = None self._postprocess_fn = None @@ -627,7 +629,7 @@ def __init__(self, self.svi_fn = svi_fn self._proxy_fn = None self._proxy_u_fn = None - self.estimator = None + self.estimator = estimator def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ref): "Compute the jacobian, hessian and log likelihood for all the data. Used with taylor expansion proxy" @@ -654,10 +656,14 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): warnings.warn("Using default second order Taylor expansion, change by using the proxy flag to {svi}") self._init_subsample_state(rng_key, model_args, model_kwargs, init_params, self.z_ref) self._proxy_fn,self._proxy_u_fn = taylor_proxy(self.z_ref, self._model, self._ll_ref, self._jac_all, self._hess_all) - # Initialize the potential and gradient potential functions + if self.estimator =="poisson": + self._l = 1 # initialize? + self._a = 1 + # Initialize the potential and gradient potential functions self._potential_fn = lambda model, model_args, model_kwargs, z, n, m, proxy_fn, proxy_u_fn : lambda z:potential_est(model=model, model_args=model_args, model_kwargs=model_kwargs, z=z, n=n, m=m, proxy_fn=proxy_fn, proxy_u_fn=proxy_u_fn) + # Initialize the hmc sampler: sample_fn = sample_kernel self._init_fn, self._sample_fn = hmc(potential_fn_gen=self._potential_fn, kinetic_fn=euclidean_kinetic_energy, @@ -762,10 +768,10 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg model_kwargs=model_kwargs, subsample_method= self.subsample_method, model=self._model, - #ll_ref =self._ll_ref, - #jac_all=self._jac_all, - #z_ref=self.z_ref, - #hess_all = self._hess_all, + ll_ref =self._ll_ref, + jac_all=self._jac_all, + z_ref=self.z_ref, + hess_all = self._hess_all, ll_u = self._ll_u, n=self._n, m=self.m, @@ -783,8 +789,14 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg elif self.proxy == "svi": self._proxy_fn, self._proxy_u_fn = svi_proxy(self.svi_fn, model_args, model_kwargs) if self.estimator == "poisson": + #signed pseudo-marginal algorithm with the block-Poisson estimator + #use the term signed PM for any pseudo-marginal algorithm that uses the technique in Lyne + # et al. (2015) where a pseudo-marginal sampler is run on the absolute value of the estimated + # posterior and subsequently sign-corrected by importance sampling. Similarly, we call the + # algorithm described in this section signed HMC-ECS print("Poisson, working on it") - self._neg_ll, self._sign = signed_estimator(self._model, model_args, model_kwargs, init_state.z, self.a, self.l, self._proxy_fn, self._proxy_u_fn) + self._neg_ll, self._sign = signed_estimator(self._model, model_args, model_kwargs, init_state.z, self._a, self._l, self._proxy_fn, self._proxy_u_fn) + exit() else: self._ll_u = potential_est(model=self._model, model_args=model_args_sub(self._u, model_args), @@ -801,7 +813,7 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg return init_sub_state else: #TODO: What is this for? It does not go into it for num_chains>1 - raise ValueError("Not implemented for n_chains > 1") + raise ValueError("Not implemented for chains > 1") # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth, # wa_steps because those variables do not depend on traced args: init_params, rng_key. @@ -872,6 +884,18 @@ def sample(self, state, model_args, model_kwargs): rng_key_subsample, rng_key_transition, rng_key_likelihood, rng_key = random.split( state.rng_key, 4) + if self.estimator == "poisson": + #TODO: What to do here? does the negative likelihood need to be stored? how about the sign? store in the state? + u_new = _sample_u_poisson(rng_key,self.m,self._l) + neg_ll, sign = signed_estimator(model = self._model, + model_args=model_args, + model_kwargs=model_kwargs, + z=state.z, + a=self._a, + l =self._l, + proxy_fn = self._proxy_fn, + proxy_u_fn = self._proxy_u_fn) + u_new = _update_block(rng_key_subsample, state.u, self._n, self.m, self.g) # estimate likelihood of subsample with single block updated @@ -894,7 +918,6 @@ def sample(self, state, model_args, model_kwargs): ######## UPDATE PARAMETERS ########## hmc_subsamplestate = HMCECSState(u=u, hmc_state=state.hmc_state,ll_u=ll_u) - hmc_subsamplestate = tuplemerge(hmc_subsamplestate._asdict(),state._asdict()) return self._sample_fn(hmc_subsamplestate, @@ -904,11 +927,11 @@ def sample(self, state, model_args, model_kwargs): proxy_fn = self._proxy_fn, proxy_u_fn = self._proxy_u_fn, model = self._model, - #ll_ref = self._ll_ref, - #jac_all =self._jac_all, - #z= state.z, - #z_ref = self.z_ref, #TODO: Not necessary , remove(z_ref, hess_all, jac_all,ll_ref) - #hess_all = self._hess_all, + ll_ref = self._ll_ref, + jac_all =self._jac_all, + z= state.z, + z_ref = self.z_ref, + hess_all = self._hess_all, ll_u = ll_u, u= u, n= self._n, diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py index 6988d7c81..f8e0c0797 100644 --- a/numpyro/contrib/hmcecs_utils.py +++ b/numpyro/contrib/hmcecs_utils.py @@ -100,6 +100,7 @@ def potential_est(model, model_args, model_kwargs, z, n, m, proxy_fn, proxy_u_fn ll_sub, _ = log_density_obs_hmcecs(model, model_args, {}, z) # log likelihood for subsample with current theta diff = ll_sub - proxy_u_fn(z=z, model_args=model_args, model_kwargs=model_kwargs) + l_hat = proxy_fn(z) + n / m * diff sigma = n ** 2 / m * jnp.var(diff) @@ -227,7 +228,7 @@ def neural_proxy(): return None -def signed_estimator(model, model_args, model_kwargs, z, a, l, proxy, proxy_u): +def signed_estimator(model, model_args, model_kwargs, z, a, l, proxy_fn, proxy_u_fn): """ :param model: @@ -243,13 +244,13 @@ def signed_estimator(model, model_args, model_kwargs, z, a, l, proxy, proxy_u): xis = 0. sign = 1. - for args in model_args: + for args in model_args: #TODO: Perhaps for index in len(model_args) ? ll_sub, _ = log_density_obs_hmcecs(model, args, {}, z) # log likelihood for subsample with current theta - xi = (jnp.exp(ll_sub - proxy_u(z=z, model_args=args, model_kwargs=model_kwargs)) - a) / l + xi = (jnp.exp(ll_sub - proxy_u_fn(z=z, model_args=args, model_kwargs=model_kwargs)) - a) / l sign *= jnp.prod(jnp.sign(xi)) xis += jnp.sum(jnp.abs(xi), axis=0) - lhat = proxy(z) + (a + l) / l + xis + lhat = proxy_fn(z) + (a + l) / l + xis ll_prior, _ = log_density_prior_hmcecs(model, model_args, model_kwargs, z) neg_ll = - lhat - ll_prior diff --git a/numpyro/distributions/kl.py b/numpyro/distributions/kl.py index e888ca42d..d54d361d6 100644 --- a/numpyro/distributions/kl.py +++ b/numpyro/distributions/kl.py @@ -177,8 +177,4 @@ def _kl_masked_masked(p, q): def _kl_normal_normal(p, q): var_ratio = jnp.square(p.scale / q.scale) t1 = jnp.square((p.loc - q.loc) / q.scale) -<<<<<<< HEAD return 0.5 * (var_ratio + t1 - 1 - jnp.log(var_ratio)) -======= - return 0.5 * (var_ratio + t1 - 1 - jnp.log(var_ratio)) ->>>>>>> Maybe working From 81b8956f7d00cf74aa87fa5f01bf97d315d99daa Mon Sep 17 00:00:00 2001 From: Lys Date: Wed, 4 Nov 2020 12:43:09 +0100 Subject: [PATCH 24/93] Working on documentation and poisson --- examples/logistic_hmcecs.py | 1 + numpyro/contrib/funsor/enum_messenger.py | 9 ++- numpyro/contrib/hmcecs.py | 77 ++++++++++++------------ numpyro/contrib/hmcecs_utils.py | 33 ++++++---- 4 files changed, 69 insertions(+), 51 deletions(-) diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index 67ce31fd0..4f164727c 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -156,6 +156,7 @@ def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None, file_hyperparams.write('Subsample size (m): {}\n'.format(m)) file_hyperparams.write('Block size (g): {}\n'.format(g)) file_hyperparams.write('Data size (n): {}\n'.format(feats.shape[0])) + file_hyperparams.write('Estimator: {}\n'.format(estimator)) file_hyperparams.write('...........................................\n') file_hyperparams.close() diff --git a/numpyro/contrib/funsor/enum_messenger.py b/numpyro/contrib/funsor/enum_messenger.py index a7cf3b553..bba8ff7eb 100644 --- a/numpyro/contrib/funsor/enum_messenger.py +++ b/numpyro/contrib/funsor/enum_messenger.py @@ -7,8 +7,13 @@ from jax import lax import jax.numpy as jnp - -import funsor +try: + import funsor +except ImportError as e: + raise ImportError("Looking like you want to do inference for models with " + "discrete latent variables. This is an experimental feature. " + "You need to install `funsor` to be able to use this feature. " + "It can be installed with `pip install funsor`.") from e from numpyro.handlers import trace as OrigTraceMessenger from numpyro.primitives import Messenger, apply_stack from numpyro.primitives import plate as OrigPlateMessenger diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index 5be4bfc27..d822e4ce6 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -91,7 +91,7 @@ def momentum_generator(prototype_r, mass_matrix_sqrt, rng_key): @partial(jit, static_argnums=(2, 3, 4)) def _update_block(rng_key, u, n, m, g): - """Returns the indexes from the subsample that will be updated, there is replacement. + """Returns indexes of the new subsample. The update mechanism selects blocks of indices within the subsample to be updated. The number of indexes to be updated depend on the block size, higher block size more correlation among elements in the subsample. :param rng_key :param u subsample indexes @@ -115,6 +115,9 @@ def _update_block(rng_key, u, n, m, g): def _sample_u_poisson(rng_key, m, l): """ Initialize subsamples u + ***References*** + 1.Hamiltonian Monte Carlo with Energy Conserving Subsampling + 2.The blockPoisson estimator for optimally tuned exact subsampling MCMC. :param m: subsample size :param l: length of the current subsample block :param g: number of blocks @@ -126,18 +129,21 @@ def _sample_u_poisson(rng_key, m, l): @partial(jit, static_argnums=(2, 3, 4)) def _update_block_poisson(rng_key, u, m, l, g): - """ Update block of u + """ Update block of u, where the length of the block of indexes to update is given by the Poisson distribution. + ***References*** + 1.Hamiltonian Monte Carlo with Energy Conserving Subsampling + 2.The blockPoisson estimator for optimally tuned exact subsampling MCMC. :param rng_key :param u: current subsample indexes - :param m: - :param l: - :param g: + :param m: Subsample size + :param l: lambda + :param g: Block size within subsample """ if (g > m) or (g < 1): raise ValueError('Block size (g) = {} needs to = or > than 1 and smaller than the subsample size {}'.format(g,m)) u = u.copy() block_key, sample_key = random.split(rng_key) - num_updates = int(round(l / g, 0)) + num_updates = int(round(l / g, 0)) # choose lambda/g number of blocks to update chosen_blocks = random.randint(block_key, (num_updates,), 0, l) new_blocks = _sample_u_poisson(sample_key, m, num_updates) for i, block in enumerate(chosen_blocks): @@ -618,8 +624,7 @@ def __init__(self, self._u = None self._neg_ll = None self._sign = None - self._l = 1 #TODO: What to initialize this to? - self._a = 1 + self._l = 100 #TODO: What to initialize this to? # Set on first call to init self._init_fn = None self._postprocess_fn = None @@ -657,8 +662,7 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): self._init_subsample_state(rng_key, model_args, model_kwargs, init_params, self.z_ref) self._proxy_fn,self._proxy_u_fn = taylor_proxy(self.z_ref, self._model, self._ll_ref, self._jac_all, self._hess_all) if self.estimator =="poisson": - self._l = 1 # initialize? - self._a = 1 + self._l = 100 # lambda # Initialize the potential and gradient potential functions self._potential_fn = lambda model, model_args, model_kwargs, z, n, m, proxy_fn, proxy_u_fn : lambda z:potential_est(model=model, @@ -795,8 +799,7 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg # posterior and subsequently sign-corrected by importance sampling. Similarly, we call the # algorithm described in this section signed HMC-ECS print("Poisson, working on it") - self._neg_ll, self._sign = signed_estimator(self._model, model_args, model_kwargs, init_state.z, self._a, self._l, self._proxy_fn, self._proxy_u_fn) - exit() + self._neg_ll, self._sign = signed_estimator(self._model, model_args, model_kwargs, init_state.z, self._l, self._proxy_fn, self._proxy_u_fn) else: self._ll_u = potential_est(model=self._model, model_args=model_args_sub(self._u, model_args), @@ -883,42 +886,40 @@ def sample(self, state, model_args, model_kwargs): if self.subsample_method == "perturb": rng_key_subsample, rng_key_transition, rng_key_likelihood, rng_key = random.split( state.rng_key, 4) - if self.estimator == "poisson": #TODO: What to do here? does the negative likelihood need to be stored? how about the sign? store in the state? u_new = _sample_u_poisson(rng_key,self.m,self._l) neg_ll, sign = signed_estimator(model = self._model, - model_args=model_args, + model_args=model_args_sub(u_new,model_args), model_kwargs=model_kwargs, z=state.z, - a=self._a, l =self._l, proxy_fn = self._proxy_fn, proxy_u_fn = self._proxy_u_fn) - u_new = _update_block(rng_key_subsample, state.u, self._n, self.m, self.g) - # estimate likelihood of subsample with single block updated - - llu_new = potential_est(model=self._model, - model_args=model_args_sub(u_new,model_args), - model_kwargs=model_kwargs, - z=state.z, - n=self._n, - m=self.m, - proxy_fn=self._proxy_fn, - proxy_u_fn=self._proxy_u_fn) - # accept new subsample with probability min(1,L^{hat}_{u_new}(z) - L^{hat}_{u}(z)) - # NOTE: latent variables (z aka theta) same, subsample indices (u) different by one block. - accept_prob = jnp.clip(jnp.exp(-llu_new + state.ll_u), a_max=1.) - transition = random.bernoulli(rng_key_transition, accept_prob) - u, ll_u = cond(transition, - (u_new, llu_new), identity, - (state.u, state.ll_u), identity) - - ######## UPDATE PARAMETERS ########## - - hmc_subsamplestate = HMCECSState(u=u, hmc_state=state.hmc_state,ll_u=ll_u) - hmc_subsamplestate = tuplemerge(hmc_subsamplestate._asdict(),state._asdict()) + else: + u_new = _update_block(rng_key_subsample, state.u, self._n, self.m, self.g) + # estimate likelihood of subsample with single block updated + llu_new = potential_est(model=self._model, + model_args=model_args_sub(u_new,model_args), + model_kwargs=model_kwargs, + z=state.z, + n=self._n, + m=self.m, + proxy_fn=self._proxy_fn, + proxy_u_fn=self._proxy_u_fn) + # accept new subsample with probability min(1,L^{hat}_{u_new}(z) - L^{hat}_{u}(z)) + # NOTE: latent variables (z aka theta) same, subsample indices (u) different by one block. + accept_prob = jnp.clip(jnp.exp(-llu_new + state.ll_u), a_max=1.) + transition = random.bernoulli(rng_key_transition, accept_prob) + u, ll_u = cond(transition, + (u_new, llu_new), identity, + (state.u, state.ll_u), identity) + + ######## UPDATE PARAMETERS ########## + + hmc_subsamplestate = HMCECSState(u=u, hmc_state=state.hmc_state,ll_u=ll_u) + hmc_subsamplestate = tuplemerge(hmc_subsamplestate._asdict(),state._asdict()) return self._sample_fn(hmc_subsamplestate, model_args=model_args, diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py index f8e0c0797..76db32620 100644 --- a/numpyro/contrib/hmcecs_utils.py +++ b/numpyro/contrib/hmcecs_utils.py @@ -9,6 +9,8 @@ from numpyro.distributions.util import is_identically_one from numpyro.handlers import substitute, trace from numpyro.util import ravel_pytree +from numpyro.handlers import seed, substitute, trace +from numpyro.contrib.funsor.infer_util import plate_to_enum_plate,packed_trace from collections import namedtuple IntegratorState = namedtuple('IntegratorState', ['z', 'r', 'potential_energy', 'z_grad']) @@ -33,8 +35,11 @@ def model_kwargs_sub(u, kwargs): kwargs[key_arg] = jnp.take(val_arg, u, axis=0) return kwargs def log_density_obs_hmcecs(model, model_args, model_kwargs, params): + #model = substitute(model, data=params) + #model_trace = trace(model).get_trace(*model_args, **model_kwargs) model = substitute(model, data=params) - model_trace = trace(model).get_trace(*model_args, **model_kwargs) + with plate_to_enum_plate(): + model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs) log_joint = jnp.array(0.) for site in model_trace.values(): if site['type'] == 'sample' and site['is_observed'] and not isinstance(site['fn'], dist.PRNGIdentity): @@ -63,8 +68,11 @@ def log_density_prior_hmcecs(model, model_args, model_kwargs, params): name. :return: log of joint density and a corresponding model trace """ + # model = substitute(model, data=params) + # model_trace = trace(model).get_trace(*model_args, **model_kwargs) model = substitute(model, data=params) - model_trace = trace(model).get_trace(*model_args, **model_kwargs) + with plate_to_enum_plate(): + model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs) log_joint = jnp.array(0.) for site in model_trace.values(): if site['type'] == 'sample' and not isinstance(site['fn'], dist.PRNGIdentity) and not site['is_observed']: @@ -228,28 +236,31 @@ def neural_proxy(): return None -def signed_estimator(model, model_args, model_kwargs, z, a, l, proxy_fn, proxy_u_fn): - """ +def signed_estimator(model, model_args, model_kwargs, z, l, proxy_fn, proxy_u_fn): + """ + Function at minusloglike_estPoisson :param model: :param model_args: :param model_kwargs: :param z: - :param a: - :param l: Length of the block of data to be updated within the subsample + :param l: Lambda ~number of samples of the likelihood estimator :param proxy: :param proxy_u: :return: """ xis = 0. sign = 1. + d = 0 + a = d - l #For a fixed λ, V[LbB] is minimized at a = d − λ. Quiroz 2018c - for args in model_args: #TODO: Perhaps for index in len(model_args) ? - ll_sub, _ = log_density_obs_hmcecs(model, args, {}, z) # log likelihood for subsample with current theta - xi = (jnp.exp(ll_sub - proxy_u_fn(z=z, model_args=args, model_kwargs=model_kwargs)) - a) / l - sign *= jnp.prod(jnp.sign(xi)) - xis += jnp.sum(jnp.abs(xi), axis=0) + #for args in model_args: #TODO: Perhaps for index in len(model_args) ? + for args_i in range(len(model_args)): #Now it's doing everything twice + ll_sub, _ = log_density_obs_hmcecs(model, model_args, {}, z) # log likelihood for subsample with current theta + xi = (jnp.exp(ll_sub - proxy_u_fn(z=z, model_args=model_args, model_kwargs=model_kwargs)) - a) / l + sign *= jnp.prod(jnp.sign(xi)) + xis += jnp.sum(jnp.abs(xi)) #, axis=0) lhat = proxy_fn(z) + (a + l) / l + xis ll_prior, _ = log_density_prior_hmcecs(model, model_args, model_kwargs, z) From ffbaa9e54c153bb4d258e2c22f47ea2b2ae4162f Mon Sep 17 00:00:00 2001 From: Lys Date: Fri, 6 Nov 2020 17:34:51 +0100 Subject: [PATCH 25/93] Added: Poisson stuff (missing initialization) --- numpyro/contrib/hmcecs.py | 131 +++++++++++++++++++++----------- numpyro/contrib/hmcecs_utils.py | 33 ++++---- 2 files changed, 105 insertions(+), 59 deletions(-) diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index d822e4ce6..660c54156 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -119,12 +119,13 @@ def _sample_u_poisson(rng_key, m, l): 1.Hamiltonian Monte Carlo with Energy Conserving Subsampling 2.The blockPoisson estimator for optimally tuned exact subsampling MCMC. :param m: subsample size - :param l: length of the current subsample block + :param l: lambda u blocks :param g: number of blocks """ pois_key, sub_key = random.split(rng_key) - block_lengths = dist.discrete.Poisson(1).sample(pois_key, (l,)) - u = random.randint(sub_key, (jnp.sum(block_lengths), m), 0, m) + block_lengths = dist.discrete.Poisson(1).sample(pois_key, (l,)) #lambda block lengths + u = random.randint(sub_key, (jnp.sum(block_lengths), ), 0, m) + return jnp.split(u, jnp.cumsum(block_lengths), axis=0) @partial(jit, static_argnums=(2, 3, 4)) @@ -256,6 +257,7 @@ def init_kernel(init_params, u= None, rng_key=random.PRNGKey(0), subsample_method=None, + estimator=None, proxy_fn=None, proxy_u_fn = None): """ @@ -361,6 +363,7 @@ def init_kernel(init_params, def _hmc_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key,subsample_method, + estimator=None, proxy_fn = None, proxy_u_fn = None, model = None, @@ -416,6 +419,7 @@ def _hmc_next(step_size, inverse_mass_matrix, vv_state, def _nuts_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key,subsample_method, + estimator=None, proxy_fn=None,proxy_u_fn=None, model=None, ll_ref=None,jac_all=None,z = None,z_ref=None,hess_all=None,ll_u=None,u=None, @@ -457,6 +461,7 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, def sample_kernel(hmc_state,model_args=(),model_kwargs=None, subsample_method=None, + estimator = None, proxy_fn=None, proxy_u_fn=None, model=None, @@ -466,7 +471,7 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, z_ref=None, hess_all=None, ll_u=None, - u=None,n=None,m=None,): #TODO: Remove so many args + u=None,n=None,m=None,): """ Given an existing :data:`~numpyro.infer.mcmc.HMCState`, run HMC with fixed (possibly adapted) step size and return a new :data:`~numpyro.infer.mcmc.HMCState`. @@ -481,7 +486,10 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, model_kwargs = {} if model_kwargs is None else model_kwargs if subsample_method =="perturb": - model_args = model_args_sub(u,model_args) + if estimator == "poisson": + model_args = [model_args_sub(u_i, model_args) for u_i in u] #here u = poisson_u + else: + model_args = model_args_sub(u,model_args) rng_key, rng_key_momentum, rng_key_transition = random.split(hmc_state.rng_key, 3) r = momentum_generator(hmc_state.z, hmc_state.adapt_state.mass_matrix_sqrt, rng_key_momentum) vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad) @@ -493,6 +501,7 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, model_kwargs, rng_key_transition, subsample_method, + estimator, proxy_fn, proxy_u_fn, model, @@ -586,8 +595,8 @@ def __init__(self, init_strategy=init_to_uniform, find_heuristic_step_size=False, subsample_method = None, + estimator=None, # poisson or not proxy="taylor", - estimator =None,#poisson or not svi_fn=None, m= None, g = None, @@ -624,7 +633,7 @@ def __init__(self, self._u = None self._neg_ll = None self._sign = None - self._l = 100 #TODO: What to initialize this to? + self._l = 100 # Set on first call to init self._init_fn = None self._postprocess_fn = None @@ -634,6 +643,7 @@ def __init__(self, self.svi_fn = svi_fn self._proxy_fn = None self._proxy_u_fn = None + self._signed_estimator_fn = None self.estimator = estimator def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ref): @@ -653,7 +663,6 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): if self.subsample_method is not None: assert self.z_ref is not None, "Please provide a (i.e map) estimate for the parameters" self._n = model_args[0].shape[0] - self._u = random.randint(rng_key, (self.m,), 0, self._n) # Choose the covariate calculation method if self.proxy == "svi": self._proxy_fn,self._proxy_u_fn = svi_proxy(self.svi_fn,model_args,model_kwargs) @@ -662,29 +671,53 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): self._init_subsample_state(rng_key, model_args, model_kwargs, init_params, self.z_ref) self._proxy_fn,self._proxy_u_fn = taylor_proxy(self.z_ref, self._model, self._ll_ref, self._jac_all, self._hess_all) if self.estimator =="poisson": - self._l = 100 # lambda + self._l = 25 # lambda subsamples + self._u = _sample_u_poisson(rng_key, self.m, self._l) + #TODO: Confirm that the signed estimator is the new potential function---> If so the output has to be fixed + self._potential_fn = lambda model,model_args,model_kwargs,z,l, proxy_fn,proxy_u_fn : lambda z:signed_estimator(model = model,model_args=model_args, + model_kwargs= model_kwargs,z=z,l=l,proxy_fn=proxy_fn, + proxy_u_fn=proxy_u_fn) + # Initialize the hmc sampler: sample_fn = sample_kernel + self._init_fn, self._sample_fn = hmc(potential_fn_gen=self._potential_fn, + kinetic_fn=euclidean_kinetic_energy, + algo=self._algo) - # Initialize the potential and gradient potential functions - self._potential_fn = lambda model, model_args, model_kwargs, z, n, m, proxy_fn, proxy_u_fn : lambda z:potential_est(model=model, - model_args=model_args, model_kwargs=model_kwargs, z=z, n=n, m=m, proxy_fn=proxy_fn, proxy_u_fn=proxy_u_fn) + self._init_strategy = partial(init_near_values, values=self.z_ref) + # Initialize the model parameters + rng_key_init_model, rng_key = random.split(rng_key) + model_args = [model_args_sub(u_i, model_args) for u_i in self._u] #TODO: The initialization function has to be initialized on a subsample + #Highlight: Initialize with one non empty list? + init_params, potential_fn, postprocess_fn, model_trace = initialize_model( + rng_key_init_model, + self._model, + init_strategy=self._init_strategy, + dynamic_args=True, + model_args=model_args, + model_kwargs=model_kwargs) - # Initialize the hmc sampler: sample_fn = sample_kernel - self._init_fn, self._sample_fn = hmc(potential_fn_gen=self._potential_fn, - kinetic_fn=euclidean_kinetic_energy, - algo=self._algo) + else: + self._u = random.randint(rng_key, (self.m,), 0, self._n) + # Initialize the potential and gradient potential functions + self._potential_fn = lambda model, model_args, model_kwargs, z, n, m, proxy_fn, proxy_u_fn : lambda z:potential_est(model=model, + model_args=model_args, model_kwargs=model_kwargs, z=z, n=n, m=m, proxy_fn=proxy_fn, proxy_u_fn=proxy_u_fn) + # Initialize the hmc sampler: sample_fn = sample_kernel + self._init_fn, self._sample_fn = hmc(potential_fn_gen=self._potential_fn, + kinetic_fn=euclidean_kinetic_energy, + algo=self._algo) - self._init_strategy = partial(init_near_values, values=self.z_ref) - # Initialize the model parameters - rng_key_init_model, rng_key = random.split(rng_key) - init_params, potential_fn, postprocess_fn, model_trace = initialize_model( - rng_key_init_model, - self._model, - init_strategy=self._init_strategy, - dynamic_args=True, - model_args=model_args_sub(self._u, model_args), - model_kwargs=model_kwargs) + self._init_strategy = partial(init_near_values, values=self.z_ref) + # Initialize the model parameters + rng_key_init_model, rng_key = random.split(rng_key) + + init_params, potential_fn, postprocess_fn, model_trace = initialize_model( + rng_key_init_model, + self._model, + init_strategy=self._init_strategy, + dynamic_args=True, + model_args=model_args_sub(self._u, model_args), + model_kwargs=model_kwargs) if (self.g > self.m) or (self.g < 1): raise ValueError( @@ -798,8 +831,17 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg # et al. (2015) where a pseudo-marginal sampler is run on the absolute value of the estimated # posterior and subsequently sign-corrected by importance sampling. Similarly, we call the # algorithm described in this section signed HMC-ECS - print("Poisson, working on it") - self._neg_ll, self._sign = signed_estimator(self._model, model_args, model_kwargs, init_state.z, self._l, self._proxy_fn, self._proxy_u_fn) + model_args = [model_args_sub(u_i, model_args)for u_i in self._u] + self._neg_ll, self._sign = signed_estimator(self._model, + model_args, + model_kwargs, + init_state.z, + self._l, + self._proxy_fn, + self._proxy_u_fn) + self._ll_u = self._sign*self._neg_ll #TODO: ??????????? + exit() + else: self._ll_u = potential_est(model=self._model, model_args=model_args_sub(self._u, model_args), @@ -809,6 +851,7 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg m=self.m, proxy_fn=self._proxy_fn, proxy_u_fn=self._proxy_u_fn) + exit() hmc_init_sub_state = HMCECSState(u=self._u, hmc_state=init_state.hmc_state, ll_u=self._ll_u) @@ -888,19 +931,21 @@ def sample(self, state, model_args, model_kwargs): state.rng_key, 4) if self.estimator == "poisson": #TODO: What to do here? does the negative likelihood need to be stored? how about the sign? store in the state? - u_new = _sample_u_poisson(rng_key,self.m,self._l) + u_new = _sample_u_poisson(rng_key, self.m, self._l) + model_args = [model_args_sub(u_i, model_args) for u_i in u_new] neg_ll, sign = signed_estimator(model = self._model, - model_args=model_args_sub(u_new,model_args), + model_args=model_args, model_kwargs=model_kwargs, z=state.z, l =self._l, proxy_fn = self._proxy_fn, proxy_u_fn = self._proxy_u_fn) + llu_new = neg_ll*sign else: u_new = _update_block(rng_key_subsample, state.u, self._n, self.m, self.g) # estimate likelihood of subsample with single block updated - llu_new = potential_est(model=self._model, + llu_new = self._potential_fn(model=self._model, model_args=model_args_sub(u_new,model_args), model_kwargs=model_kwargs, z=state.z, @@ -908,18 +953,18 @@ def sample(self, state, model_args, model_kwargs): m=self.m, proxy_fn=self._proxy_fn, proxy_u_fn=self._proxy_u_fn) - # accept new subsample with probability min(1,L^{hat}_{u_new}(z) - L^{hat}_{u}(z)) - # NOTE: latent variables (z aka theta) same, subsample indices (u) different by one block. - accept_prob = jnp.clip(jnp.exp(-llu_new + state.ll_u), a_max=1.) - transition = random.bernoulli(rng_key_transition, accept_prob) - u, ll_u = cond(transition, - (u_new, llu_new), identity, - (state.u, state.ll_u), identity) - - ######## UPDATE PARAMETERS ########## - - hmc_subsamplestate = HMCECSState(u=u, hmc_state=state.hmc_state,ll_u=ll_u) - hmc_subsamplestate = tuplemerge(hmc_subsamplestate._asdict(),state._asdict()) + # accept new subsample with probability min(1,L^{hat}_{u_new}(z) - L^{hat}_{u}(z)) + # NOTE: latent variables (z aka theta) same, subsample indices (u) different by one block. + accept_prob = jnp.clip(jnp.exp(-llu_new + state.ll_u), a_max=1.) + transition = random.bernoulli(rng_key_transition, accept_prob) + u, ll_u = cond(transition, + (u_new, llu_new), identity, + (state.u, state.ll_u), identity) + + ######## UPDATE PARAMETERS ########## + + hmc_subsamplestate = HMCECSState(u=u, hmc_state=state.hmc_state,ll_u=ll_u) + hmc_subsamplestate = tuplemerge(hmc_subsamplestate._asdict(),state._asdict()) return self._sample_fn(hmc_subsamplestate, model_args=model_args, diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py index 76db32620..d7563653b 100644 --- a/numpyro/contrib/hmcecs_utils.py +++ b/numpyro/contrib/hmcecs_utils.py @@ -35,11 +35,11 @@ def model_kwargs_sub(u, kwargs): kwargs[key_arg] = jnp.take(val_arg, u, axis=0) return kwargs def log_density_obs_hmcecs(model, model_args, model_kwargs, params): - #model = substitute(model, data=params) - #model_trace = trace(model).get_trace(*model_args, **model_kwargs) model = substitute(model, data=params) - with plate_to_enum_plate(): - model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs) + model_trace = trace(model).get_trace(*model_args, **model_kwargs) + #model = substitute(model, data=params) + # with plate_to_enum_plate(): + # model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs) log_joint = jnp.array(0.) for site in model_trace.values(): if site['type'] == 'sample' and site['is_observed'] and not isinstance(site['fn'], dist.PRNGIdentity): @@ -68,11 +68,11 @@ def log_density_prior_hmcecs(model, model_args, model_kwargs, params): name. :return: log of joint density and a corresponding model trace """ - # model = substitute(model, data=params) - # model_trace = trace(model).get_trace(*model_args, **model_kwargs) model = substitute(model, data=params) - with plate_to_enum_plate(): - model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs) + model_trace = trace(model).get_trace(*model_args, **model_kwargs) + # model = substitute(model, data=params) + # with plate_to_enum_plate(): + # model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs) log_joint = jnp.array(0.) for site in model_trace.values(): if site['type'] == 'sample' and not isinstance(site['fn'], dist.PRNGIdentity) and not site['is_observed']: @@ -241,10 +241,10 @@ def signed_estimator(model, model_args, model_kwargs, z, l, proxy_fn, proxy_u_fn """ Function at minusloglike_estPoisson :param model: - :param model_args: + :param model_args: Subsample of model arguments [l,m,n_feats] :param model_kwargs: :param z: - :param l: Lambda ~number of samples of the likelihood estimator + :param l: Lambda number of subsamples (u indexes) :param proxy: :param proxy_u: :return: @@ -253,16 +253,17 @@ def signed_estimator(model, model_args, model_kwargs, z, l, proxy_fn, proxy_u_fn sign = 1. d = 0 a = d - l #For a fixed λ, V[LbB] is minimized at a = d − λ. Quiroz 2018c + #TODO: Remove empty lists? + model_args = [args_l for args_l in model_args if len(args_l[0]) != 0] + for args_l in model_args: #Iterate over each of the lambda groups of model args - #for args in model_args: #TODO: Perhaps for index in len(model_args) ? - for args_i in range(len(model_args)): #Now it's doing everything twice - - ll_sub, _ = log_density_obs_hmcecs(model, model_args, {}, z) # log likelihood for subsample with current theta - xi = (jnp.exp(ll_sub - proxy_u_fn(z=z, model_args=model_args, model_kwargs=model_kwargs)) - a) / l + ll_sub, _ = log_density_obs_hmcecs(model, args_l, {}, z) # log likelihood for each u subsample + xi = (jnp.exp(ll_sub - proxy_u_fn(z=z, model_args=args_l, model_kwargs=model_kwargs)) - a) / l sign *= jnp.prod(jnp.sign(xi)) xis += jnp.sum(jnp.abs(xi)) #, axis=0) lhat = proxy_fn(z) + (a + l) / l + xis - ll_prior, _ = log_density_prior_hmcecs(model, model_args, model_kwargs, z) + + ll_prior, _ = log_density_prior_hmcecs(model, model_args[0], model_kwargs, z) #the ll of the prior does not depend on the model args, so we just take some pair neg_ll = - lhat - ll_prior From 1e83bb73f49761d5a786d600c8c20ef070d7a3ba Mon Sep 17 00:00:00 2001 From: Lys Date: Mon, 9 Nov 2020 20:23:43 +0100 Subject: [PATCH 26/93] BlockPoissonRunning --- examples/logistic_hmcecs.py | 8 +- numpyro/contrib/hmcecs.py | 190 +++++++++++++++++++++----------- numpyro/contrib/hmcecs_utils.py | 24 ++-- 3 files changed, 140 insertions(+), 82 deletions(-) diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index 4f164727c..79dd2f185 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -224,10 +224,10 @@ def Plot_KL(map_method,ecs_algo,algo,proxy,estimator,n_samples,n_warmup,epochs): factor_NUTS = 50 colors = cm.rainbow(np.linspace(0, 1, len(m))) run_test = False - if run_test: - print("Running standard NUTS") - est_posterior_NUTS = infer_hmcecs(rng_key, feats=feats[:factor_NUTS], obs=obs[:factor_NUTS], - n_samples=n_samples, warmup=n_warmup, m="all", g=g, algo=algo) + #if run_test: + #print("Running standard NUTS") + est_posterior_NUTS = infer_hmcecs(rng_key, feats=feats[:factor_NUTS], obs=obs[:factor_NUTS], + n_samples=n_samples, warmup=n_warmup, m="all", g=g, algo=algo) for m_val, color in zip(m,colors): est_posterior_ECS = infer_hmcecs(rng_key, feats=feats[:factor_ECS], obs=obs[:factor_ECS], n_samples=n_samples, diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index 660c54156..79ddb1e49 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -22,8 +22,9 @@ from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model, log_density from numpyro.util import cond, fori_loop, identity import sys -sys.path.append('/home/lys/Dropbox/PhD/numpyro/numpyro/contrib/') +sys.path.append('/home/lys/Dropbox/PhD/numpyro/numpyro/contrib/') #TODO: remove import numpyro.distributions as dist +from itertools import chain from hmcecs_utils import potential_est, init_near_values,tuplemerge,\ model_args_sub,model_kwargs_sub,taylor_proxy,svi_proxy,neural_proxy,log_density_obs_hmcecs,log_density_prior_hmcecs,signed_estimator @@ -124,8 +125,8 @@ def _sample_u_poisson(rng_key, m, l): """ pois_key, sub_key = random.split(rng_key) block_lengths = dist.discrete.Poisson(1).sample(pois_key, (l,)) #lambda block lengths - u = random.randint(sub_key, (jnp.sum(block_lengths), ), 0, m) - + #u = random.randint(sub_key, (jnp.sum(block_lengths), ), 0, m) + u = random.randint(sub_key, (jnp.sum(block_lengths), m), 0, m) return jnp.split(u, jnp.cumsum(block_lengths), axis=0) @partial(jit, static_argnums=(2, 3, 4)) @@ -255,6 +256,7 @@ def init_kernel(init_params, n = None, m = None, u= None, + l=None, rng_key=random.PRNGKey(0), subsample_method=None, estimator=None, @@ -312,8 +314,11 @@ def init_kernel(init_params, else: if subsample_method == "perturb": kwargs = {} if model_kwargs is None else model_kwargs - pe_fn = potential_fn_gen(model=model, model_args=model_args, model_kwargs=kwargs, z=z, n=n, m=m, - proxy_fn=proxy_fn, proxy_u_fn=proxy_u_fn) + if estimator == "poisson": + pe_fn = potential_fn_gen(model=model, model_args=model_args, model_kwargs=kwargs, z=z, l=l,proxy_fn=proxy_fn, proxy_u_fn=proxy_u_fn) + else: + pe_fn = potential_fn_gen(model=model, model_args=model_args, model_kwargs=kwargs, z=z, n=n, m=m,proxy_fn=proxy_fn, proxy_u_fn=proxy_u_fn) + else: kwargs = {} if model_kwargs is None else model_kwargs pe_fn = potential_fn_gen(*model_args, **kwargs) @@ -375,7 +380,8 @@ def _hmc_next(step_size, inverse_mass_matrix, vv_state, ll_u = None, u = None, n = None, - m = None): + m = None, + l=None): if potential_fn_gen: if grad_potential_fn_gen: kwargs = {} if model_kwargs is None else model_kwargs @@ -384,15 +390,24 @@ def _hmc_next(step_size, inverse_mass_matrix, vv_state, else: if subsample_method == "perturb": - #pe_fn = potential_fn_gen(model, model_args, model_kwargs,vv_state.z, z_ref, n, m, proxy_fn, proxy_u_fn, u) - pe_fn = potential_fn_gen(model=model, - model_args=model_args, - model_kwargs=model_kwargs, - z=vv_state.z, - n=n, - m=m, - proxy_fn=proxy_fn, - proxy_u_fn=proxy_u_fn) + if estimator == "poisson": + pe_fn = potential_fn_gen(model=model, + model_args=model_args, + model_kwargs=model_kwargs, + z=vv_state.z, + l=l, + proxy_fn=proxy_fn, + proxy_u_fn=proxy_u_fn) + + else: + pe_fn = potential_fn_gen(model=model, + model_args=model_args, + model_kwargs=model_kwargs, + z=vv_state.z, + n=n, + m=m, + proxy_fn=proxy_fn, + proxy_u_fn=proxy_u_fn) kwargs = {} if model_kwargs is None else model_kwargs else: pe_fn = potential_fn_gen(*model_args, **model_kwargs) @@ -423,7 +438,7 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, proxy_fn=None,proxy_u_fn=None, model=None, ll_ref=None,jac_all=None,z = None,z_ref=None,hess_all=None,ll_u=None,u=None, - n=None,m=None): + n=None,m=None,l=None): if potential_fn_gen: nonlocal vv_update if grad_potential_fn_gen: @@ -432,15 +447,26 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, pe_fn = potential_fn_gen(*model_args, **model_kwargs) else: if subsample_method == "perturb": - #pe_fn = potential_fn_gen(model, model_args, model_kwargs, vv_state.z, z_ref, n, m, proxy_fn,proxy_u_fn, u) - pe_fn = potential_fn_gen(model=model, - model_args=model_args, - model_kwargs=model_kwargs, - z=vv_state.z, - n=n, - m=m, - proxy_fn=proxy_fn, - proxy_u_fn=proxy_u_fn) + if estimator == "poisson": + pe_fn = potential_fn_gen(model=model, + model_args=model_args, + model_kwargs=model_kwargs, + z=vv_state.z, + l=l, + proxy_fn=proxy_fn, + proxy_u_fn=proxy_u_fn) + + + + else: + pe_fn = potential_fn_gen(model=model, + model_args=model_args, + model_kwargs=model_kwargs, + z=vv_state.z, + n=n, + m=m, + proxy_fn=proxy_fn, + proxy_u_fn=proxy_u_fn) else: pe_fn = potential_fn_gen(*model_args, **model_kwargs) _, vv_update = velocity_verlet(pe_fn, kinetic_fn) @@ -471,7 +497,7 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, z_ref=None, hess_all=None, ll_u=None, - u=None,n=None,m=None,): + u=None,n=None,m=None,l=None): """ Given an existing :data:`~numpyro.infer.mcmc.HMCState`, run HMC with fixed (possibly adapted) step size and return a new :data:`~numpyro.infer.mcmc.HMCState`. @@ -479,6 +505,20 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, :param hmc_state: Current sample (and associated state). :param tuple model_args: Model arguments if `potential_fn_gen` is specified. :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified. + :param subsample_method: Indicates if hmc energy conserving method shall be implemented for subsampling + :param proxy_fn + :param proxy_u_fn + :param model + :param ll_ref + :param jac_all + :param z + :param z_ref + :param hess_all + :param ll_u + :param u + :param n + :param m + :param l : lambda value for block poisson estimator method. Indicates the number of subsamples within a subsample :return: new proposed :data:`~numpyro.infer.mcmc.HMCState` from simulating Hamiltonian dynamics given existing state. @@ -506,7 +546,7 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, proxy_u_fn, model, ll_ref,jac_all,z,z_ref,hess_all,ll_u,u, - n,m) + n,m,l) # not update adapt_state after warmup phase adapt_state = cond(hmc_state.i < wa_steps, (hmc_state.i, accept_prob, vv_state, hmc_state.adapt_state), @@ -631,7 +671,6 @@ def __init__(self, self._hess_all = None self._ll_u = None self._u = None - self._neg_ll = None self._sign = None self._l = 100 # Set on first call to init @@ -639,6 +678,7 @@ def __init__(self, self._postprocess_fn = None self._sample_fn = None self._subsample_fn = None + self._sign = [] self.proxy = proxy self.svi_fn = svi_fn self._proxy_fn = None @@ -673,10 +713,11 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): if self.estimator =="poisson": self._l = 25 # lambda subsamples self._u = _sample_u_poisson(rng_key, self.m, self._l) + #TODO: Confirm that the signed estimator is the new potential function---> If so the output has to be fixed self._potential_fn = lambda model,model_args,model_kwargs,z,l, proxy_fn,proxy_u_fn : lambda z:signed_estimator(model = model,model_args=model_args, model_kwargs= model_kwargs,z=z,l=l,proxy_fn=proxy_fn, - proxy_u_fn=proxy_u_fn) + proxy_u_fn=proxy_u_fn)[0] # Initialize the hmc sampler: sample_fn = sample_kernel self._init_fn, self._sample_fn = hmc(potential_fn_gen=self._potential_fn, kinetic_fn=euclidean_kinetic_energy, @@ -686,13 +727,15 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): # Initialize the model parameters rng_key_init_model, rng_key = random.split(rng_key) model_args = [model_args_sub(u_i, model_args) for u_i in self._u] #TODO: The initialization function has to be initialized on a subsample - #Highlight: Initialize with one non empty list? + + #model_args = list(chain(*model_args)) #Highlight: This just chains all the elements in the sublist , len(lists_of_lists) = n , len(chain(list_of_lists)) = sum(n_elements_inside_list=*n + self._init_strategy = partial(init_near_values, values=self.z_ref) init_params, potential_fn, postprocess_fn, model_trace = initialize_model( rng_key_init_model, self._model, init_strategy=self._init_strategy, dynamic_args=True, - model_args=model_args, + model_args=tuple([arg[0] for arg in next(chain(model_args))]), #Pick the first non-empty block model_kwargs=model_kwargs) else: @@ -787,10 +830,15 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params) #should work for all cases + if self._potential_fn and init_params is None: raise ValueError('Valid value of `init_params` must be provided with' ' `potential_fn`.') if self.subsample_method == "perturb": + if self.estimator == "poisson": + init_model_args = [model_args_sub(u_i, model_args) for u_i in self._u] + else: + init_model_args = model_args_sub(self._u,model_args) hmc_init_fn = lambda init_params,rng_key: self._init_fn(init_params=init_params, num_warmup = num_warmup, step_size = self._step_size, @@ -801,9 +849,10 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg trajectory_length=self._trajectory_length, max_tree_depth=self._max_tree_depth, find_heuristic_step_size=self._find_heuristic_step_size, - model_args=model_args_sub(self._u,model_args), + model_args=init_model_args, model_kwargs=model_kwargs, subsample_method= self.subsample_method, + estimator= self.estimator, model=self._model, ll_ref =self._ll_ref, jac_all=self._jac_all, @@ -813,6 +862,7 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg n=self._n, m=self.m, u = self._u, + l = self._l, proxy_fn = self._proxy_fn, proxy_u_fn = self._proxy_u_fn) @@ -821,10 +871,6 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg rng_key_hmc_init,_ = random.split(rng_key) init_state = hmc_init_fn(init_params, rng_key_hmc_init) #HMCState + HMCECSState - if self.proxy == "taylor": - self._proxy_fn,self._proxy_u_fn = taylor_proxy(self.z_ref, self._model, self._ll_ref, self._jac_all, self._hess_all) - elif self.proxy == "svi": - self._proxy_fn, self._proxy_u_fn = svi_proxy(self.svi_fn, model_args, model_kwargs) if self.estimator == "poisson": #signed pseudo-marginal algorithm with the block-Poisson estimator #use the term signed PM for any pseudo-marginal algorithm that uses the technique in Lyne @@ -832,15 +878,15 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg # posterior and subsequently sign-corrected by importance sampling. Similarly, we call the # algorithm described in this section signed HMC-ECS model_args = [model_args_sub(u_i, model_args)for u_i in self._u] - self._neg_ll, self._sign = signed_estimator(self._model, - model_args, - model_kwargs, - init_state.z, - self._l, - self._proxy_fn, - self._proxy_u_fn) - self._ll_u = self._sign*self._neg_ll #TODO: ??????????? - exit() + neg_ll, sign = signed_estimator(self._model, + model_args, + model_kwargs, + init_state.z, + self._l, + self._proxy_fn, + self._proxy_u_fn) + self._sign.append(sign) + self._ll_u = neg_ll else: self._ll_u = potential_est(model=self._model, @@ -851,27 +897,39 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg m=self.m, proxy_fn=self._proxy_fn, proxy_u_fn=self._proxy_u_fn) - exit() - hmc_init_sub_state = HMCECSState(u=self._u, - hmc_state=init_state.hmc_state, - ll_u=self._ll_u) - init_sub_state = tuplemerge(init_state._asdict(),hmc_init_sub_state._asdict()) + hmc_init_sub_state = HMCECSState(u=self._u, + hmc_state=init_state.hmc_state, + ll_u=self._ll_u) + init_sub_state = tuplemerge(init_state._asdict(),hmc_init_sub_state._asdict()) - return init_sub_state + return init_sub_state else: #TODO: What is this for? It does not go into it for num_chains>1 raise ValueError("Not implemented for chains > 1") # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth, # wa_steps because those variables do not depend on traced args: init_params, rng_key. init_state = vmap(hmc_init_fn)(init_params, rng_key) - self._ll_u = potential_est(model=self._model, - model_args=model_args_sub(self._u, model_args), - model_kwargs=model_kwargs, - z=init_state.z, - n=self._n, - m=self.m, - proxy_fn=self._proxy_fn, - proxy_u_fn=self._proxy_u_fn) + if self.estimator == "poisson": + model_args = [model_args_sub(u_i, model_args)for u_i in self._u] + neg_ll, sign = signed_estimator(self._model, + model_args, + model_kwargs, + init_state.z, + self._l, + self._proxy_fn, + self._proxy_u_fn) + self._sign.append(sign) + self._ll_u = neg_ll + + else: + self._ll_u = potential_est(model=self._model, + model_args=model_args_sub(self._u, model_args), + model_kwargs=model_kwargs, + z=init_state.z, + n=self._n, + m=self.m, + proxy_fn=self._proxy_fn, + proxy_u_fn=self._proxy_u_fn) hmc_init_sub_fn = lambda init_params, rng_key: HMCECSState(u=self._u, hmc_state=init_state, ll_u=self._ll_u) @@ -932,15 +990,16 @@ def sample(self, state, model_args, model_kwargs): if self.estimator == "poisson": #TODO: What to do here? does the negative likelihood need to be stored? how about the sign? store in the state? u_new = _sample_u_poisson(rng_key, self.m, self._l) - model_args = [model_args_sub(u_i, model_args) for u_i in u_new] neg_ll, sign = signed_estimator(model = self._model, - model_args=model_args, + model_args=[model_args_sub(u_i, model_args) for u_i in u_new], model_kwargs=model_kwargs, z=state.z, l =self._l, proxy_fn = self._proxy_fn, proxy_u_fn = self._proxy_u_fn) - llu_new = neg_ll*sign + self._sign.append(sign) + # Correct the negativeloglikelihood by substracting the density of the prior to calculate the potential + llu_new = jnp.min(jnp.array([0, -neg_ll + state.ll_u])) else: u_new = _update_block(rng_key_subsample, state.u, self._n, self.m, self.g) @@ -956,11 +1015,12 @@ def sample(self, state, model_args, model_kwargs): # accept new subsample with probability min(1,L^{hat}_{u_new}(z) - L^{hat}_{u}(z)) # NOTE: latent variables (z aka theta) same, subsample indices (u) different by one block. accept_prob = jnp.clip(jnp.exp(-llu_new + state.ll_u), a_max=1.) - transition = random.bernoulli(rng_key_transition, accept_prob) + transition = random.bernoulli(rng_key_transition, accept_prob) #TODO: Why Bernouilli instead of Uniform? u, ll_u = cond(transition, (u_new, llu_new), identity, (state.u, state.ll_u), identity) + ######## UPDATE PARAMETERS ########## hmc_subsamplestate = HMCECSState(u=u, hmc_state=state.hmc_state,ll_u=ll_u) @@ -970,6 +1030,7 @@ def sample(self, state, model_args, model_kwargs): model_args=model_args, model_kwargs=model_kwargs, subsample_method=self.subsample_method, + estimator =self.estimator, proxy_fn = self._proxy_fn, proxy_u_fn = self._proxy_u_fn, model = self._model, @@ -981,7 +1042,8 @@ def sample(self, state, model_args, model_kwargs): ll_u = ll_u, u= u, n= self._n, - m= self.m) + m= self.m, + l=self._l) diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py index d7563653b..321950049 100644 --- a/numpyro/contrib/hmcecs_utils.py +++ b/numpyro/contrib/hmcecs_utils.py @@ -27,13 +27,13 @@ def model_args_sub(u, model_args): args.append(arg) return tuple(args) - def model_kwargs_sub(u, kwargs): """Subsample observations and features""" for key_arg, val_arg in kwargs.items(): if key_arg == "observations" or key_arg == "features": kwargs[key_arg] = jnp.take(val_arg, u, axis=0) return kwargs + def log_density_obs_hmcecs(model, model_args, model_kwargs, params): model = substitute(model, data=params) model_trace = trace(model).get_trace(*model_args, **model_kwargs) @@ -56,6 +56,7 @@ def log_density_obs_hmcecs(model, model_args, model_kwargs, params): log_joint = log_joint + jnp.sum(log_prob) return log_joint, model_trace + def log_density_prior_hmcecs(model, model_args, model_kwargs, params): """ (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given @@ -91,7 +92,6 @@ def log_density_prior_hmcecs(model, model_args, model_kwargs, params): log_joint = log_joint + log_prob return log_joint, model_trace - def reducer( accum, d ): accum.update(d) return accum @@ -103,7 +103,6 @@ def tuplemerge( *dictionaries ): return namedtuple('HMCCombinedState', merged )(**merged) # <==== Gist of the gist - def potential_est(model, model_args, model_kwargs, z, n, m, proxy_fn, proxy_u_fn): ll_sub, _ = log_density_obs_hmcecs(model, model_args, {}, z) # log likelihood for subsample with current theta @@ -117,8 +116,6 @@ def potential_est(model, model_args, model_kwargs, z, n, m, proxy_fn, proxy_u_fn return (-l_hat + .5 * sigma) - ll_prior - - def velocity_verlet_hmcecs(potential_fn, kinetic_fn, grad_potential_fn=None): r""" Second order symplectic integrator that uses the velocity verlet algorithm @@ -218,7 +215,6 @@ def proxy_u(z, model_args, model_kwargs, *args, **kwargs): return proxy, proxy_u - def svi_proxy(svi, model_args, model_kwargs): def proxy(z, *args, **kwargs): z_ref = svi.guide.expectation(z) @@ -235,11 +231,9 @@ def proxy_u(z, model_args, model_kwargs, *args, **kwargs): def neural_proxy(): return None - - def signed_estimator(model, model_args, model_kwargs, z, l, proxy_fn, proxy_u_fn): """ - Function at minusloglike_estPoisson + Estimate the grdient potential estimate :param model: :param model_args: Subsample of model arguments [l,m,n_feats] :param model_kwargs: @@ -248,25 +242,27 @@ def signed_estimator(model, model_args, model_kwargs, z, l, proxy_fn, proxy_u_fn :param proxy: :param proxy_u: :return: + neg_ll: Negative likelihood + sign """ xis = 0. sign = 1. d = 0 a = d - l #For a fixed λ, V[LbB] is minimized at a = d − λ. Quiroz 2018c - #TODO: Remove empty lists? model_args = [args_l for args_l in model_args if len(args_l[0]) != 0] for args_l in model_args: #Iterate over each of the lambda groups of model args - + args_l = tuple([arg.reshape(arg.shape[0]*arg.shape[1],-1) for arg in args_l]) #TODO:Not sure is this ok ll_sub, _ = log_density_obs_hmcecs(model, args_l, {}, z) # log likelihood for each u subsample xi = (jnp.exp(ll_sub - proxy_u_fn(z=z, model_args=args_l, model_kwargs=model_kwargs)) - a) / l sign *= jnp.prod(jnp.sign(xi)) xis += jnp.sum(jnp.abs(xi)) #, axis=0) lhat = proxy_fn(z) + (a + l) / l + xis - ll_prior, _ = log_density_prior_hmcecs(model, model_args[0], model_kwargs, z) #the ll of the prior does not depend on the model args, so we just take some pair - + prior_arg = tuple([arg.reshape(arg.shape[0] * arg.shape[1], -1) for arg in model_args[0]]) + ll_prior, _ = log_density_prior_hmcecs(model, prior_arg, model_kwargs, z) #the ll of the prior does not depend on the model args, so we just take some pair + # Correct the negativeloglikelihood by substracting the density of the prior to calculate the potential + #potentialEst = -loglikeEst - dprior(theta,pfamily,priorPar1,priorPar2) neg_ll = - lhat - ll_prior - return neg_ll, sign From 3f918ed0341b94dfc7945dbae9aee6f627f1f40a Mon Sep 17 00:00:00 2001 From: Lys Date: Tue, 10 Nov 2020 16:24:48 +0100 Subject: [PATCH 27/93] FIXED: potential estimator --- examples/logistic_hmcecs.py | 6 +-- numpyro/contrib/hmcecs.py | 83 ++++++++++++++++++--------------- numpyro/contrib/hmcecs_utils.py | 46 +++++++++++------- 3 files changed, 79 insertions(+), 56 deletions(-) diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index 79dd2f185..6e24efbc0 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -224,9 +224,9 @@ def Plot_KL(map_method,ecs_algo,algo,proxy,estimator,n_samples,n_warmup,epochs): factor_NUTS = 50 colors = cm.rainbow(np.linspace(0, 1, len(m))) run_test = False - #if run_test: - #print("Running standard NUTS") - est_posterior_NUTS = infer_hmcecs(rng_key, feats=feats[:factor_NUTS], obs=obs[:factor_NUTS], + if run_test: + print("Running standard NUTS") + est_posterior_NUTS = infer_hmcecs(rng_key, feats=feats[:factor_NUTS], obs=obs[:factor_NUTS], n_samples=n_samples, warmup=n_warmup, m="all", g=g, algo=algo) for m_val, color in zip(m,colors): est_posterior_ECS = infer_hmcecs(rng_key, feats=feats[:factor_ECS], obs=obs[:factor_ECS], diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index 79ddb1e49..ce501283c 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -30,9 +30,8 @@ HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob', 'mean_accept_prob', 'diverging', 'adapt_state','rng_key']) -#HMCECSState = namedtuple("HMCECState",["u","hmc_state","z_ref","ll_ref","jac_all","hess_all","ll_u"]) -HMCECSState = namedtuple("HMCECState",['u', 'hmc_state', 'll_u']) +HMCECSState = namedtuple("HMCECState",['u', 'hmc_state', 'll_u','sign']) """ A :func:`~collections.namedtuple` consisting of the following fields: @@ -248,6 +247,7 @@ def init_kernel(init_params, model_args=(), model_kwargs=None, model = None, + sign = None, ll_ref=None, jac_all=None, z_ref= None, @@ -294,6 +294,17 @@ def init_kernel(init_params, :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified. :param jax.random.PRNGKey rng_key: random key to be used as the source of randomness. + :param model:, + :param sign:, + :param ll_ref:, + :param jac_all, + :param z_ref, + :param hess_all, + :param ll_u , + :param n , + :param m , + :param u, + :param l, """ step_size = lax.convert_element_type(step_size, canonicalize_dtype(jnp.float64)) @@ -350,7 +361,7 @@ def init_kernel(init_params, mass_matrix_size=jnp.size(ravel_pytree(z)[0])) r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum) - #vv_init, vv_update = velocity_verlet_hmcecs(pe_fn, kinetic_fn,grad_potential_fn=gpe_fn) + vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn) vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad) @@ -359,7 +370,8 @@ def init_kernel(init_params, hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, 0, 0., 0., False, wa_state,rng_key_hmc) - hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state,ll_u=ll_u) + + hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state,ll_u=ll_u,sign=sign) hmc_state = tuplemerge(hmc_sub_state._asdict(),hmc_state._asdict()) @@ -372,13 +384,7 @@ def _hmc_next(step_size, inverse_mass_matrix, vv_state, proxy_fn = None, proxy_u_fn = None, model = None, - ll_ref = None, - jac_all = None, - z = None, - z_ref = None, - hess_all = None, - ll_u = None, - u = None, + ll_ref = None,jac_all = None,z = None,z_ref = None,hess_all = None,ll_u = None,u = None, n = None, m = None, l=None): @@ -497,6 +503,7 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, z_ref=None, hess_all=None, ll_u=None, + sign = None, u=None,n=None,m=None,l=None): """ Given an existing :data:`~numpyro.infer.mcmc.HMCState`, run HMC with fixed (possibly adapted) @@ -559,7 +566,7 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n hmcstate = HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps, accept_prob, mean_accept_prob, diverging, adapt_state,rng_key) - hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state,ll_u=ll_u) + hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state,ll_u=ll_u,sign = sign) #TODO: Check if sign is correct hmcstate = tuplemerge(hmc_sub_state._asdict(),hmcstate._asdict()) return hmcstate @@ -714,7 +721,6 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): self._l = 25 # lambda subsamples self._u = _sample_u_poisson(rng_key, self.m, self._l) - #TODO: Confirm that the signed estimator is the new potential function---> If so the output has to be fixed self._potential_fn = lambda model,model_args,model_kwargs,z,l, proxy_fn,proxy_u_fn : lambda z:signed_estimator(model = model,model_args=model_args, model_kwargs= model_kwargs,z=z,l=l,proxy_fn=proxy_fn, proxy_u_fn=proxy_u_fn)[0] @@ -726,18 +732,18 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): self._init_strategy = partial(init_near_values, values=self.z_ref) # Initialize the model parameters rng_key_init_model, rng_key = random.split(rng_key) - model_args = [model_args_sub(u_i, model_args) for u_i in self._u] #TODO: The initialization function has to be initialized on a subsample + init_model_args = [model_args_sub(u_i, model_args) for u_i in self._u] - #model_args = list(chain(*model_args)) #Highlight: This just chains all the elements in the sublist , len(lists_of_lists) = n , len(chain(list_of_lists)) = sum(n_elements_inside_list=*n self._init_strategy = partial(init_near_values, values=self.z_ref) init_params, potential_fn, postprocess_fn, model_trace = initialize_model( rng_key_init_model, self._model, init_strategy=self._init_strategy, dynamic_args=True, - model_args=tuple([arg[0] for arg in next(chain(model_args))]), #Pick the first non-empty block + model_args=tuple([arg[0] for arg in next(chain(init_model_args))]), #Highlight:Pick the first non-empty block ; 'chain' joins all the elements in the sublist , len(lists_of_lists) = n , len(chain(list_of_lists)) = sum(n_elements_inside_list=*n) model_kwargs=model_kwargs) + else: self._u = random.randint(rng_key, (self.m,), 0, self._n) # Initialize the potential and gradient potential functions @@ -863,6 +869,7 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg m=self.m, u = self._u, l = self._l, + sign = self._sign, proxy_fn = self._proxy_fn, proxy_u_fn = self._proxy_u_fn) @@ -877,17 +884,19 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg # et al. (2015) where a pseudo-marginal sampler is run on the absolute value of the estimated # posterior and subsequently sign-corrected by importance sampling. Similarly, we call the # algorithm described in this section signed HMC-ECS - model_args = [model_args_sub(u_i, model_args)for u_i in self._u] - neg_ll, sign = signed_estimator(self._model, - model_args, - model_kwargs, - init_state.z, - self._l, - self._proxy_fn, - self._proxy_u_fn) + #model_args = [model_args_sub(u_i, model_args)for u_i in self._u] + neg_ll, sign = signed_estimator(model = self._model, + model_args = [model_args_sub(u_i, model_args)for u_i in self._u], + model_kwargs= model_kwargs, + z=init_state.z, + l=self._l, + proxy_fn=self._proxy_fn, + proxy_u_fn = self._proxy_u_fn) + self._sign.append(sign) self._ll_u = neg_ll + else: self._ll_u = potential_est(model=self._model, model_args=model_args_sub(self._u, model_args), @@ -899,7 +908,7 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg proxy_u_fn=self._proxy_u_fn) hmc_init_sub_state = HMCECSState(u=self._u, hmc_state=init_state.hmc_state, - ll_u=self._ll_u) + ll_u=self._ll_u,sign=self._sign) init_sub_state = tuplemerge(init_state._asdict(),hmc_init_sub_state._asdict()) return init_sub_state @@ -910,14 +919,14 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg # wa_steps because those variables do not depend on traced args: init_params, rng_key. init_state = vmap(hmc_init_fn)(init_params, rng_key) if self.estimator == "poisson": - model_args = [model_args_sub(u_i, model_args)for u_i in self._u] - neg_ll, sign = signed_estimator(self._model, - model_args, - model_kwargs, - init_state.z, - self._l, - self._proxy_fn, - self._proxy_u_fn) + #model_args = [model_args_sub(u_i, model_args)for u_i in self._u] + neg_ll, sign = signed_estimator(model=self._model, + model_args=[model_args_sub(u_i, model_args)for u_i in self._u], + model_kwargs= model_kwargs_sub, + z=init_state.z, + l = self._l, + proxy_fn = self._proxy_fn, + proxy_u_fn = self._proxy_u_fn) self._sign.append(sign) self._ll_u = neg_ll @@ -931,7 +940,7 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg proxy_fn=self._proxy_fn, proxy_u_fn=self._proxy_u_fn) - hmc_init_sub_fn = lambda init_params, rng_key: HMCECSState(u=self._u, hmc_state=init_state, ll_u=self._ll_u) + hmc_init_sub_fn = lambda init_params, rng_key: HMCECSState(u=self._u, hmc_state=init_state, ll_u=self._ll_u,sign = self._sign) init_subsample_state = vmap(hmc_init_sub_fn)(init_params,rng_key) @@ -1015,15 +1024,15 @@ def sample(self, state, model_args, model_kwargs): # accept new subsample with probability min(1,L^{hat}_{u_new}(z) - L^{hat}_{u}(z)) # NOTE: latent variables (z aka theta) same, subsample indices (u) different by one block. accept_prob = jnp.clip(jnp.exp(-llu_new + state.ll_u), a_max=1.) - transition = random.bernoulli(rng_key_transition, accept_prob) #TODO: Why Bernouilli instead of Uniform? + transition = random.bernoulli(rng_key_transition, accept_prob) #TODO: Why Bernoulli instead of Uniform? u, ll_u = cond(transition, (u_new, llu_new), identity, (state.u, state.ll_u), identity) ######## UPDATE PARAMETERS ########## - - hmc_subsamplestate = HMCECSState(u=u, hmc_state=state.hmc_state,ll_u=ll_u) + print(self._sign) + hmc_subsamplestate = HMCECSState(u=u, hmc_state=state.hmc_state,ll_u=ll_u,sign=self._sign) hmc_subsamplestate = tuplemerge(hmc_subsamplestate._asdict(),state._asdict()) return self._sample_fn(hmc_subsamplestate, diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py index 321950049..e04827b8b 100644 --- a/numpyro/contrib/hmcecs_utils.py +++ b/numpyro/contrib/hmcecs_utils.py @@ -104,6 +104,8 @@ def tuplemerge( *dictionaries ): return namedtuple('HMCCombinedState', merged )(**merged) # <==== Gist of the gist def potential_est(model, model_args, model_kwargs, z, n, m, proxy_fn, proxy_u_fn): + """Computes the estimation of the likelihood of the potential + :param: proxy_U_fn : Function to calculate the covariates that correct the subsample likelihood""" ll_sub, _ = log_density_obs_hmcecs(model, model_args, {}, z) # log likelihood for subsample with current theta diff = ll_sub - proxy_u_fn(z=z, model_args=model_args, model_kwargs=model_kwargs) @@ -231,37 +233,49 @@ def proxy_u(z, model_args, model_kwargs, *args, **kwargs): def neural_proxy(): return None +def split_list(lst, n): + """Pair up the split model arguments back.""" + for i in range(0, len(lst), n): + if i+n < len(lst)-1: + yield tuple( map(lst.__getitem__, [i,i+n])) + else: + break def signed_estimator(model, model_args, model_kwargs, z, l, proxy_fn, proxy_u_fn): """ - Estimate the grdient potential estimate - :param model: + Estimate the gradient potential estimate + :param model: Likelihood function :param model_args: Subsample of model arguments [l,m,n_feats] :param model_kwargs: - :param z: + :param z: Model parameters estimates :param l: Lambda number of subsamples (u indexes) - :param proxy: - :param proxy_u: + :param proxy_fn: + :param proxy_u_fn: :return: - neg_ll: Negative likelihood - sign + neg_ll: Negative likelihood estimate of the potential + sign: Sign of the likelihood estimate over the subsamples, it will be used after all the samples are collected """ + import itertools xis = 0. sign = 1. d = 0 a = d - l #For a fixed λ, V[LbB] is minimized at a = d − λ. Quiroz 2018c - model_args = [args_l for args_l in model_args if len(args_l[0]) != 0] + model_args = [args_l for args_l in model_args if len(args_l[0]) != 0] #remove empty lambda blocks for args_l in model_args: #Iterate over each of the lambda groups of model args - args_l = tuple([arg.reshape(arg.shape[0]*arg.shape[1],-1) for arg in args_l]) #TODO:Not sure is this ok - ll_sub, _ = log_density_obs_hmcecs(model, args_l, {}, z) # log likelihood for each u subsample - xi = (jnp.exp(ll_sub - proxy_u_fn(z=z, model_args=args_l, model_kwargs=model_kwargs)) - a) / l - sign *= jnp.prod(jnp.sign(xi)) - xis += jnp.sum(jnp.abs(xi)) #, axis=0) + block_len = args_l[0].shape[0] + args_l = [jnp.split(arg, arg.shape[0]) for arg in args_l] + args_l = list(itertools.chain.from_iterable(args_l)) #Join list of lists + args_l = [arr.squeeze(axis=0) for arr in args_l] + args_l = list(split_list(args_l,block_len)) + for args_l_b in args_l: + ll_sub, _ = log_density_obs_hmcecs(model, args_l_b, {}, z) # log likelihood for each u subsample + xi = (jnp.exp(ll_sub - proxy_u_fn(z=z, model_args=args_l_b, model_kwargs=model_kwargs)) - a) / l + sign *= jnp.prod(jnp.sign(xi)) + xis += jnp.sum(jnp.abs(xi)) #, axis=0) lhat = proxy_fn(z) + (a + l) / l + xis - prior_arg = tuple([arg.reshape(arg.shape[0] * arg.shape[1], -1) for arg in model_args[0]]) + prior_arg = tuple([arg.reshape(arg.shape[0] * arg.shape[1], -1) for arg in model_args[0]])#Join the block subsamples, does not matter because the prior does not look t them ll_prior, _ = log_density_prior_hmcecs(model, prior_arg, model_kwargs, z) #the ll of the prior does not depend on the model args, so we just take some pair - # Correct the negativeloglikelihood by substracting the density of the prior to calculate the potential - #potentialEst = -loglikeEst - dprior(theta,pfamily,priorPar1,priorPar2) + # Correct the negativeloglikelihood by substracting the density of the prior --> potentialEst = -loglikeEst - dprior(theta,pfamily,priorPar1,priorPar2) neg_ll = - lhat - ll_prior return neg_ll, sign From 9fc0f3ff565878958112e71183061293cfa8bbab Mon Sep 17 00:00:00 2001 From: Lys Date: Tue, 10 Nov 2020 20:23:00 +0100 Subject: [PATCH 28/93] FINISHED: Block-poisson missing tests --- examples/logistic_hmcecs.py | 3 +-- numpyro/contrib/hmcecs.py | 47 ++++++++++++++++++++++----------- numpyro/contrib/hmcecs_utils.py | 2 +- 3 files changed, 33 insertions(+), 19 deletions(-) diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index 6e24efbc0..3b9084542 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -64,7 +64,6 @@ def model(feats, obs): """ n, m = feats.shape theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), 2 * jnp.ones(m))) - numpyro.sample('obs', dist.Bernoulli(logits=jnp.matmul(feats, theta)), obs=obs) def infer_nuts(rng_key, feats, obs, samples, warmup ): @@ -223,7 +222,7 @@ def Plot_KL(map_method,ecs_algo,algo,proxy,estimator,n_samples,n_warmup,epochs): g = 5 factor_NUTS = 50 colors = cm.rainbow(np.linspace(0, 1, len(m))) - run_test = False + run_test = True if run_test: print("Running standard NUTS") est_posterior_NUTS = infer_hmcecs(rng_key, feats=feats[:factor_NUTS], obs=obs[:factor_NUTS], diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index ce501283c..58acaba0c 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -116,8 +116,8 @@ def _update_block(rng_key, u, n, m, g): def _sample_u_poisson(rng_key, m, l): """ Initialize subsamples u ***References*** - 1.Hamiltonian Monte Carlo with Energy Conserving Subsampling - 2.The blockPoisson estimator for optimally tuned exact subsampling MCMC. + 1. Hamiltonian Monte Carlo with Energy Conserving Subsampling + 2. The blockPoisson estimator for optimally tuned exact subsampling MCMC. :param m: subsample size :param l: lambda u blocks :param g: number of blocks @@ -132,8 +132,8 @@ def _sample_u_poisson(rng_key, m, l): def _update_block_poisson(rng_key, u, m, l, g): """ Update block of u, where the length of the block of indexes to update is given by the Poisson distribution. ***References*** - 1.Hamiltonian Monte Carlo with Energy Conserving Subsampling - 2.The blockPoisson estimator for optimally tuned exact subsampling MCMC. + 1. Hamiltonian Monte Carlo with Energy Conserving Subsampling + 2.T he blockPoisson estimator for optimally tuned exact subsampling MCMC. :param rng_key :param u: current subsample indexes :param m: Subsample size @@ -165,6 +165,9 @@ def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, grad_potentia Matthew D. Hoffman, and Andrew Gelman. 3. *A Conceptual Introduction to Hamiltonian Monte Carlo`*, Michael Betancourt + **ECS References*** + 1. Hamiltonian Monte Carlo with Energy Conserving Subsampling + 2. The blockPoisson estimator for optimally tuned exact subsampling MCMC. :param potential_fn: Python callable that computes the potential energy given input parameters. The input parameters to `potential_fn` can be @@ -292,8 +295,7 @@ def init_kernel(init_params, step size at the beginning of each adaptation window. Defaults to False. :param tuple model_args: Model arguments if `potential_fn_gen` is specified. :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified. - :param jax.random.PRNGKey rng_key: random key to be used as the source of - randomness. + :param model:, :param sign:, :param ll_ref:, @@ -305,6 +307,13 @@ def init_kernel(init_params, :param m , :param u, :param l, + :param jax.random.PRNGKey rng_key: random key to be used as the source of + randomness. + :param subsample_method: Allows for activation of HMC-ECS or Subsampling, + :param estimator: Allows between an approximate likelihood estimator of the potential function (default), or an exact + calculation (poisson) + :param proxy_fn: Pre-compiled function that calculates the covariate (likelihood correction) for the parameters given the reference estimate + :param proxy_u_fn: Pre-compiled function that calculates the covariate (likelihood correction) for the paraneters given the subsample (model_args) """ step_size = lax.convert_element_type(step_size, canonicalize_dtype(jnp.float64)) @@ -561,13 +570,25 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, hmc_state.adapt_state, identity) + itr = hmc_state.i + 1 n = jnp.where(hmc_state.i < wa_steps, itr, itr - wa_steps) mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n + hmcstate = HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps, accept_prob, mean_accept_prob, diverging, adapt_state,rng_key) - hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state,ll_u=ll_u,sign = sign) #TODO: Check if sign is correct + + # Highlight: The accepted proposals samples are in vv_state.z /hmcstate.z, as we return them, we change their sign + #TODO: Make this prettier + if subsample_method == "perturb" and estimator == "poisson" and itr > wa_steps: + z_new={} + for x,y in hmcstate.z.items(): + z_new[x] = y*sign[-1] + hmcstate = hmcstate._replace(z=z_new) + + hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state,ll_u=ll_u,sign = sign) hmcstate = tuplemerge(hmc_sub_state._asdict(),hmcstate._asdict()) + return hmcstate # Make `init_kernel` and `sample_kernel` visible from the global scope once @@ -893,7 +914,7 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg proxy_fn=self._proxy_fn, proxy_u_fn = self._proxy_u_fn) - self._sign.append(sign) + #self._sign.append(sign) #Highlight, do not append the sign here, not necessary self._ll_u = neg_ll @@ -997,7 +1018,6 @@ def sample(self, state, model_args, model_kwargs): rng_key_subsample, rng_key_transition, rng_key_likelihood, rng_key = random.split( state.rng_key, 4) if self.estimator == "poisson": - #TODO: What to do here? does the negative likelihood need to be stored? how about the sign? store in the state? u_new = _sample_u_poisson(rng_key, self.m, self._l) neg_ll, sign = signed_estimator(model = self._model, model_args=[model_args_sub(u_i, model_args) for u_i in u_new], @@ -1031,7 +1051,6 @@ def sample(self, state, model_args, model_kwargs): ######## UPDATE PARAMETERS ########## - print(self._sign) hmc_subsamplestate = HMCECSState(u=u, hmc_state=state.hmc_state,ll_u=ll_u,sign=self._sign) hmc_subsamplestate = tuplemerge(hmc_subsamplestate._asdict(),state._asdict()) @@ -1052,12 +1071,8 @@ def sample(self, state, model_args, model_kwargs): u= u, n= self._n, m= self.m, - l=self._l) - - - - - + l=self._l, + sign = self._sign) else: return self._sample_fn(state, model_args, model_kwargs) diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py index e04827b8b..5e88b10ae 100644 --- a/numpyro/contrib/hmcecs_utils.py +++ b/numpyro/contrib/hmcecs_utils.py @@ -262,7 +262,7 @@ def signed_estimator(model, model_args, model_kwargs, z, l, proxy_fn, proxy_u_fn model_args = [args_l for args_l in model_args if len(args_l[0]) != 0] #remove empty lambda blocks for args_l in model_args: #Iterate over each of the lambda groups of model args block_len = args_l[0].shape[0] - args_l = [jnp.split(arg, arg.shape[0]) for arg in args_l] + args_l = [jnp.split(arg, arg.shape[0]) for arg in args_l] # split the arrays of blocks args_l = list(itertools.chain.from_iterable(args_l)) #Join list of lists args_l = [arr.squeeze(axis=0) for arr in args_l] args_l = list(split_list(args_l,block_len)) From 24e1c1beadf1108b5c7ea3678be38e2686da10b9 Mon Sep 17 00:00:00 2001 From: Lys Date: Thu, 12 Nov 2020 15:06:03 +0100 Subject: [PATCH 29/93] MISSING: Postprocessing --- examples/logistic_hmcecs.py | 26 +++++++++++++++++--------- numpyro/contrib/hmcecs.py | 29 ++++++++++++++++------------- numpyro/contrib/hmcecs_utils.py | 8 ++++++-- numpyro/infer/mcmc.py | 9 +++++++-- 4 files changed, 46 insertions(+), 26 deletions(-) diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index 3b9084542..ff71799d9 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -4,7 +4,7 @@ import jax.numpy as jnp import numpyro import numpyro.distributions as dist -from numpyro.infer import NUTS, MCMC, Predictive +from numpyro.infer import NUTS, MCMC, Predictive,HMC import sys, os from jax.config import config import datetime,time @@ -16,7 +16,8 @@ sys.path.append('/home/lys/Dropbox/PhD/numpyro/numpyro/contrib/') sys.path.append('/home/lys/Dropbox/PhD/numpyro/numpyro/examples/') -from hmcecs import HMC +from hmcecs import HMCECS +from hmcecs_utils import poisson_samples_correction #from numpyro.contrib.hmcecs import HMC from sklearn.datasets import load_breast_cancer @@ -95,7 +96,7 @@ def infer_hmc(rng_key, feats, obs, samples, warmup ): -def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None,algo="NUTS",subsample_method=None,map_method=None,proxy="taylor",estimator=None,num_epochs=None ): +def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None,algo="NUTS",subsample_method=None,map_method=None,proxy="taylor",estimator=None,num_epochs=None,postprocess_fn=None ): hmcecs_key, map_key = jax.random.split(rng_key) n, _ = feats.shape file_hyperparams = open("PLOTS_{}/Hyperparameters_{}.txt".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms"),now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")), "a") @@ -144,10 +145,17 @@ def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None, svi = None start = time.time() - kernel = HMC(model=model,z_ref=z_ref,m=m,g=g,algo=algo,subsample_method=subsample_method,proxy=proxy,svi_fn=svi,estimator = estimator,target_accept_prob=0.8) - - mcmc = MCMC(kernel,num_warmup=warmup,num_samples=n_samples,num_chains=1) - mcmc.run(rng_key,feats,obs) + extra_fields = [] + if estimator == "poisson": + postprocess_fn = None # poisson_samples_correction + #extra_fields = ("sign",) + kernel = HMCECS(model=model,z_ref=z_ref,m=m,g=g,algo=algo, + subsample_method=subsample_method,proxy=proxy,svi_fn=svi, + estimator = estimator,target_accept_prob=0.8)#,postprocess_fn=postprocess_fn) + + mcmc = MCMC(kernel,num_warmup=warmup,num_samples=n_samples,num_chains=1,postprocess_fn=postprocess_fn) + mcmc.run(rng_key,feats,obs,extra_fields=extra_fields) + #extra_fields = mcmc.get_extra_fields() stop = time.time() file_hyperparams.write('MCMC/NUTS elapsed time {}: {} \n'.format(subsample_method,time.time() - start)) file_hyperparams.write('Effective size {}: {}\n'.format(subsample_method,n_samples)) @@ -158,7 +166,7 @@ def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None, file_hyperparams.write('Estimator: {}\n'.format(estimator)) file_hyperparams.write('...........................................\n') file_hyperparams.close() - + #print(mcmc.get_samples().keys()) save_obj(mcmc.get_samples(),"{}/MCMC_Dict_Samples_{}_m_{}.pkl".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")),subsample_method,m)) return mcmc.get_samples() @@ -222,7 +230,7 @@ def Plot_KL(map_method,ecs_algo,algo,proxy,estimator,n_samples,n_warmup,epochs): g = 5 factor_NUTS = 50 colors = cm.rainbow(np.linspace(0, 1, len(m))) - run_test = True + run_test = False if run_test: print("Running standard NUTS") est_posterior_NUTS = infer_hmcecs(rng_key, feats=feats[:factor_NUTS], obs=obs[:factor_NUTS], diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index 58acaba0c..5df672aa8 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -579,12 +579,12 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, accept_prob, mean_accept_prob, diverging, adapt_state,rng_key) # Highlight: The accepted proposals samples are in vv_state.z /hmcstate.z, as we return them, we change their sign - #TODO: Make this prettier - if subsample_method == "perturb" and estimator == "poisson" and itr > wa_steps: - z_new={} - for x,y in hmcstate.z.items(): - z_new[x] = y*sign[-1] - hmcstate = hmcstate._replace(z=z_new) + # #TODO: Make this prettier + # if subsample_method == "perturb" and estimator == "poisson" and itr > wa_steps: + # z_new={} + # for x,y in hmcstate.z.items(): + # z_new[x] = y*sign[-1] + # hmcstate = hmcstate._replace(z=z_new) hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state,ll_u=ll_u,sign = sign) hmcstate = tuplemerge(hmc_sub_state._asdict(),hmcstate._asdict()) @@ -605,7 +605,7 @@ def _log_prob(trace): return jnp.sum(node['fn'].log_prob(node['value']), 1) -class HMC(MCMCKernel): +class HMCECS(MCMCKernel): """ Hamiltonian Monte Carlo inference, using fixed trajectory length, with provision for step size and mass matrix adaptation. @@ -669,7 +669,8 @@ def __init__(self, m= None, g = None, z_ref= None, - algo = "HMC" + algo = "HMC", + postprocess_fn = None, ): if not (model is None) ^ (potential_fn is None): raise ValueError('Only one of `model` or `potential_fn` must be specified.') @@ -703,10 +704,10 @@ def __init__(self, self._l = 100 # Set on first call to init self._init_fn = None - self._postprocess_fn = None + self._postprocess_fn = postprocess_fn self._sample_fn = None self._subsample_fn = None - self._sign = [] + self._sign = jnp.array([jnp.nan]) self.proxy = proxy self.svi_fn = svi_fn self._proxy_fn = None @@ -948,7 +949,7 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg l = self._l, proxy_fn = self._proxy_fn, proxy_u_fn = self._proxy_u_fn) - self._sign.append(sign) + self._sign = jnp.array(sign) self._ll_u = neg_ll else: @@ -1026,7 +1027,9 @@ def sample(self, state, model_args, model_kwargs): l =self._l, proxy_fn = self._proxy_fn, proxy_u_fn = self._proxy_u_fn) - self._sign.append(sign) + + self._sign = jnp.append(self._sign,jnp.array([sign]),axis=0) + self._sign = self._sign[jnp.isfinite(self._sign)] #remove dummy start point, since we annot initialize empty arrays # Correct the negativeloglikelihood by substracting the density of the prior to calculate the potential llu_new = jnp.min(jnp.array([0, -neg_ll + state.ll_u])) @@ -1081,7 +1084,7 @@ def sample(self, state, model_args, model_kwargs): -class NUTS(HMC): +class NUTS(HMCECS): """ Hamiltonian Monte Carlo inference, using the No U-Turn Sampler (NUTS) with adaptive path length and mass matrix adaptation. diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py index 5e88b10ae..4c27c3b8e 100644 --- a/numpyro/contrib/hmcecs_utils.py +++ b/numpyro/contrib/hmcecs_utils.py @@ -236,7 +236,7 @@ def neural_proxy(): def split_list(lst, n): """Pair up the split model arguments back.""" for i in range(0, len(lst), n): - if i+n < len(lst)-1: + if i+n < len(lst)-1: #TODO: Change back to len(lst), after debugging yield tuple( map(lst.__getitem__, [i,i+n])) else: break @@ -270,7 +270,7 @@ def signed_estimator(model, model_args, model_kwargs, z, l, proxy_fn, proxy_u_fn ll_sub, _ = log_density_obs_hmcecs(model, args_l_b, {}, z) # log likelihood for each u subsample xi = (jnp.exp(ll_sub - proxy_u_fn(z=z, model_args=args_l_b, model_kwargs=model_kwargs)) - a) / l sign *= jnp.prod(jnp.sign(xi)) - xis += jnp.sum(jnp.abs(xi)) #, axis=0) + xis += jnp.sum(jnp.abs(xi)) lhat = proxy_fn(z) + (a + l) / l + xis prior_arg = tuple([arg.reshape(arg.shape[0] * arg.shape[1], -1) for arg in model_args[0]])#Join the block subsamples, does not matter because the prior does not look t them @@ -281,3 +281,7 @@ def signed_estimator(model, model_args, model_kwargs, z, l, proxy_fn, proxy_u_fn +def poisson_samples_correction(*args,**kwargs): + "Changes the suport of the samples" + return args + diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 28067e905..2faf1fce6 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -180,6 +180,7 @@ def collect(x): # if any(n == "hmc_state" for n in f): # return attrgetter(*collect_fields)(x[0].hmc_state) # else: + print(attrgetter(*collect_fields)(x[0])) return attrgetter(*collect_fields)(x[0]) else: return x[0] @@ -285,8 +286,7 @@ def _get_cached_fn(self): if self._jit_model_args: fn = partial(_sample_fn_jit_args, sampler=self.sampler) else: - fn = partial(_sample_fn_nojit_args, sampler=self.sampler, - args=self._args, kwargs=self._kwargs) + fn = partial(_sample_fn_nojit_args, sampler=self.sampler,args=self._args, kwargs=self._kwargs) if key is not None: self._cache[key] = fn return fn @@ -316,7 +316,11 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields): init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,) lower_idx = self._collection_params["lower"] upper_idx = self._collection_params["upper"] +<<<<<<< HEAD phase = self._collection_params["phase"] +======= + #TODO: the returned object needs to accomodate the sign +>>>>>>> MISSING: Postprocessing collect_vals = fori_collect(lower_idx, upper_idx, @@ -328,6 +332,7 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields): collection_size=self._collection_params["collection_size"], progbar_desc=partial(_get_progbar_desc_str, lower_idx, phase), diagnostics_fn=diagnostics) + states, last_val = collect_vals # Get first argument of type `HMCState` last_state = last_val[0] From 4243abed25800a3740d67dfb0ebe1f391e3a89e8 Mon Sep 17 00:00:00 2001 From: Lys Date: Fri, 13 Nov 2020 12:04:04 +0100 Subject: [PATCH 30/93] FIXED: sign --- examples/logistic_hmcecs.py | 6 ++++-- numpyro/contrib/hmcecs.py | 8 ++++---- numpyro/infer/mcmc.py | 21 ++++----------------- 3 files changed, 12 insertions(+), 23 deletions(-) diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index ff71799d9..4fddb423d 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -148,14 +148,16 @@ def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None, extra_fields = [] if estimator == "poisson": postprocess_fn = None # poisson_samples_correction - #extra_fields = ("sign",) + extra_fields = ("sign",) kernel = HMCECS(model=model,z_ref=z_ref,m=m,g=g,algo=algo, subsample_method=subsample_method,proxy=proxy,svi_fn=svi, estimator = estimator,target_accept_prob=0.8)#,postprocess_fn=postprocess_fn) mcmc = MCMC(kernel,num_warmup=warmup,num_samples=n_samples,num_chains=1,postprocess_fn=postprocess_fn) mcmc.run(rng_key,feats,obs,extra_fields=extra_fields) - #extra_fields = mcmc.get_extra_fields() + extra_fields = mcmc.get_extra_fields() + print(extra_fields) + print(extra_fields.keys()) stop = time.time() file_hyperparams.write('MCMC/NUTS elapsed time {}: {} \n'.format(subsample_method,time.time() - start)) file_hyperparams.write('Effective size {}: {}\n'.format(subsample_method,n_samples)) diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index 5df672aa8..dfa87ccc8 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -915,7 +915,7 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg proxy_fn=self._proxy_fn, proxy_u_fn = self._proxy_u_fn) - #self._sign.append(sign) #Highlight, do not append the sign here, not necessary + self._sign = jnp.array(sign) #Highlight, do not append the sign here, not necessary self._ll_u = neg_ll @@ -1027,9 +1027,9 @@ def sample(self, state, model_args, model_kwargs): l =self._l, proxy_fn = self._proxy_fn, proxy_u_fn = self._proxy_u_fn) - - self._sign = jnp.append(self._sign,jnp.array([sign]),axis=0) - self._sign = self._sign[jnp.isfinite(self._sign)] #remove dummy start point, since we annot initialize empty arrays + self._sign=jnp.array(sign) + #self._sign = jnp.append(self._sign,jnp.array([sign]),axis=0) + #self._sign = self._sign[jnp.isfinite(self._sign)] #remove dummy start point, since we annot initialize empty arrays # Correct the negativeloglikelihood by substracting the density of the prior to calculate the potential llu_new = jnp.min(jnp.array([0, -neg_ll + state.ll_u])) diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 2faf1fce6..8ad86c914 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -168,7 +168,6 @@ def _sample_fn_jit_args(state, sampler): def _sample_fn_nojit_args(state, sampler, args, kwargs): # state is a tuple of size 1 - containing HMCState - return sampler.sample(state[0], args, kwargs), @@ -176,12 +175,7 @@ def _collect_fn(collect_fields): @cached_by(_collect_fn, collect_fields) def collect(x): if collect_fields: - # f = getattr(x[0], '_fields', None) - # if any(n == "hmc_state" for n in f): - # return attrgetter(*collect_fields)(x[0].hmc_state) - # else: - print(attrgetter(*collect_fields)(x[0])) - return attrgetter(*collect_fields)(x[0]) + return attrgetter(*collect_fields)(x[0]) else: return x[0] @@ -286,7 +280,8 @@ def _get_cached_fn(self): if self._jit_model_args: fn = partial(_sample_fn_jit_args, sampler=self.sampler) else: - fn = partial(_sample_fn_nojit_args, sampler=self.sampler,args=self._args, kwargs=self._kwargs) + fn = partial(_sample_fn_nojit_args, sampler=self.sampler, + args=self._args, kwargs=self._kwargs) if key is not None: self._cache[key] = fn return fn @@ -311,16 +306,11 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields): postprocess_fn = self.sampler.postprocess_fn(args, kwargs) else: postprocess_fn = self.postprocess_fn - diagnostics = lambda x: self.sampler.get_diagnostics_str(x[0]) if rng_key.ndim == 1 else '' # noqa: E731 init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,) lower_idx = self._collection_params["lower"] upper_idx = self._collection_params["upper"] -<<<<<<< HEAD phase = self._collection_params["phase"] -======= - #TODO: the returned object needs to accomodate the sign ->>>>>>> MISSING: Postprocessing collect_vals = fori_collect(lower_idx, upper_idx, @@ -332,7 +322,6 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields): collection_size=self._collection_params["collection_size"], progbar_desc=partial(_get_progbar_desc_str, lower_idx, phase), diagnostics_fn=diagnostics) - states, last_val = collect_vals # Get first argument of type `HMCState` last_state = last_val[0] @@ -425,7 +414,6 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): if self.num_chains > 1 and rng_key.ndim == 1: rng_key = random.split(rng_key, self.num_chains) - if self._warmup_state is not None: self._set_collection_params(0, self.num_samples, self.num_samples, "sample") init_state = self._warmup_state._replace(rng_key=rng_key) @@ -505,11 +493,10 @@ def get_extra_fields(self, group_by_chain=False): def print_summary(self, prob=0.9, exclude_deterministic=True): # Exclude deterministic sites by default sites = self._states[self._sample_field] - if isinstance(sites, dict) and exclude_deterministic: sites = {k: v for k, v in self._states[self._sample_field].items() if k in self._last_state.z} print_summary(sites, prob=prob) extra_fields = self.get_extra_fields() if 'diverging' in extra_fields: - print("Number of divergences: {}".format(jnp.sum(extra_fields['diverging']))) + print("Number of divergences: {}".format(jnp.sum(extra_fields['diverging']))) \ No newline at end of file From f996d18704164d496bc9f34ecea0a0ba50b3a8da Mon Sep 17 00:00:00 2001 From: Lys Date: Fri, 13 Nov 2020 16:42:56 +0100 Subject: [PATCH 31/93] More debugging --- examples/logistic_hmcecs.py | 2 -- numpyro/contrib/hmcecs.py | 3 +-- numpyro/contrib/hmcecs_utils.py | 3 ++- numpyro/infer/mcmc.py | 6 ++++++ 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index 4fddb423d..190483371 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -156,8 +156,6 @@ def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None, mcmc = MCMC(kernel,num_warmup=warmup,num_samples=n_samples,num_chains=1,postprocess_fn=postprocess_fn) mcmc.run(rng_key,feats,obs,extra_fields=extra_fields) extra_fields = mcmc.get_extra_fields() - print(extra_fields) - print(extra_fields.keys()) stop = time.time() file_hyperparams.write('MCMC/NUTS elapsed time {}: {} \n'.format(subsample_method,time.time() - start)) file_hyperparams.write('Effective size {}: {}\n'.format(subsample_method,n_samples)) diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index dfa87ccc8..a7ca10d39 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -835,7 +835,7 @@ def sample_field(self): @property def default_fields(self): - return ('z', 'diverging') + return ('z', 'diverging','sign') def get_diagnostics_str(self, state): return '{} steps of size {:.2e}. acc. prob={:.2f}'.format(state.num_steps, @@ -906,7 +906,6 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg # et al. (2015) where a pseudo-marginal sampler is run on the absolute value of the estimated # posterior and subsequently sign-corrected by importance sampling. Similarly, we call the # algorithm described in this section signed HMC-ECS - #model_args = [model_args_sub(u_i, model_args)for u_i in self._u] neg_ll, sign = signed_estimator(model = self._model, model_args = [model_args_sub(u_i, model_args)for u_i in self._u], model_kwargs= model_kwargs, diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py index 4c27c3b8e..b74cdfeaa 100644 --- a/numpyro/contrib/hmcecs_utils.py +++ b/numpyro/contrib/hmcecs_utils.py @@ -282,6 +282,7 @@ def signed_estimator(model, model_args, model_kwargs, z, l, proxy_fn, proxy_u_fn def poisson_samples_correction(*args,**kwargs): - "Changes the suport of the samples" + "Changes the support of the samples by using the sign from the " + return args diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 8ad86c914..8fd343144 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -328,14 +328,20 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields): if len(collect_fields) == 1: states = (states,) states = dict(zip(collect_fields, states)) + #print(states) # Apply constraints if number of samples is non-zero + #print(self._sample_field) site_values = tree_flatten(states[self._sample_field])[0] + #print(site_values) # XXX: lax.map still works if some arrays have 0 size # so we only need to filter out the case site_value.shape[0] == 0 # (which happens when lower_idx==upper_idx) + print(self._sample_field) + #print(states[self._sample_field]) if len(site_values) > 0 and jnp.shape(site_values[0])[0] > 0: if self.chain_method == "vectorized" and self.num_chains > 1: postprocess_fn = vmap(postprocess_fn) + print(states[self._sample_field]) states[self._sample_field] = lax.map(postprocess_fn, states[self._sample_field]) return states, last_state From 1095f1932711952569c7df004f434ed5ccd52578 Mon Sep 17 00:00:00 2001 From: ola Date: Wed, 9 Dec 2020 09:22:25 +0100 Subject: [PATCH 32/93] Fixed style. --- examples/autoguide_hmcecs.py | 13 +- numpyro/contrib/hmcecs.py | 386 ++++++++++++++++++----------------- 2 files changed, 205 insertions(+), 194 deletions(-) diff --git a/examples/autoguide_hmcecs.py b/examples/autoguide_hmcecs.py index cf3805ae8..9de321a6a 100644 --- a/examples/autoguide_hmcecs.py +++ b/examples/autoguide_hmcecs.py @@ -43,6 +43,7 @@ 'AutoDelta' ] + class ReinitGuide(ABC): @abstractmethod def init_params(self): @@ -51,6 +52,8 @@ def init_params(self): @abstractmethod def find_params(self, rng_keys, *args, **kwargs): raise NotImplementedError + + class AutoGuide(ABC): """ Base class for automatic guides. @@ -103,7 +106,7 @@ def _sample_latent(self, *args, **kwargs): def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure rng_key = random.PRNGKey(0) - #rng_key = numpyro.rng_key("_{}_rng_key_setup".format(self.prefix)) + # rng_key = numpyro.rng_key("_{}_rng_key_setup".format(self.prefix)) model = handlers.seed(self.model, rng_key) self.prototype_trace = handlers.block(handlers.trace(model).get_trace)(*args, **kwargs) self._args = args @@ -161,7 +164,7 @@ def __init__(self, model, prefix="auto", init_strategy=init_to_uniform): def _setup_prototype(self, *args, **kwargs): rng_key = random.PRNGKey(0) - #rng_key = numpyro.rng_key("_{}_rng_key_setup".format(self.prefix)) + # rng_key = numpyro.rng_key("_{}_rng_key_setup".format(self.prefix)) with handlers.block(): init_params, _, self._postprocess_fn, self.prototype_trace = initialize_model( rng_key, self.model, @@ -418,7 +421,7 @@ def get_base_dist(self): def get_transform(self, params): loc = params['{}_loc'.format(self.prefix)] scale_tril = params['{}_scale_tril'.format(self.prefix)] - return LowerCholeskyAffine(loc, scale_tril) #TODO: Changed MultivariateAffineTransform to LowerCholeskyAffine + return LowerCholeskyAffine(loc, scale_tril) # TODO: Changed MultivariateAffineTransform to LowerCholeskyAffine def get_posterior(self, params): """ @@ -729,6 +732,6 @@ def find_params(self, rng_keys, *args, **kwargs): def _setup_prototype(self, *args, **kwargs): super(AutoDelta, self)._setup_prototype(*args, **kwargs) - #rng_key = numpyro.rng_key("_{}_rng_key_init".format(self.prefix)) + # rng_key = numpyro.rng_key("_{}_rng_key_init".format(self.prefix)) rng_key = random.PRNGKey(1) - self.find_params(rng_key, *args, **kwargs) \ No newline at end of file + self.find_params(rng_key, *args, **kwargs) diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index a7ca10d39..4cbcb1f15 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -1,15 +1,20 @@ """Contributed code for HMC and NUTS energy conserving sampling adapted from """ -from collections import namedtuple import math import os import warnings +from collections import namedtuple +from itertools import chain -from jax import device_put, lax, partial, random, vmap,jacfwd, hessian,jit,ops +import jax.numpy as jnp +from jax import device_put, lax, partial, random, vmap, jacfwd, hessian, jit, ops from jax.dtypes import canonicalize_dtype from jax.flatten_util import ravel_pytree -import jax.numpy as jnp +import numpyro.distributions as dist +from numpyro.contrib.hmcecs_utils import potential_est, init_near_values, tuplemerge, \ + model_args_sub, model_kwargs_sub, taylor_proxy, svi_proxy, log_density_obs_hmcecs, \ + signed_estimator from numpyro.infer.hmc_util import ( IntegratorState, build_tree, @@ -19,19 +24,13 @@ warmup_adapter ) from numpyro.infer.mcmc import MCMCKernel -from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model, log_density +from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model from numpyro.util import cond, fori_loop, identity -import sys -sys.path.append('/home/lys/Dropbox/PhD/numpyro/numpyro/contrib/') #TODO: remove -import numpyro.distributions as dist -from itertools import chain -from hmcecs_utils import potential_est, init_near_values,tuplemerge,\ - model_args_sub,model_kwargs_sub,taylor_proxy,svi_proxy,neural_proxy,log_density_obs_hmcecs,log_density_prior_hmcecs,signed_estimator HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob', - 'mean_accept_prob', 'diverging', 'adapt_state','rng_key']) + 'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key']) -HMCECSState = namedtuple("HMCECState",['u', 'hmc_state', 'll_u','sign']) +HMCECSState = namedtuple("HMCECState", ['u', 'hmc_state', 'll_u', 'sign']) """ A :func:`~collections.namedtuple` consisting of the following fields: @@ -89,6 +88,7 @@ def momentum_generator(prototype_r, mass_matrix_sqrt, rng_key): else: raise ValueError("Mass matrix has incorrect number of dims.") + @partial(jit, static_argnums=(2, 3, 4)) def _update_block(rng_key, u, n, m, g): """Returns indexes of the new subsample. The update mechanism selects blocks of indices within the subsample to be updated. @@ -100,19 +100,22 @@ def _update_block(rng_key, u, n, m, g): :param g block size: subsample subdivision""" if (g > m) or (g < 1): - raise ValueError('Block size (g) = {} needs to = or > than 1 and smaller than the subsample size {}'.format(g,m)) + raise ValueError( + 'Block size (g) = {} needs to = or > than 1 and smaller than the subsample size {}'.format(g, m)) rng_key_block, rng_key_index = random.split(rng_key) # uniformly choose block to update - chosen_block = random.randint(rng_key, shape=(), minval= 0, maxval=g + 1) - idxs_new = random.randint(rng_key_index, shape=(m // g,), minval=0, maxval=n) #choose block within the subsample to update - u_new = jnp.zeros(m, jnp.dtype(u)) #empty array with size m + chosen_block = random.randint(rng_key, shape=(), minval=0, maxval=g + 1) + idxs_new = random.randint(rng_key_index, shape=(m // g,), minval=0, + maxval=n) # choose block within the subsample to update + u_new = jnp.zeros(m, jnp.dtype(u)) # empty array with size m for i in range(m): - #if index in the subsample // g = chosen block : pick new indexes from the subsample size - #else not update: keep the same indexes + # if index in the subsample // g = chosen block : pick new indexes from the subsample size + # else not update: keep the same indexes u_new = ops.index_add(u_new, i, lax.cond(i // g == chosen_block, i, lambda _: idxs_new[i % (m // g)], i, lambda _: u[i])) return u_new + def _sample_u_poisson(rng_key, m, l): """ Initialize subsamples u ***References*** @@ -123,11 +126,12 @@ def _sample_u_poisson(rng_key, m, l): :param g: number of blocks """ pois_key, sub_key = random.split(rng_key) - block_lengths = dist.discrete.Poisson(1).sample(pois_key, (l,)) #lambda block lengths - #u = random.randint(sub_key, (jnp.sum(block_lengths), ), 0, m) + block_lengths = dist.discrete.Poisson(1).sample(pois_key, (l,)) # lambda block lengths + # u = random.randint(sub_key, (jnp.sum(block_lengths), ), 0, m) u = random.randint(sub_key, (jnp.sum(block_lengths), m), 0, m) return jnp.split(u, jnp.cumsum(block_lengths), axis=0) + @partial(jit, static_argnums=(2, 3, 4)) def _update_block_poisson(rng_key, u, m, l, g): """ Update block of u, where the length of the block of indexes to update is given by the Poisson distribution. @@ -141,10 +145,11 @@ def _update_block_poisson(rng_key, u, m, l, g): :param g: Block size within subsample """ if (g > m) or (g < 1): - raise ValueError('Block size (g) = {} needs to = or > than 1 and smaller than the subsample size {}'.format(g,m)) + raise ValueError( + 'Block size (g) = {} needs to = or > than 1 and smaller than the subsample size {}'.format(g, m)) u = u.copy() block_key, sample_key = random.split(rng_key) - num_updates = int(round(l / g, 0)) # choose lambda/g number of blocks to update + num_updates = int(round(l / g, 0)) # choose lambda/g number of blocks to update chosen_blocks = random.randint(block_key, (num_updates,), 0, l) new_blocks = _sample_u_poisson(sample_key, m, num_updates) for i, block in enumerate(chosen_blocks): @@ -152,7 +157,7 @@ def _update_block_poisson(rng_key, u, m, l, g): return u -def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, grad_potential_fn_gen=None,algo='NUTS'): +def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, grad_potential_fn_gen=None, algo='NUTS'): r""" Hamiltonian Monte Carlo inference, using either fixed number of steps or the No U-Turn Sampler (NUTS) with adaptive path length. @@ -244,27 +249,27 @@ def init_kernel(init_params, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, - trajectory_length=2*math.pi, + trajectory_length=2 * math.pi, max_tree_depth=10, find_heuristic_step_size=False, model_args=(), model_kwargs=None, - model = None, - sign = None, + model=None, + sign=None, ll_ref=None, jac_all=None, - z_ref= None, + z_ref=None, hess_all=None, - ll_u = None, - n = None, - m = None, - u= None, + ll_u=None, + n=None, + m=None, + u=None, l=None, rng_key=random.PRNGKey(0), subsample_method=None, estimator=None, proxy_fn=None, - proxy_u_fn = None): + proxy_u_fn=None): """ Initializes the HMC sampler. @@ -335,9 +340,11 @@ def init_kernel(init_params, if subsample_method == "perturb": kwargs = {} if model_kwargs is None else model_kwargs if estimator == "poisson": - pe_fn = potential_fn_gen(model=model, model_args=model_args, model_kwargs=kwargs, z=z, l=l,proxy_fn=proxy_fn, proxy_u_fn=proxy_u_fn) + pe_fn = potential_fn_gen(model=model, model_args=model_args, model_kwargs=kwargs, z=z, l=l, + proxy_fn=proxy_fn, proxy_u_fn=proxy_u_fn) else: - pe_fn = potential_fn_gen(model=model, model_args=model_args, model_kwargs=kwargs, z=z, n=n, m=m,proxy_fn=proxy_fn, proxy_u_fn=proxy_u_fn) + pe_fn = potential_fn_gen(model=model, model_args=model_args, model_kwargs=kwargs, z=z, n=n, m=m, + proxy_fn=proxy_fn, proxy_u_fn=proxy_u_fn) else: kwargs = {} if model_kwargs is None else model_kwargs @@ -378,24 +385,23 @@ def init_kernel(init_params, energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r) hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, - 0, 0., 0., False, wa_state,rng_key_hmc) - - hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state,ll_u=ll_u,sign=sign) + 0, 0., 0., False, wa_state, rng_key_hmc) - hmc_state = tuplemerge(hmc_sub_state._asdict(),hmc_state._asdict()) + hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state, ll_u=ll_u, sign=sign) + hmc_state = tuplemerge(hmc_sub_state._asdict(), hmc_state._asdict()) return device_put(hmc_state) def _hmc_next(step_size, inverse_mass_matrix, vv_state, - model_args, model_kwargs, rng_key,subsample_method, + model_args, model_kwargs, rng_key, subsample_method, estimator=None, - proxy_fn = None, - proxy_u_fn = None, - model = None, - ll_ref = None,jac_all = None,z = None,z_ref = None,hess_all = None,ll_u = None,u = None, - n = None, - m = None, + proxy_fn=None, + proxy_u_fn=None, + model=None, + ll_ref=None, jac_all=None, z=None, z_ref=None, hess_all=None, ll_u=None, u=None, + n=None, + m=None, l=None): if potential_fn_gen: if grad_potential_fn_gen: @@ -448,18 +454,18 @@ def _hmc_next(step_size, inverse_mass_matrix, vv_state, return vv_state, energy, num_steps, accept_prob, diverging def _nuts_next(step_size, inverse_mass_matrix, vv_state, - model_args, model_kwargs, rng_key,subsample_method, + model_args, model_kwargs, rng_key, subsample_method, estimator=None, - proxy_fn=None,proxy_u_fn=None, + proxy_fn=None, proxy_u_fn=None, model=None, - ll_ref=None,jac_all=None,z = None,z_ref=None,hess_all=None,ll_u=None,u=None, - n=None,m=None,l=None): + ll_ref=None, jac_all=None, z=None, z_ref=None, hess_all=None, ll_u=None, u=None, + n=None, m=None, l=None): if potential_fn_gen: nonlocal vv_update if grad_potential_fn_gen: - kwargs = {} if model_kwargs is None else model_kwargs - gpe_fn = grad_potential_fn_gen(*model_args, **kwargs, ) - pe_fn = potential_fn_gen(*model_args, **model_kwargs) + kwargs = {} if model_kwargs is None else model_kwargs + gpe_fn = grad_potential_fn_gen(*model_args, **kwargs, ) + pe_fn = potential_fn_gen(*model_args, **model_kwargs) else: if subsample_method == "perturb": if estimator == "poisson": @@ -500,9 +506,9 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, _next = _nuts_next if algo == 'NUTS' else _hmc_next - def sample_kernel(hmc_state,model_args=(),model_kwargs=None, + def sample_kernel(hmc_state, model_args=(), model_kwargs=None, subsample_method=None, - estimator = None, + estimator=None, proxy_fn=None, proxy_u_fn=None, model=None, @@ -512,8 +518,8 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, z_ref=None, hess_all=None, ll_u=None, - sign = None, - u=None,n=None,m=None,l=None): + sign=None, + u=None, n=None, m=None, l=None): """ Given an existing :data:`~numpyro.infer.mcmc.HMCState`, run HMC with fixed (possibly adapted) step size and return a new :data:`~numpyro.infer.mcmc.HMCState`. @@ -541,11 +547,11 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, """ model_kwargs = {} if model_kwargs is None else model_kwargs - if subsample_method =="perturb": + if subsample_method == "perturb": if estimator == "poisson": - model_args = [model_args_sub(u_i, model_args) for u_i in u] #here u = poisson_u + model_args = [model_args_sub(u_i, model_args) for u_i in u] # here u = poisson_u else: - model_args = model_args_sub(u,model_args) + model_args = model_args_sub(u, model_args) rng_key, rng_key_momentum, rng_key_transition = random.split(hmc_state.rng_key, 3) r = momentum_generator(hmc_state.z, hmc_state.adapt_state.mass_matrix_sqrt, rng_key_momentum) vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad) @@ -561,8 +567,8 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, proxy_fn, proxy_u_fn, model, - ll_ref,jac_all,z,z_ref,hess_all,ll_u,u, - n,m,l) + ll_ref, jac_all, z, z_ref, hess_all, ll_u, u, + n, m, l) # not update adapt_state after warmup phase adapt_state = cond(hmc_state.i < wa_steps, (hmc_state.i, accept_prob, vv_state, hmc_state.adapt_state), @@ -570,13 +576,12 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, hmc_state.adapt_state, identity) - itr = hmc_state.i + 1 n = jnp.where(hmc_state.i < wa_steps, itr, itr - wa_steps) mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n hmcstate = HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps, - accept_prob, mean_accept_prob, diverging, adapt_state,rng_key) + accept_prob, mean_accept_prob, diverging, adapt_state, rng_key) # Highlight: The accepted proposals samples are in vv_state.z /hmcstate.z, as we return them, we change their sign # #TODO: Make this prettier @@ -586,8 +591,8 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, # z_new[x] = y*sign[-1] # hmcstate = hmcstate._replace(z=z_new) - hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state,ll_u=ll_u,sign = sign) - hmcstate = tuplemerge(hmc_sub_state._asdict(),hmcstate._asdict()) + hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state, ll_u=ll_u, sign=sign) + hmcstate = tuplemerge(hmc_sub_state._asdict(), hmcstate._asdict()) return hmcstate @@ -599,6 +604,7 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, return init_kernel, sample_kernel + def _log_prob(trace): """ Compute probability of each observation """ node = trace['observations'] @@ -649,10 +655,11 @@ class HMCECS(MCMCKernel): :param z_ref MAP estimate of the parameters :param covariate_fn Proxy function to calculate the covariates for the likelihood correction """ + def __init__(self, model=None, potential_fn=None, - grad_potential = None, + grad_potential=None, kinetic_fn=None, step_size=1.0, adapt_step_size=True, @@ -662,15 +669,15 @@ def __init__(self, trajectory_length=2 * math.pi, init_strategy=init_to_uniform, find_heuristic_step_size=False, - subsample_method = None, + subsample_method=None, estimator=None, # poisson or not proxy="taylor", svi_fn=None, - m= None, - g = None, - z_ref= None, - algo = "HMC", - postprocess_fn = None, + m=None, + g=None, + z_ref=None, + algo="HMC", + postprocess_fn=None, ): if not (model is None) ^ (potential_fn is None): raise ValueError('Only one of `model` or `potential_fn` must be specified.') @@ -689,7 +696,7 @@ def __init__(self, self._max_tree_depth = 10 self._init_strategy = init_strategy self._find_heuristic_step_size = find_heuristic_step_size - #HMCECS parameters + # HMCECS parameters self.subsample_method = subsample_method self.m = m if m is not None else 4 self.g = g if g is not None else 2 @@ -715,7 +722,7 @@ def __init__(self, self._signed_estimator_fn = None self.estimator = estimator - def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ref): + def _init_subsample_state(self, rng_key, model_args, model_kwargs, init_params, z_ref): "Compute the jacobian, hessian and log likelihood for all the data. Used with taylor expansion proxy" rng_key_subsample, rng_key_model, rng_key_hmc_init, rng_key_potential, rng_key = random.split(rng_key, 5) @@ -724,28 +731,29 @@ def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ hess_all, _ = ravel_pytree(hessian(ld_fn)(z_ref)) k, = self._jac_all.shape self._hess_all = hess_all.reshape((k, k)) - ld_fn = lambda args: partial(log_density_obs_hmcecs,self._model,model_args,model_kwargs)(args)[0] + ld_fn = lambda args: partial(log_density_obs_hmcecs, self._model, model_args, model_kwargs)(args)[0] self._ll_ref = ld_fn(z_ref) - def _init_state(self, rng_key, model_args, model_kwargs, init_params): if self.subsample_method is not None: assert self.z_ref is not None, "Please provide a (i.e map) estimate for the parameters" self._n = model_args[0].shape[0] # Choose the covariate calculation method if self.proxy == "svi": - self._proxy_fn,self._proxy_u_fn = svi_proxy(self.svi_fn,model_args,model_kwargs) + self._proxy_fn, self._proxy_u_fn = svi_proxy(self.svi_fn, model_args, model_kwargs) elif self.proxy == "taylor": warnings.warn("Using default second order Taylor expansion, change by using the proxy flag to {svi}") self._init_subsample_state(rng_key, model_args, model_kwargs, init_params, self.z_ref) - self._proxy_fn,self._proxy_u_fn = taylor_proxy(self.z_ref, self._model, self._ll_ref, self._jac_all, self._hess_all) - if self.estimator =="poisson": - self._l = 25 # lambda subsamples + self._proxy_fn, self._proxy_u_fn = taylor_proxy(self.z_ref, self._model, self._ll_ref, self._jac_all, + self._hess_all) + if self.estimator == "poisson": + self._l = 25 # lambda subsamples self._u = _sample_u_poisson(rng_key, self.m, self._l) - self._potential_fn = lambda model,model_args,model_kwargs,z,l, proxy_fn,proxy_u_fn : lambda z:signed_estimator(model = model,model_args=model_args, - model_kwargs= model_kwargs,z=z,l=l,proxy_fn=proxy_fn, - proxy_u_fn=proxy_u_fn)[0] + self._potential_fn = lambda model, model_args, model_kwargs, z, l, proxy_fn, proxy_u_fn: lambda z: \ + signed_estimator(model=model, model_args=model_args, + model_kwargs=model_kwargs, z=z, l=l, proxy_fn=proxy_fn, + proxy_u_fn=proxy_u_fn)[0] # Initialize the hmc sampler: sample_fn = sample_kernel self._init_fn, self._sample_fn = hmc(potential_fn_gen=self._potential_fn, kinetic_fn=euclidean_kinetic_energy, @@ -762,21 +770,24 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): self._model, init_strategy=self._init_strategy, dynamic_args=True, - model_args=tuple([arg[0] for arg in next(chain(init_model_args))]), #Highlight:Pick the first non-empty block ; 'chain' joins all the elements in the sublist , len(lists_of_lists) = n , len(chain(list_of_lists)) = sum(n_elements_inside_list=*n) + model_args=tuple([arg[0] for arg in next(chain(init_model_args))]), + # Highlight:Pick the first non-empty block ; 'chain' joins all the elements in the sublist , len(lists_of_lists) = n , len(chain(list_of_lists)) = sum(n_elements_inside_list=*n) model_kwargs=model_kwargs) else: self._u = random.randint(rng_key, (self.m,), 0, self._n) # Initialize the potential and gradient potential functions - self._potential_fn = lambda model, model_args, model_kwargs, z, n, m, proxy_fn, proxy_u_fn : lambda z:potential_est(model=model, - model_args=model_args, model_kwargs=model_kwargs, z=z, n=n, m=m, proxy_fn=proxy_fn, proxy_u_fn=proxy_u_fn) + self._potential_fn = lambda model, model_args, model_kwargs, z, n, m, proxy_fn, proxy_u_fn: lambda \ + z: potential_est(model=model, + model_args=model_args, model_kwargs=model_kwargs, z=z, n=n, m=m, + proxy_fn=proxy_fn, + proxy_u_fn=proxy_u_fn) # Initialize the hmc sampler: sample_fn = sample_kernel self._init_fn, self._sample_fn = hmc(potential_fn_gen=self._potential_fn, - kinetic_fn=euclidean_kinetic_energy, - algo=self._algo) - + kinetic_fn=euclidean_kinetic_energy, + algo=self._algo) self._init_strategy = partial(init_near_values, values=self.z_ref) # Initialize the model parameters @@ -791,12 +802,12 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): model_kwargs=model_kwargs) if (self.g > self.m) or (self.g < 1): - raise ValueError( - 'Block size (g) = {} needs to = or > than 1 and smaller than the subsample size {}'.format(self.g, - self.m)) + raise ValueError( + 'Block size (g) = {} needs to = or > than 1 and smaller than the subsample size {}'.format(self.g, + self.m)) elif (self.m > self._n): - raise ValueError( - 'Subsample size (m) = {} needs to = or < than data size (n) {}'.format(self.m, self._n)) + raise ValueError( + 'Subsample size (m) = {} needs to = or < than data size (n) {}'.format(self.m, self._n)) else: if self._model is not None: @@ -824,7 +835,6 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): return init_params - @property def model(self): return self._model @@ -835,14 +845,14 @@ def sample_field(self): @property def default_fields(self): - return ('z', 'diverging','sign') + return ('z', 'diverging', 'sign') def get_diagnostics_str(self, state): return '{} steps of size {:.2e}. acc. prob={:.2f}'.format(state.num_steps, state.adapt_state.step_size, state.mean_accept_prob) - def _block_indices(self,size, num_blocks): + def _block_indices(self, size, num_blocks): a = jnp.repeat(jnp.arange(num_blocks - 1), size // num_blocks) b = jnp.repeat(num_blocks - 1, size - len(jnp.repeat(jnp.arange(num_blocks - 1), size // num_blocks))) return jnp.hstack((a, b)) @@ -856,8 +866,8 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg else: rng_key, rng_key_init_model = jnp.swapaxes(vmap(random.split)(rng_key), 0, 1) - - init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params) #should work for all cases + init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, + init_params) # should work for all cases if self._potential_fn and init_params is None: raise ValueError('Valid value of `init_params` must be provided with' @@ -866,55 +876,55 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg if self.estimator == "poisson": init_model_args = [model_args_sub(u_i, model_args) for u_i in self._u] else: - init_model_args = model_args_sub(self._u,model_args) - hmc_init_fn = lambda init_params,rng_key: self._init_fn(init_params=init_params, - num_warmup = num_warmup, - step_size = self._step_size, - adapt_step_size = self._adapt_step_size, - adapt_mass_matrix = self._adapt_mass_matrix, - dense_mass = self._dense_mass, - target_accept_prob = self._target_accept_prob, - trajectory_length=self._trajectory_length, - max_tree_depth=self._max_tree_depth, - find_heuristic_step_size=self._find_heuristic_step_size, - model_args=init_model_args, - model_kwargs=model_kwargs, - subsample_method= self.subsample_method, - estimator= self.estimator, - model=self._model, - ll_ref =self._ll_ref, - jac_all=self._jac_all, - z_ref=self.z_ref, - hess_all = self._hess_all, - ll_u = self._ll_u, - n=self._n, - m=self.m, - u = self._u, - l = self._l, - sign = self._sign, - proxy_fn = self._proxy_fn, - proxy_u_fn = self._proxy_u_fn) - - if rng_key.ndim ==1: - #rng_key_hmc_init = jnp.array([1000966916, 171341646]) - rng_key_hmc_init,_ = random.split(rng_key) - - init_state = hmc_init_fn(init_params, rng_key_hmc_init) #HMCState + HMCECSState + init_model_args = model_args_sub(self._u, model_args) + hmc_init_fn = lambda init_params, rng_key: self._init_fn(init_params=init_params, + num_warmup=num_warmup, + step_size=self._step_size, + adapt_step_size=self._adapt_step_size, + adapt_mass_matrix=self._adapt_mass_matrix, + dense_mass=self._dense_mass, + target_accept_prob=self._target_accept_prob, + trajectory_length=self._trajectory_length, + max_tree_depth=self._max_tree_depth, + find_heuristic_step_size=self._find_heuristic_step_size, + model_args=init_model_args, + model_kwargs=model_kwargs, + subsample_method=self.subsample_method, + estimator=self.estimator, + model=self._model, + ll_ref=self._ll_ref, + jac_all=self._jac_all, + z_ref=self.z_ref, + hess_all=self._hess_all, + ll_u=self._ll_u, + n=self._n, + m=self.m, + u=self._u, + l=self._l, + sign=self._sign, + proxy_fn=self._proxy_fn, + proxy_u_fn=self._proxy_u_fn) + + if rng_key.ndim == 1: + # rng_key_hmc_init = jnp.array([1000966916, 171341646]) + rng_key_hmc_init, _ = random.split(rng_key) + + init_state = hmc_init_fn(init_params, rng_key_hmc_init) # HMCState + HMCECSState if self.estimator == "poisson": - #signed pseudo-marginal algorithm with the block-Poisson estimator - #use the term signed PM for any pseudo-marginal algorithm that uses the technique in Lyne + # signed pseudo-marginal algorithm with the block-Poisson estimator + # use the term signed PM for any pseudo-marginal algorithm that uses the technique in Lyne # et al. (2015) where a pseudo-marginal sampler is run on the absolute value of the estimated # posterior and subsequently sign-corrected by importance sampling. Similarly, we call the # algorithm described in this section signed HMC-ECS - neg_ll, sign = signed_estimator(model = self._model, - model_args = [model_args_sub(u_i, model_args)for u_i in self._u], - model_kwargs= model_kwargs, + neg_ll, sign = signed_estimator(model=self._model, + model_args=[model_args_sub(u_i, model_args) for u_i in self._u], + model_kwargs=model_kwargs, z=init_state.z, l=self._l, proxy_fn=self._proxy_fn, - proxy_u_fn = self._proxy_u_fn) + proxy_u_fn=self._proxy_u_fn) - self._sign = jnp.array(sign) #Highlight, do not append the sign here, not necessary + self._sign = jnp.array(sign) # Highlight, do not append the sign here, not necessary self._ll_u = neg_ll @@ -927,27 +937,27 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg m=self.m, proxy_fn=self._proxy_fn, proxy_u_fn=self._proxy_u_fn) - hmc_init_sub_state = HMCECSState(u=self._u, - hmc_state=init_state.hmc_state, - ll_u=self._ll_u,sign=self._sign) - init_sub_state = tuplemerge(init_state._asdict(),hmc_init_sub_state._asdict()) + hmc_init_sub_state = HMCECSState(u=self._u, + hmc_state=init_state.hmc_state, + ll_u=self._ll_u, sign=self._sign) + init_sub_state = tuplemerge(init_state._asdict(), hmc_init_sub_state._asdict()) return init_sub_state - else: #TODO: What is this for? It does not go into it for num_chains>1 + else: # TODO: What is this for? It does not go into it for num_chains>1 raise ValueError("Not implemented for chains > 1") # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth, # wa_steps because those variables do not depend on traced args: init_params, rng_key. init_state = vmap(hmc_init_fn)(init_params, rng_key) if self.estimator == "poisson": - #model_args = [model_args_sub(u_i, model_args)for u_i in self._u] + # model_args = [model_args_sub(u_i, model_args)for u_i in self._u] neg_ll, sign = signed_estimator(model=self._model, - model_args=[model_args_sub(u_i, model_args)for u_i in self._u], - model_kwargs= model_kwargs_sub, + model_args=[model_args_sub(u_i, model_args) for u_i in self._u], + model_kwargs=model_kwargs_sub, z=init_state.z, - l = self._l, - proxy_fn = self._proxy_fn, - proxy_u_fn = self._proxy_u_fn) + l=self._l, + proxy_fn=self._proxy_fn, + proxy_u_fn=self._proxy_u_fn) self._sign = jnp.array(sign) self._ll_u = neg_ll @@ -961,12 +971,13 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg proxy_fn=self._proxy_fn, proxy_u_fn=self._proxy_u_fn) - hmc_init_sub_fn = lambda init_params, rng_key: HMCECSState(u=self._u, hmc_state=init_state, ll_u=self._ll_u,sign = self._sign) + hmc_init_sub_fn = lambda init_params, rng_key: HMCECSState(u=self._u, hmc_state=init_state, + ll_u=self._ll_u, sign=self._sign) - init_subsample_state = vmap(hmc_init_sub_fn)(init_params,rng_key) + init_subsample_state = vmap(hmc_init_sub_fn)(init_params, rng_key) sample_fn = vmap(self._sample_fn, in_axes=(0, None, None)) - HMCCombinedState = tuplemerge(init_state._asdict,init_subsample_state._asdict()) + HMCCombinedState = tuplemerge(init_state._asdict, init_subsample_state._asdict()) self._sample_fn = sample_fn return HMCCombinedState @@ -1019,16 +1030,16 @@ def sample(self, state, model_args, model_kwargs): state.rng_key, 4) if self.estimator == "poisson": u_new = _sample_u_poisson(rng_key, self.m, self._l) - neg_ll, sign = signed_estimator(model = self._model, + neg_ll, sign = signed_estimator(model=self._model, model_args=[model_args_sub(u_i, model_args) for u_i in u_new], model_kwargs=model_kwargs, z=state.z, - l =self._l, - proxy_fn = self._proxy_fn, - proxy_u_fn = self._proxy_u_fn) - self._sign=jnp.array(sign) - #self._sign = jnp.append(self._sign,jnp.array([sign]),axis=0) - #self._sign = self._sign[jnp.isfinite(self._sign)] #remove dummy start point, since we annot initialize empty arrays + l=self._l, + proxy_fn=self._proxy_fn, + proxy_u_fn=self._proxy_u_fn) + self._sign = jnp.array(sign) + # self._sign = jnp.append(self._sign,jnp.array([sign]),axis=0) + # self._sign = self._sign[jnp.isfinite(self._sign)] #remove dummy start point, since we annot initialize empty arrays # Correct the negativeloglikelihood by substracting the density of the prior to calculate the potential llu_new = jnp.min(jnp.array([0, -neg_ll + state.ll_u])) @@ -1036,53 +1047,49 @@ def sample(self, state, model_args, model_kwargs): u_new = _update_block(rng_key_subsample, state.u, self._n, self.m, self.g) # estimate likelihood of subsample with single block updated llu_new = self._potential_fn(model=self._model, - model_args=model_args_sub(u_new,model_args), - model_kwargs=model_kwargs, - z=state.z, - n=self._n, - m=self.m, - proxy_fn=self._proxy_fn, - proxy_u_fn=self._proxy_u_fn) + model_args=model_args_sub(u_new, model_args), + model_kwargs=model_kwargs, + z=state.z, + n=self._n, + m=self.m, + proxy_fn=self._proxy_fn, + proxy_u_fn=self._proxy_u_fn) # accept new subsample with probability min(1,L^{hat}_{u_new}(z) - L^{hat}_{u}(z)) # NOTE: latent variables (z aka theta) same, subsample indices (u) different by one block. accept_prob = jnp.clip(jnp.exp(-llu_new + state.ll_u), a_max=1.) - transition = random.bernoulli(rng_key_transition, accept_prob) #TODO: Why Bernoulli instead of Uniform? + transition = random.bernoulli(rng_key_transition, accept_prob) # TODO: Why Bernoulli instead of Uniform? u, ll_u = cond(transition, (u_new, llu_new), identity, (state.u, state.ll_u), identity) - ######## UPDATE PARAMETERS ########## - hmc_subsamplestate = HMCECSState(u=u, hmc_state=state.hmc_state,ll_u=ll_u,sign=self._sign) - hmc_subsamplestate = tuplemerge(hmc_subsamplestate._asdict(),state._asdict()) + hmc_subsamplestate = HMCECSState(u=u, hmc_state=state.hmc_state, ll_u=ll_u, sign=self._sign) + hmc_subsamplestate = tuplemerge(hmc_subsamplestate._asdict(), state._asdict()) return self._sample_fn(hmc_subsamplestate, model_args=model_args, model_kwargs=model_kwargs, subsample_method=self.subsample_method, - estimator =self.estimator, - proxy_fn = self._proxy_fn, - proxy_u_fn = self._proxy_u_fn, - model = self._model, - ll_ref = self._ll_ref, - jac_all =self._jac_all, - z= state.z, - z_ref = self.z_ref, - hess_all = self._hess_all, - ll_u = ll_u, - u= u, - n= self._n, - m= self.m, + estimator=self.estimator, + proxy_fn=self._proxy_fn, + proxy_u_fn=self._proxy_u_fn, + model=self._model, + ll_ref=self._ll_ref, + jac_all=self._jac_all, + z=state.z, + z_ref=self.z_ref, + hess_all=self._hess_all, + ll_u=ll_u, + u=u, + n=self._n, + m=self.m, l=self._l, - sign = self._sign) + sign=self._sign) else: return self._sample_fn(state, model_args, model_kwargs) - - - class NUTS(HMCECS): """ Hamiltonian Monte Carlo inference, using the No U-Turn Sampler (NUTS) @@ -1127,6 +1134,7 @@ class NUTS(HMCECS): :param bool find_heuristic_step_size: whether to a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False. """ + def __init__(self, model=None, potential_fn=None, From 0c180266763edeafd8fff77c27db7d970ce7582e Mon Sep 17 00:00:00 2001 From: OlaRonning Date: Mon, 14 Dec 2020 21:43:33 +0100 Subject: [PATCH 33/93] HMCECS working, fixed problems with SVI MAP and factored code. --- examples/hmcecs/covtype.py | 73 +++ examples/hmcecs/higgs.py | 138 ++++++ .../contrib}/autoguide_hmcecs.py | 0 numpyro/contrib/hmcecs.py | 454 ++++++++++-------- numpyro/contrib/hmcecs_utils.py | 72 +-- 5 files changed, 504 insertions(+), 233 deletions(-) create mode 100644 examples/hmcecs/covtype.py create mode 100644 examples/hmcecs/higgs.py rename {examples => numpyro/contrib}/autoguide_hmcecs.py (100%) diff --git a/examples/hmcecs/covtype.py b/examples/hmcecs/covtype.py new file mode 100644 index 000000000..fdac66a04 --- /dev/null +++ b/examples/hmcecs/covtype.py @@ -0,0 +1,73 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import time + +from jax import random +import jax.numpy as jnp + +import numpyro +import numpyro.distributions as dist +from numpyro.examples.datasets import COVTYPE, load_dataset +from numpyro.infer import MCMC, NUTS + + +def _load_dataset(): + _, fetch = load_dataset(COVTYPE, shuffle=False) + features, labels = fetch() + + # normalize features and add intercept + features = (features - features.mean(0)) / features.std(0) + features = jnp.hstack([features, jnp.ones((features.shape[0], 1))]) + + # make binary feature + _, counts = jnp.unique(labels, return_counts=True) + specific_category = jnp.argmax(counts) + labels = (labels == specific_category) + + N, dim = features.shape + print("Data shape:", features.shape) + print("Label distribution: {} has label 1, {} has label 0" + .format(labels.sum(), N - labels.sum())) + return features, labels + + +def model(data, labels): + dim = data.shape[1] + coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim))) + logits = jnp.dot(data, coefs) + return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) + + +def benchmark_hmc(args, features, labels): + step_size = jnp.sqrt(0.5 / features.shape[0]) + trajectory_length = step_size * args.num_steps + rng_key = random.PRNGKey(1) + start = time.time() + kernel = NUTS(model, trajectory_length=trajectory_length) + mcmc = MCMC(kernel, 0, args.num_samples) + mcmc.run(rng_key, features, labels) + mcmc.print_summary() + print('\nMCMC elapsed time:', time.time() - start) + + +def main(args): + features, labels = _load_dataset() + benchmark_hmc(args, features, labels) + + +if __name__ == '__main__': + assert numpyro.__version__.startswith('0.4.1') + parser = argparse.ArgumentParser(description="parse args") + parser.add_argument('-n', '--num-samples', default=100, type=int, help='number of samples') + parser.add_argument('--num-steps', default=10, type=int, help='number of steps (for "HMC")') + parser.add_argument('--num-chains', nargs='?', default=1, type=int) + parser.add_argument('--algo', default='NUTS', type=str, help='whether to run "HMC" or "NUTS"') + parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".') + args = parser.parse_args() + + numpyro.set_platform(args.device) + numpyro.set_host_device_count(args.num_chains) + + main(args) diff --git a/examples/hmcecs/higgs.py b/examples/hmcecs/higgs.py new file mode 100644 index 000000000..527d3a82d --- /dev/null +++ b/examples/hmcecs/higgs.py @@ -0,0 +1,138 @@ +import jax +import jax.numpy as jnp +import numpy as np +from jax import jit, lax +from sklearn.datasets import load_breast_cancer + +import numpyro +import numpyro.distributions as dist +from examples.logistic_hmcecs_svi import svi_map +from numpyro import optim +from numpyro.contrib.autoguide_hmcecs import AutoDiagonalNormal +from numpyro.contrib.hmcecs import HMCECS +from numpyro.infer import NUTS, MCMC +from numpyro.infer.elbo import ELBO +from numpyro.infer.svi import SVI +from numpyro.util import fori_loop + +numpyro.set_platform("cpu") + + +def load_dataset(observations, features, batch_size=None, shuffle=True): + num_records = observations.shape[0] + idxs = jnp.arange(num_records) + if not batch_size: + batch_size = num_records + + def init(): + return num_records // batch_size, np.random.permutation(idxs) if shuffle else idxs + + def get_batch(i=0, idxs=idxs): + ret_idx = lax.dynamic_slice_in_dim(idxs, i * batch_size, batch_size) + batch_obs = jnp.take(observations, ret_idx, axis=0) + batch_feats = jnp.take(features, ret_idx, axis=0) + return batch_obs, batch_feats + + return init, get_batch + + +def svi_map(model, rng_key, feats, obs, num_epochs, batch_size): + @jit + def epoch_train(svi_state): + def body_fn(i, val): + batch_obs, batch_feats = train_fetch(i, train_idx) + loss_sum, svi_state = val + svi_state, loss = svi.update(svi_state, batch_feats, batch_obs) + loss_sum += loss + return loss_sum, svi_state + + return fori_loop(0, num_train, body_fn, (0., svi_state)) + + n, _ = feats.shape + guide = AutoDiagonalNormal(model) + svi = SVI(model, guide, optim.Adam(0.0003), loss=ELBO()) + svi_state = svi.init(rng_key, feats, obs) + train_init, train_fetch = load_dataset(obs, feats, batch_size=batch_size) + + for i in range(num_epochs): + num_train, train_idx = train_init() + train_loss, svi_state = epoch_train(svi_state) + return svi.get_params(svi_state), svi, svi_state + + +def breast_cancer_data(): + """ Logistic regression model as implemetned in https://arxiv.org/pdf/1708.00955.pdf with Higgs Dataset """ + dataset = load_breast_cancer() + feats = dataset.data + feats = (feats - feats.mean(0)) / feats.std(0) + feats = jnp.hstack((feats, jnp.ones((feats.shape[0], 1)))) + + return feats, dataset.target + + +def model(feats, obs): + """ Logistic regression model """ + n, m = feats.shape + theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), 2 * jnp.ones(m))) + numpyro.sample('obs', dist.Bernoulli(logits=jnp.matmul(feats, theta)), obs=obs) + + +def infer_hmcecs(rng_key, feats, obs, m=None, g=None, n_samples=None, warmup=None, algo="NUTS", subsample_method=None, + map_method=None, proxy="taylor", estimator=None, num_epochs=None, postprocess_fn=None): + hmcecs_key, map_key = jax.random.split(rng_key) + n, _ = feats.shape + + if map_method == "SVI": + factor_SVI = obs.shape[0] + batch_size = 32 + map_key, post_key = jax.random.split(map_key) + z_ref, svi, svi_state = svi_map(model, map_key, feats=feats[:factor_SVI], obs=obs[:factor_SVI], + num_epochs=num_epochs, batch_size=batch_size) + z_ref = svi.guide.sample_posterior(post_key, svi.get_params(svi_state), (100,)) + z_ref = {name: value.mean(0) for name, value in z_ref.items()} + else: + svi = None + map_samples = 10 + map_warmup = 5 + if map_method == "NUTS": + kernel = NUTS(model=model, target_accept_prob=0.8) + if map_method == 'HMC': + kernel = NUTS(model=model, target_accept_prob=0.8) + mcmc = MCMC(kernel, num_warmup=map_warmup, num_samples=map_samples) + mcmc.run(rng_key, feats, obs) + samples = mcmc.get_samples() + z_ref = {key: value.mean(0) for key, value in samples.items()} + + extra_fields = [] + if estimator == "poisson": + postprocess_fn = None + extra_fields = ("sign",) + + kernel = HMCECS(model=model, z_ref=z_ref, m=m, g=g, algo=algo, subsample_method=subsample_method, proxy=proxy, + svi_fn=svi, estimator=estimator, target_accept_prob=0.8) + + mcmc = MCMC(kernel, num_warmup=warmup, num_samples=n_samples, num_chains=1, postprocess_fn=postprocess_fn) + mcmc.run(rng_key, feats, obs, extra_fields=extra_fields) + + return mcmc.get_samples() + + +if __name__ == '__main__': + num_samples = 10 + num_warmup = 5 + ecs_algo = 'NUTS' + ecs_proxy = 'taylor' + estimator = 'perturb' + map_init = 'SVI' + epochs = 1000 + rng_key = jax.random.PRNGKey(37) + + feats, obs = breast_cancer_data() + + n, = obs.shape + m = int(jnp.sqrt(n)) + g = 5 + + infer_hmcecs(rng_key, feats=feats, obs=obs, n_samples=num_samples, + warmup=num_warmup, m=m, g=g, algo=ecs_algo, subsample_method="perturb", + proxy=ecs_proxy, estimator=estimator, map_method=map_init, num_epochs=epochs) diff --git a/examples/autoguide_hmcecs.py b/numpyro/contrib/autoguide_hmcecs.py similarity index 100% rename from examples/autoguide_hmcecs.py rename to numpyro/contrib/autoguide_hmcecs.py diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py index a7ca10d39..6b3e25183 100644 --- a/numpyro/contrib/hmcecs.py +++ b/numpyro/contrib/hmcecs.py @@ -1,15 +1,19 @@ """Contributed code for HMC and NUTS energy conserving sampling adapted from """ -from collections import namedtuple import math import os import warnings +from collections import namedtuple +from itertools import chain -from jax import device_put, lax, partial, random, vmap,jacfwd, hessian,jit,ops +import jax.numpy as jnp +from jax import device_put, lax, partial, random, vmap, jacfwd, hessian, jit, ops from jax.dtypes import canonicalize_dtype from jax.flatten_util import ravel_pytree -import jax.numpy as jnp +import numpyro.distributions as dist +from numpyro.contrib.hmcecs_utils import potential_est, init_near_values, tuplemerge, \ + model_args_sub, model_kwargs_sub, taylor_proxy, svi_proxy, log_density_obs_hmcecs, signed_estimator from numpyro.infer.hmc_util import ( IntegratorState, build_tree, @@ -19,19 +23,13 @@ warmup_adapter ) from numpyro.infer.mcmc import MCMCKernel -from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model, log_density +from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model from numpyro.util import cond, fori_loop, identity -import sys -sys.path.append('/home/lys/Dropbox/PhD/numpyro/numpyro/contrib/') #TODO: remove -import numpyro.distributions as dist -from itertools import chain -from hmcecs_utils import potential_est, init_near_values,tuplemerge,\ - model_args_sub,model_kwargs_sub,taylor_proxy,svi_proxy,neural_proxy,log_density_obs_hmcecs,log_density_prior_hmcecs,signed_estimator HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob', - 'mean_accept_prob', 'diverging', 'adapt_state','rng_key']) + 'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key']) -HMCECSState = namedtuple("HMCECState",['u', 'hmc_state', 'll_u','sign']) +HMCECSState = namedtuple("HMCECState", ['u', 'hmc_state', 'll_u', 'sign', 'z_and_sign']) """ A :func:`~collections.namedtuple` consisting of the following fields: @@ -89,6 +87,7 @@ def momentum_generator(prototype_r, mass_matrix_sqrt, rng_key): else: raise ValueError("Mass matrix has incorrect number of dims.") + @partial(jit, static_argnums=(2, 3, 4)) def _update_block(rng_key, u, n, m, g): """Returns indexes of the new subsample. The update mechanism selects blocks of indices within the subsample to be updated. @@ -100,19 +99,24 @@ def _update_block(rng_key, u, n, m, g): :param g block size: subsample subdivision""" if (g > m) or (g < 1): - raise ValueError('Block size (g) = {} needs to = or > than 1 and smaller than the subsample size {}'.format(g,m)) + raise ValueError( + 'Block size (g) = {} needs to = or > than 1 and smaller than the subsample size {}'.format(g, m)) rng_key_block, rng_key_index = random.split(rng_key) # uniformly choose block to update - chosen_block = random.randint(rng_key, shape=(), minval= 0, maxval=g + 1) - idxs_new = random.randint(rng_key_index, shape=(m // g,), minval=0, maxval=n) #choose block within the subsample to update - u_new = jnp.zeros(m, jnp.dtype(u)) #empty array with size m + chosen_block = random.randint(rng_key, shape=(), minval=0, maxval=g + 1) + idxs_new = random.randint(rng_key_index, shape=(m // g,), minval=0, + maxval=n) # choose block within the subsample to update + u_new = jnp.zeros(m, jnp.dtype(u)) # empty array with size m for i in range(m): - #if index in the subsample // g = chosen block : pick new indexes from the subsample size - #else not update: keep the same indexes + # if index in the subsample // g = chosen block : pick new indexes from the subsample size + # else not update: keep the same indexes u_new = ops.index_add(u_new, i, lax.cond(i // g == chosen_block, i, lambda _: idxs_new[i % (m // g)], i, lambda _: u[i])) return u_new + +# @partial(jit, static_argnums=(0,1,2)) +# @functools.partial(jit, static_argnums=(2)) def _sample_u_poisson(rng_key, m, l): """ Initialize subsamples u ***References*** @@ -123,12 +127,29 @@ def _sample_u_poisson(rng_key, m, l): :param g: number of blocks """ pois_key, sub_key = random.split(rng_key) - block_lengths = dist.discrete.Poisson(1).sample(pois_key, (l,)) #lambda block lengths - #u = random.randint(sub_key, (jnp.sum(block_lengths), ), 0, m) - u = random.randint(sub_key, (jnp.sum(block_lengths), m), 0, m) + block_lengths = dist.discrete.Poisson(1).sample(pois_key, (l,)) # lambda block lengths + # u = random.randint(sub_key, (jnp.sum(block_lengths), m), 0, m) + # @partial(mask, in_shapes=['(_,)'], out_shape='(_, _)') + # def u_rand(block_lenghts): + # b = jnp.sum(block_lengths).astype(int) + # #return jit(random.randint, static_argnums=(0,1, 2,3))(sub_key, (b,m), 0, m) + # return random.randint(sub_key, (b,m), 0, m) + # u = u_rand([block_lengths],{})#dict(b=jnp.sum(block_lengths).astype(int),m=m,l=l)) + # print(u.shape) + b = jnp.sum(block_lengths) + u_random = jit(random.randint, static_argnums=(0, 1, 2, 3)) + u = u_random(sub_key, (b, m), 0, m) + # @partial(mask,in_shapes=['(tmp,)'],out_shape='(b,)') + # def u_rand(block_lengths): + # return jnp.zeros(jnp.sum(block_lengths)) + # u = u_rand([block_lengths],dict(tmp=l,b=jnp.sum(block_lengths))) + # print(u.shape) + # exit() + return jnp.split(u, jnp.cumsum(block_lengths), axis=0) -@partial(jit, static_argnums=(2, 3, 4)) + +@partial(jit, static_argnums=(2, 3, 4, 5)) def _update_block_poisson(rng_key, u, m, l, g): """ Update block of u, where the length of the block of indexes to update is given by the Poisson distribution. ***References*** @@ -141,18 +162,20 @@ def _update_block_poisson(rng_key, u, m, l, g): :param g: Block size within subsample """ if (g > m) or (g < 1): - raise ValueError('Block size (g) = {} needs to = or > than 1 and smaller than the subsample size {}'.format(g,m)) + raise ValueError( + 'Block size (g) = {} needs to = or > than 1 and smaller than the subsample size {}'.format(g, m)) u = u.copy() block_key, sample_key = random.split(rng_key) - num_updates = int(round(l / g, 0)) # choose lambda/g number of blocks to update + num_updates = int(round(l / g, 0)) # choose lambda/g number of blocks to update chosen_blocks = random.randint(block_key, (num_updates,), 0, l) - new_blocks = _sample_u_poisson(sample_key, m, num_updates) + _sample_u_poisson_jit = jit(_sample_u_poisson, static_argnums=(2)) + new_blocks = _sample_u_poisson_jit(sample_key, m, num_updates) for i, block in enumerate(chosen_blocks): u[block] = new_blocks[i] return u -def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, grad_potential_fn_gen=None,algo='NUTS'): +def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, grad_potential_fn_gen=None, algo='NUTS'): r""" Hamiltonian Monte Carlo inference, using either fixed number of steps or the No U-Turn Sampler (NUTS) with adaptive path length. @@ -244,27 +267,28 @@ def init_kernel(init_params, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, - trajectory_length=2*math.pi, + trajectory_length=2 * math.pi, max_tree_depth=10, find_heuristic_step_size=False, model_args=(), model_kwargs=None, - model = None, - sign = None, + model=None, + sign=None, + sign_sum=None, ll_ref=None, jac_all=None, - z_ref= None, + z_ref=None, hess_all=None, - ll_u = None, - n = None, - m = None, - u= None, + ll_u=None, + n=None, + m=None, + u=None, l=None, rng_key=random.PRNGKey(0), subsample_method=None, estimator=None, proxy_fn=None, - proxy_u_fn = None): + proxy_u_fn=None): """ Initializes the HMC sampler. @@ -335,9 +359,11 @@ def init_kernel(init_params, if subsample_method == "perturb": kwargs = {} if model_kwargs is None else model_kwargs if estimator == "poisson": - pe_fn = potential_fn_gen(model=model, model_args=model_args, model_kwargs=kwargs, z=z, l=l,proxy_fn=proxy_fn, proxy_u_fn=proxy_u_fn) + pe_fn = potential_fn_gen(model=model, model_args=model_args, model_kwargs=kwargs, z=z, l=l, + proxy_fn=proxy_fn, proxy_u_fn=proxy_u_fn) else: - pe_fn = potential_fn_gen(model=model, model_args=model_args, model_kwargs=kwargs, z=z, n=n, m=m,proxy_fn=proxy_fn, proxy_u_fn=proxy_u_fn) + pe_fn = potential_fn_gen(model=model, model_args=model_args, model_kwargs=kwargs, z=z, n=n, m=m, + proxy_fn=proxy_fn, proxy_u_fn=proxy_u_fn) else: kwargs = {} if model_kwargs is None else model_kwargs @@ -378,25 +404,26 @@ def init_kernel(init_params, energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r) hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, - 0, 0., 0., False, wa_state,rng_key_hmc) + 0, 0., 0., False, wa_state, rng_key_hmc) - hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state,ll_u=ll_u,sign=sign) + z_and_sign = {**vv_state.z, 'sign': sign, "sign_sum": sign_sum} - hmc_state = tuplemerge(hmc_sub_state._asdict(),hmc_state._asdict()) + hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state, ll_u=ll_u, sign=sign, z_and_sign=z_and_sign) + hmc_state = tuplemerge(hmc_sub_state._asdict(), hmc_state._asdict()) return device_put(hmc_state) def _hmc_next(step_size, inverse_mass_matrix, vv_state, - model_args, model_kwargs, rng_key,subsample_method, + model_args, model_kwargs, rng_key, subsample_method, estimator=None, - proxy_fn = None, - proxy_u_fn = None, - model = None, - ll_ref = None,jac_all = None,z = None,z_ref = None,hess_all = None,ll_u = None,u = None, - n = None, - m = None, - l=None): + proxy_fn=None, + proxy_u_fn=None, + model=None, + ll_ref=None, jac_all=None, z=None, z_ref=None, hess_all=None, ll_u=None, u=None, + n=None, + m=None, + l=None, ): if potential_fn_gen: if grad_potential_fn_gen: kwargs = {} if model_kwargs is None else model_kwargs @@ -448,18 +475,18 @@ def _hmc_next(step_size, inverse_mass_matrix, vv_state, return vv_state, energy, num_steps, accept_prob, diverging def _nuts_next(step_size, inverse_mass_matrix, vv_state, - model_args, model_kwargs, rng_key,subsample_method, + model_args, model_kwargs, rng_key, subsample_method, estimator=None, - proxy_fn=None,proxy_u_fn=None, + proxy_fn=None, proxy_u_fn=None, model=None, - ll_ref=None,jac_all=None,z = None,z_ref=None,hess_all=None,ll_u=None,u=None, - n=None,m=None,l=None): + ll_ref=None, jac_all=None, z=None, z_ref=None, hess_all=None, ll_u=None, u=None, + n=None, m=None, l=None): if potential_fn_gen: nonlocal vv_update if grad_potential_fn_gen: - kwargs = {} if model_kwargs is None else model_kwargs - gpe_fn = grad_potential_fn_gen(*model_args, **kwargs, ) - pe_fn = potential_fn_gen(*model_args, **model_kwargs) + kwargs = {} if model_kwargs is None else model_kwargs + gpe_fn = grad_potential_fn_gen(*model_args, **kwargs, ) + pe_fn = potential_fn_gen(*model_args, **model_kwargs) else: if subsample_method == "perturb": if estimator == "poisson": @@ -500,9 +527,9 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, _next = _nuts_next if algo == 'NUTS' else _hmc_next - def sample_kernel(hmc_state,model_args=(),model_kwargs=None, + def sample_kernel(hmc_state, model_args=(), model_kwargs=None, subsample_method=None, - estimator = None, + estimator=None, proxy_fn=None, proxy_u_fn=None, model=None, @@ -512,8 +539,9 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, z_ref=None, hess_all=None, ll_u=None, - sign = None, - u=None,n=None,m=None,l=None): + sign=None, + u=None, n=None, m=None, l=None, + sign_sum=None): """ Given an existing :data:`~numpyro.infer.mcmc.HMCState`, run HMC with fixed (possibly adapted) step size and return a new :data:`~numpyro.infer.mcmc.HMCState`. @@ -541,11 +569,11 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, """ model_kwargs = {} if model_kwargs is None else model_kwargs - if subsample_method =="perturb": + if subsample_method == "perturb": if estimator == "poisson": - model_args = [model_args_sub(u_i, model_args) for u_i in u] #here u = poisson_u + model_args = [model_args_sub(u_i, model_args) for u_i in u] # here u = poisson_u else: - model_args = model_args_sub(u,model_args) + model_args = model_args_sub(u, model_args) rng_key, rng_key_momentum, rng_key_transition = random.split(hmc_state.rng_key, 3) r = momentum_generator(hmc_state.z, hmc_state.adapt_state.mass_matrix_sqrt, rng_key_momentum) vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad) @@ -561,8 +589,8 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, proxy_fn, proxy_u_fn, model, - ll_ref,jac_all,z,z_ref,hess_all,ll_u,u, - n,m,l) + ll_ref, jac_all, z, z_ref, hess_all, ll_u, u, + n, m, l) # not update adapt_state after warmup phase adapt_state = cond(hmc_state.i < wa_steps, (hmc_state.i, accept_prob, vv_state, hmc_state.adapt_state), @@ -570,24 +598,19 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, hmc_state.adapt_state, identity) - itr = hmc_state.i + 1 n = jnp.where(hmc_state.i < wa_steps, itr, itr - wa_steps) mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n hmcstate = HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps, - accept_prob, mean_accept_prob, diverging, adapt_state,rng_key) + accept_prob, mean_accept_prob, diverging, adapt_state, rng_key) - # Highlight: The accepted proposals samples are in vv_state.z /hmcstate.z, as we return them, we change their sign - # #TODO: Make this prettier - # if subsample_method == "perturb" and estimator == "poisson" and itr > wa_steps: - # z_new={} - # for x,y in hmcstate.z.items(): - # z_new[x] = y*sign[-1] - # hmcstate = hmcstate._replace(z=z_new) + # Highlight: The accepted proposals samples are in vv_state.z /hmcstate.z, we store them together with the sign - hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state,ll_u=ll_u,sign = sign) - hmcstate = tuplemerge(hmc_sub_state._asdict(),hmcstate._asdict()) + sign_sum = cond(hmc_state.i < wa_steps, sign_sum, lambda sign_sum: float(0), sign_sum, identity) + z_and_sign = {**vv_state.z, 'sign': sign, "sign_sum": sign_sum} + hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state, ll_u=ll_u, sign=sign, z_and_sign=z_and_sign) + hmcstate = tuplemerge(hmc_sub_state._asdict(), hmcstate._asdict()) return hmcstate @@ -599,6 +622,7 @@ def sample_kernel(hmc_state,model_args=(),model_kwargs=None, return init_kernel, sample_kernel + def _log_prob(trace): """ Compute probability of each observation """ node = trace['observations'] @@ -649,10 +673,11 @@ class HMCECS(MCMCKernel): :param z_ref MAP estimate of the parameters :param covariate_fn Proxy function to calculate the covariates for the likelihood correction """ + def __init__(self, model=None, potential_fn=None, - grad_potential = None, + grad_potential=None, kinetic_fn=None, step_size=1.0, adapt_step_size=True, @@ -662,15 +687,15 @@ def __init__(self, trajectory_length=2 * math.pi, init_strategy=init_to_uniform, find_heuristic_step_size=False, - subsample_method = None, + subsample_method=None, estimator=None, # poisson or not proxy="taylor", svi_fn=None, - m= None, - g = None, - z_ref= None, - algo = "HMC", - postprocess_fn = None, + m=None, + g=None, + z_ref=None, + algo="HMC", + postprocess_fn=None, ): if not (model is None) ^ (potential_fn is None): raise ValueError('Only one of `model` or `potential_fn` must be specified.') @@ -689,7 +714,7 @@ def __init__(self, self._max_tree_depth = 10 self._init_strategy = init_strategy self._find_heuristic_step_size = find_heuristic_step_size - #HMCECS parameters + # HMCECS parameters self.subsample_method = subsample_method self.m = m if m is not None else 4 self.g = g if g is not None else 2 @@ -701,13 +726,14 @@ def __init__(self, self._ll_u = None self._u = None self._sign = None + self._sign_sum = float(0) self._l = 100 # Set on first call to init self._init_fn = None self._postprocess_fn = postprocess_fn self._sample_fn = None self._subsample_fn = None - self._sign = jnp.array([jnp.nan]) + self._sign = float(0) self.proxy = proxy self.svi_fn = svi_fn self._proxy_fn = None @@ -715,7 +741,7 @@ def __init__(self, self._signed_estimator_fn = None self.estimator = estimator - def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ref): + def _init_subsample_state(self, rng_key, model_args, model_kwargs, init_params, z_ref): "Compute the jacobian, hessian and log likelihood for all the data. Used with taylor expansion proxy" rng_key_subsample, rng_key_model, rng_key_hmc_init, rng_key_potential, rng_key = random.split(rng_key, 5) @@ -724,28 +750,30 @@ def _init_subsample_state(self,rng_key, model_args, model_kwargs, init_params,z_ hess_all, _ = ravel_pytree(hessian(ld_fn)(z_ref)) k, = self._jac_all.shape self._hess_all = hess_all.reshape((k, k)) - ld_fn = lambda args: partial(log_density_obs_hmcecs,self._model,model_args,model_kwargs)(args)[0] + ld_fn = lambda args: partial(log_density_obs_hmcecs, self._model, model_args, model_kwargs)(args)[0] self._ll_ref = ld_fn(z_ref) - def _init_state(self, rng_key, model_args, model_kwargs, init_params): if self.subsample_method is not None: assert self.z_ref is not None, "Please provide a (i.e map) estimate for the parameters" self._n = model_args[0].shape[0] # Choose the covariate calculation method if self.proxy == "svi": - self._proxy_fn,self._proxy_u_fn = svi_proxy(self.svi_fn,model_args,model_kwargs) + self._proxy_fn, self._proxy_u_fn = svi_proxy(self.svi_fn, model_args, model_kwargs) elif self.proxy == "taylor": warnings.warn("Using default second order Taylor expansion, change by using the proxy flag to {svi}") self._init_subsample_state(rng_key, model_args, model_kwargs, init_params, self.z_ref) - self._proxy_fn,self._proxy_u_fn = taylor_proxy(self.z_ref, self._model, self._ll_ref, self._jac_all, self._hess_all) - if self.estimator =="poisson": - self._l = 25 # lambda subsamples + self._proxy_fn, self._proxy_u_fn = taylor_proxy(self.z_ref, self._model, self._ll_ref, self._jac_all, + self._hess_all) + if self.estimator == "poisson": + self._l = 50 # lambda subsamples + # _sample_u_poisson_jit = jit(_sample_u_poisson, static_argnums=(1, 2)) self._u = _sample_u_poisson(rng_key, self.m, self._l) - self._potential_fn = lambda model,model_args,model_kwargs,z,l, proxy_fn,proxy_u_fn : lambda z:signed_estimator(model = model,model_args=model_args, - model_kwargs= model_kwargs,z=z,l=l,proxy_fn=proxy_fn, - proxy_u_fn=proxy_u_fn)[0] + self._potential_fn = lambda model, model_args, model_kwargs, z, l, proxy_fn, proxy_u_fn: lambda z: \ + signed_estimator(model=model, model_args=model_args, + model_kwargs=model_kwargs, z=z, l=l, proxy_fn=proxy_fn, + proxy_u_fn=proxy_u_fn)[0] # Initialize the hmc sampler: sample_fn = sample_kernel self._init_fn, self._sample_fn = hmc(potential_fn_gen=self._potential_fn, kinetic_fn=euclidean_kinetic_energy, @@ -762,21 +790,24 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): self._model, init_strategy=self._init_strategy, dynamic_args=True, - model_args=tuple([arg[0] for arg in next(chain(init_model_args))]), #Highlight:Pick the first non-empty block ; 'chain' joins all the elements in the sublist , len(lists_of_lists) = n , len(chain(list_of_lists)) = sum(n_elements_inside_list=*n) + model_args=tuple([arg[0] for arg in next(chain(init_model_args))]), + # Highlight:Pick the first non-empty block ; 'chain' joins all the elements in the sublist , len(lists_of_lists) = n , len(chain(list_of_lists)) = sum(n_elements_inside_list=*n) model_kwargs=model_kwargs) + self._postprocess_fn = self._poisson_samples_correction else: self._u = random.randint(rng_key, (self.m,), 0, self._n) # Initialize the potential and gradient potential functions - self._potential_fn = lambda model, model_args, model_kwargs, z, n, m, proxy_fn, proxy_u_fn : lambda z:potential_est(model=model, - model_args=model_args, model_kwargs=model_kwargs, z=z, n=n, m=m, proxy_fn=proxy_fn, proxy_u_fn=proxy_u_fn) + self._potential_fn = lambda model, model_args, model_kwargs, z, n, m, proxy_fn, proxy_u_fn: lambda \ + z: potential_est(model=model, + model_args=model_args, model_kwargs=model_kwargs, z=z, n=n, m=m, proxy_fn=proxy_fn, + proxy_u_fn=proxy_u_fn) # Initialize the hmc sampler: sample_fn = sample_kernel self._init_fn, self._sample_fn = hmc(potential_fn_gen=self._potential_fn, - kinetic_fn=euclidean_kinetic_energy, - algo=self._algo) - + kinetic_fn=euclidean_kinetic_energy, + algo=self._algo) self._init_strategy = partial(init_near_values, values=self.z_ref) # Initialize the model parameters @@ -791,12 +822,12 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): model_kwargs=model_kwargs) if (self.g > self.m) or (self.g < 1): - raise ValueError( - 'Block size (g) = {} needs to = or > than 1 and smaller than the subsample size {}'.format(self.g, - self.m)) + raise ValueError( + 'Block size (g) = {} needs to = or > than 1 and smaller than the subsample size {}'.format(self.g, + self.m)) elif (self.m > self._n): - raise ValueError( - 'Subsample size (m) = {} needs to = or < than data size (n) {}'.format(self.m, self._n)) + raise ValueError( + 'Subsample size (m) = {} needs to = or < than data size (n) {}'.format(self.m, self._n)) else: if self._model is not None: @@ -824,25 +855,30 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): return init_params - @property def model(self): return self._model @property def sample_field(self): - return 'z' + if self.estimator == "poisson": + return "z_and_sign" + else: + return "z" @property def default_fields(self): - return ('z', 'diverging','sign') + if self.estimator == "poisson": + return ('z', 'diverging', 'sign', "z_and_sign") + else: + return 'z' def get_diagnostics_str(self, state): return '{} steps of size {:.2e}. acc. prob={:.2f}'.format(state.num_steps, state.adapt_state.step_size, state.mean_accept_prob) - def _block_indices(self,size, num_blocks): + def _block_indices(self, size, num_blocks): a = jnp.repeat(jnp.arange(num_blocks - 1), size // num_blocks) b = jnp.repeat(num_blocks - 1, size - len(jnp.repeat(jnp.arange(num_blocks - 1), size // num_blocks))) return jnp.hstack((a, b)) @@ -856,8 +892,8 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg else: rng_key, rng_key_init_model = jnp.swapaxes(vmap(random.split)(rng_key), 0, 1) - - init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params) #should work for all cases + init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, + init_params) # should work for all cases if self._potential_fn and init_params is None: raise ValueError('Valid value of `init_params` must be provided with' @@ -866,58 +902,58 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg if self.estimator == "poisson": init_model_args = [model_args_sub(u_i, model_args) for u_i in self._u] else: - init_model_args = model_args_sub(self._u,model_args) - hmc_init_fn = lambda init_params,rng_key: self._init_fn(init_params=init_params, - num_warmup = num_warmup, - step_size = self._step_size, - adapt_step_size = self._adapt_step_size, - adapt_mass_matrix = self._adapt_mass_matrix, - dense_mass = self._dense_mass, - target_accept_prob = self._target_accept_prob, - trajectory_length=self._trajectory_length, - max_tree_depth=self._max_tree_depth, - find_heuristic_step_size=self._find_heuristic_step_size, - model_args=init_model_args, - model_kwargs=model_kwargs, - subsample_method= self.subsample_method, - estimator= self.estimator, - model=self._model, - ll_ref =self._ll_ref, - jac_all=self._jac_all, - z_ref=self.z_ref, - hess_all = self._hess_all, - ll_u = self._ll_u, - n=self._n, - m=self.m, - u = self._u, - l = self._l, - sign = self._sign, - proxy_fn = self._proxy_fn, - proxy_u_fn = self._proxy_u_fn) - - if rng_key.ndim ==1: - #rng_key_hmc_init = jnp.array([1000966916, 171341646]) - rng_key_hmc_init,_ = random.split(rng_key) - - init_state = hmc_init_fn(init_params, rng_key_hmc_init) #HMCState + HMCECSState + init_model_args = model_args_sub(self._u, model_args) + hmc_init_fn = lambda init_params, rng_key: self._init_fn(init_params=init_params, + num_warmup=num_warmup, + step_size=self._step_size, + adapt_step_size=self._adapt_step_size, + adapt_mass_matrix=self._adapt_mass_matrix, + dense_mass=self._dense_mass, + target_accept_prob=self._target_accept_prob, + trajectory_length=self._trajectory_length, + max_tree_depth=self._max_tree_depth, + find_heuristic_step_size=self._find_heuristic_step_size, + model_args=init_model_args, + model_kwargs=model_kwargs, + subsample_method=self.subsample_method, + estimator=self.estimator, + model=self._model, + ll_ref=self._ll_ref, + jac_all=self._jac_all, + z_ref=self.z_ref, + hess_all=self._hess_all, + ll_u=self._ll_u, + n=self._n, + m=self.m, + u=self._u, + l=self._l, + sign=self._sign, + sign_sum=self._sign_sum, + proxy_fn=self._proxy_fn, + proxy_u_fn=self._proxy_u_fn) + + if rng_key.ndim == 1: + # rng_key_hmc_init = jnp.array([1000966916, 171341646]) + rng_key_hmc_init, _ = random.split(rng_key) + + init_state = hmc_init_fn(init_params, rng_key_hmc_init) # HMCState + HMCECSState if self.estimator == "poisson": - #signed pseudo-marginal algorithm with the block-Poisson estimator - #use the term signed PM for any pseudo-marginal algorithm that uses the technique in Lyne + # signed pseudo-marginal algorithm with the block-Poisson estimator + # use the term signed PM for any pseudo-marginal algorithm that uses the technique in Lyne # et al. (2015) where a pseudo-marginal sampler is run on the absolute value of the estimated # posterior and subsequently sign-corrected by importance sampling. Similarly, we call the # algorithm described in this section signed HMC-ECS - neg_ll, sign = signed_estimator(model = self._model, - model_args = [model_args_sub(u_i, model_args)for u_i in self._u], - model_kwargs= model_kwargs, + neg_ll, sign = signed_estimator(model=self._model, + model_args=[model_args_sub(u_i, model_args) for u_i in self._u], + model_kwargs=model_kwargs, z=init_state.z, l=self._l, proxy_fn=self._proxy_fn, - proxy_u_fn = self._proxy_u_fn) + proxy_u_fn=self._proxy_u_fn) - self._sign = jnp.array(sign) #Highlight, do not append the sign here, not necessary + self._sign = jnp.array(sign) # Highlight, do not append the sign here, not necessary self._ll_u = neg_ll - else: self._ll_u = potential_est(model=self._model, model_args=model_args_sub(self._u, model_args), @@ -927,27 +963,30 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg m=self.m, proxy_fn=self._proxy_fn, proxy_u_fn=self._proxy_u_fn) - hmc_init_sub_state = HMCECSState(u=self._u, - hmc_state=init_state.hmc_state, - ll_u=self._ll_u,sign=self._sign) - init_sub_state = tuplemerge(init_state._asdict(),hmc_init_sub_state._asdict()) + z_and_sign = {**init_state.z, 'sign': self._sign, + "sign_sum": self._sign_sum} # ,"num_warmup":num_warmup} + hmc_init_sub_state = HMCECSState(u=self._u, + hmc_state=init_state.hmc_state, + ll_u=self._ll_u, + sign=self._sign, + z_and_sign=z_and_sign) + init_sub_state = tuplemerge(init_state._asdict(), hmc_init_sub_state._asdict()) return init_sub_state - else: #TODO: What is this for? It does not go into it for num_chains>1 + else: # TODO: What is this for? It does not go into it for num_chains>1 raise ValueError("Not implemented for chains > 1") # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth, # wa_steps because those variables do not depend on traced args: init_params, rng_key. init_state = vmap(hmc_init_fn)(init_params, rng_key) if self.estimator == "poisson": - #model_args = [model_args_sub(u_i, model_args)for u_i in self._u] neg_ll, sign = signed_estimator(model=self._model, - model_args=[model_args_sub(u_i, model_args)for u_i in self._u], - model_kwargs= model_kwargs_sub, + model_args=[model_args_sub(u_i, model_args) for u_i in self._u], + model_kwargs=model_kwargs_sub, z=init_state.z, - l = self._l, - proxy_fn = self._proxy_fn, - proxy_u_fn = self._proxy_u_fn) + l=self._l, + proxy_fn=self._proxy_fn, + proxy_u_fn=self._proxy_u_fn) self._sign = jnp.array(sign) self._ll_u = neg_ll @@ -960,13 +999,15 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg m=self.m, proxy_fn=self._proxy_fn, proxy_u_fn=self._proxy_u_fn) + z_and_sign = {**vv_state.z, 'sign': self._sign, "sign_sum": self._sign_sum} + hmc_init_sub_fn = lambda init_params, rng_key: HMCECSState(u=self._u, hmc_state=init_state, + ll_u=self._ll_u, sign=self._sign, + z_and_sign=z_and_sign) - hmc_init_sub_fn = lambda init_params, rng_key: HMCECSState(u=self._u, hmc_state=init_state, ll_u=self._ll_u,sign = self._sign) - - init_subsample_state = vmap(hmc_init_sub_fn)(init_params,rng_key) + init_subsample_state = vmap(hmc_init_sub_fn)(init_params, rng_key) sample_fn = vmap(self._sample_fn, in_axes=(0, None, None)) - HMCCombinedState = tuplemerge(init_state._asdict,init_subsample_state._asdict()) + HMCCombinedState = tuplemerge(init_state._asdict, init_subsample_state._asdict()) self._sample_fn = sample_fn return HMCCombinedState @@ -998,10 +1039,24 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg self._sample_fn = sample_fn return init_state + def _poisson_postprocess(self, states): + """Changes the support of the parameters samples by using the sign estimated during the sampling + Ir = Sum [z_j*sign_j] / Sum (sign_j)""" + states_params = {k: states[k] for k in + states.keys() - {'sign', 'sign_sum'}} # change the support for all the parameters sampled + states_params = {key: (states_params[key] * states["sign"]) / states["sign"] for key in states_params.keys()} + return states_params + + def _poisson_samples_correction(self, states, *args, **kwargs): + """Changes the support of the samples by using the sign estimated during the samplinghttps://github.com/pyro-ppl/funsor + Ir = Sum [z_j*sign_j] / Sum (sign_j)""" + return self._poisson_postprocess + def postprocess_fn(self, args, kwargs): if self._postprocess_fn is None: return identity - return self._postprocess_fn(*args, **kwargs) + else: + return self._postprocess_fn(*args, **kwargs) def sample(self, state, model_args, model_kwargs): """ @@ -1018,25 +1073,25 @@ def sample(self, state, model_args, model_kwargs): rng_key_subsample, rng_key_transition, rng_key_likelihood, rng_key = random.split( state.rng_key, 4) if self.estimator == "poisson": + # _sample_u_poisson_jit = jit(_sample_u_poisson,static_argnums=(0,1,2)) u_new = _sample_u_poisson(rng_key, self.m, self._l) - neg_ll, sign = signed_estimator(model = self._model, + neg_ll, sign = signed_estimator(model=self._model, model_args=[model_args_sub(u_i, model_args) for u_i in u_new], model_kwargs=model_kwargs, z=state.z, - l =self._l, - proxy_fn = self._proxy_fn, - proxy_u_fn = self._proxy_u_fn) - self._sign=jnp.array(sign) - #self._sign = jnp.append(self._sign,jnp.array([sign]),axis=0) - #self._sign = self._sign[jnp.isfinite(self._sign)] #remove dummy start point, since we annot initialize empty arrays + l=self._l, + proxy_fn=self._proxy_fn, + proxy_u_fn=self._proxy_u_fn) + self._sign = jnp.array(sign) + state.z_and_sign["sign_sum"] += self._sign # TODO: Probably is a multiplication # Correct the negativeloglikelihood by substracting the density of the prior to calculate the potential llu_new = jnp.min(jnp.array([0, -neg_ll + state.ll_u])) else: u_new = _update_block(rng_key_subsample, state.u, self._n, self.m, self.g) # estimate likelihood of subsample with single block updated - llu_new = self._potential_fn(model=self._model, - model_args=model_args_sub(u_new,model_args), + llu_new = potential_est(model=self._model, + model_args=model_args_sub(u_new, model_args), model_kwargs=model_kwargs, z=state.z, n=self._n, @@ -1046,43 +1101,41 @@ def sample(self, state, model_args, model_kwargs): # accept new subsample with probability min(1,L^{hat}_{u_new}(z) - L^{hat}_{u}(z)) # NOTE: latent variables (z aka theta) same, subsample indices (u) different by one block. accept_prob = jnp.clip(jnp.exp(-llu_new + state.ll_u), a_max=1.) - transition = random.bernoulli(rng_key_transition, accept_prob) #TODO: Why Bernoulli instead of Uniform? + transition = random.bernoulli(rng_key_transition, accept_prob) # TODO: Why Bernoulli instead of Uniform? u, ll_u = cond(transition, (u_new, llu_new), identity, (state.u, state.ll_u), identity) - ######## UPDATE PARAMETERS ########## - hmc_subsamplestate = HMCECSState(u=u, hmc_state=state.hmc_state,ll_u=ll_u,sign=self._sign) - hmc_subsamplestate = tuplemerge(hmc_subsamplestate._asdict(),state._asdict()) - + z_and_sign = {**state.z, 'sign': self._sign, "sign_sum": self._sign_sum} + hmc_subsamplestate = HMCECSState(u=u, hmc_state=state.hmc_state, ll_u=ll_u, sign=self._sign, + z_and_sign=z_and_sign) + hmc_subsamplestate = tuplemerge(hmc_subsamplestate._asdict(), state._asdict()) return self._sample_fn(hmc_subsamplestate, model_args=model_args, model_kwargs=model_kwargs, subsample_method=self.subsample_method, - estimator =self.estimator, - proxy_fn = self._proxy_fn, - proxy_u_fn = self._proxy_u_fn, - model = self._model, - ll_ref = self._ll_ref, - jac_all =self._jac_all, - z= state.z, - z_ref = self.z_ref, - hess_all = self._hess_all, - ll_u = ll_u, - u= u, - n= self._n, - m= self.m, + estimator=self.estimator, + proxy_fn=self._proxy_fn, + proxy_u_fn=self._proxy_u_fn, + model=self._model, + ll_ref=self._ll_ref, + jac_all=self._jac_all, + z=state.z, + z_ref=self.z_ref, + hess_all=self._hess_all, + ll_u=ll_u, + u=u, + n=self._n, + m=self.m, l=self._l, - sign = self._sign) + sign=self._sign, + sign_sum=state.z_and_sign["sign_sum"]) else: return self._sample_fn(state, model_args, model_kwargs) - - - class NUTS(HMCECS): """ Hamiltonian Monte Carlo inference, using the No U-Turn Sampler (NUTS) @@ -1127,6 +1180,7 @@ class NUTS(HMCECS): :param bool find_heuristic_step_size: whether to a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False. """ + def __init__(self, model=None, potential_fn=None, diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py index b74cdfeaa..6f1cf8ff1 100644 --- a/numpyro/contrib/hmcecs_utils.py +++ b/numpyro/contrib/hmcecs_utils.py @@ -1,17 +1,15 @@ +from collections import namedtuple from functools import partial import jax import jax.numpy as jnp from jax import grad, value_and_grad from jax.tree_util import tree_multimap -import numpyro + import numpyro.distributions as dist from numpyro.distributions.util import is_identically_one from numpyro.handlers import substitute, trace from numpyro.util import ravel_pytree -from numpyro.handlers import seed, substitute, trace -from numpyro.contrib.funsor.infer_util import plate_to_enum_plate,packed_trace -from collections import namedtuple IntegratorState = namedtuple('IntegratorState', ['z', 'r', 'potential_energy', 'z_grad']) IntegratorState.__new__.__defaults__ = (None,) * len(IntegratorState._fields) @@ -27,6 +25,7 @@ def model_args_sub(u, model_args): args.append(arg) return tuple(args) + def model_kwargs_sub(u, kwargs): """Subsample observations and features""" for key_arg, val_arg in kwargs.items(): @@ -34,10 +33,11 @@ def model_kwargs_sub(u, kwargs): kwargs[key_arg] = jnp.take(val_arg, u, axis=0) return kwargs + def log_density_obs_hmcecs(model, model_args, model_kwargs, params): model = substitute(model, data=params) model_trace = trace(model).get_trace(*model_args, **model_kwargs) - #model = substitute(model, data=params) + # model = substitute(model, data=params) # with plate_to_enum_plate(): # model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs) log_joint = jnp.array(0.) @@ -52,11 +52,12 @@ def log_density_obs_hmcecs(model, model_args, model_kwargs, params): log_prob = site['fn'].log_prob(value) if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob - #log_joint += log_prob #TODO: log_joint += jnp.sum(log_prob) ?---> gives a single number + # log_joint += log_prob #TODO: log_joint += jnp.sum(log_prob) ?---> gives a single number log_joint = log_joint + jnp.sum(log_prob) return log_joint, model_trace + def log_density_prior_hmcecs(model, model_args, model_kwargs, params): """ (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given @@ -92,16 +93,19 @@ def log_density_prior_hmcecs(model, model_args, model_kwargs, params): log_joint = log_joint + log_prob return log_joint, model_trace -def reducer( accum, d ): - accum.update(d) - return accum -def tuplemerge( *dictionaries ): - from functools import reduce +def reducer(accum, d): + accum.update(d) + return accum - merged = reduce( reducer, dictionaries, {} ) - return namedtuple('HMCCombinedState', merged )(**merged) # <==== Gist of the gist +def tuplemerge(*dictionaries): + from functools import reduce + + merged = reduce(reducer, dictionaries, {}) + + return namedtuple('HMCCombinedState', merged)(**merged) # <==== Gist of the gist + def potential_est(model, model_args, model_kwargs, z, n, m, proxy_fn, proxy_u_fn): """Computes the estimation of the likelihood of the potential @@ -118,6 +122,7 @@ def potential_est(model, model_args, model_kwargs, z, n, m, proxy_fn, proxy_u_fn return (-l_hat + .5 * sigma) - ll_prior + def velocity_verlet_hmcecs(potential_fn, kinetic_fn, grad_potential_fn=None): r""" Second order symplectic integrator that uses the velocity verlet algorithm @@ -146,7 +151,6 @@ def init_fn(z, r, potential_energy=None, z_grad=None): if potential_energy is None or z_grad is None: potential_energy, z_grad = compute_value_grad(z) - return IntegratorState(z, r, potential_energy, z_grad) def update_fn(step_size, inverse_mass_matrix, state): @@ -164,11 +168,12 @@ def update_fn(step_size, inverse_mass_matrix, state): z = tree_multimap(lambda z, r_grad: z + step_size * r_grad, z, r_grad) # z(n+1) potential_energy, z_grad = compute_value_grad(z) r = tree_multimap(lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad) # r(n+1) - #return IntegratorState(z, r, potential_energy, z_grad) + # return IntegratorState(z, r, potential_energy, z_grad) return IntegratorState(z, r, potential_energy, z_grad) return init_fn, update_fn + def init_near_values(site=None, values={}): """Initialize the sampling to a noisy map estimate of the parameters""" from functools import partial @@ -188,6 +193,7 @@ def init_near_values(site=None, values={}): except: return init_to_uniform(site) + def taylor_proxy(z_ref, model, ll_ref, jac_all, hess_all): """Corrects the subsample likelihood using covariates the taylor expansion :param z_ref = reference estimate (e.g MAP) of the model's parameters @@ -195,6 +201,7 @@ def taylor_proxy(z_ref, model, ll_ref, jac_all, hess_all): :param ll_ref = reference loglikelihood :param jac_all= Jacobian vector of the entire dataset :param hess_all = Hessian matrix of the entire dataset""" + def proxy(z, *args, **kwargs): z_flat, _ = ravel_pytree(z) zref_flat, _ = ravel_pytree(z_ref) @@ -217,6 +224,7 @@ def proxy_u(z, model_args, model_kwargs, *args, **kwargs): return proxy, proxy_u + def svi_proxy(svi, model_args, model_kwargs): def proxy(z, *args, **kwargs): z_ref = svi.guide.expectation(z) @@ -230,16 +238,20 @@ def proxy_u(z, model_args, model_kwargs, *args, **kwargs): return proxy, proxy_u + def neural_proxy(): return None + def split_list(lst, n): """Pair up the split model arguments back.""" for i in range(0, len(lst), n): - if i+n < len(lst)-1: #TODO: Change back to len(lst), after debugging - yield tuple( map(lst.__getitem__, [i,i+n])) + if i + n < len(lst): + yield tuple(map(lst.__getitem__, [i, i + n])) else: break + + def signed_estimator(model, model_args, model_kwargs, z, l, proxy_fn, proxy_u_fn): """ Estimate the gradient potential estimate @@ -258,14 +270,14 @@ def signed_estimator(model, model_args, model_kwargs, z, l, proxy_fn, proxy_u_fn xis = 0. sign = 1. d = 0 - a = d - l #For a fixed λ, V[LbB] is minimized at a = d − λ. Quiroz 2018c - model_args = [args_l for args_l in model_args if len(args_l[0]) != 0] #remove empty lambda blocks - for args_l in model_args: #Iterate over each of the lambda groups of model args + a = d - l # For a fixed λ, V[LbB] is minimized at a = d − λ. Quiroz 2018c + model_args = [args_l for args_l in model_args if len(args_l[0]) != 0] # remove empty lambda blocks + for args_l in model_args: # Iterate over each of the lambda groups of model args block_len = args_l[0].shape[0] - args_l = [jnp.split(arg, arg.shape[0]) for arg in args_l] # split the arrays of blocks - args_l = list(itertools.chain.from_iterable(args_l)) #Join list of lists + args_l = [jnp.split(arg, arg.shape[0]) for arg in args_l] # split the arrays of blocks + args_l = list(itertools.chain.from_iterable(args_l)) # Join list of lists args_l = [arr.squeeze(axis=0) for arr in args_l] - args_l = list(split_list(args_l,block_len)) + args_l = list(split_list(args_l, block_len)) for args_l_b in args_l: ll_sub, _ = log_density_obs_hmcecs(model, args_l_b, {}, z) # log likelihood for each u subsample xi = (jnp.exp(ll_sub - proxy_u_fn(z=z, model_args=args_l_b, model_kwargs=model_kwargs)) - a) / l @@ -273,16 +285,10 @@ def signed_estimator(model, model_args, model_kwargs, z, l, proxy_fn, proxy_u_fn xis += jnp.sum(jnp.abs(xi)) lhat = proxy_fn(z) + (a + l) / l + xis - prior_arg = tuple([arg.reshape(arg.shape[0] * arg.shape[1], -1) for arg in model_args[0]])#Join the block subsamples, does not matter because the prior does not look t them - ll_prior, _ = log_density_prior_hmcecs(model, prior_arg, model_kwargs, z) #the ll of the prior does not depend on the model args, so we just take some pair + prior_arg = tuple([arg.reshape(arg.shape[0] * arg.shape[1], -1) for arg in model_args[ + 0]]) # Join the block subsamples, does not matter because the prior does not look t them + ll_prior, _ = log_density_prior_hmcecs(model, prior_arg, model_kwargs, + z) # the ll of the prior does not depend on the model args, so we just take some pair # Correct the negativeloglikelihood by substracting the density of the prior --> potentialEst = -loglikeEst - dprior(theta,pfamily,priorPar1,priorPar2) neg_ll = - lhat - ll_prior return neg_ll, sign - - - -def poisson_samples_correction(*args,**kwargs): - "Changes the support of the samples by using the sign from the " - - return args - From b8f8830cc9fa594ed877ba54310845990cb17713 Mon Sep 17 00:00:00 2001 From: ola Date: Tue, 15 Dec 2020 15:01:17 +0100 Subject: [PATCH 34/93] Added MNIST BNN example using flax. --- examples/hmcecs/__init__.py | 0 examples/hmcecs/higgs.py | 208 ++++++++++++++++++----------------- examples/hmcecs/lda.py | 60 ++++++++++ examples/hmcecs/mnist_bnn.py | 181 ++++++++++++++++++++++++++++++ numpyro/examples/datasets.py | 35 +----- 5 files changed, 353 insertions(+), 131 deletions(-) create mode 100644 examples/hmcecs/__init__.py create mode 100644 examples/hmcecs/lda.py create mode 100644 examples/hmcecs/mnist_bnn.py diff --git a/examples/hmcecs/__init__.py b/examples/hmcecs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/hmcecs/higgs.py b/examples/hmcecs/higgs.py index 527d3a82d..394248b8b 100644 --- a/examples/hmcecs/higgs.py +++ b/examples/hmcecs/higgs.py @@ -1,8 +1,12 @@ +""" Logistic regression model as implemetned in https://arxiv.org/pdf/1708.00955.pdf with Higgs Dataset """ +# !/usr/bin/env python +from collections import namedtuple + import jax import jax.numpy as jnp -import numpy as np -from jax import jit, lax -from sklearn.datasets import load_breast_cancer +import jax.numpy as np_jax +from jax.tree_util import tree_map +from sklearn.model_selection import train_test_split import numpyro import numpyro.distributions as dist @@ -10,129 +14,129 @@ from numpyro import optim from numpyro.contrib.autoguide_hmcecs import AutoDiagonalNormal from numpyro.contrib.hmcecs import HMCECS +from numpyro.diagnostics import summary +from numpyro.examples.datasets import _load_higgs from numpyro.infer import NUTS, MCMC -from numpyro.infer.elbo import ELBO +from numpyro.infer.elbo import Trace_ELBO from numpyro.infer.svi import SVI -from numpyro.util import fori_loop - -numpyro.set_platform("cpu") - - -def load_dataset(observations, features, batch_size=None, shuffle=True): - num_records = observations.shape[0] - idxs = jnp.arange(num_records) - if not batch_size: - batch_size = num_records - - def init(): - return num_records // batch_size, np.random.permutation(idxs) if shuffle else idxs - def get_batch(i=0, idxs=idxs): - ret_idx = lax.dynamic_slice_in_dim(idxs, i * batch_size, batch_size) - batch_obs = jnp.take(observations, ret_idx, axis=0) - batch_feats = jnp.take(features, ret_idx, axis=0) - return batch_obs, batch_feats +numpyro.set_platform("gpu") - return init, get_batch +DataLoaderState = namedtuple("DataLoaderState", ('iteration', 'rng_key', 'indexes', 'max_iter')) -def svi_map(model, rng_key, feats, obs, num_epochs, batch_size): - @jit - def epoch_train(svi_state): - def body_fn(i, val): - batch_obs, batch_feats = train_fetch(i, train_idx) - loss_sum, svi_state = val - svi_state, loss = svi.update(svi_state, batch_feats, batch_obs) - loss_sum += loss - return loss_sum, svi_state - - return fori_loop(0, num_train, body_fn, (0., svi_state)) - - n, _ = feats.shape - guide = AutoDiagonalNormal(model) - svi = SVI(model, guide, optim.Adam(0.0003), loss=ELBO()) - svi_state = svi.init(rng_key, feats, obs) - train_init, train_fetch = load_dataset(obs, feats, batch_size=batch_size) +def dataloader(*xs, batch_size=32, train_size=None, test_size=None, shuffle=True): + assert len(xs) > 1 + splitxs = train_test_split(*xs, train_size=train_size, test_size=test_size) + trainxs, testxs = splitxs[0::2], splitxs[1::2] + max_train_iter, max_test_iter = len(trainxs[0]) // batch_size, len(testxs[0]) // batch_size - for i in range(num_epochs): - num_train, train_idx = train_init() - train_loss, svi_state = epoch_train(svi_state) - return svi.get_params(svi_state), svi, svi_state + def make_dataset(dxs, max_iter): + def init(rng_key): + return DataLoaderState(0, rng_key, jnp.arange(len(dxs[0])), max_iter) + def next_step(state): -def breast_cancer_data(): - """ Logistic regression model as implemetned in https://arxiv.org/pdf/1708.00955.pdf with Higgs Dataset """ - dataset = load_breast_cancer() - feats = dataset.data - feats = (feats - feats.mean(0)) / feats.std(0) - feats = jnp.hstack((feats, jnp.ones((feats.shape[0], 1)))) + iteration = state.iteration % state.max_iter + batch = tuple(x[state.indexes[iteration * batch_size:(iteration + 1) * batch_size]] + for x in dxs) + if iteration + 1 == state.max_iter: + shuffle_rng_key, rng_key = jax.random.split(state.rng_key) + if shuffle: + indexes = jax.random.shuffle(shuffle_rng_key, state.indexes) + else: + indexes = state.indexes + return batch, DataLoaderState(state.iteration + 1, rng_key, indexes, state.max_iter) + else: + return batch, DataLoaderState(state.iteration + 1, state.rng_key, state.indexes, state.max_iter) - return feats, dataset.target + return init, next_step + return make_dataset(trainxs, max_train_iter), make_dataset(testxs, max_test_iter), testxs -def model(feats, obs): - """ Logistic regression model """ - n, m = feats.shape - theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), 2 * jnp.ones(m))) - numpyro.sample('obs', dist.Bernoulli(logits=jnp.matmul(feats, theta)), obs=obs) - -def infer_hmcecs(rng_key, feats, obs, m=None, g=None, n_samples=None, warmup=None, algo="NUTS", subsample_method=None, - map_method=None, proxy="taylor", estimator=None, num_epochs=None, postprocess_fn=None): +def svi_map(model, rng_key, feats, obs, num_epochs, batch_size): + guide = AutoDiagonalNormal(model) + svi = SVI(model, guide, optim.Adam(0.0003), loss=Trace_ELBO()) + svi_rng_key, data_rng_key = jax.random.split(rng_key) + (init_train, next_train), _, _ = dataloader(feats, obs, train_size=0.9, batch_size=batch_size) + batch_fn = jax.jit(svi.update) + svi_state = None + data_state = init_train(data_rng_key) + num_batches = 0 + for _ in range(num_epochs): + for j in range(data_state.max_iter): + xs, data_state = next_train(data_state) + if svi_state is None: + svi_state = svi.init(svi_rng_key, *xs) + svi_state, _ = batch_fn(svi_state, *xs) + num_batches += 1 + return svi, svi_state + + +def infer_nuts(rng_key, features, obs, samples, warmup): + kernel = NUTS(model=logistic_regression, target_accept_prob=0.8) + mcmc = MCMC(kernel, num_warmup=warmup, num_samples=samples) + mcmc.run(rng_key, features, obs) + samples = mcmc.get_samples() + samples = tree_map(lambda x: x[None, ...], samples) + r_hat_average = np_jax.sum(summary(samples)["theta"]["r_hat"]) / len(summary(samples)["theta"]["r_hat"]) + + return mcmc.get_samples(), r_hat_average + + +def infer_hmcecs(rng_key, obs, features, m=None, g=None, n_samples=None, warmup=None, algo="NUTS", + subsample_method=None, map_method=None, proxy="taylor", estimator=None, num_epochs=None): hmcecs_key, map_key = jax.random.split(rng_key) - n, _ = feats.shape + n, _ = features.shape - if map_method == "SVI": - factor_SVI = obs.shape[0] - batch_size = 32 + svi = None + if map_method == "nuts": + samples, r_hat_average = infer_nuts(map_key, features, obs, samples=10, warmup=5) + z_ref = {key: value.mean(0) for key, value in samples.items()} + elif map_method == "svi": map_key, post_key = jax.random.split(map_key) - z_ref, svi, svi_state = svi_map(model, map_key, feats=feats[:factor_SVI], obs=obs[:factor_SVI], - num_epochs=num_epochs, batch_size=batch_size) + svi, svi_state = svi_map(logistic_regression, + map_key, + feats=features, + obs=obs, + num_epochs=num_epochs, + batch_size=256) z_ref = svi.guide.sample_posterior(post_key, svi.get_params(svi_state), (100,)) z_ref = {name: value.mean(0) for name, value in z_ref.items()} - else: - svi = None - map_samples = 10 - map_warmup = 5 - if map_method == "NUTS": - kernel = NUTS(model=model, target_accept_prob=0.8) - if map_method == 'HMC': - kernel = NUTS(model=model, target_accept_prob=0.8) - mcmc = MCMC(kernel, num_warmup=map_warmup, num_samples=map_samples) - mcmc.run(rng_key, feats, obs) - samples = mcmc.get_samples() - z_ref = {key: value.mean(0) for key, value in samples.items()} - extra_fields = [] - if estimator == "poisson": - postprocess_fn = None - extra_fields = ("sign",) + kernel = HMCECS(model=logistic_regression, z_ref=z_ref, m=m, g=g, algo=algo.upper(), + subsample_method=subsample_method, proxy=proxy, svi_fn=svi, + estimator=estimator, target_accept_prob=0.8) - kernel = HMCECS(model=model, z_ref=z_ref, m=m, g=g, algo=algo, subsample_method=subsample_method, proxy=proxy, - svi_fn=svi, estimator=estimator, target_accept_prob=0.8) - - mcmc = MCMC(kernel, num_warmup=warmup, num_samples=n_samples, num_chains=1, postprocess_fn=postprocess_fn) - mcmc.run(rng_key, feats, obs, extra_fields=extra_fields) + mcmc = MCMC(kernel, num_warmup=warmup, num_samples=n_samples, num_chains=1) + mcmc.run(rng_key, features, obs) return mcmc.get_samples() -if __name__ == '__main__': - num_samples = 10 - num_warmup = 5 - ecs_algo = 'NUTS' - ecs_proxy = 'taylor' - estimator = 'perturb' - map_init = 'SVI' - epochs = 1000 - rng_key = jax.random.PRNGKey(37) +def logistic_regression(features, obs): + n, m = features.shape + theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), 2 * jnp.ones(m))) + numpyro.sample('obs', dist.Bernoulli(logits=jnp.matmul(features, theta)), obs=obs) + - feats, obs = breast_cancer_data() +def higgs_data(): + return _load_higgs() - n, = obs.shape - m = int(jnp.sqrt(n)) - g = 5 - infer_hmcecs(rng_key, feats=feats, obs=obs, n_samples=num_samples, - warmup=num_warmup, m=m, g=g, algo=ecs_algo, subsample_method="perturb", - proxy=ecs_proxy, estimator=estimator, map_method=map_init, num_epochs=epochs) +if __name__ == '__main__': + rng_key = jax.random.PRNGKey(37) + obs, feats = higgs_data() + num_examples = 1000 + + est_posterior_ECS = infer_hmcecs(rng_key, obs[:num_examples], feats[:num_examples], + n_samples=10, + warmup=5, + m=30, g=5, + algo='nuts', + subsample_method="perturb", + proxy='svi', + estimator='', + map_method='svi', + num_epochs=100) diff --git a/examples/hmcecs/lda.py b/examples/hmcecs/lda.py new file mode 100644 index 000000000..25a642c4f --- /dev/null +++ b/examples/hmcecs/lda.py @@ -0,0 +1,60 @@ +import sys + +from jax.experimental import stax +from sklearn.datasets import fetch_20newsgroups +from sklearn.feature_extraction.text import CountVectorizer + +import jax +import jax.numpy as jnp +from sklearn.utils import shuffle + +import numpyro +import numpyro.distributions as dist + +import numpy as np + +from numpyro.contrib.indexing import Vindex + + +def lda(doc_words, lengths, num_topics=20, num_words=100, num_max_elements=10, + num_hidden=100): + num_docs = doc_words.shape[0] + topic_word_probs = numpyro.sample('topic_word_probs', + dist.Dirichlet(jnp.ones((num_topics, num_words)) / num_words).to_event(1)) + 1e-7 + element_plate = numpyro.plate('words', num_max_elements, dim=-1) + with numpyro.plate('documents', num_docs, dim=-2): + document_topic_probs = numpyro.sample('topic_probs', dist.Dirichlet(jnp.ones(num_topics) / num_topics)) + with element_plate: + word_topic = numpyro.sample('word_topic', dist.Categorical(document_topic_probs)) + numpyro.sample('word', dist.Categorical(Vindex(topic_word_probs)[word_topic]), obs=doc_words) + + +def lda_guide(doc_words, lengths, num_topics=20, num_words=100, num_max_elements=10, + num_hidden=100): + num_docs = doc_words.shape[0] + topic_word_probs_val = numpyro.param('topic_word_probs_val', jnp.ones((num_topics, num_words)), + constraint=dist.constraints.simplex) + _topic_word_probs = numpyro.sample('topic_word_probs', dist.Delta(topic_word_probs_val).to_event(1)) + amortize_nn = numpyro.module('amortize_nn', stax.serial( + stax.Dense(num_hidden), + stax.Relu, + stax.Dense(num_topics), + stax.Softmax + ), (num_docs, num_max_elements)) + document_topic_probs_vals = amortize_nn(doc_words)[..., None, :] + 1e-7 + _document_topic_probs = numpyro.sample('topic_probs', dist.Delta(document_topic_probs_vals)) + + +def main(_argv): + newsgroups = fetch_20newsgroups()['data'] + num_words = 300 + count_vectorizer = CountVectorizer(max_df=.95, min_df=.01, + token_pattern=r'(?u)\b[^\d\W]\w+\b', + max_features=num_words, + stop_words='english') + newsgroups_docs = count_vectorizer.fit_transform(newsgroups) + rng_key = jax.random.PRNGKey(37) + + +if __name__ == '__main__': + main(sys.argv) diff --git a/examples/hmcecs/mnist_bnn.py b/examples/hmcecs/mnist_bnn.py new file mode 100644 index 000000000..6973218c5 --- /dev/null +++ b/examples/hmcecs/mnist_bnn.py @@ -0,0 +1,181 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +""" +Example: Bayesian Neural Network +================================ + +We demonstrate how to use NUTS to do inference on a simple (small) +Bayesian neural network with two hidden layers. +""" + +import argparse +import time + +import jax.numpy as jnp +import jax.random as random +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +from flax import nn +from jax import vmap + +import numpyro +import numpyro.distributions as dist +from numpyro import handlers +from numpyro.contrib.module import random_flax_module +from numpyro.examples.datasets import load_dataset, MNIST +from numpyro.infer import MCMC, NUTS + +matplotlib.use('Agg') # noqa: E402 + + +class Network(nn.Module): + def apply(self, x, hid_channels, out_channels): + l1 = nn.relu(nn.Dense(x, features=hid_channels)) + l2 = nn.relu(nn.Dense(l1, features=hid_channels)) + logits = nn.Dense(l2, features=out_channels) + return logits + + +def mnist_model(features, hid_channels, obs=None): + module = Network.partial(hid_channels=hid_channels, out_channels=10) + net = random_flax_module('snn', module, dist.Normal(0, 1.), input_shape=features.shape) + if obs is not None: + obs = obs[..., None] + numpyro.sample('obs', dist.Categorical(logits=net(features)), obs=obs) + + +def mnist_data(split='train'): + mnist_init, mnist_batch = load_dataset(MNIST, split=split) + _, idxs = mnist_init() + X, Y = mnist_batch(0, idxs) + _, m, _ = X.shape + X = X.reshape(-1, m ** 2) + return X, Y + + +def mnist_main(args): + hid_channels = 32 + X, Y = mnist_data() + rng_key, rng_key_predict = random.split(random.PRNGKey(37)) + samples = run_inference(mnist_model, args, rng_key, X[:args.num_data], hid_channels, Y[:args.num_data]) + + # predict Y_test at inputs X_test + vmap_args = (samples, random.split(rng_key_predict, args.num_samples * args.num_chains)) + X, Y = mnist_data('test') + predictions = vmap(lambda samples, rng_key: predict(mnist_model, rng_key, samples, X[:100], hid_channels))( + *vmap_args) + predictions = predictions[..., 0] + + +class RegNetwork(nn.Module): + def apply(self, x, hid_channels, out_channels): + l1 = nn.tanh(nn.Dense(x, features=hid_channels)) + l2 = nn.tahn(nn.Dense(l1, features=hid_channels)) + mean = nn.Dense(l2, features=out_channels) + return mean + + +def reg_model(features, obs, hid_channels): + in_channels, out_channels = features.shape[1], 1 + module = Network.partial(hid_channels=hid_channels, out_channels=out_channels) + + net = random_flax_module('snn', module, dist.Normal(0, 1.), input_shape=()) + mean = net(features) + + # we put a prior on the observation noise + prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0)) + sigma_obs = 1.0 / jnp.sqrt(prec_obs) # prior + + numpyro.sample("Y", dist.Normal(mean, sigma_obs), obs=obs[..., None]) + + +# helper function for HMC inference +def run_inference(model, args, rng_key, X, Y, D_H): + start = time.time() + kernel = NUTS(model) + mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains) + mcmc.run(rng_key, X, Y, D_H) + mcmc.print_summary() + print('\nMCMC elapsed time:', time.time() - start) + return mcmc.get_samples() + + +# helper function for prediction +def predict(model, rng_key, samples, *args, **kwargs): + model = handlers.substitute(handlers.seed(model, rng_key), samples) + # note that Y will be sampled in the model because we pass Y=None here + model_trace = handlers.trace(model).get_trace(*args, **kwargs) + return model_trace['obs']['value'] + + +# create artificial regression dataset +def get_data(N=50, D_X=3, sigma_obs=0.05, N_test=500): + D_Y = 1 # create 1d outputs + np.random.seed(0) + X = jnp.linspace(-1, 1, N) + X = jnp.power(X[:, np.newaxis], jnp.arange(D_X)) + W = 0.5 * np.random.randn(D_X) + Y = jnp.dot(X, W) + 0.5 * jnp.power(0.5 + X[:, 1], 2.0) * jnp.sin(4.0 * X[:, 1]) + Y += sigma_obs * np.random.randn(N) + Y = Y[:, np.newaxis] + Y -= jnp.mean(Y) + Y /= jnp.std(Y) + + assert X.shape == (N, D_X) + assert Y.shape == (N, D_Y) + + X_test = jnp.linspace(-1.3, 1.3, N_test) + X_test = jnp.power(X_test[:, np.newaxis], jnp.arange(D_X)) + + return X, Y, X_test + + +def main(args): + N, D_X, D_H = args.num_data, 3, args.num_hidden + X, Y, X_test = get_data(N=N, D_X=D_X) + + # do inference + rng_key, rng_key_predict = random.split(random.PRNGKey(0)) + samples = run_inference(reg_model, args, rng_key, X, Y, D_H) + + # predict Y_test at inputs X_test + vmap_args = (samples, random.split(rng_key_predict, args.num_samples * args.num_chains)) + predictions = vmap(lambda samples, rng_key: predict(reg_model, rng_key, samples, X_test, D_H))(*vmap_args) + predictions = predictions[..., 0] + + # compute mean prediction and confidence interval around median + mean_prediction = jnp.mean(predictions, axis=0) + percentiles = np.percentile(predictions, [5.0, 95.0], axis=0) + + # make plots + fig, ax = plt.subplots(1, 1) + + # plot training data + ax.plot(X[:, 1], Y[:, 0], 'kx') + # plot 90% confidence level of predictions + ax.fill_between(X_test[:, 1], percentiles[0, :], percentiles[1, :], color='lightblue') + # plot mean prediction + ax.plot(X_test[:, 1], mean_prediction, 'blue', ls='solid', lw=2.0) + ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI") + + plt.savefig('bnn_plot.pdf') + plt.tight_layout() + + +if __name__ == "__main__": + assert numpyro.__version__.startswith('0.4.1') + parser = argparse.ArgumentParser(description="Bayesian neural network example") + parser.add_argument("-n", "--num-samples", nargs="?", default=20, type=int) + parser.add_argument("--num-warmup", nargs='?', default=10, type=int) + parser.add_argument("--num-chains", nargs='?', default=1, type=int) + parser.add_argument("--num-data", nargs='?', default=1000, type=int) + parser.add_argument("--num-hidden", nargs='?', default=5, type=int) + parser.add_argument("--device", default='gpu', type=str, help='use "cpu" or "gpu".') + args = parser.parse_args() + + numpyro.set_platform(args.device) + numpyro.set_host_device_count(args.num_chains) + + mnist_main(args) diff --git a/numpyro/examples/datasets.py b/numpyro/examples/datasets.py index feecccfd3..b20ff7038 100644 --- a/numpyro/examples/datasets.py +++ b/numpyro/examples/datasets.py @@ -1,15 +1,16 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import namedtuple import csv import gzip import os import pickle import struct +import warnings +from collections import namedtuple from urllib.parse import urlparse from urllib.request import urlretrieve -import warnings + import numpy as np import pandas as pd from jax import device_put, lax @@ -22,20 +23,16 @@ '.data')) os.makedirs(DATA_DIR, exist_ok=True) - dset = namedtuple('dset', ['name', 'urls']) - BASEBALL = dset('baseball', [ 'https://d2hg8soec8ck9v.cloudfront.net/datasets/EfronMorrisBB.txt', ]) - COVTYPE = dset('covtype', [ 'https://d2hg8soec8ck9v.cloudfront.net/datasets/covtype.zip', ]) - MNIST = dset('mnist', [ 'https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/train-images-idx3-ubyte.gz', 'https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/train-labels-idx1-ubyte.gz', @@ -43,30 +40,27 @@ 'https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/t10k-labels-idx1-ubyte.gz', ]) - SP500 = dset('SP500', [ 'https://d2hg8soec8ck9v.cloudfront.net/datasets/SP500.csv', ]) - UCBADMIT = dset('ucbadmit', [ 'https://d2hg8soec8ck9v.cloudfront.net/datasets/UCBadmit.csv', ]) - LYNXHARE = dset('lynxhare', [ 'https://d2hg8soec8ck9v.cloudfront.net/datasets/LynxHare.txt', ]) - JSB_CHORALES = dset('jsb_chorales', [ 'https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/jsb_chorales.pickle', ]) -HIGGS = dset("higgs",[ +HIGGS = dset("higgs", [ "https://archive.ics.uci.edu/ml/machine-learning-databases/00280/HIGGS.csv.gz", ]) + def _download(dset): for url in dset.urls: file = os.path.basename(urlparse(url).path) @@ -219,6 +213,7 @@ def _load_jsb_chorales(): processed_dataset[k] = (lengths, _pad_sequence(sequences).astype("int32")) return processed_dataset + def _load_higgs(): warnings.warn("Downloading 2.6 GB dataset") _download(HIGGS) @@ -227,24 +222,6 @@ def _load_higgs(): obs, feats = df.iloc[:, 0], df.iloc[:, 1:] return obs.to_numpy(), feats.to_numpy() - #SLOW (no pandas) option - # observations,features = [],[] - # with gzip.open(file_path, mode='rt') as f: - # csv_reader = csv.DictReader( - # f, - # delimiter=',', - # restkey="30", - # fieldnames=['observations'] +['feature_{}'.format(i) for i in range(28)], - # ) - # - # for row in csv_reader: - # observations.append(row["observations"]) - # for i in range(28): - # print(row["feature_{}".format(i)]) - # features.append(row["feature_{}".format(i)]) - # return {"observations": np.stack(observations),"features": np.stack(features)} - - def _load(dset): if dset == BASEBALL: From 65531c2b68f6b1e6a3bd6805de3e339ec5b67544 Mon Sep 17 00:00:00 2001 From: Ola Date: Thu, 7 Jan 2021 11:13:35 +0100 Subject: [PATCH 35/93] Working potential with algebraic effect handlers. --- check_potential.py | 178 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 178 insertions(+) create mode 100644 check_potential.py diff --git a/check_potential.py b/check_potential.py new file mode 100644 index 000000000..af022391b --- /dev/null +++ b/check_potential.py @@ -0,0 +1,178 @@ +from collections import namedtuple +from numpyro.primitives import Messenger +from functools import partial +import numpyro + +import jax +import jax.numpy as jnp +from jax import grad, value_and_grad +from jax.tree_util import tree_multimap +from jax import random + +import numpyro.distributions as dist +from numpyro.distributions.util import is_identically_one +from numpyro.handlers import substitute, trace, seed, block +from numpyro.primitives import _subsample_fn +from numpyro.util import ravel_pytree +from numpyro.contrib.hmcecs_utils import log_density_obs_hmcecs, log_density_prior_hmcecs +import jax.numpy as jnp + + +def _wrap_model(model): + def fn(*args, **kwargs): + subsample_values = kwargs.pop("_subsample_sites", {}) + with substitute(data=subsample_values): + model(*args, **kwargs) + + return fn + + +def _wrap_est_model(model, estimators, plate_sizes): + def fn(*args, **kwargs): + subsample_values = kwargs.pop("_subsample_sites", {}) + with substitute(data=subsample_values): + with estimator(model, estimators, plate_sizes): + model(*args, **kwargs) + + return fn + + +def model(data, *args, **kwargs): + x = numpyro.sample("x", dist.Normal(0., 1.)) + with numpyro.plate("N", data.shape[0], subsample_size=100): + batch = numpyro.subsample(data, event_dim=0) + obs = numpyro.sample("obs", dist.Normal(x, 1.), obs=batch) + + +class estimator(Messenger): + def __init__(self, fn, estimators, plate_sizes): + self.estimators = estimators + self.plate_sizes = plate_sizes + super(estimator, self).__init__(fn) + + def process_message(self, msg): + if msg['type'] == 'sample' and msg['is_observed'] and msg['cond_indep_stack']: + log_prob = msg['fn'].log_prob + msg['fn'].log_prob = lambda *args, **kwargs: \ + self.estimators[msg['name']](*args, name=msg['name'], z=_extract_params(msg['fn']), log_prob=log_prob, + sizes=self.plate_sizes[msg['cond_indep_stack'][0].name], + **kwargs) # TODO: check multiple levels + + +def my_estimator(value, name, z, sizes, log_prob, proxy_fn=lambda x, y: x, uproxy_fn=lambda x: x, **kwargs, ): + n, m = sizes + ll_sub = log_prob(value).sum() + diff = ll_sub - uproxy_fn(name, value, z) + l_hat = proxy_fn(name, z) + n / m * diff + sigma = n ** 2 / m * jnp.var(diff) + return l_hat - .5 * sigma + + +def _extract_params(distribution): + params, _ = distribution.tree_flatten() + return params + + +def my_taylor(ref_trace, ll_ref, jac_all, hess_all): + def proxy(name, z): + z_ref = _extract_params(ref_trace[name]['fn']) + jac, hess = jac_all[name], hess_all[name] + log_like = jnp.array(0.) + for argnum in range(len(z_ref)): + z_diff = z[argnum] - z_ref[argnum] + j, h = jac[argnum], hess[argnum] + k, = j.shape + log_like += j.T @ z_diff + .5 * z_diff.T @ h.reshape(k, k) @ z_diff # TODO: factor out + return ll_ref[name] + log_like + + def uproxy(name, value, z): + ref_dist = ref_trace[name]['fn'] + z_ref, aux_data = ref_dist.tree_flatten() + + log_prob = lambda *params: ref_dist.tree_unflatten(aux_data, params).log_prob(value).sum() + log_like = jnp.array(0.) + for argnum in range(len(z_ref)): + z_diff = z[argnum] - z_ref[argnum] + jac = jax.jacobian(log_prob, argnum)(*z_ref) + k, = jac.shape + hess = jax.hessian(log_prob, argnum)(*z_ref) + log_like += jac @ z_diff + .5 * z_diff @ hess.reshape(k, k) @ z_diff.T + + return log_prob(*z_ref) + log_like + + return proxy, uproxy + + +class subsample_size(Messenger): + def __init__(self, fn, plate_sizes, rng_key=None): + super(subsample_size, self).__init__(fn) + self.plate_sizes = plate_sizes + self.rng_key = rng_key + + def process_message(self, msg): + if msg['type'] == 'plate' and msg['args'] and msg["args"][0] > msg["args"][1]: + if msg['name'] in self.plate_sizes: + msg['args'] = self.plate_sizes[msg['name']] + msg['value'] = _subsample_fn(*msg['args'], self.rng_key) if msg["args"][1] < msg["args"][ + 0] else jnp.arange(msg["args"][0]) + + +def _tangent_curve(dist, value, tangent_fn): + z, aux_data = dist.tree_flatten() + log_prob = lambda *params: dist.tree_unflatten(aux_data, params).log_prob(value).sum() + return tuple(tangent_fn(log_prob, argnum)(*z) for argnum in range(len(z))) + + +def check_estimator_handler(): + data = random.normal(random.PRNGKey(1), (10000,)) + 1 + z = {'x': jnp.array(0.9511842)} + model_trace = trace(seed(model, random.PRNGKey(2))).get_trace(data) + u = {name: site["value"] for name, site in model_trace.items() + if site["type"] == "plate" and site["args"][0] > site["args"][1]} + z_ref = {k: v + .1 for k, v in z.items()} + plate_sizes_all = {name: (model_trace[name]["args"][0], model_trace[name]["args"][0]) for name in u} + with subsample_size(model, plate_sizes_all): + ref_trace = trace(substitute(model, data=z_ref)).get_trace(data) + jac_all = {name: _tangent_curve(site['fn'], site['value'], jax.jacobian) for name, site in ref_trace.items() if + (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} + hess_all = {name: _tangent_curve(site['fn'], site['value'], jax.hessian) for name, site in ref_trace.items() if + (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} + ll_ref = {name: site['fn'].log_prob(site['value']) for name, site in ref_trace.items() if + (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} + + plate_sizes = {name: model_trace[name]["args"] for name in u} + ref_trace = trace(substitute(model, data={**u, **z_ref})).get_trace(data) + z_ref, _ = ravel_pytree(z_ref) + proxy_fn, uproxy_fn = my_taylor(ref_trace, ll_ref, jac_all, hess_all) + with estimator(model, {'obs': partial(my_estimator, proxy_fn=proxy_fn, uproxy_fn=uproxy_fn)}, plate_sizes): + print(log_density_obs_hmcecs(_wrap_model(model), (data,), {"_subsample_sites": u}, z)) + + +def check_handler(): + data = random.normal(random.PRNGKey(1), (10000,)) + 1 + z = {'x': jnp.array(0.9511842)} + model_trace = trace(seed(model, random.PRNGKey(2))).get_trace(data) + u = {name: site["value"] for name, site in model_trace.items() + if site["type"] == "plate" and site["args"][0] > site["args"][1]} + z_ref = {k: v + .1 for k, v in z.items()} + plate_sizes_all = {name: (model_trace[name]["args"][0], model_trace[name]["args"][0]) for name in u} + with subsample_size(model, plate_sizes_all): + ref_trace = trace(substitute(model, data=z_ref)).get_trace(data) + jac_all = {name: _tangent_curve(site['fn'], site['value'], jax.jacobian) for name, site in ref_trace.items() if + (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} + hess_all = {name: _tangent_curve(site['fn'], site['value'], jax.hessian) for name, site in ref_trace.items() if + (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} + ll_ref = {name: site['fn'].log_prob(site['value']) for name, site in ref_trace.items() if + (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} + + plate_sizes = {name: model_trace[name]["args"] for name in u} + ref_trace = trace(substitute(model, data={**u, **z_ref})).get_trace(data) + z_ref, _ = ravel_pytree(z_ref) + proxy_fn, uproxy_fn = my_taylor(ref_trace, ll_ref, jac_all, hess_all) + print(log_density_obs_hmcecs( + _wrap_est_model(model, {'obs': partial(my_estimator, proxy_fn=proxy_fn, uproxy_fn=uproxy_fn)}, plate_sizes), + (data,), {"_subsample_sites": u}, z)) + + +if __name__ == '__main__': + check_handler() From 216c2cf2649937e28ef830b8c658b15e82550f48 Mon Sep 17 00:00:00 2001 From: Ola Date: Thu, 7 Jan 2021 11:51:53 +0100 Subject: [PATCH 36/93] Potential estimator integrated with ECS class. --- numpyro/contrib/ecs.py | 169 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 169 insertions(+) create mode 100644 numpyro/contrib/ecs.py diff --git a/numpyro/contrib/ecs.py b/numpyro/contrib/ecs.py new file mode 100644 index 000000000..1316fb939 --- /dev/null +++ b/numpyro/contrib/ecs.py @@ -0,0 +1,169 @@ +""" Based on fehiepsi implementation: https://gist.github.com/fehiepsi/b4a5a80b245600b99467a0264be05fd5 """ +from collections import namedtuple +import copy + +from jax import device_put, lax, random, partial, jit, ops +import jax.numpy as jnp + +import numpyro +import numpyro.distributions as dist +from numpyro.handlers import substitute, trace, seed +from numpyro.infer import MCMC, NUTS, log_likelihood +from numpyro.infer.mcmc import MCMCKernel +from numpyro.util import identity +from numpyro.contrib.hmcecs_utils import init_near_values +from check_potential import my_estimator, my_taylor, estimator + +HMC_ECS_State = namedtuple("HMC_ECS_State", "uz, hmc_state, accept_prob, rng_key") +""" + - **uz** - a dict of current subsample indices and the current latent values + - **hmc_state** - current hmc_state + - **accept_prob** - acceptance probability of the proposal subsample indices + - **rng_key** - random key to generate new subsample indices +""" + +""" Notes: +- [x] init(...) ] +sample(...) + will use check_potential handler method! +""" + + +def _wrap_model(model): + def fn(*args, **kwargs): + subsample_values = kwargs.pop("_subsample_sites", {}) + with substitute(data=subsample_values): + model(*args, **kwargs) + + return fn + + +def _wrap_est_model(model, estimators, plate_sizes): + def fn(*args, **kwargs): + subsample_values = kwargs.pop("_subsample_sites", {}) + with substitute(data=subsample_values): + with estimator(model, estimators, plate_sizes): + model(*args, **kwargs) + + return fn + + +@partial(jit, static_argnums=(2, 3, 4)) +def _update_block(rng_key, u, n, m, g): + """Returns indexes of the new subsample. The update mechanism selects blocks of indices within the subsample to be updated. + The number of indexes to be updated depend on the block size, higher block size more correlation among elements in the subsample. + :param rng_key: + :param u: subsample indexes + :param n: total number of data + :param m: subsample size + :param g: number of subsample blocks + """ + + if not (0 < g <= m): + raise ValueError(f'Block size 0 < {g} <= {m}') + rng_key_block, rng_key_index = random.split(rng_key) + + chosen_block = random.randint(rng_key_block, shape=(), minval=0, maxval=g + 1) + idxs_new = random.choice(rng_key_index, n, shape=(m // g,), replace=False) + + u_new = jnp.zeros(m, jnp.dtype(u)) + for i in range(m): # TODO: look into block update + u_new = ops.index_add(u_new, i, lax.cond(i // g == chosen_block, + i, lambda _: idxs_new[i % (m // g)], + i, lambda _: u[i])) + return u_new + + +class ECS(MCMCKernel): + sample_field = "uz" + + def __init__(self, inner_kernel, estimator_fn=None, proxy_gen_fn=None, z_ref=None): + self.inner_kernel = copy.copy(inner_kernel) + self.inner_kernel._model = inner_kernel.model # Removed wrapper! + self._proxy_gen_fn = proxy_gen_fn + self._estimator_fn = estimator_fn + self._z_ref = z_ref + self._plate_sizes = None + + @property + def model(self): + return self.inner_kernel._model + + def postprocess_fn(self, args, kwargs): + def fn(uz): + z = {k: v for k, v in uz.items() if k not in self._plate_sizes} + return self.inner_kernel.postprocess_fn(args, kwargs)(z) + + return fn + + def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): + model_kwargs = {} if model_kwargs is None else model_kwargs.copy() + rng_key, key_u, key_z = random.split(rng_key, 3) + + prototype_trace = trace(seed(self.model, key_u)).get_trace(*model_args, **model_kwargs) + u = {name: site["value"] for name, site in prototype_trace.items() + if site["type"] == "plate" and site["args"][0] > site["args"][1]} + + self._plate_sizes = {name: prototype_trace[name]["args"] + (min(prototype_trace[name]["args"][1] // 2, 100),) + for name in u} + + proxy_fn, uproxy_fn = self._proxy_gen_fn(model, model_args, model_kwargs, prototype_trace, z_ref, u) + + estimators = {name: partial(self._estimator_fn, proxy_fn=proxy_fn, uproxy_fn=uproxy_fn) + for name, site in prototype_trace.items() if + (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} + self.inner_kernel._model = _wrap_est_model(model, estimators, self._plate_sizes) + init_params = {name: init_near_values(site, self._z_ref) for name, site in prototype_trace.items()} + model_kwargs["_subsample_sites"] = u + hmc_state = self.inner_kernel.init(key_z, num_warmup, init_params, + model_args, model_kwargs) + uz = {**u, **hmc_state.z} + return device_put(HMC_ECS_State(uz, hmc_state, 1., rng_key)) + + def sample(self, state, model_args, model_kwargs): + model_kwargs = {} if model_kwargs is None else model_kwargs.copy() + rng_key, key_u = random.split(state.rng_key) + u = {k: v for k, v in state.uz.items() if k in self._plate_sizes} + u_new = {} + for name, (size, subsample_size, num_blocks) in self._plate_sizes.items(): + key_u, subkey = random.split(key_u) + u_new[name] = _update_block(subkey, u[name], size, subsample_size, + num_blocks) # TODO: dynamically adjust block size + sample = self.postprocess_fn(model_args, model_kwargs)(state.hmc_state.z) + u_loglik = log_likelihood(self.model, sample, *model_args, batch_ndims=0, **model_kwargs, _subsample_sites=u) + u_loglik = sum(v.sum() for v in u_loglik.values()) + u_new_loglik = log_likelihood(self.model, sample, *model_args, batch_ndims=0, **model_kwargs, + _subsample_sites=u_new) + u_new_loglik = sum(v.sum() for v in u_new_loglik.values()) + accept_prob = jnp.clip(jnp.exp(u_new_loglik - u_loglik), a_max=1.0) + u = lax.cond(random.bernoulli(key_u, accept_prob), u_new, identity, u, identity) + model_kwargs["_subsample_sites"] = u + hmc_state = self.inner_kernel.sample(state.hmc_state, model_args, model_kwargs) + uz = {**u, **hmc_state.z} + return HMC_ECS_State(uz, hmc_state, accept_prob, rng_key) + + +def model(data, *args, **kwargs): + x = numpyro.sample("x", dist.Normal(0., 1.)) + with numpyro.plate("N", data.shape[0], subsample_size=300): + batch = numpyro.subsample(data, event_dim=0) + numpyro.sample("obs", dist.Normal(x, 1.), obs=batch) + + +def plain_model(data, *args, **kwargs): + x = numpyro.sample("x", dist.Normal(0., 1.)) + numpyro.sample("obs", dist.Normal(x, 1.), obs=data) + + +if __name__ == '__main__': + data = random.normal(random.PRNGKey(1), (10000,)) + 1 + kernel = NUTS(plain_model) + mcmc = MCMC(kernel, 3000, 3000) + mcmc.run(random.PRNGKey(1), data) + z_ref = {k: v.mean() for k, v in mcmc.get_samples().items()} + + kernel = ECS(NUTS(model), estimator_fn=my_estimator, proxy_gen_fn=my_taylor, z_ref=z_ref) + mcmc = MCMC(kernel, 1500, 1500) + mcmc.run(random.PRNGKey(0), data, extra_fields=("accept_prob",)) + # there is a bug when exclude_deterministic=True, which will be fixed upstream + mcmc.print_summary(exclude_deterministic=False) From d6e6700a79c6b3b98883fc6ab703ecfe61c1264d Mon Sep 17 00:00:00 2001 From: Ola Date: Fri, 8 Jan 2021 11:27:53 +0100 Subject: [PATCH 37/93] ECS wrapper working on toy example. --- check_potential.py | 95 ++++++++++++++++-------------------------- numpyro/contrib/ecs.py | 47 ++++++++++++++------- numpyro/primitives.py | 9 ++-- 3 files changed, 73 insertions(+), 78 deletions(-) diff --git a/check_potential.py b/check_potential.py index af022391b..a6d6dc6a1 100644 --- a/check_potential.py +++ b/check_potential.py @@ -1,21 +1,17 @@ -from collections import namedtuple -from numpyro.primitives import Messenger from functools import partial -import numpyro import jax import jax.numpy as jnp -from jax import grad, value_and_grad -from jax.tree_util import tree_multimap -from jax import random +from jax import random, hessian, jacfwd +import numpyro import numpyro.distributions as dist -from numpyro.distributions.util import is_identically_one -from numpyro.handlers import substitute, trace, seed, block -from numpyro.primitives import _subsample_fn +from numpyro.contrib.hmcecs_utils import log_density_obs_hmcecs +from numpyro.contrib.hmcecs_utils import potential_est, taylor_proxy +from numpyro.handlers import substitute, trace, seed +from numpyro.infer.util import log_density +from numpyro.primitives import Messenger, _subsample_fn from numpyro.util import ravel_pytree -from numpyro.contrib.hmcecs_utils import log_density_obs_hmcecs, log_density_prior_hmcecs -import jax.numpy as jnp def _wrap_model(model): @@ -37,13 +33,18 @@ def fn(*args, **kwargs): return fn -def model(data, *args, **kwargs): +def model(data): x = numpyro.sample("x", dist.Normal(0., 1.)) with numpyro.plate("N", data.shape[0], subsample_size=100): batch = numpyro.subsample(data, event_dim=0) obs = numpyro.sample("obs", dist.Normal(x, 1.), obs=batch) +def plain_model(data): + x = numpyro.sample("x", dist.Normal(0., 1.)) + obs = numpyro.sample("obs", dist.Normal(x, 1.), obs=data) + + class estimator(Messenger): def __init__(self, fn, estimators, plate_sizes): self.estimators = estimators @@ -53,6 +54,7 @@ def __init__(self, fn, estimators, plate_sizes): def process_message(self, msg): if msg['type'] == 'sample' and msg['is_observed'] and msg['cond_indep_stack']: log_prob = msg['fn'].log_prob + msg['scale'] = 1. msg['fn'].log_prob = lambda *args, **kwargs: \ self.estimators[msg['name']](*args, name=msg['name'], z=_extract_params(msg['fn']), log_prob=log_prob, sizes=self.plate_sizes[msg['cond_indep_stack'][0].name], @@ -60,7 +62,7 @@ def process_message(self, msg): def my_estimator(value, name, z, sizes, log_prob, proxy_fn=lambda x, y: x, uproxy_fn=lambda x: x, **kwargs, ): - n, m = sizes + n, m, g = sizes ll_sub = log_prob(value).sum() diff = ll_sub - uproxy_fn(name, value, z) l_hat = proxy_fn(name, z) + n / m * diff @@ -82,8 +84,8 @@ def proxy(name, z): z_diff = z[argnum] - z_ref[argnum] j, h = jac[argnum], hess[argnum] k, = j.shape - log_like += j.T @ z_diff + .5 * z_diff.T @ h.reshape(k, k) @ z_diff # TODO: factor out - return ll_ref[name] + log_like + log_like += j.T @ z_diff + .5 * z_diff.T @ h.reshape(k, k) @ z_diff + return ll_ref[name].sum() + log_like def uproxy(name, value, z): ref_dist = ref_trace[name]['fn'] @@ -98,7 +100,7 @@ def uproxy(name, value, z): hess = jax.hessian(log_prob, argnum)(*z_ref) log_like += jac @ z_diff + .5 * z_diff @ hess.reshape(k, k) @ z_diff.T - return log_prob(*z_ref) + log_like + return log_prob(*z_ref).sum() + log_like return proxy, uproxy @@ -123,31 +125,6 @@ def _tangent_curve(dist, value, tangent_fn): return tuple(tangent_fn(log_prob, argnum)(*z) for argnum in range(len(z))) -def check_estimator_handler(): - data = random.normal(random.PRNGKey(1), (10000,)) + 1 - z = {'x': jnp.array(0.9511842)} - model_trace = trace(seed(model, random.PRNGKey(2))).get_trace(data) - u = {name: site["value"] for name, site in model_trace.items() - if site["type"] == "plate" and site["args"][0] > site["args"][1]} - z_ref = {k: v + .1 for k, v in z.items()} - plate_sizes_all = {name: (model_trace[name]["args"][0], model_trace[name]["args"][0]) for name in u} - with subsample_size(model, plate_sizes_all): - ref_trace = trace(substitute(model, data=z_ref)).get_trace(data) - jac_all = {name: _tangent_curve(site['fn'], site['value'], jax.jacobian) for name, site in ref_trace.items() if - (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} - hess_all = {name: _tangent_curve(site['fn'], site['value'], jax.hessian) for name, site in ref_trace.items() if - (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} - ll_ref = {name: site['fn'].log_prob(site['value']) for name, site in ref_trace.items() if - (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} - - plate_sizes = {name: model_trace[name]["args"] for name in u} - ref_trace = trace(substitute(model, data={**u, **z_ref})).get_trace(data) - z_ref, _ = ravel_pytree(z_ref) - proxy_fn, uproxy_fn = my_taylor(ref_trace, ll_ref, jac_all, hess_all) - with estimator(model, {'obs': partial(my_estimator, proxy_fn=proxy_fn, uproxy_fn=uproxy_fn)}, plate_sizes): - print(log_density_obs_hmcecs(_wrap_model(model), (data,), {"_subsample_sites": u}, z)) - - def check_handler(): data = random.normal(random.PRNGKey(1), (10000,)) + 1 z = {'x': jnp.array(0.9511842)} @@ -155,23 +132,25 @@ def check_handler(): u = {name: site["value"] for name, site in model_trace.items() if site["type"] == "plate" and site["args"][0] > site["args"][1]} z_ref = {k: v + .1 for k, v in z.items()} - plate_sizes_all = {name: (model_trace[name]["args"][0], model_trace[name]["args"][0]) for name in u} - with subsample_size(model, plate_sizes_all): - ref_trace = trace(substitute(model, data=z_ref)).get_trace(data) - jac_all = {name: _tangent_curve(site['fn'], site['value'], jax.jacobian) for name, site in ref_trace.items() if - (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} - hess_all = {name: _tangent_curve(site['fn'], site['value'], jax.hessian) for name, site in ref_trace.items() if - (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} - ll_ref = {name: site['fn'].log_prob(site['value']) for name, site in ref_trace.items() if - (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} - - plate_sizes = {name: model_trace[name]["args"] for name in u} - ref_trace = trace(substitute(model, data={**u, **z_ref})).get_trace(data) - z_ref, _ = ravel_pytree(z_ref) - proxy_fn, uproxy_fn = my_taylor(ref_trace, ll_ref, jac_all, hess_all) - print(log_density_obs_hmcecs( - _wrap_est_model(model, {'obs': partial(my_estimator, proxy_fn=proxy_fn, uproxy_fn=uproxy_fn)}, plate_sizes), - (data,), {"_subsample_sites": u}, z)) + proxy_fn, uproxy_fn = my_taylor(model, (data,), {}, model_trace, z_ref, u) + plate_sizes = {name: model_trace[name]["args"] + (1,) for name in u} + wrapped_model = _wrap_est_model(model, {'obs': partial(my_estimator, proxy_fn=proxy_fn, uproxy_fn=uproxy_fn)}, + plate_sizes) + new_potential, _ = log_density(wrapped_model, (data,), {"_subsample_sites": u}, z) + print('new potential', new_potential) + + ld_fn = lambda args: jnp.sum(partial(log_density_obs_hmcecs, plain_model, (data,), {})(args)[0]) + jac_all, _ = ravel_pytree(jacfwd(ld_fn)(z_ref)) + print('ref jac all', jac_all) + hess_all, _ = ravel_pytree(hessian(ld_fn)(z_ref)) + k, = jac_all.shape + hess_all = hess_all.reshape((k, k)) + print('ref hess all', hess_all) + ll_ref = ld_fn(z_ref) + print('ref ll', ll_ref) + proxy_fn, uproxy_fn = taylor_proxy(z_ref, plain_model, ll_ref, jac_all, hess_all) + + print('reference potential', potential_est(plain_model, (data[u['N']],), {}, z, 10000, 100, proxy_fn, uproxy_fn)) if __name__ == '__main__': diff --git a/numpyro/contrib/ecs.py b/numpyro/contrib/ecs.py index 1316fb939..4cfbd0b70 100644 --- a/numpyro/contrib/ecs.py +++ b/numpyro/contrib/ecs.py @@ -1,18 +1,18 @@ """ Based on fehiepsi implementation: https://gist.github.com/fehiepsi/b4a5a80b245600b99467a0264be05fd5 """ -from collections import namedtuple import copy +from collections import namedtuple -from jax import device_put, lax, random, partial, jit, ops import jax.numpy as jnp +from jax import device_put, lax, random, partial, jit, jacobian, hessian, make_jaxpr import numpyro import numpyro.distributions as dist +from check_potential import my_estimator, my_taylor, estimator, subsample_size, _tangent_curve +from numpyro.contrib.hmcecs_utils import init_near_values from numpyro.handlers import substitute, trace, seed from numpyro.infer import MCMC, NUTS, log_likelihood from numpyro.infer.mcmc import MCMCKernel from numpyro.util import identity -from numpyro.contrib.hmcecs_utils import init_near_values -from check_potential import my_estimator, my_taylor, estimator HMC_ECS_State = namedtuple("HMC_ECS_State", "uz, hmc_state, accept_prob, rng_key") """ @@ -59,18 +59,14 @@ def _update_block(rng_key, u, n, m, g): :param g: number of subsample blocks """ - if not (0 < g <= m): - raise ValueError(f'Block size 0 < {g} <= {m}') rng_key_block, rng_key_index = random.split(rng_key) chosen_block = random.randint(rng_key_block, shape=(), minval=0, maxval=g + 1) - idxs_new = random.choice(rng_key_index, n, shape=(m // g,), replace=False) + new_idx = random.randint(rng_key_index, minval=0, maxval=n, shape=(m,)) + block_mask = (jnp.arange(m) // g == chosen_block).astype(int) + rest_mask = (block_mask - 1) ** 2 - u_new = jnp.zeros(m, jnp.dtype(u)) - for i in range(m): # TODO: look into block update - u_new = ops.index_add(u_new, i, lax.cond(i // g == chosen_block, - i, lambda _: idxs_new[i % (m // g)], - i, lambda _: u[i])) + u_new = u * rest_mask + block_mask * new_idx return u_new @@ -107,7 +103,21 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): self._plate_sizes = {name: prototype_trace[name]["args"] + (min(prototype_trace[name]["args"][1] // 2, 100),) for name in u} - proxy_fn, uproxy_fn = self._proxy_gen_fn(model, model_args, model_kwargs, prototype_trace, z_ref, u) + plate_sizes_all = {name: (prototype_trace[name]["args"][0], prototype_trace[name]["args"][0]) for name in u} + with subsample_size(model, plate_sizes_all): + ref_trace = trace(substitute(model, data=z_ref)).get_trace(*model_args, **model_kwargs) + jac_all = {name: _tangent_curve(site['fn'], site['value'], jacobian) for name, site in ref_trace.items() + if + (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} + hess_all = {name: _tangent_curve(site['fn'], site['value'], hessian) for name, site in ref_trace.items() + if + (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} + ll_ref = {name: site['fn'].log_prob(site['value']) for name, site in ref_trace.items() if + (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} + + ref_trace = trace(substitute(model, data={**z_ref, **u})).get_trace(*model_args, + **model_kwargs) # TODO: check reparam + proxy_fn, uproxy_fn = self._proxy_gen_fn(ref_trace, ll_ref, jac_all, hess_all) estimators = {name: partial(self._estimator_fn, proxy_fn=proxy_fn, uproxy_fn=uproxy_fn) for name, site in prototype_trace.items() if @@ -145,7 +155,7 @@ def sample(self, state, model_args, model_kwargs): def model(data, *args, **kwargs): x = numpyro.sample("x", dist.Normal(0., 1.)) - with numpyro.plate("N", data.shape[0], subsample_size=300): + with numpyro.plate("N", data.shape[0], subsample_size=1000): batch = numpyro.subsample(data, event_dim=0) numpyro.sample("obs", dist.Normal(x, 1.), obs=batch) @@ -156,13 +166,18 @@ def plain_model(data, *args, **kwargs): if __name__ == '__main__': - data = random.normal(random.PRNGKey(1), (10000,)) + 1 + data = random.normal(random.PRNGKey(1), (10_000,)) + 1 kernel = NUTS(plain_model) - mcmc = MCMC(kernel, 3000, 3000) + state = kernel.init(random.PRNGKey(1), 500, None, (data,), {}) + print(make_jaxpr(kernel.sample)(state, (data,), {}), file=open('nuts_jaxpr.txt', 'w')) + mcmc = MCMC(kernel, 500, 500) mcmc.run(random.PRNGKey(1), data) + mcmc.print_summary(exclude_deterministic=False) z_ref = {k: v.mean() for k, v in mcmc.get_samples().items()} kernel = ECS(NUTS(model), estimator_fn=my_estimator, proxy_gen_fn=my_taylor, z_ref=z_ref) + state = kernel.init(random.PRNGKey(1), 500, None, (data,), {}) + print(make_jaxpr(kernel.sample)(state, (data,), {}), file=open('ecs_jaxpr.txt', 'w')) mcmc = MCMC(kernel, 1500, 1500) mcmc.run(random.PRNGKey(0), data, extra_fields=("accept_prob",)) # there is a bug when exclude_deterministic=True, which will be fixed upstream diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 8c8ae05af..f286e907c 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -13,7 +13,6 @@ _PYRO_STACK = [] - CondIndepStackFrame = namedtuple('CondIndepStackFrame', ['name', 'dim', 'size']) @@ -36,7 +35,7 @@ def apply_stack(msg): # A Messenger that sets msg["stop"] == True also prevents application # of postprocess_message by Messengers above it on the stack # via the pointer variable from the process_message loop - for handler in _PYRO_STACK[-pointer-1:]: + for handler in _PYRO_STACK[-pointer - 1:]: handler.postprocess_message(msg) return msg @@ -239,6 +238,7 @@ class plate(Messenger): is used as the plate dim. If `None` (default), the leftmost available dim is allocated. """ + def __init__(self, name, size, subsample_size=None, dim=None): self.name = name self.size = size @@ -266,10 +266,11 @@ def _subsample(name, size, subsample_size, dim): } apply_stack(msg) subsample = msg['value'] + subsample_size = msg['args'][1] # TODO: rewrite plate if subsample_size is not None and subsample_size != subsample.shape[0]: raise ValueError("subsample_size does not match len(subsample), {} vs {}.".format( subsample_size, len(subsample)) + - " Did you accidentally use different subsample_size in the model and guide?") + " Did you accidentally use different subsample_size in the model and guide?") cond_indep_stack = msg['cond_indep_stack'] occupied_dims = {f.dim for f in cond_indep_stack} if dim is None: @@ -334,7 +335,7 @@ def postprocess_message(self, msg): statement = "numpyro.subsample(..., event_dim={})".format(event_dim) raise ValueError( "Inside numpyro.plate({}, {}, dim={}) invalid shape of {}: {}" - .format(self.name, self.size, self.dim, statement, shape)) + .format(self.name, self.size, self.dim, statement, shape)) if self.subsample_size < self.size: value = msg["value"] new_value = jnp.take(value, self._indices, dim) From c59b317d94364f1150b50b052f167e7f4ba6c13b Mon Sep 17 00:00:00 2001 From: Ola Date: Fri, 8 Jan 2021 14:47:07 +0100 Subject: [PATCH 38/93] cleaned code. --- check_potential.py | 157 ---- numpyro/contrib/ecs.py | 31 +- numpyro/contrib/hmcecs.py | 1205 ------------------------------- numpyro/contrib/hmcecs_utils.py | 354 +++------ numpyro/handlers.py | 15 +- numpyro/infer/util.py | 8 +- 6 files changed, 120 insertions(+), 1650 deletions(-) delete mode 100644 check_potential.py delete mode 100644 numpyro/contrib/hmcecs.py diff --git a/check_potential.py b/check_potential.py deleted file mode 100644 index a6d6dc6a1..000000000 --- a/check_potential.py +++ /dev/null @@ -1,157 +0,0 @@ -from functools import partial - -import jax -import jax.numpy as jnp -from jax import random, hessian, jacfwd - -import numpyro -import numpyro.distributions as dist -from numpyro.contrib.hmcecs_utils import log_density_obs_hmcecs -from numpyro.contrib.hmcecs_utils import potential_est, taylor_proxy -from numpyro.handlers import substitute, trace, seed -from numpyro.infer.util import log_density -from numpyro.primitives import Messenger, _subsample_fn -from numpyro.util import ravel_pytree - - -def _wrap_model(model): - def fn(*args, **kwargs): - subsample_values = kwargs.pop("_subsample_sites", {}) - with substitute(data=subsample_values): - model(*args, **kwargs) - - return fn - - -def _wrap_est_model(model, estimators, plate_sizes): - def fn(*args, **kwargs): - subsample_values = kwargs.pop("_subsample_sites", {}) - with substitute(data=subsample_values): - with estimator(model, estimators, plate_sizes): - model(*args, **kwargs) - - return fn - - -def model(data): - x = numpyro.sample("x", dist.Normal(0., 1.)) - with numpyro.plate("N", data.shape[0], subsample_size=100): - batch = numpyro.subsample(data, event_dim=0) - obs = numpyro.sample("obs", dist.Normal(x, 1.), obs=batch) - - -def plain_model(data): - x = numpyro.sample("x", dist.Normal(0., 1.)) - obs = numpyro.sample("obs", dist.Normal(x, 1.), obs=data) - - -class estimator(Messenger): - def __init__(self, fn, estimators, plate_sizes): - self.estimators = estimators - self.plate_sizes = plate_sizes - super(estimator, self).__init__(fn) - - def process_message(self, msg): - if msg['type'] == 'sample' and msg['is_observed'] and msg['cond_indep_stack']: - log_prob = msg['fn'].log_prob - msg['scale'] = 1. - msg['fn'].log_prob = lambda *args, **kwargs: \ - self.estimators[msg['name']](*args, name=msg['name'], z=_extract_params(msg['fn']), log_prob=log_prob, - sizes=self.plate_sizes[msg['cond_indep_stack'][0].name], - **kwargs) # TODO: check multiple levels - - -def my_estimator(value, name, z, sizes, log_prob, proxy_fn=lambda x, y: x, uproxy_fn=lambda x: x, **kwargs, ): - n, m, g = sizes - ll_sub = log_prob(value).sum() - diff = ll_sub - uproxy_fn(name, value, z) - l_hat = proxy_fn(name, z) + n / m * diff - sigma = n ** 2 / m * jnp.var(diff) - return l_hat - .5 * sigma - - -def _extract_params(distribution): - params, _ = distribution.tree_flatten() - return params - - -def my_taylor(ref_trace, ll_ref, jac_all, hess_all): - def proxy(name, z): - z_ref = _extract_params(ref_trace[name]['fn']) - jac, hess = jac_all[name], hess_all[name] - log_like = jnp.array(0.) - for argnum in range(len(z_ref)): - z_diff = z[argnum] - z_ref[argnum] - j, h = jac[argnum], hess[argnum] - k, = j.shape - log_like += j.T @ z_diff + .5 * z_diff.T @ h.reshape(k, k) @ z_diff - return ll_ref[name].sum() + log_like - - def uproxy(name, value, z): - ref_dist = ref_trace[name]['fn'] - z_ref, aux_data = ref_dist.tree_flatten() - - log_prob = lambda *params: ref_dist.tree_unflatten(aux_data, params).log_prob(value).sum() - log_like = jnp.array(0.) - for argnum in range(len(z_ref)): - z_diff = z[argnum] - z_ref[argnum] - jac = jax.jacobian(log_prob, argnum)(*z_ref) - k, = jac.shape - hess = jax.hessian(log_prob, argnum)(*z_ref) - log_like += jac @ z_diff + .5 * z_diff @ hess.reshape(k, k) @ z_diff.T - - return log_prob(*z_ref).sum() + log_like - - return proxy, uproxy - - -class subsample_size(Messenger): - def __init__(self, fn, plate_sizes, rng_key=None): - super(subsample_size, self).__init__(fn) - self.plate_sizes = plate_sizes - self.rng_key = rng_key - - def process_message(self, msg): - if msg['type'] == 'plate' and msg['args'] and msg["args"][0] > msg["args"][1]: - if msg['name'] in self.plate_sizes: - msg['args'] = self.plate_sizes[msg['name']] - msg['value'] = _subsample_fn(*msg['args'], self.rng_key) if msg["args"][1] < msg["args"][ - 0] else jnp.arange(msg["args"][0]) - - -def _tangent_curve(dist, value, tangent_fn): - z, aux_data = dist.tree_flatten() - log_prob = lambda *params: dist.tree_unflatten(aux_data, params).log_prob(value).sum() - return tuple(tangent_fn(log_prob, argnum)(*z) for argnum in range(len(z))) - - -def check_handler(): - data = random.normal(random.PRNGKey(1), (10000,)) + 1 - z = {'x': jnp.array(0.9511842)} - model_trace = trace(seed(model, random.PRNGKey(2))).get_trace(data) - u = {name: site["value"] for name, site in model_trace.items() - if site["type"] == "plate" and site["args"][0] > site["args"][1]} - z_ref = {k: v + .1 for k, v in z.items()} - proxy_fn, uproxy_fn = my_taylor(model, (data,), {}, model_trace, z_ref, u) - plate_sizes = {name: model_trace[name]["args"] + (1,) for name in u} - wrapped_model = _wrap_est_model(model, {'obs': partial(my_estimator, proxy_fn=proxy_fn, uproxy_fn=uproxy_fn)}, - plate_sizes) - new_potential, _ = log_density(wrapped_model, (data,), {"_subsample_sites": u}, z) - print('new potential', new_potential) - - ld_fn = lambda args: jnp.sum(partial(log_density_obs_hmcecs, plain_model, (data,), {})(args)[0]) - jac_all, _ = ravel_pytree(jacfwd(ld_fn)(z_ref)) - print('ref jac all', jac_all) - hess_all, _ = ravel_pytree(hessian(ld_fn)(z_ref)) - k, = jac_all.shape - hess_all = hess_all.reshape((k, k)) - print('ref hess all', hess_all) - ll_ref = ld_fn(z_ref) - print('ref ll', ll_ref) - proxy_fn, uproxy_fn = taylor_proxy(z_ref, plain_model, ll_ref, jac_all, hess_all) - - print('reference potential', potential_est(plain_model, (data[u['N']],), {}, z, 10000, 100, proxy_fn, uproxy_fn)) - - -if __name__ == '__main__': - check_handler() diff --git a/numpyro/contrib/ecs.py b/numpyro/contrib/ecs.py index 4cfbd0b70..0393094e0 100644 --- a/numpyro/contrib/ecs.py +++ b/numpyro/contrib/ecs.py @@ -7,8 +7,14 @@ import numpyro import numpyro.distributions as dist -from check_potential import my_estimator, my_taylor, estimator, subsample_size, _tangent_curve -from numpyro.contrib.hmcecs_utils import init_near_values +from numpyro.contrib.hmcecs_utils import ( + init_near_values, + difference_estimator_fn, + taylor_proxy, + estimator, + subsample_size, + _tangent_curve +) from numpyro.handlers import substitute, trace, seed from numpyro.infer import MCMC, NUTS, log_likelihood from numpyro.infer.mcmc import MCMCKernel @@ -29,15 +35,6 @@ """ -def _wrap_model(model): - def fn(*args, **kwargs): - subsample_values = kwargs.pop("_subsample_sites", {}) - with substitute(data=subsample_values): - model(*args, **kwargs) - - return fn - - def _wrap_est_model(model, estimators, plate_sizes): def fn(*args, **kwargs): subsample_values = kwargs.pop("_subsample_sites", {}) @@ -75,7 +72,7 @@ class ECS(MCMCKernel): def __init__(self, inner_kernel, estimator_fn=None, proxy_gen_fn=None, z_ref=None): self.inner_kernel = copy.copy(inner_kernel) - self.inner_kernel._model = inner_kernel.model # Removed wrapper! + self.inner_kernel._model = inner_kernel.model self._proxy_gen_fn = proxy_gen_fn self._estimator_fn = estimator_fn self._z_ref = z_ref @@ -167,18 +164,14 @@ def plain_model(data, *args, **kwargs): if __name__ == '__main__': data = random.normal(random.PRNGKey(1), (10_000,)) + 1 + # Get reference parameters kernel = NUTS(plain_model) - state = kernel.init(random.PRNGKey(1), 500, None, (data,), {}) - print(make_jaxpr(kernel.sample)(state, (data,), {}), file=open('nuts_jaxpr.txt', 'w')) mcmc = MCMC(kernel, 500, 500) mcmc.run(random.PRNGKey(1), data) mcmc.print_summary(exclude_deterministic=False) z_ref = {k: v.mean() for k, v in mcmc.get_samples().items()} - - kernel = ECS(NUTS(model), estimator_fn=my_estimator, proxy_gen_fn=my_taylor, z_ref=z_ref) - state = kernel.init(random.PRNGKey(1), 500, None, (data,), {}) - print(make_jaxpr(kernel.sample)(state, (data,), {}), file=open('ecs_jaxpr.txt', 'w')) + # Compute HMCECS + kernel = ECS(NUTS(model), estimator_fn=difference_estimator_fn, proxy_gen_fn=taylor_proxy, z_ref=z_ref) mcmc = MCMC(kernel, 1500, 1500) mcmc.run(random.PRNGKey(0), data, extra_fields=("accept_prob",)) - # there is a bug when exclude_deterministic=True, which will be fixed upstream mcmc.print_summary(exclude_deterministic=False) diff --git a/numpyro/contrib/hmcecs.py b/numpyro/contrib/hmcecs.py deleted file mode 100644 index 6b3e25183..000000000 --- a/numpyro/contrib/hmcecs.py +++ /dev/null @@ -1,1205 +0,0 @@ -"""Contributed code for HMC and NUTS energy conserving sampling adapted from """ - -import math -import os -import warnings -from collections import namedtuple -from itertools import chain - -import jax.numpy as jnp -from jax import device_put, lax, partial, random, vmap, jacfwd, hessian, jit, ops -from jax.dtypes import canonicalize_dtype -from jax.flatten_util import ravel_pytree - -import numpyro.distributions as dist -from numpyro.contrib.hmcecs_utils import potential_est, init_near_values, tuplemerge, \ - model_args_sub, model_kwargs_sub, taylor_proxy, svi_proxy, log_density_obs_hmcecs, signed_estimator -from numpyro.infer.hmc_util import ( - IntegratorState, - build_tree, - euclidean_kinetic_energy, - find_reasonable_step_size, - velocity_verlet, - warmup_adapter -) -from numpyro.infer.mcmc import MCMCKernel -from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model -from numpyro.util import cond, fori_loop, identity - -HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob', - 'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key']) - -HMCECSState = namedtuple("HMCECState", ['u', 'hmc_state', 'll_u', 'sign', 'z_and_sign']) - -""" -A :func:`~collections.namedtuple` consisting of the following fields: - - - **i** - iteration. This is reset to 0 after warmup. - - **z** - Python collection representing values (unconstrained samples from - the posterior) at latent sites. - - **z_grad** - Gradient of potential energy w.r.t. latent sample sites. - - **potential_energy** - Potential energy computed at the given value of ``z``. - - **energy** - Sum of potential energy and kinetic energy of the current state. - - **num_steps** - Number of steps in the Hamiltonian trajectory (for diagnostics). - - **accept_prob** - Acceptance probability of the proposal. Note that ``z`` - does not correspond to the proposal if it is rejected. - - **mean_accept_prob** - Mean acceptance probability until current iteration - during warmup adaptation or sampling (for diagnostics). - - **diverging** - A boolean value to indicate whether the current trajectory is diverging. - - **adapt_state** - A ``HMCAdaptState`` namedtuple which contains adaptation information - during warmup: - - + **step_size** - Step size to be used by the integrator in the next iteration. - + **inverse_mass_matrix** - The inverse mass matrix to be used for the next - iteration. - + **mass_matrix_sqrt** - The square root of mass matrix to be used for the next - iteration. In case of dense mass, this is the Cholesky factorization of the - mass matrix. - - - **rng_key** - random number generator seed used for the iteration. - - **u** - Subsample - - **blocks** - blocks in which the subsample is divided - - **z_ref** - MAP estimation of the model parameters to initialize the subsampling. - - **ll_map** - Log likelihood of the map estimated parameters. - - **jac_map** - Jacobian vector from the map estimated parameters. - - **hess_map** - Hessian matrix from the map estimated parameters - - **Control variates** - Log likelihood correction - - **ll_u** - Log likelihood of the subsample -""" - - -def _get_num_steps(step_size, trajectory_length): - num_steps = jnp.clip(trajectory_length / step_size, a_min=1) - # NB: casting to jnp.int64 does not take effect (returns jnp.int32 instead) - # if jax_enable_x64 is False - return num_steps.astype(canonicalize_dtype(jnp.int64)) - - -def momentum_generator(prototype_r, mass_matrix_sqrt, rng_key): - _, unpack_fn = ravel_pytree(prototype_r) - eps = random.normal(rng_key, jnp.shape(mass_matrix_sqrt)[:1]) - if mass_matrix_sqrt.ndim == 1: - r = jnp.multiply(mass_matrix_sqrt, eps) - return unpack_fn(r) - elif mass_matrix_sqrt.ndim == 2: - r = jnp.dot(mass_matrix_sqrt, eps) - return unpack_fn(r) - else: - raise ValueError("Mass matrix has incorrect number of dims.") - - -@partial(jit, static_argnums=(2, 3, 4)) -def _update_block(rng_key, u, n, m, g): - """Returns indexes of the new subsample. The update mechanism selects blocks of indices within the subsample to be updated. - The number of indexes to be updated depend on the block size, higher block size more correlation among elements in the subsample. - :param rng_key - :param u subsample indexes - :param n total number of data - :param m subsample size - :param g block size: subsample subdivision""" - - if (g > m) or (g < 1): - raise ValueError( - 'Block size (g) = {} needs to = or > than 1 and smaller than the subsample size {}'.format(g, m)) - rng_key_block, rng_key_index = random.split(rng_key) - # uniformly choose block to update - chosen_block = random.randint(rng_key, shape=(), minval=0, maxval=g + 1) - idxs_new = random.randint(rng_key_index, shape=(m // g,), minval=0, - maxval=n) # choose block within the subsample to update - u_new = jnp.zeros(m, jnp.dtype(u)) # empty array with size m - for i in range(m): - # if index in the subsample // g = chosen block : pick new indexes from the subsample size - # else not update: keep the same indexes - u_new = ops.index_add(u_new, i, - lax.cond(i // g == chosen_block, i, lambda _: idxs_new[i % (m // g)], i, lambda _: u[i])) - return u_new - - -# @partial(jit, static_argnums=(0,1,2)) -# @functools.partial(jit, static_argnums=(2)) -def _sample_u_poisson(rng_key, m, l): - """ Initialize subsamples u - ***References*** - 1. Hamiltonian Monte Carlo with Energy Conserving Subsampling - 2. The blockPoisson estimator for optimally tuned exact subsampling MCMC. - :param m: subsample size - :param l: lambda u blocks - :param g: number of blocks - """ - pois_key, sub_key = random.split(rng_key) - block_lengths = dist.discrete.Poisson(1).sample(pois_key, (l,)) # lambda block lengths - # u = random.randint(sub_key, (jnp.sum(block_lengths), m), 0, m) - # @partial(mask, in_shapes=['(_,)'], out_shape='(_, _)') - # def u_rand(block_lenghts): - # b = jnp.sum(block_lengths).astype(int) - # #return jit(random.randint, static_argnums=(0,1, 2,3))(sub_key, (b,m), 0, m) - # return random.randint(sub_key, (b,m), 0, m) - # u = u_rand([block_lengths],{})#dict(b=jnp.sum(block_lengths).astype(int),m=m,l=l)) - # print(u.shape) - b = jnp.sum(block_lengths) - u_random = jit(random.randint, static_argnums=(0, 1, 2, 3)) - u = u_random(sub_key, (b, m), 0, m) - # @partial(mask,in_shapes=['(tmp,)'],out_shape='(b,)') - # def u_rand(block_lengths): - # return jnp.zeros(jnp.sum(block_lengths)) - # u = u_rand([block_lengths],dict(tmp=l,b=jnp.sum(block_lengths))) - # print(u.shape) - # exit() - - return jnp.split(u, jnp.cumsum(block_lengths), axis=0) - - -@partial(jit, static_argnums=(2, 3, 4, 5)) -def _update_block_poisson(rng_key, u, m, l, g): - """ Update block of u, where the length of the block of indexes to update is given by the Poisson distribution. - ***References*** - 1. Hamiltonian Monte Carlo with Energy Conserving Subsampling - 2.T he blockPoisson estimator for optimally tuned exact subsampling MCMC. - :param rng_key - :param u: current subsample indexes - :param m: Subsample size - :param l: lambda - :param g: Block size within subsample - """ - if (g > m) or (g < 1): - raise ValueError( - 'Block size (g) = {} needs to = or > than 1 and smaller than the subsample size {}'.format(g, m)) - u = u.copy() - block_key, sample_key = random.split(rng_key) - num_updates = int(round(l / g, 0)) # choose lambda/g number of blocks to update - chosen_blocks = random.randint(block_key, (num_updates,), 0, l) - _sample_u_poisson_jit = jit(_sample_u_poisson, static_argnums=(2)) - new_blocks = _sample_u_poisson_jit(sample_key, m, num_updates) - for i, block in enumerate(chosen_blocks): - u[block] = new_blocks[i] - return u - - -def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, grad_potential_fn_gen=None, algo='NUTS'): - r""" - Hamiltonian Monte Carlo inference, using either fixed number of - steps or the No U-Turn Sampler (NUTS) with adaptive path length. - - **References:** - - 1. *MCMC Using Hamiltonian Dynamics*, - Radford M. Neal - 2. *The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo*, - Matthew D. Hoffman, and Andrew Gelman. - 3. *A Conceptual Introduction to Hamiltonian Monte Carlo`*, - Michael Betancourt - **ECS References*** - 1. Hamiltonian Monte Carlo with Energy Conserving Subsampling - 2. The blockPoisson estimator for optimally tuned exact subsampling MCMC. - - :param potential_fn: Python callable that computes the potential energy - given input parameters. The input parameters to `potential_fn` can be - any python collection type, provided that `init_params` argument to - `init_kernel` has the same type. - :param potential_fn_gen: Python callable that when provided with model - arguments / keyword arguments returns `potential_fn`. This - may be provided to do inference on the same model with changing data. - If the data shape remains the same, we can compile `sample_kernel` - once, and use the same for multiple inference runs. - :param kinetic_fn: Python callable that returns the kinetic energy given - inverse mass matrix and momentum. If not provided, the default is - euclidean kinetic energy. - :param str algo: Whether to run ``HMC`` with fixed number of steps or ``NUTS`` - with adaptive path length. Default is ``NUTS``. - :return: a tuple of callables (`init_kernel`, `sample_kernel`), the first - one to initialize the sampler, and the second one to generate samples - given an existing one. - - .. warning:: - Instead of using this interface directly, we would highly recommend you - to use the higher level :class:`numpyro.infer.MCMC` API instead. - - **Example** - - .. doctest:: - - >>> import jax - >>> from jax import random - >>> import jax.numpy as jnp - >>> import numpyro - >>> import numpyro.distributions as dist - >>> from numpyro.infer.hmc import hmc - >>> from numpyro.infer.util import initialize_model - >>> from numpyro.util import fori_collect - - >>> true_coefs = jnp.array([1., 2., 3.]) - >>> data = random.normal(random.PRNGKey(2), (2000, 3)) - >>> dim = 3 - >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3)) - >>> - >>> def model(data, labels): - ... coefs_mean = jnp.zeros(dim) - ... coefs = numpyro.sample('beta', dist.Normal(coefs_mean, jnp.ones(3))) - ... intercept = numpyro.sample('intercept', dist.Normal(0., 10.)) - ... return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels) - >>> - >>> model_info = initialize_model(random.PRNGKey(0), model, model_args=(data, labels,)) - >>> init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS') - >>> hmc_state = init_kernel(model_info.param_info, - ... trajectory_length=10, - ... num_warmup=300) - >>> samples = fori_collect(0, 500, sample_kernel, hmc_state, - ... transform=lambda state: model_info.postprocess_fn(state.z)) - >>> print(jnp.mean(samples['beta'], axis=0)) # doctest: +SKIP - [0.9153987 2.0754058 2.9621222] - """ - if kinetic_fn is None: - kinetic_fn = euclidean_kinetic_energy - vv_update = None - trajectory_len = None - max_treedepth = None - wa_update = None - wa_steps = None - max_delta_energy = 1000. - if algo not in {'HMC', 'NUTS'}: - raise ValueError('`algo` must be one of `HMC` or `NUTS`.') - - def init_kernel(init_params, - num_warmup, - step_size=1.0, - inverse_mass_matrix=None, - adapt_step_size=True, - adapt_mass_matrix=True, - dense_mass=False, - target_accept_prob=0.8, - trajectory_length=2 * math.pi, - max_tree_depth=10, - find_heuristic_step_size=False, - model_args=(), - model_kwargs=None, - model=None, - sign=None, - sign_sum=None, - ll_ref=None, - jac_all=None, - z_ref=None, - hess_all=None, - ll_u=None, - n=None, - m=None, - u=None, - l=None, - rng_key=random.PRNGKey(0), - subsample_method=None, - estimator=None, - proxy_fn=None, - proxy_u_fn=None): - """ - Initializes the HMC sampler. - - :param init_params: Initial parameters to begin sampling. The type must - be consistent with the input type to `potential_fn`. - :param int num_warmup: Number of warmup steps; samples generated - during warmup are discarded. - :param float step_size: Determines the size of a single step taken by the - verlet integrator while computing the trajectory using Hamiltonian - dynamics. If not specified, it will be set to 1. - :param numpy.ndarray inverse_mass_matrix: Initial value for inverse mass matrix. - This may be adapted during warmup if adapt_mass_matrix = True. - If no value is specified, then it is initialized to the identity matrix. - :param bool adapt_step_size: A flag to decide if we want to adapt step_size - during warm-up phase using Dual Averaging scheme. - :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass - matrix during warm-up phase using Welford scheme. - :param bool dense_mass: A flag to decide if mass matrix is dense or - diagonal (default when ``dense_mass=False``) - :param float target_accept_prob: Target acceptance probability for step size - adaptation using Dual Averaging. Increasing this value will lead to a smaller - step size, hence the sampling will be slower but more robust. Default to 0.8. - :param float trajectory_length: Length of a MCMC trajectory for HMC. Default - value is :math:`2\\pi`. - :param int max_tree_depth: Max depth of the binary tree created during the doubling - scheme of NUTS sampler. Defaults to 10. - :param bool find_heuristic_step_size: whether to a heuristic function to adjust the - step size at the beginning of each adaptation window. Defaults to False. - :param tuple model_args: Model arguments if `potential_fn_gen` is specified. - :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified. - - :param model:, - :param sign:, - :param ll_ref:, - :param jac_all, - :param z_ref, - :param hess_all, - :param ll_u , - :param n , - :param m , - :param u, - :param l, - :param jax.random.PRNGKey rng_key: random key to be used as the source of - randomness. - :param subsample_method: Allows for activation of HMC-ECS or Subsampling, - :param estimator: Allows between an approximate likelihood estimator of the potential function (default), or an exact - calculation (poisson) - :param proxy_fn: Pre-compiled function that calculates the covariate (likelihood correction) for the parameters given the reference estimate - :param proxy_u_fn: Pre-compiled function that calculates the covariate (likelihood correction) for the paraneters given the subsample (model_args) - - """ - step_size = lax.convert_element_type(step_size, canonicalize_dtype(jnp.float64)) - nonlocal wa_update, trajectory_len, max_treedepth, vv_update, wa_steps - wa_steps = num_warmup - trajectory_len = trajectory_length - max_treedepth = max_tree_depth - - if isinstance(init_params, ParamInfo): - z, pe, z_grad = init_params - else: - z, pe, z_grad = init_params, None, None - - pe_fn = potential_fn - if potential_fn_gen: - if pe_fn is not None: - raise ValueError('Only one of `potential_fn` or `potential_fn_gen` must be provided.') - else: - if subsample_method == "perturb": - kwargs = {} if model_kwargs is None else model_kwargs - if estimator == "poisson": - pe_fn = potential_fn_gen(model=model, model_args=model_args, model_kwargs=kwargs, z=z, l=l, - proxy_fn=proxy_fn, proxy_u_fn=proxy_u_fn) - else: - pe_fn = potential_fn_gen(model=model, model_args=model_args, model_kwargs=kwargs, z=z, n=n, m=m, - proxy_fn=proxy_fn, proxy_u_fn=proxy_u_fn) - - else: - kwargs = {} if model_kwargs is None else model_kwargs - pe_fn = potential_fn_gen(*model_args, **kwargs) - if grad_potential_fn_gen: - kwargs = {} if model_kwargs is None else model_kwargs - gpe_fn = grad_potential_fn_gen(*model_args, **kwargs) - else: - gpe_fn = None - - find_reasonable_ss = None - - if find_heuristic_step_size: - find_reasonable_ss = partial(find_reasonable_step_size, - pe_fn, - kinetic_fn, - momentum_generator) - - wa_init, wa_update = warmup_adapter(num_warmup, - adapt_step_size=adapt_step_size, - adapt_mass_matrix=adapt_mass_matrix, - dense_mass=dense_mass, - target_accept_prob=target_accept_prob, - find_reasonable_step_size=find_reasonable_ss) - - rng_key_hmc, rng_key_wa, rng_key_momentum = random.split(rng_key, 3) - z_info = IntegratorState(z=z, potential_energy=pe, z_grad=z_grad) - wa_state = wa_init(z_info, rng_key_wa, step_size, - inverse_mass_matrix=inverse_mass_matrix, - mass_matrix_size=jnp.size(ravel_pytree(z)[0])) - - r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum) - - vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn) - - vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad) - - energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r) - - hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, - 0, 0., 0., False, wa_state, rng_key_hmc) - - z_and_sign = {**vv_state.z, 'sign': sign, "sign_sum": sign_sum} - - hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state, ll_u=ll_u, sign=sign, z_and_sign=z_and_sign) - - hmc_state = tuplemerge(hmc_sub_state._asdict(), hmc_state._asdict()) - - return device_put(hmc_state) - - def _hmc_next(step_size, inverse_mass_matrix, vv_state, - model_args, model_kwargs, rng_key, subsample_method, - estimator=None, - proxy_fn=None, - proxy_u_fn=None, - model=None, - ll_ref=None, jac_all=None, z=None, z_ref=None, hess_all=None, ll_u=None, u=None, - n=None, - m=None, - l=None, ): - if potential_fn_gen: - if grad_potential_fn_gen: - kwargs = {} if model_kwargs is None else model_kwargs - gpe_fn = grad_potential_fn_gen(*model_args, **kwargs, ) - pe_fn = potential_fn_gen(*model_args, **model_kwargs) - - else: - if subsample_method == "perturb": - if estimator == "poisson": - pe_fn = potential_fn_gen(model=model, - model_args=model_args, - model_kwargs=model_kwargs, - z=vv_state.z, - l=l, - proxy_fn=proxy_fn, - proxy_u_fn=proxy_u_fn) - - else: - pe_fn = potential_fn_gen(model=model, - model_args=model_args, - model_kwargs=model_kwargs, - z=vv_state.z, - n=n, - m=m, - proxy_fn=proxy_fn, - proxy_u_fn=proxy_u_fn) - kwargs = {} if model_kwargs is None else model_kwargs - else: - pe_fn = potential_fn_gen(*model_args, **model_kwargs) - nonlocal vv_update - _, vv_update = velocity_verlet(pe_fn, kinetic_fn) - - num_steps = _get_num_steps(step_size, trajectory_len) - - vv_state_new = fori_loop(0, num_steps, - lambda i, val: vv_update(step_size, inverse_mass_matrix, val), - vv_state) - energy_old = vv_state.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state.r) - energy_new = vv_state_new.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state_new.r) - delta_energy = energy_new - energy_old - delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy) - accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0) - diverging = delta_energy > max_delta_energy - transition = random.bernoulli(rng_key, accept_prob) - vv_state, energy = cond(transition, - (vv_state_new, energy_new), identity, - (vv_state, energy_old), identity) - - return vv_state, energy, num_steps, accept_prob, diverging - - def _nuts_next(step_size, inverse_mass_matrix, vv_state, - model_args, model_kwargs, rng_key, subsample_method, - estimator=None, - proxy_fn=None, proxy_u_fn=None, - model=None, - ll_ref=None, jac_all=None, z=None, z_ref=None, hess_all=None, ll_u=None, u=None, - n=None, m=None, l=None): - if potential_fn_gen: - nonlocal vv_update - if grad_potential_fn_gen: - kwargs = {} if model_kwargs is None else model_kwargs - gpe_fn = grad_potential_fn_gen(*model_args, **kwargs, ) - pe_fn = potential_fn_gen(*model_args, **model_kwargs) - else: - if subsample_method == "perturb": - if estimator == "poisson": - pe_fn = potential_fn_gen(model=model, - model_args=model_args, - model_kwargs=model_kwargs, - z=vv_state.z, - l=l, - proxy_fn=proxy_fn, - proxy_u_fn=proxy_u_fn) - - - - else: - pe_fn = potential_fn_gen(model=model, - model_args=model_args, - model_kwargs=model_kwargs, - z=vv_state.z, - n=n, - m=m, - proxy_fn=proxy_fn, - proxy_u_fn=proxy_u_fn) - else: - pe_fn = potential_fn_gen(*model_args, **model_kwargs) - _, vv_update = velocity_verlet(pe_fn, kinetic_fn) - - binary_tree = build_tree(vv_update, kinetic_fn, vv_state, - inverse_mass_matrix, step_size, rng_key, - max_delta_energy=max_delta_energy, - max_tree_depth=max_treedepth) - accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals - num_steps = binary_tree.num_proposals - vv_state = IntegratorState(z=binary_tree.z_proposal, - r=vv_state.r, - potential_energy=binary_tree.z_proposal_pe, - z_grad=binary_tree.z_proposal_grad) - return vv_state, binary_tree.z_proposal_energy, num_steps, accept_prob, binary_tree.diverging - - _next = _nuts_next if algo == 'NUTS' else _hmc_next - - def sample_kernel(hmc_state, model_args=(), model_kwargs=None, - subsample_method=None, - estimator=None, - proxy_fn=None, - proxy_u_fn=None, - model=None, - ll_ref=None, - jac_all=None, - z=None, - z_ref=None, - hess_all=None, - ll_u=None, - sign=None, - u=None, n=None, m=None, l=None, - sign_sum=None): - """ - Given an existing :data:`~numpyro.infer.mcmc.HMCState`, run HMC with fixed (possibly adapted) - step size and return a new :data:`~numpyro.infer.mcmc.HMCState`. - - :param hmc_state: Current sample (and associated state). - :param tuple model_args: Model arguments if `potential_fn_gen` is specified. - :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified. - :param subsample_method: Indicates if hmc energy conserving method shall be implemented for subsampling - :param proxy_fn - :param proxy_u_fn - :param model - :param ll_ref - :param jac_all - :param z - :param z_ref - :param hess_all - :param ll_u - :param u - :param n - :param m - :param l : lambda value for block poisson estimator method. Indicates the number of subsamples within a subsample - :return: new proposed :data:`~numpyro.infer.mcmc.HMCState` from simulating - Hamiltonian dynamics given existing state. - - """ - - model_kwargs = {} if model_kwargs is None else model_kwargs - if subsample_method == "perturb": - if estimator == "poisson": - model_args = [model_args_sub(u_i, model_args) for u_i in u] # here u = poisson_u - else: - model_args = model_args_sub(u, model_args) - rng_key, rng_key_momentum, rng_key_transition = random.split(hmc_state.rng_key, 3) - r = momentum_generator(hmc_state.z, hmc_state.adapt_state.mass_matrix_sqrt, rng_key_momentum) - vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad) - - vv_state, energy, num_steps, accept_prob, diverging = _next(hmc_state.adapt_state.step_size, - hmc_state.adapt_state.inverse_mass_matrix, - vv_state, - model_args, - model_kwargs, - rng_key_transition, - subsample_method, - estimator, - proxy_fn, - proxy_u_fn, - model, - ll_ref, jac_all, z, z_ref, hess_all, ll_u, u, - n, m, l) - # not update adapt_state after warmup phase - adapt_state = cond(hmc_state.i < wa_steps, - (hmc_state.i, accept_prob, vv_state, hmc_state.adapt_state), - lambda args: wa_update(*args), - hmc_state.adapt_state, - identity) - - itr = hmc_state.i + 1 - n = jnp.where(hmc_state.i < wa_steps, itr, itr - wa_steps) - mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n - - hmcstate = HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps, - accept_prob, mean_accept_prob, diverging, adapt_state, rng_key) - - # Highlight: The accepted proposals samples are in vv_state.z /hmcstate.z, we store them together with the sign - - sign_sum = cond(hmc_state.i < wa_steps, sign_sum, lambda sign_sum: float(0), sign_sum, identity) - z_and_sign = {**vv_state.z, 'sign': sign, "sign_sum": sign_sum} - hmc_sub_state = HMCECSState(u=u, hmc_state=hmc_state, ll_u=ll_u, sign=sign, z_and_sign=z_and_sign) - hmcstate = tuplemerge(hmc_sub_state._asdict(), hmcstate._asdict()) - - return hmcstate - - # Make `init_kernel` and `sample_kernel` visible from the global scope once - # `hmc` is called for sphinx doc generation. - if 'SPHINX_BUILD' in os.environ: - hmc.init_kernel = init_kernel - hmc.sample_kernel = sample_kernel - - return init_kernel, sample_kernel - - -def _log_prob(trace): - """ Compute probability of each observation """ - node = trace['observations'] - return jnp.sum(node['fn'].log_prob(node['value']), 1) - - -class HMCECS(MCMCKernel): - """ - Hamiltonian Monte Carlo inference, using fixed trajectory length, with - provision for step size and mass matrix adaptation. - - **References:** - - 1. *MCMC Using Hamiltonian Dynamics*, - Radford M. Neal - - :param model: Python callable containing Pyro :mod:`~numpyro.primitives`. - If model is provided, `potential_fn` will be inferred using the model. - :param potential_fn: Python callable that computes the potential energy - given input parameters. The input parameters to `potential_fn` can be - any python collection type, provided that `init_params` argument to - `init_kernel` has the same type. - :param kinetic_fn: Python callable that returns the kinetic energy given - inverse mass matrix and momentum. If not provided, the default is - euclidean kinetic energy. - :param float step_size: Determines the size of a single step taken by the - verlet integrator while computing the trajectory using Hamiltonian - dynamics. If not specified, it will be set to 1. - :param bool adapt_step_size: A flag to decide if we want to adapt step_size - during warm-up phase using Dual Averaging scheme. - :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass - matrix during warm-up phase using Welford scheme. - :param bool dense_mass: A flag to decide if mass matrix is dense or - diagonal (default when ``dense_mass=False``) - :param float target_accept_prob: Target acceptance probability for step size - adaptation using Dual Averaging. Increasing this value will lead to a smaller - step size, hence the sampling will be slower but more robust. Default to 0.8. - :param float trajectory_length: Length of a MCMC trajectory for HMC. Default - value is :math:`2\\pi`. - :param callable init_strategy: a per-site initialization function. - See :ref:`init_strategy` section for available functions. - :param bool find_heuristic_step_size: whether to a heuristic function to adjust the - step size at the beginning of each adaptation window. Defaults to False. - :param subsample_method If "perturb" is provided, the "potential_fn" function will be calculated - using the equations from section 7.2.1 in https://jmlr.org/papers/volume18/15-205/15-205.pdf - :param m subsample size - :param g block size - :param z_ref MAP estimate of the parameters - :param covariate_fn Proxy function to calculate the covariates for the likelihood correction - """ - - def __init__(self, - model=None, - potential_fn=None, - grad_potential=None, - kinetic_fn=None, - step_size=1.0, - adapt_step_size=True, - adapt_mass_matrix=True, - dense_mass=False, - target_accept_prob=0.8, - trajectory_length=2 * math.pi, - init_strategy=init_to_uniform, - find_heuristic_step_size=False, - subsample_method=None, - estimator=None, # poisson or not - proxy="taylor", - svi_fn=None, - m=None, - g=None, - z_ref=None, - algo="HMC", - postprocess_fn=None, - ): - if not (model is None) ^ (potential_fn is None): - raise ValueError('Only one of `model` or `potential_fn` must be specified.') - - self._model = model - self._potential_fn = potential_fn - self._grad_potential = grad_potential - self._kinetic_fn = kinetic_fn if kinetic_fn is not None else euclidean_kinetic_energy - self._step_size = step_size - self._adapt_step_size = adapt_step_size - self._adapt_mass_matrix = adapt_mass_matrix - self._dense_mass = dense_mass - self._target_accept_prob = target_accept_prob - self._trajectory_length = trajectory_length - self._algo = algo - self._max_tree_depth = 10 - self._init_strategy = init_strategy - self._find_heuristic_step_size = find_heuristic_step_size - # HMCECS parameters - self.subsample_method = subsample_method - self.m = m if m is not None else 4 - self.g = g if g is not None else 2 - self.z_ref = z_ref - self._n = None - self._ll_ref = None - self._jac_all = None - self._hess_all = None - self._ll_u = None - self._u = None - self._sign = None - self._sign_sum = float(0) - self._l = 100 - # Set on first call to init - self._init_fn = None - self._postprocess_fn = postprocess_fn - self._sample_fn = None - self._subsample_fn = None - self._sign = float(0) - self.proxy = proxy - self.svi_fn = svi_fn - self._proxy_fn = None - self._proxy_u_fn = None - self._signed_estimator_fn = None - self.estimator = estimator - - def _init_subsample_state(self, rng_key, model_args, model_kwargs, init_params, z_ref): - "Compute the jacobian, hessian and log likelihood for all the data. Used with taylor expansion proxy" - rng_key_subsample, rng_key_model, rng_key_hmc_init, rng_key_potential, rng_key = random.split(rng_key, 5) - - ld_fn = lambda args: jnp.sum(partial(log_density_obs_hmcecs, self._model, model_args, model_kwargs)(args)[0]) - self._jac_all, _ = ravel_pytree(jacfwd(ld_fn)(z_ref)) - hess_all, _ = ravel_pytree(hessian(ld_fn)(z_ref)) - k, = self._jac_all.shape - self._hess_all = hess_all.reshape((k, k)) - ld_fn = lambda args: partial(log_density_obs_hmcecs, self._model, model_args, model_kwargs)(args)[0] - self._ll_ref = ld_fn(z_ref) - - def _init_state(self, rng_key, model_args, model_kwargs, init_params): - if self.subsample_method is not None: - assert self.z_ref is not None, "Please provide a (i.e map) estimate for the parameters" - self._n = model_args[0].shape[0] - # Choose the covariate calculation method - if self.proxy == "svi": - self._proxy_fn, self._proxy_u_fn = svi_proxy(self.svi_fn, model_args, model_kwargs) - elif self.proxy == "taylor": - warnings.warn("Using default second order Taylor expansion, change by using the proxy flag to {svi}") - self._init_subsample_state(rng_key, model_args, model_kwargs, init_params, self.z_ref) - self._proxy_fn, self._proxy_u_fn = taylor_proxy(self.z_ref, self._model, self._ll_ref, self._jac_all, - self._hess_all) - if self.estimator == "poisson": - self._l = 50 # lambda subsamples - # _sample_u_poisson_jit = jit(_sample_u_poisson, static_argnums=(1, 2)) - self._u = _sample_u_poisson(rng_key, self.m, self._l) - - self._potential_fn = lambda model, model_args, model_kwargs, z, l, proxy_fn, proxy_u_fn: lambda z: \ - signed_estimator(model=model, model_args=model_args, - model_kwargs=model_kwargs, z=z, l=l, proxy_fn=proxy_fn, - proxy_u_fn=proxy_u_fn)[0] - # Initialize the hmc sampler: sample_fn = sample_kernel - self._init_fn, self._sample_fn = hmc(potential_fn_gen=self._potential_fn, - kinetic_fn=euclidean_kinetic_energy, - algo=self._algo) - - self._init_strategy = partial(init_near_values, values=self.z_ref) - # Initialize the model parameters - rng_key_init_model, rng_key = random.split(rng_key) - init_model_args = [model_args_sub(u_i, model_args) for u_i in self._u] - - self._init_strategy = partial(init_near_values, values=self.z_ref) - init_params, potential_fn, postprocess_fn, model_trace = initialize_model( - rng_key_init_model, - self._model, - init_strategy=self._init_strategy, - dynamic_args=True, - model_args=tuple([arg[0] for arg in next(chain(init_model_args))]), - # Highlight:Pick the first non-empty block ; 'chain' joins all the elements in the sublist , len(lists_of_lists) = n , len(chain(list_of_lists)) = sum(n_elements_inside_list=*n) - model_kwargs=model_kwargs) - self._postprocess_fn = self._poisson_samples_correction - - - else: - self._u = random.randint(rng_key, (self.m,), 0, self._n) - # Initialize the potential and gradient potential functions - self._potential_fn = lambda model, model_args, model_kwargs, z, n, m, proxy_fn, proxy_u_fn: lambda \ - z: potential_est(model=model, - model_args=model_args, model_kwargs=model_kwargs, z=z, n=n, m=m, proxy_fn=proxy_fn, - proxy_u_fn=proxy_u_fn) - - # Initialize the hmc sampler: sample_fn = sample_kernel - self._init_fn, self._sample_fn = hmc(potential_fn_gen=self._potential_fn, - kinetic_fn=euclidean_kinetic_energy, - algo=self._algo) - - self._init_strategy = partial(init_near_values, values=self.z_ref) - # Initialize the model parameters - rng_key_init_model, rng_key = random.split(rng_key) - - init_params, potential_fn, postprocess_fn, model_trace = initialize_model( - rng_key_init_model, - self._model, - init_strategy=self._init_strategy, - dynamic_args=True, - model_args=model_args_sub(self._u, model_args), - model_kwargs=model_kwargs) - - if (self.g > self.m) or (self.g < 1): - raise ValueError( - 'Block size (g) = {} needs to = or > than 1 and smaller than the subsample size {}'.format(self.g, - self.m)) - elif (self.m > self._n): - raise ValueError( - 'Subsample size (m) = {} needs to = or < than data size (n) {}'.format(self.m, self._n)) - - else: - if self._model is not None: - init_params, potential_fn, postprocess_fn, model_trace = initialize_model( - rng_key, - self._model, - dynamic_args=True, - model_args=model_args, - model_kwargs=model_kwargs) - - if any(v['type'] == 'param' for v in model_trace.values()): - warnings.warn("'param' sites will be treated as constants during inference. To define " - "an improper variable, please use a 'sample' site with log probability " - "masked out. For example, `sample('x', dist.LogNormal(0, 1).mask(False)` " - "means that `x` has improper distribution over the positive domain.") - if self._init_fn is None: - self._init_fn, self._sample_fn = hmc(potential_fn_gen=potential_fn, - kinetic_fn=self._kinetic_fn, - algo=self._algo) - self._postprocess_fn = postprocess_fn - elif self._init_fn is None: - self._init_fn, self._sample_fn = hmc(potential_fn=self._potential_fn, - kinetic_fn=self._kinetic_fn, - algo=self._algo) - - return init_params - - @property - def model(self): - return self._model - - @property - def sample_field(self): - if self.estimator == "poisson": - return "z_and_sign" - else: - return "z" - - @property - def default_fields(self): - if self.estimator == "poisson": - return ('z', 'diverging', 'sign', "z_and_sign") - else: - return 'z' - - def get_diagnostics_str(self, state): - return '{} steps of size {:.2e}. acc. prob={:.2f}'.format(state.num_steps, - state.adapt_state.step_size, - state.mean_accept_prob) - - def _block_indices(self, size, num_blocks): - a = jnp.repeat(jnp.arange(num_blocks - 1), size // num_blocks) - b = jnp.repeat(num_blocks - 1, size - len(jnp.repeat(jnp.arange(num_blocks - 1), size // num_blocks))) - return jnp.hstack((a, b)) - - def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}): - """Initialize sampling algorithms""" - # non-vectorized - if rng_key.ndim == 1: - rng_key, rng_key_init_model = random.split(rng_key) - # vectorized - else: - rng_key, rng_key_init_model = jnp.swapaxes(vmap(random.split)(rng_key), 0, 1) - - init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, - init_params) # should work for all cases - - if self._potential_fn and init_params is None: - raise ValueError('Valid value of `init_params` must be provided with' - ' `potential_fn`.') - if self.subsample_method == "perturb": - if self.estimator == "poisson": - init_model_args = [model_args_sub(u_i, model_args) for u_i in self._u] - else: - init_model_args = model_args_sub(self._u, model_args) - hmc_init_fn = lambda init_params, rng_key: self._init_fn(init_params=init_params, - num_warmup=num_warmup, - step_size=self._step_size, - adapt_step_size=self._adapt_step_size, - adapt_mass_matrix=self._adapt_mass_matrix, - dense_mass=self._dense_mass, - target_accept_prob=self._target_accept_prob, - trajectory_length=self._trajectory_length, - max_tree_depth=self._max_tree_depth, - find_heuristic_step_size=self._find_heuristic_step_size, - model_args=init_model_args, - model_kwargs=model_kwargs, - subsample_method=self.subsample_method, - estimator=self.estimator, - model=self._model, - ll_ref=self._ll_ref, - jac_all=self._jac_all, - z_ref=self.z_ref, - hess_all=self._hess_all, - ll_u=self._ll_u, - n=self._n, - m=self.m, - u=self._u, - l=self._l, - sign=self._sign, - sign_sum=self._sign_sum, - proxy_fn=self._proxy_fn, - proxy_u_fn=self._proxy_u_fn) - - if rng_key.ndim == 1: - # rng_key_hmc_init = jnp.array([1000966916, 171341646]) - rng_key_hmc_init, _ = random.split(rng_key) - - init_state = hmc_init_fn(init_params, rng_key_hmc_init) # HMCState + HMCECSState - if self.estimator == "poisson": - # signed pseudo-marginal algorithm with the block-Poisson estimator - # use the term signed PM for any pseudo-marginal algorithm that uses the technique in Lyne - # et al. (2015) where a pseudo-marginal sampler is run on the absolute value of the estimated - # posterior and subsequently sign-corrected by importance sampling. Similarly, we call the - # algorithm described in this section signed HMC-ECS - neg_ll, sign = signed_estimator(model=self._model, - model_args=[model_args_sub(u_i, model_args) for u_i in self._u], - model_kwargs=model_kwargs, - z=init_state.z, - l=self._l, - proxy_fn=self._proxy_fn, - proxy_u_fn=self._proxy_u_fn) - - self._sign = jnp.array(sign) # Highlight, do not append the sign here, not necessary - self._ll_u = neg_ll - - else: - self._ll_u = potential_est(model=self._model, - model_args=model_args_sub(self._u, model_args), - model_kwargs=model_kwargs, - z=init_state.z, - n=self._n, - m=self.m, - proxy_fn=self._proxy_fn, - proxy_u_fn=self._proxy_u_fn) - z_and_sign = {**init_state.z, 'sign': self._sign, - "sign_sum": self._sign_sum} # ,"num_warmup":num_warmup} - hmc_init_sub_state = HMCECSState(u=self._u, - hmc_state=init_state.hmc_state, - ll_u=self._ll_u, - sign=self._sign, - z_and_sign=z_and_sign) - init_sub_state = tuplemerge(init_state._asdict(), hmc_init_sub_state._asdict()) - - return init_sub_state - else: # TODO: What is this for? It does not go into it for num_chains>1 - raise ValueError("Not implemented for chains > 1") - # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some - # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth, - # wa_steps because those variables do not depend on traced args: init_params, rng_key. - init_state = vmap(hmc_init_fn)(init_params, rng_key) - if self.estimator == "poisson": - neg_ll, sign = signed_estimator(model=self._model, - model_args=[model_args_sub(u_i, model_args) for u_i in self._u], - model_kwargs=model_kwargs_sub, - z=init_state.z, - l=self._l, - proxy_fn=self._proxy_fn, - proxy_u_fn=self._proxy_u_fn) - self._sign = jnp.array(sign) - self._ll_u = neg_ll - - else: - self._ll_u = potential_est(model=self._model, - model_args=model_args_sub(self._u, model_args), - model_kwargs=model_kwargs, - z=init_state.z, - n=self._n, - m=self.m, - proxy_fn=self._proxy_fn, - proxy_u_fn=self._proxy_u_fn) - z_and_sign = {**vv_state.z, 'sign': self._sign, "sign_sum": self._sign_sum} - hmc_init_sub_fn = lambda init_params, rng_key: HMCECSState(u=self._u, hmc_state=init_state, - ll_u=self._ll_u, sign=self._sign, - z_and_sign=z_and_sign) - - init_subsample_state = vmap(hmc_init_sub_fn)(init_params, rng_key) - - sample_fn = vmap(self._sample_fn, in_axes=(0, None, None)) - HMCCombinedState = tuplemerge(init_state._asdict, init_subsample_state._asdict()) - self._sample_fn = sample_fn - return HMCCombinedState - - else: - hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731 - init_params, - num_warmup=num_warmup, - step_size=self._step_size, - adapt_step_size=self._adapt_step_size, - adapt_mass_matrix=self._adapt_mass_matrix, - dense_mass=self._dense_mass, - target_accept_prob=self._target_accept_prob, - trajectory_length=self._trajectory_length, - max_tree_depth=self._max_tree_depth, - find_heuristic_step_size=self._find_heuristic_step_size, - model_args=model_args, - model_kwargs=model_kwargs, - rng_key=rng_key, - ) - if rng_key.ndim == 1: - init_state = hmc_init_fn(init_params, rng_key) - return init_state - else: - # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some - # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth, - # wa_steps because those variables do not depend on traced args: init_params, rng_key. - init_state = vmap(hmc_init_fn)(init_params, rng_key) - sample_fn = vmap(self._sample_fn, in_axes=(0, None, None)) - self._sample_fn = sample_fn - return init_state - - def _poisson_postprocess(self, states): - """Changes the support of the parameters samples by using the sign estimated during the sampling - Ir = Sum [z_j*sign_j] / Sum (sign_j)""" - states_params = {k: states[k] for k in - states.keys() - {'sign', 'sign_sum'}} # change the support for all the parameters sampled - states_params = {key: (states_params[key] * states["sign"]) / states["sign"] for key in states_params.keys()} - return states_params - - def _poisson_samples_correction(self, states, *args, **kwargs): - """Changes the support of the samples by using the sign estimated during the samplinghttps://github.com/pyro-ppl/funsor - Ir = Sum [z_j*sign_j] / Sum (sign_j)""" - return self._poisson_postprocess - - def postprocess_fn(self, args, kwargs): - if self._postprocess_fn is None: - return identity - else: - return self._postprocess_fn(*args, **kwargs) - - def sample(self, state, model_args, model_kwargs): - """ - Run HMC from the given :data:`~numpyro.infer.hmc.HMCState` and return the resulting - :data:`~numpyro.infer.hmc.HMCState`. - - :param HMCState state: Represents the current state. - :param model_args: Arguments provided to the model. - :param model_kwargs: Keyword arguments provided to the model. - :return: Next `state` after running HMC. - """ - - if self.subsample_method == "perturb": - rng_key_subsample, rng_key_transition, rng_key_likelihood, rng_key = random.split( - state.rng_key, 4) - if self.estimator == "poisson": - # _sample_u_poisson_jit = jit(_sample_u_poisson,static_argnums=(0,1,2)) - u_new = _sample_u_poisson(rng_key, self.m, self._l) - neg_ll, sign = signed_estimator(model=self._model, - model_args=[model_args_sub(u_i, model_args) for u_i in u_new], - model_kwargs=model_kwargs, - z=state.z, - l=self._l, - proxy_fn=self._proxy_fn, - proxy_u_fn=self._proxy_u_fn) - self._sign = jnp.array(sign) - state.z_and_sign["sign_sum"] += self._sign # TODO: Probably is a multiplication - # Correct the negativeloglikelihood by substracting the density of the prior to calculate the potential - llu_new = jnp.min(jnp.array([0, -neg_ll + state.ll_u])) - - else: - u_new = _update_block(rng_key_subsample, state.u, self._n, self.m, self.g) - # estimate likelihood of subsample with single block updated - llu_new = potential_est(model=self._model, - model_args=model_args_sub(u_new, model_args), - model_kwargs=model_kwargs, - z=state.z, - n=self._n, - m=self.m, - proxy_fn=self._proxy_fn, - proxy_u_fn=self._proxy_u_fn) - # accept new subsample with probability min(1,L^{hat}_{u_new}(z) - L^{hat}_{u}(z)) - # NOTE: latent variables (z aka theta) same, subsample indices (u) different by one block. - accept_prob = jnp.clip(jnp.exp(-llu_new + state.ll_u), a_max=1.) - transition = random.bernoulli(rng_key_transition, accept_prob) # TODO: Why Bernoulli instead of Uniform? - u, ll_u = cond(transition, - (u_new, llu_new), identity, - (state.u, state.ll_u), identity) - - ######## UPDATE PARAMETERS ########## - z_and_sign = {**state.z, 'sign': self._sign, "sign_sum": self._sign_sum} - hmc_subsamplestate = HMCECSState(u=u, hmc_state=state.hmc_state, ll_u=ll_u, sign=self._sign, - z_and_sign=z_and_sign) - hmc_subsamplestate = tuplemerge(hmc_subsamplestate._asdict(), state._asdict()) - return self._sample_fn(hmc_subsamplestate, - model_args=model_args, - model_kwargs=model_kwargs, - subsample_method=self.subsample_method, - estimator=self.estimator, - proxy_fn=self._proxy_fn, - proxy_u_fn=self._proxy_u_fn, - model=self._model, - ll_ref=self._ll_ref, - jac_all=self._jac_all, - z=state.z, - z_ref=self.z_ref, - hess_all=self._hess_all, - ll_u=ll_u, - u=u, - n=self._n, - m=self.m, - l=self._l, - sign=self._sign, - sign_sum=state.z_and_sign["sign_sum"]) - - else: - return self._sample_fn(state, model_args, model_kwargs) - - -class NUTS(HMCECS): - """ - Hamiltonian Monte Carlo inference, using the No U-Turn Sampler (NUTS) - with adaptive path length and mass matrix adaptation. - - **References:** - - 1. *MCMC Using Hamiltonian Dynamics*, - Radford M. Neal - 2. *The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo*, - Matthew D. Hoffman, and Andrew Gelman. - 3. *A Conceptual Introduction to Hamiltonian Monte Carlo`*, - Michael Betancourt - - :param model: Python callable containing Pyro :mod:`~numpyro.primitives`. - If model is provided, `potential_fn` will be inferred using the model. - :param potential_fn: Python callable that computes the potential energy - given input parameters. The input parameters to `potential_fn` can be - any python collection type, provided that `init_params` argument to - `init_kernel` has the same type. - :param kinetic_fn: Python callable that returns the kinetic energy given - inverse mass matrix and momentum. If not provided, the default is - euclidean kinetic energy. - :param float step_size: Determines the size of a single step taken by the - verlet integrator while computing the trajectory using Hamiltonian - dynamics. If not specified, it will be set to 1. - :param bool adapt_step_size: A flag to decide if we want to adapt step_size - during warm-up phase using Dual Averaging scheme. - :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass - matrix during warm-up phase using Welford scheme. - :param bool dense_mass: A flag to decide if mass matrix is dense or - diagonal (default when ``dense_mass=False``) - :param float target_accept_prob: Target acceptance probability for step size - adaptation using Dual Averaging. Increasing this value will lead to a smaller - step size, hence the sampling will be slower but more robust. Default to 0.8. - :param float trajectory_length: Length of a MCMC trajectory for HMC. This arg has - no effect in NUTS sampler. - :param int max_tree_depth: Max depth of the binary tree created during the doubling - scheme of NUTS sampler. Defaults to 10. - :param callable init_strategy: a per-site initialization function. - See :ref:`init_strategy` section for available functions. - :param bool find_heuristic_step_size: whether to a heuristic function to adjust the - step size at the beginning of each adaptation window. Defaults to False. - """ - - def __init__(self, - model=None, - potential_fn=None, - kinetic_fn=None, - step_size=1.0, - adapt_step_size=True, - adapt_mass_matrix=True, - dense_mass=False, - target_accept_prob=0.8, - trajectory_length=None, - max_tree_depth=10, - init_strategy=init_to_uniform, - find_heuristic_step_size=False): - super(NUTS, self).__init__(potential_fn=potential_fn, model=model, kinetic_fn=kinetic_fn, - step_size=step_size, adapt_step_size=adapt_step_size, - adapt_mass_matrix=adapt_mass_matrix, dense_mass=dense_mass, - target_accept_prob=target_accept_prob, - trajectory_length=trajectory_length, - init_strategy=init_strategy, - find_heuristic_step_size=find_heuristic_step_size) - self._max_tree_depth = max_tree_depth - self._algo = 'NUTS' diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/hmcecs_utils.py index 6f1cf8ff1..15b5ba0ad 100644 --- a/numpyro/contrib/hmcecs_utils.py +++ b/numpyro/contrib/hmcecs_utils.py @@ -1,179 +1,14 @@ from collections import namedtuple -from functools import partial import jax import jax.numpy as jnp -from jax import grad, value_and_grad -from jax.tree_util import tree_multimap -import numpyro.distributions as dist -from numpyro.distributions.util import is_identically_one -from numpyro.handlers import substitute, trace -from numpyro.util import ravel_pytree +from numpyro.primitives import Messenger, _subsample_fn IntegratorState = namedtuple('IntegratorState', ['z', 'r', 'potential_energy', 'z_grad']) IntegratorState.__new__.__defaults__ = (None,) * len(IntegratorState._fields) -def model_args_sub(u, model_args): - """Subsample observations and features according to u subsample indexes""" - args = [] - for arg in model_args: - if isinstance(arg, jnp.ndarray) and arg.shape[0] > len(u): - args.append(jnp.take(arg, u, axis=0)) - else: - args.append(arg) - return tuple(args) - - -def model_kwargs_sub(u, kwargs): - """Subsample observations and features""" - for key_arg, val_arg in kwargs.items(): - if key_arg == "observations" or key_arg == "features": - kwargs[key_arg] = jnp.take(val_arg, u, axis=0) - return kwargs - - -def log_density_obs_hmcecs(model, model_args, model_kwargs, params): - model = substitute(model, data=params) - model_trace = trace(model).get_trace(*model_args, **model_kwargs) - # model = substitute(model, data=params) - # with plate_to_enum_plate(): - # model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs) - log_joint = jnp.array(0.) - for site in model_trace.values(): - if site['type'] == 'sample' and site['is_observed'] and not isinstance(site['fn'], dist.PRNGIdentity): - value = site['value'] - intermediates = site['intermediates'] - scale = site['scale'] - if intermediates: - log_prob = site['fn'].log_prob(value, intermediates) - else: - log_prob = site['fn'].log_prob(value) - if (scale is not None) and (not is_identically_one(scale)): - log_prob = scale * log_prob - # log_joint += log_prob #TODO: log_joint += jnp.sum(log_prob) ?---> gives a single number - log_joint = log_joint + jnp.sum(log_prob) - - return log_joint, model_trace - - -def log_density_prior_hmcecs(model, model_args, model_kwargs, params): - """ - (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given - latent values ``params``. - - :param model: Python callable containing NumPyro primitives. - :param tuple model_args: args provided to the model. - :param dict model_kwargs: kwargs provided to the model. - :param dict params: dictionary of current parameter values keyed by site - name. - :return: log of joint density and a corresponding model trace - """ - model = substitute(model, data=params) - model_trace = trace(model).get_trace(*model_args, **model_kwargs) - # model = substitute(model, data=params) - # with plate_to_enum_plate(): - # model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs) - log_joint = jnp.array(0.) - for site in model_trace.values(): - if site['type'] == 'sample' and not isinstance(site['fn'], dist.PRNGIdentity) and not site['is_observed']: - value = site['value'] - intermediates = site['intermediates'] - scale = site['scale'] - if intermediates: - log_prob = site['fn'].log_prob(value, intermediates) - else: - log_prob = site['fn'].log_prob(value) - - if (scale is not None) and (not is_identically_one(scale)): - log_prob = scale * log_prob - - log_prob = jnp.sum(log_prob) - log_joint = log_joint + log_prob - return log_joint, model_trace - - -def reducer(accum, d): - accum.update(d) - return accum - - -def tuplemerge(*dictionaries): - from functools import reduce - - merged = reduce(reducer, dictionaries, {}) - - return namedtuple('HMCCombinedState', merged)(**merged) # <==== Gist of the gist - - -def potential_est(model, model_args, model_kwargs, z, n, m, proxy_fn, proxy_u_fn): - """Computes the estimation of the likelihood of the potential - :param: proxy_U_fn : Function to calculate the covariates that correct the subsample likelihood""" - ll_sub, _ = log_density_obs_hmcecs(model, model_args, {}, z) # log likelihood for subsample with current theta - - diff = ll_sub - proxy_u_fn(z=z, model_args=model_args, model_kwargs=model_kwargs) - - l_hat = proxy_fn(z) + n / m * diff - - sigma = n ** 2 / m * jnp.var(diff) - - ll_prior, _ = log_density_prior_hmcecs(model, model_args, model_kwargs, z) - - return (-l_hat + .5 * sigma) - ll_prior - - -def velocity_verlet_hmcecs(potential_fn, kinetic_fn, grad_potential_fn=None): - r""" - Second order symplectic integrator that uses the velocity verlet algorithm - for position `z` and momentum `r`. - - :param potential_fn: Python callable that computes the potential energy - given input parameters. The input parameters to `potential_fn` can be - any python collection type. If HMCECS is used the gradient of the potential - energy funtion is calculated - :param kinetic_fn: Python callable that returns the kinetic energy given - inverse mass matrix and momentum. - :return: a pair of (`init_fn`, `update_fn`). - """ - - compute_value_grad = value_and_grad(potential_fn) if grad_potential_fn is None \ - else lambda z: (potential_fn(z), grad_potential_fn(z)) - - def init_fn(z, r, potential_energy=None, z_grad=None): - """ - :param z: Position of the particle. - :param r: Momentum of the particle. - :param potential_energy: Potential energy at `z`. - :param z_grad: gradient of potential energy at `z`. - :return: initial state for the integrator. - """ - if potential_energy is None or z_grad is None: - potential_energy, z_grad = compute_value_grad(z) - - return IntegratorState(z, r, potential_energy, z_grad) - - def update_fn(step_size, inverse_mass_matrix, state): - """ - :param float step_size: Size of a single step. - :param inverse_mass_matrix: Inverse of mass matrix, which is used to - calculate kinetic energy. - :param state: Current state of the integrator. - :return: new state for the integrator. - """ - z, r, _, z_grad = state - - r = tree_multimap(lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad) # r(n+1/2) - r_grad = grad(kinetic_fn, argnums=1)(inverse_mass_matrix, r) - z = tree_multimap(lambda z, r_grad: z + step_size * r_grad, z, r_grad) # z(n+1) - potential_energy, z_grad = compute_value_grad(z) - r = tree_multimap(lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad) # r(n+1) - # return IntegratorState(z, r, potential_energy, z_grad) - return IntegratorState(z, r, potential_energy, z_grad) - - return init_fn, update_fn - - def init_near_values(site=None, values={}): """Initialize the sampling to a noisy map estimate of the parameters""" from functools import partial @@ -194,101 +29,96 @@ def init_near_values(site=None, values={}): return init_to_uniform(site) -def taylor_proxy(z_ref, model, ll_ref, jac_all, hess_all): - """Corrects the subsample likelihood using covariates the taylor expansion - :param z_ref = reference estimate (e.g MAP) of the model's parameters - :param model = model likelihood - :param ll_ref = reference loglikelihood - :param jac_all= Jacobian vector of the entire dataset - :param hess_all = Hessian matrix of the entire dataset""" - - def proxy(z, *args, **kwargs): - z_flat, _ = ravel_pytree(z) - zref_flat, _ = ravel_pytree(z_ref) - z_diff = z_flat - zref_flat - return jnp.sum(ll_ref) + jac_all.T @ z_diff + .5 * z_diff.T @ hess_all @ z_diff - - def proxy_u(z, model_args, model_kwargs, *args, **kwargs): - z_flat, _ = ravel_pytree(z) - zref_flat, _ = ravel_pytree(z_ref) - z_diff = z_flat - zref_flat - - ld_fn = lambda args: jnp.sum(partial(log_density_obs_hmcecs, model, model_args, model_kwargs)(args)[0]) - - ll_sub, jac_sub = jax.value_and_grad(ld_fn)(z_ref) - k, = jac_all.shape - hess_sub, _ = ravel_pytree(jax.hessian(ld_fn)(z_ref)) - jac_sub, _ = ravel_pytree(jac_sub) - - return ll_sub + jac_sub @ z_diff + .5 * z_diff @ hess_sub.reshape((k, k)) @ z_diff.T - - return proxy, proxy_u - - -def svi_proxy(svi, model_args, model_kwargs): - def proxy(z, *args, **kwargs): - z_ref = svi.guide.expectation(z) - ll, _ = log_density_obs_hmcecs(svi.model, model_args, model_kwargs, z_ref) - return ll - - def proxy_u(z, model_args, model_kwargs, *args, **kwargs): - z_ref = svi.guide.expectation(z) - ll, _ = log_density_prior_hmcecs(svi.model, model_args, model_kwargs, z_ref) - return ll - - return proxy, proxy_u - - -def neural_proxy(): - return None - - -def split_list(lst, n): - """Pair up the split model arguments back.""" - for i in range(0, len(lst), n): - if i + n < len(lst): - yield tuple(map(lst.__getitem__, [i, i + n])) - else: - break - +# TODO: SVI PROXY (previous code) +# def svi_proxy(svi, model_args, model_kwargs): +# def proxy(z, *args, **kwargs): +# z_ref = svi.guide.expectation(z) +# ll, _ = log_density_obs_hmcecs(svi.model, model_args, model_kwargs, z_ref) +# return ll +# +# def proxy_u(z, model_args, model_kwargs, *args, **kwargs): +# z_ref = svi.guide.expectation(z) +# ll, _ = log_density_prior_hmcecs(svi.model, model_args, model_kwargs, z_ref) +# return ll +# +# return proxy, proxy_u + + +def _extract_params(distribution): + params, _ = distribution.tree_flatten() + return params + + +class estimator(Messenger): + def __init__(self, fn, estimators, plate_sizes): + self.estimators = estimators + self.plate_sizes = plate_sizes + super(estimator, self).__init__(fn) + + def process_message(self, msg): + if msg['type'] == 'sample' and msg['is_observed'] and msg['cond_indep_stack']: + log_prob = msg['fn'].log_prob + msg['scale'] = 1. + msg['fn'].log_prob = lambda *args, **kwargs: \ + self.estimators[msg['name']](*args, name=msg['name'], z=_extract_params(msg['fn']), log_prob=log_prob, + sizes=self.plate_sizes[msg['cond_indep_stack'][0].name], + **kwargs) # TODO: check multiple levels + + +def taylor_proxy(ref_trace, ll_ref, jac_all, hess_all): + def proxy(name, z): + z_ref = _extract_params(ref_trace[name]['fn']) + jac, hess = jac_all[name], hess_all[name] + log_like = jnp.array(0.) + for argnum in range(len(z_ref)): + z_diff = z[argnum] - z_ref[argnum] + j, h = jac[argnum], hess[argnum] + k, = j.shape + log_like += j.T @ z_diff + .5 * z_diff.T @ h.reshape(k, k) @ z_diff + return ll_ref[name].sum() + log_like + + def uproxy(name, value, z): + ref_dist = ref_trace[name]['fn'] + z_ref, aux_data = ref_dist.tree_flatten() + + log_prob = lambda *params: ref_dist.tree_unflatten(aux_data, params).log_prob(value).sum() + log_like = jnp.array(0.) + for argnum in range(len(z_ref)): + z_diff = z[argnum] - z_ref[argnum] + jac = jax.jacobian(log_prob, argnum)(*z_ref) + k, = jac.shape + hess = jax.hessian(log_prob, argnum)(*z_ref) + log_like += jac @ z_diff + .5 * z_diff @ hess.reshape(k, k) @ z_diff.T + + return log_prob(*z_ref).sum() + log_like + + return proxy, uproxy + + +class subsample_size(Messenger): + def __init__(self, fn, plate_sizes, rng_key=None): + super(subsample_size, self).__init__(fn) + self.plate_sizes = plate_sizes + self.rng_key = rng_key + + def process_message(self, msg): + if msg['type'] == 'plate' and msg['args'] and msg["args"][0] > msg["args"][1]: + if msg['name'] in self.plate_sizes: + msg['args'] = self.plate_sizes[msg['name']] + msg['value'] = _subsample_fn(*msg['args'], self.rng_key) if msg["args"][1] < msg["args"][ + 0] else jnp.arange(msg["args"][0]) + + +def difference_estimator_fn(value, name, z, sizes, log_prob, proxy_fn, uproxy_fn, *args, **kwargs, ): + n, m, g = sizes + ll_sub = log_prob(value).sum() + diff = ll_sub - uproxy_fn(name, value, z) + l_hat = proxy_fn(name, z) + n / m * diff + sigma = n ** 2 / m * jnp.var(diff) + return l_hat - .5 * sigma -def signed_estimator(model, model_args, model_kwargs, z, l, proxy_fn, proxy_u_fn): - """ - Estimate the gradient potential estimate - :param model: Likelihood function - :param model_args: Subsample of model arguments [l,m,n_feats] - :param model_kwargs: - :param z: Model parameters estimates - :param l: Lambda number of subsamples (u indexes) - :param proxy_fn: - :param proxy_u_fn: - :return: - neg_ll: Negative likelihood estimate of the potential - sign: Sign of the likelihood estimate over the subsamples, it will be used after all the samples are collected - """ - import itertools - xis = 0. - sign = 1. - d = 0 - a = d - l # For a fixed λ, V[LbB] is minimized at a = d − λ. Quiroz 2018c - model_args = [args_l for args_l in model_args if len(args_l[0]) != 0] # remove empty lambda blocks - for args_l in model_args: # Iterate over each of the lambda groups of model args - block_len = args_l[0].shape[0] - args_l = [jnp.split(arg, arg.shape[0]) for arg in args_l] # split the arrays of blocks - args_l = list(itertools.chain.from_iterable(args_l)) # Join list of lists - args_l = [arr.squeeze(axis=0) for arr in args_l] - args_l = list(split_list(args_l, block_len)) - for args_l_b in args_l: - ll_sub, _ = log_density_obs_hmcecs(model, args_l_b, {}, z) # log likelihood for each u subsample - xi = (jnp.exp(ll_sub - proxy_u_fn(z=z, model_args=args_l_b, model_kwargs=model_kwargs)) - a) / l - sign *= jnp.prod(jnp.sign(xi)) - xis += jnp.sum(jnp.abs(xi)) - lhat = proxy_fn(z) + (a + l) / l + xis - prior_arg = tuple([arg.reshape(arg.shape[0] * arg.shape[1], -1) for arg in model_args[ - 0]]) # Join the block subsamples, does not matter because the prior does not look t them - ll_prior, _ = log_density_prior_hmcecs(model, prior_arg, model_kwargs, - z) # the ll of the prior does not depend on the model args, so we just take some pair - # Correct the negativeloglikelihood by substracting the density of the prior --> potentialEst = -loglikeEst - dprior(theta,pfamily,priorPar1,priorPar2) - neg_ll = - lhat - ll_prior - return neg_ll, sign +def _tangent_curve(dist, value, tangent_fn): + z, aux_data = dist.tree_flatten() + log_prob = lambda *params: dist.tree_unflatten(aux_data, params).log_prob(value).sum() + return tuple(tangent_fn(log_prob, argnum)(*z) for argnum in range(len(z))) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 718564395..fec1f8157 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -133,6 +133,7 @@ class trace(Messenger): 'type': 'sample', 'value': DeviceArray(-0.20584235, dtype=float32)})]) """ + def __enter__(self): super(trace, self).__enter__() self.trace = OrderedDict() @@ -143,7 +144,7 @@ def postprocess_message(self, msg): # skip recording helper messages e.g. `control_flow`, `to_data`, `to_funsor` # which has no name return - assert not(msg['type'] == 'sample' and msg['name'] in self.trace), \ + assert not (msg['type'] == 'sample' and msg['name'] in self.trace), \ 'all sites must have unique names but got `{}` duplicated'.format(msg['name']) self.trace[msg['name']] = msg.copy() @@ -188,6 +189,7 @@ class replay(Messenger): -0.20584235 >>> assert replayed_trace['a']['value'] == exec_trace['a']['value'] """ + def __init__(self, fn=None, guide_trace=None): assert guide_trace is not None self.guide_trace = guide_trace @@ -231,6 +233,7 @@ class block(Messenger): >>> assert 'a' not in trace_block_a >>> assert 'b' in trace_block_a """ + def __init__(self, fn=None, hide_fn=None, hide=None): if hide_fn is not None: self.hide_fn = hide_fn @@ -342,6 +345,7 @@ class condition(Messenger): >>> assert exec_trace['a']['value'] == -1 >>> assert exec_trace['a']['is_observed'] """ + def __init__(self, fn=None, data=None, condition_fn=None): self.condition_fn = condition_fn self.data = data @@ -443,6 +447,7 @@ class mask(Messenger): :param mask: a boolean or a boolean-valued array for masking elementwise log probability of sample sites (`True` includes a site, `False` excludes a site). """ + def __init__(self, fn=None, mask=True): if lax.dtype(mask) != 'bool': raise ValueError("`mask` should be a bool array.") @@ -479,6 +484,7 @@ class reparam(Messenger): :class:`~numpyro.infer.reparam.Reparam` or None. :type config: dict or callable """ + def __init__(self, fn=None, config=None): assert isinstance(config, dict) or callable(config) self.config = config @@ -521,6 +527,7 @@ class scale(Messenger): :param float scale: a positive scaling factor """ + def __init__(self, fn=None, scale=1.): if not_jax_tracer(scale): if scale <= 0: @@ -557,6 +564,7 @@ class scope(Messenger): :param fn: Python callable with NumPyro primitives. :param str prefix: a string to prepend to sample names """ + def __init__(self, fn=None, prefix=''): self.prefix = prefix super().__init__(fn) @@ -607,6 +615,7 @@ class seed(Messenger): >>> y = handlers.seed(model, rng_seed=1)() >>> assert x == y """ + def __init__(self, fn=None, rng_seed=None): if isinstance(rng_seed, int) or (isinstance(rng_seed, jnp.ndarray) and not jnp.shape(rng_seed)): rng_seed = random.PRNGKey(rng_seed) @@ -617,7 +626,7 @@ def __init__(self, fn=None, rng_seed=None): def process_message(self, msg): if (msg['type'] == 'sample' and not msg['is_observed'] and - msg['kwargs']['rng_key'] is None) or msg['type'] in ['prng_key', 'plate', 'control_flow']: + msg['kwargs']['rng_key'] is None) or msg['type'] in ['prng_key', 'plate', 'control_flow']: # no need to create a new key when value is available if msg['value'] is not None: return @@ -660,6 +669,7 @@ class substitute(Messenger): >>> exec_trace = trace(substitute(model, {'a': -1})).get_trace() >>> assert exec_trace['a']['value'] == -1 """ + def __init__(self, fn=None, data=None, substitute_fn=None): self.substitute_fn = substitute_fn self.data = data @@ -729,6 +739,7 @@ class do(Messenger): >>> assert not exec_trace['z'].get('stop', None) >>> assert z_square == 1 """ + def __init__(self, fn=None, data=None): self.data = data self._intervener_id = str(id(self)) diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 742f75bfc..3fe5d73e7 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -1,15 +1,14 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import warnings from collections import namedtuple from functools import partial -import warnings +import jax.numpy as jnp import numpy as np - from jax import device_get, lax, random, value_and_grad from jax.flatten_util import ravel_pytree -import jax.numpy as jnp import numpyro from numpyro.distributions.constraints import _GreaterThan, _Interval, real, real_vector @@ -458,7 +457,7 @@ def initialize_model(rng_key, model, for w in ws: # at site information to the warning message w.message.args = ("Site {}: {}".format(site["name"], w.message.args[0]),) \ - + w.message.args[1:] + + w.message.args[1:] warnings.showwarning(w.message, w.category, w.filename, w.lineno, file=w.file, line=w.line) raise RuntimeError("Cannot find valid initial parameters. Please check your model again.") @@ -467,7 +466,6 @@ def initialize_model(rng_key, model, def _predictive(rng_key, model, posterior_samples, batch_shape, return_sites=None, parallel=True, model_args=(), model_kwargs={}): - def single_prediction(val): rng_key, samples = val model_trace = trace(seed(substitute(model, samples), rng_key)).get_trace( From dafaa6efae532b39205b5ce1072179c319799797 Mon Sep 17 00:00:00 2001 From: Ola Date: Fri, 8 Jan 2021 14:53:19 +0100 Subject: [PATCH 39/93] renamed hmcecs_utils to ecs_utils and added todos. --- numpyro/contrib/ecs.py | 12 ++++++++++-- numpyro/contrib/{hmcecs_utils.py => ecs_utils.py} | 0 2 files changed, 10 insertions(+), 2 deletions(-) rename numpyro/contrib/{hmcecs_utils.py => ecs_utils.py} (100%) diff --git a/numpyro/contrib/ecs.py b/numpyro/contrib/ecs.py index 0393094e0..e68b6ffa6 100644 --- a/numpyro/contrib/ecs.py +++ b/numpyro/contrib/ecs.py @@ -3,11 +3,11 @@ from collections import namedtuple import jax.numpy as jnp -from jax import device_put, lax, random, partial, jit, jacobian, hessian, make_jaxpr +from jax import device_put, lax, random, partial, jit, jacobian, hessian import numpyro import numpyro.distributions as dist -from numpyro.contrib.hmcecs_utils import ( +from numpyro.contrib.ecs_utils import ( init_near_values, difference_estimator_fn, taylor_proxy, @@ -68,6 +68,11 @@ def _update_block(rng_key, u, n, m, g): class ECS(MCMCKernel): + """ Energy conserving subsampling as first described in [1]. + + ** Reference: ** + 1. *Hamiltonian Monte Carlo with Energy ConservingSubsampling* by Dang, Khue-Dang et al. + """ sample_field = "uz" def __init__(self, inner_kernel, estimator_fn=None, proxy_gen_fn=None, z_ref=None): @@ -97,9 +102,12 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): u = {name: site["value"] for name, site in prototype_trace.items() if site["type"] == "plate" and site["args"][0] > site["args"][1]} + # TODO: estimate good block size self._plate_sizes = {name: prototype_trace[name]["args"] + (min(prototype_trace[name]["args"][1] // 2, 100),) for name in u} + # Precompute Jaccobian and Hessian for Taylor Proxy + # TODO: check proxy type and branch plate_sizes_all = {name: (prototype_trace[name]["args"][0], prototype_trace[name]["args"][0]) for name in u} with subsample_size(model, plate_sizes_all): ref_trace = trace(substitute(model, data=z_ref)).get_trace(*model_args, **model_kwargs) diff --git a/numpyro/contrib/hmcecs_utils.py b/numpyro/contrib/ecs_utils.py similarity index 100% rename from numpyro/contrib/hmcecs_utils.py rename to numpyro/contrib/ecs_utils.py From 243e7bc0016503ddf3540347dd74910aad6b8ebc Mon Sep 17 00:00:00 2001 From: Ola Date: Tue, 12 Jan 2021 09:55:55 +0100 Subject: [PATCH 40/93] debugging taylor expansion. --- examples/hmcecs/higgs_new.py | 48 ++++++++++++++++++++++++++++++ numpyro/contrib/ecs.py | 57 ++++++++---------------------------- 2 files changed, 61 insertions(+), 44 deletions(-) create mode 100644 examples/hmcecs/higgs_new.py diff --git a/examples/hmcecs/higgs_new.py b/examples/hmcecs/higgs_new.py new file mode 100644 index 000000000..d42504f5a --- /dev/null +++ b/examples/hmcecs/higgs_new.py @@ -0,0 +1,48 @@ +import jax.numpy as jnp +from jax import random +from sklearn.datasets import load_breast_cancer + +import numpyro +import numpyro.distributions as dist +from numpyro.contrib.ecs import ECS +from numpyro.contrib.ecs_utils import difference_estimator_fn, taylor_proxy +from numpyro.infer import MCMC, NUTS + + +def breast_cancer_data(): + dataset = load_breast_cancer() + feats = dataset.data + feats = (feats - feats.mean(0)) / feats.std(0) + feats = jnp.hstack((feats, jnp.ones((feats.shape[0], 1)))) + return feats, dataset.target + + +def log_reg_model(features, obs): + n, m = features.shape + theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), .5 * jnp.ones(m))) + with numpyro.plate('N', n, subsample_size=75) as idx: + batch_feats = numpyro.subsample(features, event_dim=1) + batch_obs = numpyro.subsample(obs, event_dim=1) + numpyro.sample('obs', dist.Bernoulli(logits=jnp.matmul(batch_feats, theta)), obs=batch_obs) + + +def plain_log_reg_model(features, obs): + n, m = features.shape + theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), 2 * jnp.ones(m))) + numpyro.sample('obs', dist.Bernoulli(logits=jnp.matmul(features, theta)), obs=obs) + + +if __name__ == '__main__': + data, obs = breast_cancer_data() + + # Get reference parameters + kernel = NUTS(plain_log_reg_model) + mcmc = MCMC(kernel, 500, 500) + mcmc.run(random.PRNGKey(1), data, obs) + z_ref = {k: v.mean(0) for k, v in mcmc.get_samples().items()} + + # Compute HMCECS + kernel = ECS(NUTS(log_reg_model), estimator_fn=difference_estimator_fn, proxy_gen_fn=taylor_proxy, z_ref=z_ref) + mcmc = MCMC(kernel, 500, 500) + mcmc.run(random.PRNGKey(0), data, obs, extra_fields=("accept_prob",)) + mcmc.print_summary(exclude_deterministic=False) diff --git a/numpyro/contrib/ecs.py b/numpyro/contrib/ecs.py index e68b6ffa6..5d21722d0 100644 --- a/numpyro/contrib/ecs.py +++ b/numpyro/contrib/ecs.py @@ -5,25 +5,22 @@ import jax.numpy as jnp from jax import device_put, lax, random, partial, jit, jacobian, hessian -import numpyro -import numpyro.distributions as dist from numpyro.contrib.ecs_utils import ( init_near_values, - difference_estimator_fn, - taylor_proxy, estimator, subsample_size, _tangent_curve ) from numpyro.handlers import substitute, trace, seed -from numpyro.infer import MCMC, NUTS, log_likelihood +from numpyro.infer import log_likelihood from numpyro.infer.mcmc import MCMCKernel from numpyro.util import identity HMC_ECS_State = namedtuple("HMC_ECS_State", "uz, hmc_state, accept_prob, rng_key") """ - **uz** - a dict of current subsample indices and the current latent values - - **hmc_state** - current hmc_state + - **hmc_state** - current hmc_stat log_like += j.T @ z_diff + .5 * z_diff.T @ h.reshape(k, k) @ z_diff +e - **accept_prob** - acceptance probability of the proposal subsample indices - **rng_key** - random key to generate new subsample indices """ @@ -109,29 +106,28 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): # Precompute Jaccobian and Hessian for Taylor Proxy # TODO: check proxy type and branch plate_sizes_all = {name: (prototype_trace[name]["args"][0], prototype_trace[name]["args"][0]) for name in u} - with subsample_size(model, plate_sizes_all): - ref_trace = trace(substitute(model, data=z_ref)).get_trace(*model_args, **model_kwargs) + with subsample_size(self.model, plate_sizes_all): + ref_trace = trace(substitute(self.model, data=self._z_ref)).get_trace(*model_args, **model_kwargs) jac_all = {name: _tangent_curve(site['fn'], site['value'], jacobian) for name, site in ref_trace.items() - if - (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} + if (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} hess_all = {name: _tangent_curve(site['fn'], site['value'], hessian) for name, site in ref_trace.items() - if - (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} + if (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} ll_ref = {name: site['fn'].log_prob(site['value']) for name, site in ref_trace.items() if (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} - ref_trace = trace(substitute(model, data={**z_ref, **u})).get_trace(*model_args, - **model_kwargs) # TODO: check reparam + ref_trace = trace(substitute(self.model, data={**self._z_ref, **u})).get_trace(*model_args, + **model_kwargs) # TODO: check reparam proxy_fn, uproxy_fn = self._proxy_gen_fn(ref_trace, ll_ref, jac_all, hess_all) + print(jac_all) estimators = {name: partial(self._estimator_fn, proxy_fn=proxy_fn, uproxy_fn=uproxy_fn) for name, site in prototype_trace.items() if (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} - self.inner_kernel._model = _wrap_est_model(model, estimators, self._plate_sizes) + self.inner_kernel._model = _wrap_est_model(self.model, estimators, self._plate_sizes) init_params = {name: init_near_values(site, self._z_ref) for name, site in prototype_trace.items()} model_kwargs["_subsample_sites"] = u - hmc_state = self.inner_kernel.init(key_z, num_warmup, init_params, - model_args, model_kwargs) + + hmc_state = self.inner_kernel.init(key_z, num_warmup, init_params, model_args, model_kwargs) uz = {**u, **hmc_state.z} return device_put(HMC_ECS_State(uz, hmc_state, 1., rng_key)) @@ -156,30 +152,3 @@ def sample(self, state, model_args, model_kwargs): hmc_state = self.inner_kernel.sample(state.hmc_state, model_args, model_kwargs) uz = {**u, **hmc_state.z} return HMC_ECS_State(uz, hmc_state, accept_prob, rng_key) - - -def model(data, *args, **kwargs): - x = numpyro.sample("x", dist.Normal(0., 1.)) - with numpyro.plate("N", data.shape[0], subsample_size=1000): - batch = numpyro.subsample(data, event_dim=0) - numpyro.sample("obs", dist.Normal(x, 1.), obs=batch) - - -def plain_model(data, *args, **kwargs): - x = numpyro.sample("x", dist.Normal(0., 1.)) - numpyro.sample("obs", dist.Normal(x, 1.), obs=data) - - -if __name__ == '__main__': - data = random.normal(random.PRNGKey(1), (10_000,)) + 1 - # Get reference parameters - kernel = NUTS(plain_model) - mcmc = MCMC(kernel, 500, 500) - mcmc.run(random.PRNGKey(1), data) - mcmc.print_summary(exclude_deterministic=False) - z_ref = {k: v.mean() for k, v in mcmc.get_samples().items()} - # Compute HMCECS - kernel = ECS(NUTS(model), estimator_fn=difference_estimator_fn, proxy_gen_fn=taylor_proxy, z_ref=z_ref) - mcmc = MCMC(kernel, 1500, 1500) - mcmc.run(random.PRNGKey(0), data, extra_fields=("accept_prob",)) - mcmc.print_summary(exclude_deterministic=False) From c4252bb087e67b7a89cedfb60c19ddc4f771b725 Mon Sep 17 00:00:00 2001 From: ola Date: Wed, 13 Jan 2021 15:01:26 +0100 Subject: [PATCH 41/93] Updated comments with reference and added test for num_blocks={} (there was a bug). --- numpyro/infer/hmc.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/numpyro/infer/hmc.py b/numpyro/infer/hmc.py index 965509e16..a89f535d9 100644 --- a/numpyro/infer/hmc.py +++ b/numpyro/infer/hmc.py @@ -31,6 +31,8 @@ - **potential_energy** - Potential energy computed at the given value of ``z``. - **energy** - Sum of potential energy and kinetic energy of the current state. - **num_steps** - Number of steps in the Hamiltonian trajectory (for diagnostics). + In NUTS sampler, the tree depth of a trajectory can be computed from this field + with `tree_depth = np.log2(num_steps).astype(int) + 1`. - **accept_prob** - Acceptance probability of the proposal. Note that ``z`` does not correspond to the proposal if it is rejected. - **mean_accept_prob** - Mean acceptance probability until current iteration @@ -121,12 +123,10 @@ def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='NUTS'): >>> true_coefs = jnp.array([1., 2., 3.]) >>> data = random.normal(random.PRNGKey(2), (2000, 3)) - >>> dim = 3 >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3)) >>> >>> def model(data, labels): - ... coefs_mean = jnp.zeros(dim) - ... coefs = numpyro.sample('beta', dist.Normal(coefs_mean, jnp.ones(3))) + ... coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(3), jnp.ones(3))) ... intercept = numpyro.sample('intercept', dist.Normal(0., 10.)) ... return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels) >>> @@ -137,7 +137,7 @@ def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='NUTS'): ... num_warmup=300) >>> samples = fori_collect(0, 500, sample_kernel, hmc_state, ... transform=lambda state: model_info.postprocess_fn(state.z)) - >>> print(jnp.mean(samples['beta'], axis=0)) # doctest: +SKIP + >>> print(jnp.mean(samples['coefs'], axis=0)) # doctest: +SKIP [0.9153987 2.0754058 2.9621222] """ if kinetic_fn is None: @@ -399,6 +399,7 @@ def __init__(self, self._find_heuristic_step_size = find_heuristic_step_size # Set on first call to init self._init_fn = None + self._potential_fn_gen = None self._postprocess_fn = None self._sample_fn = None @@ -415,6 +416,7 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params): self._init_fn, self._sample_fn = hmc(potential_fn_gen=potential_fn, kinetic_fn=self._kinetic_fn, algo=self._algo) + self._potential_fn_gen = potential_fn self._postprocess_fn = postprocess_fn elif self._init_fn is None: self._init_fn, self._sample_fn = hmc(potential_fn=self._potential_fn, From 20b7350546be2f6b18cb680801a2c7ca56c2142a Mon Sep 17 00:00:00 2001 From: ola Date: Wed, 13 Jan 2021 15:50:39 +0100 Subject: [PATCH 42/93] Added pystan --- examples/hmcecs/higgs_new.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/examples/hmcecs/higgs_new.py b/examples/hmcecs/higgs_new.py index d42504f5a..8f8adfc30 100644 --- a/examples/hmcecs/higgs_new.py +++ b/examples/hmcecs/higgs_new.py @@ -1,4 +1,5 @@ import jax.numpy as jnp +import pystan from jax import random from sklearn.datasets import load_breast_cancer @@ -32,10 +33,7 @@ def plain_log_reg_model(features, obs): numpyro.sample('obs', dist.Bernoulli(logits=jnp.matmul(features, theta)), obs=obs) -if __name__ == '__main__': - data, obs = breast_cancer_data() - - # Get reference parameters +def hmcecs_model(data, obs): kernel = NUTS(plain_log_reg_model) mcmc = MCMC(kernel, 500, 500) mcmc.run(random.PRNGKey(1), data, obs) @@ -46,3 +44,27 @@ def plain_log_reg_model(features, obs): mcmc = MCMC(kernel, 500, 500) mcmc.run(random.PRNGKey(0), data, obs, extra_fields=("accept_prob",)) mcmc.print_summary(exclude_deterministic=False) + + +# Stan + +def stan_model(): + model_code = """ + data { + int D; + int N; + matrix[N, D] x; + int y[N]; + } + parameters { + vector[D] beta; + } + model { + y ~ bernoulli_logit(x * beta); + } + """ + return pystan.StanModel(model_code=model_code) + + +if __name__ == '__main__': + data, obs = breast_cancer_data() From 4d7e4ed4012033e8b7a306cebc82b91dd71ab0d6 Mon Sep 17 00:00:00 2001 From: Ola Date: Fri, 15 Jan 2021 21:52:15 +0100 Subject: [PATCH 43/93] Added components for variational proxy. --- numpyro/contrib/ecs_utils.py | 48 ++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/numpyro/contrib/ecs_utils.py b/numpyro/contrib/ecs_utils.py index 15b5ba0ad..1f2fba0a9 100644 --- a/numpyro/contrib/ecs_utils.py +++ b/numpyro/contrib/ecs_utils.py @@ -122,3 +122,51 @@ def _tangent_curve(dist, value, tangent_fn): z, aux_data = dist.tree_flatten() log_prob = lambda *params: dist.tree_unflatten(aux_data, params).log_prob(value).sum() return tuple(tangent_fn(log_prob, argnum)(*z) for argnum in range(len(z))) + + +import numpyro +import numpyro.distributions as dist +from numpyro.distributions import constraints + + +def model(data, obs): + theta = numpyro.sample("x", dist.Normal(0., 1.)) + with numpyro.plate("N", data.shape[0], subsample_size=5) as idx: + numpyro.sample("obs", dist.Bernoulli(logits=data[idx] * theta), obs=obs[idx]) + + +def guide(data, obs): + mu = numpyro.param('mu', 0., constraints=constraints.real) + numpyro.sample("x", dist.Normal(mu, .5)) + + +if __name__ == '__main__': + from numpyro.handlers import substitute, block + from numpyro.infer.util import _predictive, log_likelihood, log_density + from numpyro.contrib.ecs_utils import subsample_size + from numpyro.infer import SVI, Trace_ELBO + from jax import random + + data = random.normal(random.PRNGKey(0), (10,)) + obs = jnp.concatenate([jnp.ones(6), jnp.zeros(4)]) + optimizer = numpyro.optim.Adam(step_size=0.5) + svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) + + svi_result = svi.run(random.PRNGKey(1), 5000, data, obs) + guide = substitute(svi.guide, svi_result.params) + posterior_samples = _predictive(random.PRNGKey(2), guide, {}, + (10,), return_sites='', parallel=True, + model_args=(data, obs), model_kwargs={}) + model = subsample_size(model, {'N': (10, 10)}) + ll = log_likelihood(model, posterior_samples, data, obs) + # likelihood = {name: value.mean(1) for name, value in ll.items()} + weights = {name: jnp.mean((ll['obs'].T / ll['obs'].sum(1).T).T, 0) for name, value in ll.items()} + prior, _ = log_density(block(model, hide_fn=lambda site: site['type'] == 'sample' and site['is_observed']), + (data, obs), {}, + {n: v.mean() for n, v in posterior_samples.items()}) + variational, _ = log_density(guide, (data, obs), {}, {n: v.mean() for n, v in posterior_samples.items()}) + print(ll['obs'].mean(1).sum()) + print(variational, prior) + scale = variational - prior - ll['obs'].mean(1).sum() + print(variational - prior) + print(scale) From f867af18c3605747915e2ca65579501b201d6507 Mon Sep 17 00:00:00 2001 From: Ola Date: Fri, 15 Jan 2021 23:06:22 +0100 Subject: [PATCH 44/93] Added variational_proxy, todo: fix estimator. --- numpyro/contrib/ecs_utils.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/numpyro/contrib/ecs_utils.py b/numpyro/contrib/ecs_utils.py index 1f2fba0a9..c7dae5b26 100644 --- a/numpyro/contrib/ecs_utils.py +++ b/numpyro/contrib/ecs_utils.py @@ -1,13 +1,8 @@ -from collections import namedtuple - import jax import jax.numpy as jnp from numpyro.primitives import Messenger, _subsample_fn -IntegratorState = namedtuple('IntegratorState', ['z', 'r', 'potential_energy', 'z_grad']) -IntegratorState.__new__.__defaults__ = (None,) * len(IntegratorState._fields) - def init_near_values(site=None, values={}): """Initialize the sampling to a noisy map estimate of the parameters""" @@ -29,19 +24,18 @@ def init_near_values(site=None, values={}): return init_to_uniform(site) -# TODO: SVI PROXY (previous code) -# def svi_proxy(svi, model_args, model_kwargs): -# def proxy(z, *args, **kwargs): -# z_ref = svi.guide.expectation(z) -# ll, _ = log_density_obs_hmcecs(svi.model, model_args, model_kwargs, z_ref) -# return ll -# -# def proxy_u(z, model_args, model_kwargs, *args, **kwargs): -# z_ref = svi.guide.expectation(z) -# ll, _ = log_density_prior_hmcecs(svi.model, model_args, model_kwargs, z_ref) -# return ll -# -# return proxy, proxy_u +def variational_proxy(svi, S, weights, model_args, model_kwargs, ): + # TODO: fuse computation for S + log_posterior_prob(z) - log_prior_prob(z)? + log_posterior_prob = lambda params: log_density(svi.guide, model_args, model_kwargs, params) + log_prior_prob = lambda params: log_density(model) + + def proxy(z): + return S + log_posterior_prob(z) - log_prior_prob(z) + + def uproxy(z, subsample): + return S + weights[subsample].sum() + log_posterior_prob(z) - log_prior_prob(z) + + return proxy, uproxy def _extract_params(distribution): From 89c8ffe93663a83c0455b8a1ee65f2b8125614cf Mon Sep 17 00:00:00 2001 From: Ola Date: Sat, 16 Jan 2021 16:49:31 +0100 Subject: [PATCH 45/93] Integrated variational proxy into ecs. --- numpyro/contrib/ecs.py | 59 ++++++++++++++++++++++----------- numpyro/contrib/ecs_utils.py | 63 +++++------------------------------- 2 files changed, 49 insertions(+), 73 deletions(-) diff --git a/numpyro/contrib/ecs.py b/numpyro/contrib/ecs.py index 5d21722d0..380d039c0 100644 --- a/numpyro/contrib/ecs.py +++ b/numpyro/contrib/ecs.py @@ -11,9 +11,11 @@ subsample_size, _tangent_curve ) -from numpyro.handlers import substitute, trace, seed +from numpyro.contrib.ecs_utils import taylor_proxy, variational_proxy, difference_estimator_fn +from numpyro.handlers import substitute, trace, seed, block from numpyro.infer import log_likelihood from numpyro.infer.mcmc import MCMCKernel +from numpyro.infer.util import _predictive, log_density from numpyro.util import identity HMC_ECS_State = namedtuple("HMC_ECS_State", "uz, hmc_state, accept_prob, rng_key") @@ -72,13 +74,14 @@ class ECS(MCMCKernel): """ sample_field = "uz" - def __init__(self, inner_kernel, estimator_fn=None, proxy_gen_fn=None, z_ref=None): + def __init__(self, inner_kernel, proxy, ref=None, guide=None): self.inner_kernel = copy.copy(inner_kernel) self.inner_kernel._model = inner_kernel.model - self._proxy_gen_fn = proxy_gen_fn - self._estimator_fn = estimator_fn - self._z_ref = z_ref + self._guide = guide + self._proxy = proxy + self._ref = ref self._plate_sizes = None + self._estimator = difference_estimator_fn @property def model(self): @@ -106,20 +109,40 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): # Precompute Jaccobian and Hessian for Taylor Proxy # TODO: check proxy type and branch plate_sizes_all = {name: (prototype_trace[name]["args"][0], prototype_trace[name]["args"][0]) for name in u} - with subsample_size(self.model, plate_sizes_all): - ref_trace = trace(substitute(self.model, data=self._z_ref)).get_trace(*model_args, **model_kwargs) - jac_all = {name: _tangent_curve(site['fn'], site['value'], jacobian) for name, site in ref_trace.items() - if (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} - hess_all = {name: _tangent_curve(site['fn'], site['value'], hessian) for name, site in ref_trace.items() - if (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} - ll_ref = {name: site['fn'].log_prob(site['value']) for name, site in ref_trace.items() if - (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} - - ref_trace = trace(substitute(self.model, data={**self._z_ref, **u})).get_trace(*model_args, - **model_kwargs) # TODO: check reparam - proxy_fn, uproxy_fn = self._proxy_gen_fn(ref_trace, ll_ref, jac_all, hess_all) + if self._proxy == 'taylor': + with subsample_size(self.model, plate_sizes_all): + ref_trace = trace(substitute(self.model, data=self._z_ref)).get_trace(*model_args, **model_kwargs) + jac_all = {name: _tangent_curve(site['fn'], site['value'], jacobian) for name, site in ref_trace.items() + if (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} + hess_all = {name: _tangent_curve(site['fn'], site['value'], hessian) for name, site in ref_trace.items() + if (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} + ll_ref = {name: site['fn'].log_prob(site['value']) for name, site in ref_trace.items() if + (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} + + ref_trace = trace(substitute(self.model, data={**self._z_ref, **u})).get_trace(*model_args, + **model_kwargs) # TODO: check reparam + proxy_fn, uproxy_fn = taylor_proxy(ref_trace, ll_ref, jac_all, hess_all) + elif self._proxy == 'variational': + num_samples = 10 # TODO: heuristic for this + guide = substitute(self._guide, self._ref) + posterior_samples = _predictive(random.PRNGKey(2), guide, {}, + (num_samples,), return_sites='', parallel=True, + model_args=model_args, model_kwargs=model_kwargs) + with subsample_size(self.model, plate_sizes_all): + model = subsample_size(self.model, plate_sizes_all) + ll = log_likelihood(model, posterior_samples, *model_args, **model_kwargs) + # TODO: fix multiple likehoods + weights = {name: jnp.mean((value.T / value.sum(1).T).T, 0) for name, value in + ll.items()} # TODO: fix broadcast + prior, _ = log_density(block(model, hide_fn=lambda site: site['type'] == 'sample' and site['is_observed']), + model_args, model_kwargs, posterior_samples) + variational, _ = log_density(guide, model_args, model_kwargs, posterior_samples) + evidence = {name: variational / num_samples - prior / num_samples - ll.mean(1).sum() for name, ll in + ll.items()} + + proxy_fn, uproxy_fn = variational_proxy(self.model, self._guide, evidence, weights, model_args, + model_kwargs) - print(jac_all) estimators = {name: partial(self._estimator_fn, proxy_fn=proxy_fn, uproxy_fn=uproxy_fn) for name, site in prototype_trace.items() if (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} diff --git a/numpyro/contrib/ecs_utils.py b/numpyro/contrib/ecs_utils.py index c7dae5b26..eb719d75a 100644 --- a/numpyro/contrib/ecs_utils.py +++ b/numpyro/contrib/ecs_utils.py @@ -1,6 +1,7 @@ import jax import jax.numpy as jnp +from numpyro.infer.util import log_density from numpyro.primitives import Messenger, _subsample_fn @@ -24,16 +25,16 @@ def init_near_values(site=None, values={}): return init_to_uniform(site) -def variational_proxy(svi, S, weights, model_args, model_kwargs, ): +def variational_proxy(model, guide, evidence, weights, model_args, model_kwargs, ): # TODO: fuse computation for S + log_posterior_prob(z) - log_prior_prob(z)? - log_posterior_prob = lambda params: log_density(svi.guide, model_args, model_kwargs, params) - log_prior_prob = lambda params: log_density(model) + log_posterior_prob = lambda params: log_density(guide, model_args, model_kwargs, params) + log_prior_prob = lambda params: log_density(model, model_args, model_kwargs, params) - def proxy(z): - return S + log_posterior_prob(z) - log_prior_prob(z) + def proxy(name, z): + return evidence[name] + log_posterior_prob(z) - log_prior_prob(z) - def uproxy(z, subsample): - return S + weights[subsample].sum() + log_posterior_prob(z) - log_prior_prob(z) + def uproxy(name, z, subsample): + return evidence[name] + weights[subsample].sum() + log_posterior_prob(z) - log_prior_prob(z) return proxy, uproxy @@ -116,51 +117,3 @@ def _tangent_curve(dist, value, tangent_fn): z, aux_data = dist.tree_flatten() log_prob = lambda *params: dist.tree_unflatten(aux_data, params).log_prob(value).sum() return tuple(tangent_fn(log_prob, argnum)(*z) for argnum in range(len(z))) - - -import numpyro -import numpyro.distributions as dist -from numpyro.distributions import constraints - - -def model(data, obs): - theta = numpyro.sample("x", dist.Normal(0., 1.)) - with numpyro.plate("N", data.shape[0], subsample_size=5) as idx: - numpyro.sample("obs", dist.Bernoulli(logits=data[idx] * theta), obs=obs[idx]) - - -def guide(data, obs): - mu = numpyro.param('mu', 0., constraints=constraints.real) - numpyro.sample("x", dist.Normal(mu, .5)) - - -if __name__ == '__main__': - from numpyro.handlers import substitute, block - from numpyro.infer.util import _predictive, log_likelihood, log_density - from numpyro.contrib.ecs_utils import subsample_size - from numpyro.infer import SVI, Trace_ELBO - from jax import random - - data = random.normal(random.PRNGKey(0), (10,)) - obs = jnp.concatenate([jnp.ones(6), jnp.zeros(4)]) - optimizer = numpyro.optim.Adam(step_size=0.5) - svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) - - svi_result = svi.run(random.PRNGKey(1), 5000, data, obs) - guide = substitute(svi.guide, svi_result.params) - posterior_samples = _predictive(random.PRNGKey(2), guide, {}, - (10,), return_sites='', parallel=True, - model_args=(data, obs), model_kwargs={}) - model = subsample_size(model, {'N': (10, 10)}) - ll = log_likelihood(model, posterior_samples, data, obs) - # likelihood = {name: value.mean(1) for name, value in ll.items()} - weights = {name: jnp.mean((ll['obs'].T / ll['obs'].sum(1).T).T, 0) for name, value in ll.items()} - prior, _ = log_density(block(model, hide_fn=lambda site: site['type'] == 'sample' and site['is_observed']), - (data, obs), {}, - {n: v.mean() for n, v in posterior_samples.items()}) - variational, _ = log_density(guide, (data, obs), {}, {n: v.mean() for n, v in posterior_samples.items()}) - print(ll['obs'].mean(1).sum()) - print(variational, prior) - scale = variational - prior - ll['obs'].mean(1).sum() - print(variational - prior) - print(scale) From 1403dfeba1ed413f0ba3e4df0c11a62c67b22df0 Mon Sep 17 00:00:00 2001 From: ola Date: Mon, 18 Jan 2021 11:22:47 +0100 Subject: [PATCH 46/93] checkpoint: before redoing estimator. --- numpyro/contrib/ecs.py | 20 +++++++++++------ numpyro/contrib/ecs_utils.py | 19 +++++++++------- numpyro/contrib/trace_struct.py | 39 +++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 15 deletions(-) create mode 100644 numpyro/contrib/trace_struct.py diff --git a/numpyro/contrib/ecs.py b/numpyro/contrib/ecs.py index 380d039c0..b219caac2 100644 --- a/numpyro/contrib/ecs.py +++ b/numpyro/contrib/ecs.py @@ -74,11 +74,13 @@ class ECS(MCMCKernel): """ sample_field = "uz" - def __init__(self, inner_kernel, proxy, ref=None, guide=None): + def __init__(self, inner_kernel, proxy, model_struct, ref=None, guide=None): + assert proxy in ('taylor', 'variational') self.inner_kernel = copy.copy(inner_kernel) self.inner_kernel._model = inner_kernel.model self._guide = guide self._proxy = proxy + self._model_struct = model_struct self._ref = ref self._plate_sizes = None self._estimator = difference_estimator_fn @@ -107,7 +109,6 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): for name in u} # Precompute Jaccobian and Hessian for Taylor Proxy - # TODO: check proxy type and branch plate_sizes_all = {name: (prototype_trace[name]["args"][0], prototype_trace[name]["args"][0]) for name in u} if self._proxy == 'taylor': with subsample_size(self.model, plate_sizes_all): @@ -123,9 +124,10 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): **model_kwargs) # TODO: check reparam proxy_fn, uproxy_fn = taylor_proxy(ref_trace, ll_ref, jac_all, hess_all) elif self._proxy == 'variational': + pos_key, guide_key, rng_key = random.split(rng_key, 3) num_samples = 10 # TODO: heuristic for this guide = substitute(self._guide, self._ref) - posterior_samples = _predictive(random.PRNGKey(2), guide, {}, + posterior_samples = _predictive(random.pos_key, guide, {}, (num_samples,), return_sites='', parallel=True, model_args=model_args, model_kwargs=model_kwargs) with subsample_size(self.model, plate_sizes_all): @@ -140,14 +142,18 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): evidence = {name: variational / num_samples - prior / num_samples - ll.mean(1).sum() for name, ll in ll.items()} - proxy_fn, uproxy_fn = variational_proxy(self.model, self._guide, evidence, weights, model_args, - model_kwargs) + guide_trace = trace(seed(self._guide, guide_key)).get_trace(model_args, model_kwargs) - estimators = {name: partial(self._estimator_fn, proxy_fn=proxy_fn, uproxy_fn=uproxy_fn) + proxy_fn, uproxy_fn = variational_proxy(guide_trace, evidence, weights, self._model_struct) + else: + # TODO: alternatives + raise NotImplementedError + + estimators = {name: partial(self._estimator, proxy_fn=proxy_fn, uproxy_fn=uproxy_fn) for name, site in prototype_trace.items() if (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} self.inner_kernel._model = _wrap_est_model(self.model, estimators, self._plate_sizes) - init_params = {name: init_near_values(site, self._z_ref) for name, site in prototype_trace.items()} + init_params = {name: init_near_values(site, self._ref) for name, site in prototype_trace.items()} model_kwargs["_subsample_sites"] = u hmc_state = self.inner_kernel.init(key_z, num_warmup, init_params, model_args, model_kwargs) diff --git a/numpyro/contrib/ecs_utils.py b/numpyro/contrib/ecs_utils.py index eb719d75a..cc2aa5431 100644 --- a/numpyro/contrib/ecs_utils.py +++ b/numpyro/contrib/ecs_utils.py @@ -1,7 +1,6 @@ import jax import jax.numpy as jnp -from numpyro.infer.util import log_density from numpyro.primitives import Messenger, _subsample_fn @@ -25,16 +24,20 @@ def init_near_values(site=None, values={}): return init_to_uniform(site) -def variational_proxy(model, guide, evidence, weights, model_args, model_kwargs, ): - # TODO: fuse computation for S + log_posterior_prob(z) - log_prior_prob(z)? - log_posterior_prob = lambda params: log_density(guide, model_args, model_kwargs, params) - log_prior_prob = lambda params: log_density(model, model_args, model_kwargs, params) - +def variational_proxy(model_trace, guide_trace, evidence, weights, model_struct): def proxy(name, z): - return evidence[name] + log_posterior_prob(z) - log_prior_prob(z) + successors = model_struct.soccessor[name] + log_prob = jnp.array(0.) + for succ in successors: + log_prob += guide_trace[succ]['fn'].log_prob(z) - model_trace[succ]['fn'].log_prob(z) + return evidence[name] + log_prob def uproxy(name, z, subsample): - return evidence[name] + weights[subsample].sum() + log_posterior_prob(z) - log_prior_prob(z) + successors = model_struct.soccessor[name] + log_prob = jnp.array(0.) + for succ in successors: + log_prob += guide_trace[succ]['fn'].log_prob(z) - model_trace[succ]['fn'].log_prob(z) + return evidence[name] + weights[subsample].sum() * log_prob return proxy, uproxy diff --git a/numpyro/contrib/trace_struct.py b/numpyro/contrib/trace_struct.py new file mode 100644 index 000000000..34737992b --- /dev/null +++ b/numpyro/contrib/trace_struct.py @@ -0,0 +1,39 @@ +from collections import OrderedDict + + +class TraceStructure: + """ + Graph structure denoting the relationship among pyro primitives in the execution path. + """ + + def __init__(self): + self.nodes = OrderedDict() + self._successors = OrderedDict() + self._predecessors = OrderedDict() + + def __contains__(self, site): + return site in self.nodes + + def add_edge(self, from_site, to_site): + for site in (from_site, to_site): + if site not in self: + self.add_node(site) + + self._successors[from_site].add(to_site) + self._predecessors[to_site].add(to_site) + + def add_node(self, site_name, **kwargs): + if site_name in self: + # TODO: handle reused name! + pass + self.nodes[site_name] = kwargs + self._successors[site_name] = set() + self.__predecessors[site_name] = set() + + def predecessor(self, site): + return self._predecessors[site] + + def successor(self, site): + return self._successors[site] + + # TODO: remove edge From 1c6af82af720cb35d8cc40e06acb0e23eb2b4677 Mon Sep 17 00:00:00 2001 From: ola Date: Mon, 18 Jan 2021 16:48:58 +0100 Subject: [PATCH 47/93] Variational proxy running! --- numpyro/contrib/ecs.py | 34 +++++---- numpyro/contrib/ecs_utils.py | 110 +++++++++++++++++---------- numpyro/infer/mcmc.py | 139 +++++++++++++++++++++++++---------- 3 files changed, 191 insertions(+), 92 deletions(-) diff --git a/numpyro/contrib/ecs.py b/numpyro/contrib/ecs.py index b219caac2..cd0794594 100644 --- a/numpyro/contrib/ecs.py +++ b/numpyro/contrib/ecs.py @@ -11,7 +11,7 @@ subsample_size, _tangent_curve ) -from numpyro.contrib.ecs_utils import taylor_proxy, variational_proxy, difference_estimator_fn +from numpyro.contrib.ecs_utils import taylor_proxy, variational_proxy, DifferenceEstimator from numpyro.handlers import substitute, trace, seed, block from numpyro.infer import log_likelihood from numpyro.infer.mcmc import MCMCKernel @@ -34,11 +34,11 @@ """ -def _wrap_est_model(model, estimators, plate_sizes): +def _wrap_est_model(model, estimators, predecessors): def fn(*args, **kwargs): subsample_values = kwargs.pop("_subsample_sites", {}) with substitute(data=subsample_values): - with estimator(model, estimators, plate_sizes): + with estimator(model, estimators, predecessors): model(*args, **kwargs) return fn @@ -83,7 +83,7 @@ def __init__(self, inner_kernel, proxy, model_struct, ref=None, guide=None): self._model_struct = model_struct self._ref = ref self._plate_sizes = None - self._estimator = difference_estimator_fn + self._estimator = DifferenceEstimator @property def model(self): @@ -108,9 +108,9 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): self._plate_sizes = {name: prototype_trace[name]["args"] + (min(prototype_trace[name]["args"][1] // 2, 100),) for name in u} - # Precompute Jaccobian and Hessian for Taylor Proxy plate_sizes_all = {name: (prototype_trace[name]["args"][0], prototype_trace[name]["args"][0]) for name in u} if self._proxy == 'taylor': + # Precompute Jaccobian and Hessian for Taylor Proxy with subsample_size(self.model, plate_sizes_all): ref_trace = trace(substitute(self.model, data=self._z_ref)).get_trace(*model_args, **model_kwargs) jac_all = {name: _tangent_curve(site['fn'], site['value'], jacobian) for name, site in ref_trace.items() @@ -120,19 +120,19 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): ll_ref = {name: site['fn'].log_prob(site['value']) for name, site in ref_trace.items() if (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} - ref_trace = trace(substitute(self.model, data={**self._z_ref, **u})).get_trace(*model_args, - **model_kwargs) # TODO: check reparam + ref_trace = trace(substitute(self.model, data={**self._z_ref, **u})).get_trace(*model_args, **model_kwargs) proxy_fn, uproxy_fn = taylor_proxy(ref_trace, ll_ref, jac_all, hess_all) elif self._proxy == 'variational': pos_key, guide_key, rng_key = random.split(rng_key, 3) num_samples = 10 # TODO: heuristic for this guide = substitute(self._guide, self._ref) - posterior_samples = _predictive(random.pos_key, guide, {}, + posterior_samples = _predictive(pos_key, guide, {}, (num_samples,), return_sites='', parallel=True, model_args=model_args, model_kwargs=model_kwargs) with subsample_size(self.model, plate_sizes_all): model = subsample_size(self.model, plate_sizes_all) ll = log_likelihood(model, posterior_samples, *model_args, **model_kwargs) + # TODO: fix multiple likehoods weights = {name: jnp.mean((value.T / value.sum(1).T).T, 0) for name, value in ll.items()} # TODO: fix broadcast @@ -140,22 +140,26 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): model_args, model_kwargs, posterior_samples) variational, _ = log_density(guide, model_args, model_kwargs, posterior_samples) evidence = {name: variational / num_samples - prior / num_samples - ll.mean(1).sum() for name, ll in - ll.items()} + ll.items()} # TODO: must depend on structure! guide_trace = trace(seed(self._guide, guide_key)).get_trace(model_args, model_kwargs) - - proxy_fn, uproxy_fn = variational_proxy(guide_trace, evidence, weights, self._model_struct) + proxy_fn, uproxy_fn = variational_proxy(guide_trace, evidence, weights) else: - # TODO: alternatives raise NotImplementedError - estimators = {name: partial(self._estimator, proxy_fn=proxy_fn, uproxy_fn=uproxy_fn) + estimators = {name: self._estimator(name=name, + proxy=proxy_fn, uproxy=uproxy_fn, + plate_name=site['cond_indep_stack'][0].name, + plate_size=self._plate_sizes[site['cond_indep_stack'][0].name]) for name, site in prototype_trace.items() if (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} - self.inner_kernel._model = _wrap_est_model(self.model, estimators, self._plate_sizes) + + predecessors = {name: self._model_struct[name] for name in estimators} + + self.inner_kernel._model = _wrap_est_model(self.model, estimators, predecessors) + init_params = {name: init_near_values(site, self._ref) for name, site in prototype_trace.items()} model_kwargs["_subsample_sites"] = u - hmc_state = self.inner_kernel.init(key_z, num_warmup, init_params, model_args, model_kwargs) uz = {**u, **hmc_state.z} return device_put(HMC_ECS_State(uz, hmc_state, 1., rng_key)) diff --git a/numpyro/contrib/ecs_utils.py b/numpyro/contrib/ecs_utils.py index cc2aa5431..c418b6636 100644 --- a/numpyro/contrib/ecs_utils.py +++ b/numpyro/contrib/ecs_utils.py @@ -1,9 +1,17 @@ +from collections import OrderedDict, defaultdict + import jax import jax.numpy as jnp from numpyro.primitives import Messenger, _subsample_fn +def _tangent_curve(dist, value, tangent_fn): + z, aux_data = dist.tree_flatten() + log_prob = lambda *params: dist.tree_unflatten(aux_data, params).log_prob(value).sum() + return tuple(tangent_fn(log_prob, argnum)(*z) for argnum in range(len(z))) + + def init_near_values(site=None, values={}): """Initialize the sampling to a noisy map estimate of the parameters""" from functools import partial @@ -24,43 +32,36 @@ def init_near_values(site=None, values={}): return init_to_uniform(site) -def variational_proxy(model_trace, guide_trace, evidence, weights, model_struct): - def proxy(name, z): - successors = model_struct.soccessor[name] - log_prob = jnp.array(0.) - for succ in successors: - log_prob += guide_trace[succ]['fn'].log_prob(z) - model_trace[succ]['fn'].log_prob(z) - return evidence[name] + log_prob - - def uproxy(name, z, subsample): - successors = model_struct.soccessor[name] - log_prob = jnp.array(0.) - for succ in successors: - log_prob += guide_trace[succ]['fn'].log_prob(z) - model_trace[succ]['fn'].log_prob(z) - return evidence[name] + weights[subsample].sum() * log_prob - - return proxy, uproxy - - def _extract_params(distribution): params, _ = distribution.tree_flatten() return params class estimator(Messenger): - def __init__(self, fn, estimators, plate_sizes): + def __init__(self, fn, estimators, predecessors): self.estimators = estimators - self.plate_sizes = plate_sizes + self.predecessors = predecessors + self.predecessor_sites = defaultdict(OrderedDict) + self._successors = None + super(estimator, self).__init__(fn) - def process_message(self, msg): - if msg['type'] == 'sample' and msg['is_observed'] and msg['cond_indep_stack']: - log_prob = msg['fn'].log_prob - msg['scale'] = 1. - msg['fn'].log_prob = lambda *args, **kwargs: \ - self.estimators[msg['name']](*args, name=msg['name'], z=_extract_params(msg['fn']), log_prob=log_prob, - sizes=self.plate_sizes[msg['cond_indep_stack'][0].name], - **kwargs) # TODO: check multiple levels + @property + def successors(self): + if getattr(self, '_successors') is None: + successors = {} + for site_name, preds in self.predecessors.items(): + successors.update({pred_name: site_name for pred_name in preds}) # TODO: handle shared priors + self._successors = successors + return self._successors + + def postprocess_message(self, msg): + name = msg['name'] + if name in self.successors: + self.predecessor_sites[self.successors[name]][name] = msg.copy() + + if msg['type'] == 'sample' and msg['is_observed'] and msg['cond_indep_stack']: # TODO: is subsampled + msg['fn'] = self.estimators[name](msg['fn'], self.predecessor_sites[name]) def taylor_proxy(ref_trace, ll_ref, jac_all, hess_all): @@ -107,16 +108,47 @@ def process_message(self, msg): 0] else jnp.arange(msg["args"][0]) -def difference_estimator_fn(value, name, z, sizes, log_prob, proxy_fn, uproxy_fn, *args, **kwargs, ): - n, m, g = sizes - ll_sub = log_prob(value).sum() - diff = ll_sub - uproxy_fn(name, value, z) - l_hat = proxy_fn(name, z) + n / m * diff - sigma = n ** 2 / m * jnp.var(diff) - return l_hat - .5 * sigma +class DifferenceEstimator: + def __init__(self, name, proxy, uproxy, plate_name, plate_size): + self._name = name + self.plate_name = plate_name + self.size = plate_size + self.proxy = proxy + self.uproxy = uproxy + self.subsample = None + self._dist = None + self._predecessors = None + + def __call__(self, dist, predecessors): + self.dist = dist + self.predecessors = predecessors + + def log_prob(self, value): + n, m, g = self.size + ll_sub = self.dist.log_prob(value).sum() + diff = ll_sub - self.uproxy(name=self._name, + value=value, + subsample=self.predecessors[self.plate_name], + predecessors=self.predecessors) + l_hat = self.proxy(self._name) + n / m * diff + sigma = n ** 2 / m * jnp.var(diff) + return l_hat - .5 * sigma + + +def variational_proxy(guide_trace, evidence, weights): + def _log_like(predecessors): + log_prob = jnp.array(0.) + for pred in predecessors: + if pred['type'] == 'sample': + val = pred['value'] + name = pred['name'] + log_prob += guide_trace[name]['fn'].log_prob(val) - pred['fn'].log_prob(val) + return log_prob + def proxy(name, predecessors, *args, **kwargs): + return evidence[name] + _log_like(predecessors) -def _tangent_curve(dist, value, tangent_fn): - z, aux_data = dist.tree_flatten() - log_prob = lambda *params: dist.tree_unflatten(aux_data, params).log_prob(value).sum() - return tuple(tangent_fn(log_prob, argnum)(*z) for argnum in range(len(z))) + def uproxy(name, predecessors, subsample, *args, **kwargs): + return evidence[name] + weights[name][subsample].sum() * _log_like(predecessors) + + return proxy, uproxy diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 8fd343144..47fa55a28 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -39,10 +39,10 @@ class MCMCKernel(ABC): >>> import numpyro.distributions as dist >>> from numpyro.infer import MCMC - >>> MHState = namedtuple("MHState", ["z", "rng_key"]) + >>> MHState = namedtuple("MHState", ["u", "rng_key"]) >>> class MetropolisHastings(numpyro.infer.mcmc.MCMCKernel): - ... sample_field = "z" + ... sample_field = "u" ... ... def __init__(self, potential_fn, step_size=0.1): ... self.potential_fn = potential_fn @@ -52,12 +52,12 @@ class MCMCKernel(ABC): ... return MHState(init_params, rng_key) ... ... def sample(self, state, model_args, model_kwargs): - ... z, rng_key = state + ... u, rng_key = state ... rng_key, key_proposal, key_accept = random.split(rng_key, 3) - ... z_proposal = dist.Normal(z, self.step_size).sample(key_proposal) - ... accept_prob = jnp.exp(self.potential_fn(z) - self.potential_fn(z_proposal)) - ... z_new = jnp.where(dist.Uniform().sample(key_accept) < accept_prob, z_proposal, z) - ... return MHState(z_new, rng_key) + ... u_proposal = dist.Normal(u, self.step_size).sample(key_proposal) + ... accept_prob = jnp.exp(self.potential_fn(u) - self.potential_fn(u_proposal)) + ... u_new = jnp.where(dist.Uniform().sample(key_accept) < accept_prob, u_proposal, u) + ... return MHState(u_new, rng_key) >>> def f(x): ... return ((x - 2) ** 2).sum() @@ -66,6 +66,7 @@ class MCMCKernel(ABC): >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000) >>> mcmc.run(random.PRNGKey(0), init_params=jnp.array([1., 2.])) >>> samples = mcmc.get_samples() + >>> mcmc.print_summary() # doctest: +SKIP """ def postprocess_fn(self, model_args, model_kwargs): """ @@ -207,6 +208,9 @@ class MCMC(object): and :class:`~numpyro.infer.mcmc.NUTS` are available. :param int num_warmup: Number of warmup steps. :param int num_samples: Number of samples to generate from the Markov chain. + :param int thinning: Positive integer that controls the fraction of post-warmup samples that are + retained. For example if thinning is 2 then every other sample is retained. + Defaults to 1, i.e. no thinning. :param int num_chains: Number of Number of MCMC chains to run. By default, chains will be run in parallel using :func:`jax.pmap`, failing which, chains will be run in sequence. @@ -230,6 +234,7 @@ def __init__(self, num_warmup, num_samples, num_chains=1, + thinning=1, postprocess_fn=None, chain_method='parallel', progress_bar=True, @@ -240,6 +245,9 @@ def __init__(self, self.num_warmup = num_warmup self.num_samples = num_samples self.num_chains = num_chains + if not isinstance(thinning, int) or thinning < 1: + raise ValueError('thinning must be a positive integer') + self.thinning = thinning self.postprocess_fn = postprocess_fn if chain_method not in ['parallel', 'vectorized', 'sequential']: raise ValueError('Only supporting the following methods to draw chains:' @@ -263,7 +271,7 @@ def __init__(self, self._collection_params = {} self._set_collection_params() - def _get_cached_fn(self): + def _get_cached_fns(self): if self._jit_model_args: args, kwargs = (None,), (None,) else: @@ -271,20 +279,36 @@ def _get_cached_fn(self): kwargs = tree_map(lambda x: _hashable(x), tuple(sorted(self._kwargs.items()))) key = args + kwargs try: - fn = self._cache.get(key, None) + fns = self._cache.get(key, None) # If unhashable arguments are provided, proceed normally # without caching except TypeError: - fn, key = None, None - if fn is None: + fns, key = None, None + if fns is None: + + def laxmap_postprocess_fn(states, args, kwargs): + if self.postprocess_fn is None: + body_fn = self.sampler.postprocess_fn(args, kwargs) + else: + body_fn = self.postprocess_fn + if self.chain_method == "vectorized" and self.num_chains > 1: + body_fn = vmap(body_fn) + + return lax.map(body_fn, states) + if self._jit_model_args: - fn = partial(_sample_fn_jit_args, sampler=self.sampler) + sample_fn = partial(_sample_fn_jit_args, sampler=self.sampler) + postprocess_fn = jit(laxmap_postprocess_fn) else: - fn = partial(_sample_fn_nojit_args, sampler=self.sampler, - args=self._args, kwargs=self._kwargs) + sample_fn = partial(_sample_fn_nojit_args, sampler=self.sampler, + args=self._args, kwargs=self._kwargs) + postprocess_fn = jit(partial(laxmap_postprocess_fn, + args=self._args, kwargs=self._kwargs)) + + fns = sample_fn, postprocess_fn if key is not None: - self._cache[key] = fn - return fn + self._cache[key] = fns + return fns def _get_cached_init_state(self, rng_key, args, kwargs): rng_key = (_hashable(rng_key),) @@ -302,24 +326,23 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields): if init_state is None: init_state = self.sampler.init(rng_key, self.num_warmup, init_params, model_args=args, model_kwargs=kwargs) - if self.postprocess_fn is None: - postprocess_fn = self.sampler.postprocess_fn(args, kwargs) - else: - postprocess_fn = self.postprocess_fn + sample_fn, postprocess_fn = self._get_cached_fns() diagnostics = lambda x: self.sampler.get_diagnostics_str(x[0]) if rng_key.ndim == 1 else '' # noqa: E731 init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,) lower_idx = self._collection_params["lower"] upper_idx = self._collection_params["upper"] phase = self._collection_params["phase"] - + collection_size = self._collection_params["collection_size"] + collection_size = collection_size if collection_size is None else collection_size // self.thinning collect_vals = fori_collect(lower_idx, upper_idx, - self._get_cached_fn(), + sample_fn, init_val, transform=_collect_fn(collect_fields), progbar=self.progress_bar, return_last_val=True, - collection_size=self._collection_params["collection_size"], + thinning=self.thinning, + collection_size=collection_size, progbar_desc=partial(_get_progbar_desc_str, lower_idx, phase), diagnostics_fn=diagnostics) states, last_val = collect_vals @@ -328,21 +351,16 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields): if len(collect_fields) == 1: states = (states,) states = dict(zip(collect_fields, states)) - #print(states) # Apply constraints if number of samples is non-zero - #print(self._sample_field) site_values = tree_flatten(states[self._sample_field])[0] - #print(site_values) # XXX: lax.map still works if some arrays have 0 size # so we only need to filter out the case site_value.shape[0] == 0 # (which happens when lower_idx==upper_idx) - print(self._sample_field) - #print(states[self._sample_field]) if len(site_values) > 0 and jnp.shape(site_values[0])[0] > 0: - if self.chain_method == "vectorized" and self.num_chains > 1: - postprocess_fn = vmap(postprocess_fn) - print(states[self._sample_field]) - states[self._sample_field] = lax.map(postprocess_fn, states[self._sample_field]) + if self._jit_model_args: + states[self._sample_field] = postprocess_fn(states[self._sample_field], args, kwargs) + else: + states[self._sample_field] = postprocess_fn(states[self._sample_field]) return states, last_state def _set_collection_params(self, lower=None, upper=None, collection_size=None, phase=None): @@ -364,11 +382,42 @@ def _compile(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): except TypeError: pass + @property + def post_warmup_state(self): + """ + The state before the sampling phase. If this attribute is not None, + :meth:`run` will skip the warmup phase and start with the state + specified in this attribute. + + .. note:: This attribute can be used to sequentially draw MCMC samples. For example, + + .. code-block:: python + + mcmc = MCMC(NUTS(model), 100, 100) + mcmc.run(random.PRNGKey(0)) + first_100_samples = mcmc.get_samples() + mcmc.post_warmup_state = mcmc.last_state + mcmc.run(mcmc.post_warmup_state.rng_key) # or mcmc.run(random.PRNGKey(1)) + second_100_samples = mcmc.get_samples() + """ + return self._warmup_state + + @post_warmup_state.setter + def post_warmup_state(self, state): + self._warmup_state = state + + @property + def last_state(self): + """ + The final MCMC state at the end of the sampling phase. + """ + return self._last_state + def warmup(self, rng_key, *args, extra_fields=(), collect_warmup=False, init_params=None, **kwargs): """ - Run the MCMC warmup adaptation phase. After this call, the :meth:`run` method - will skip the warmup adaptation phase. To run `warmup` again for the new data, - it is required to run :meth:`warmup` again. + Run the MCMC warmup adaptation phase. After this call, `self.warmup_state` will be set + and the :meth:`run` method will skip the warmup adaptation phase. To run `warmup` again + for the new data, it is required to run :meth:`warmup` again. :param random.PRNGKey rng_key: Random number generator key to be used for the sampling. :param args: Arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init` method. @@ -497,12 +546,26 @@ def get_extra_fields(self, group_by_chain=False): return {k: v for k, v in states.items() if k != self._sample_field} def print_summary(self, prob=0.9, exclude_deterministic=True): + """ + Print the statistics of posterior samples collected during running this MCMC instance. + + :param float prob: the probability mass of samples within the credible interval. + :param bool exclude_deterministic: whether or not print out the statistics + at deterministic sites. + """ # Exclude deterministic sites by default sites = self._states[self._sample_field] if isinstance(sites, dict) and exclude_deterministic: - sites = {k: v for k, v in self._states[self._sample_field].items() - if k in self._last_state.z} + state_sample_field = attrgetter(self._sample_field)(self._last_state) + # XXX: there might be the case that state.z is not a dictionary but + # its postprocessed value `sites` is a dictionary. + # TODO: in general, when both `sites` and `state.z` are dictionaries, + # they can have different key names, not necessary due to deterministic + # behavior. We might revise this logic if needed in the future. + if isinstance(state_sample_field, dict): + sites = {k: v for k, v in self._states[self._sample_field].items() + if k in state_sample_field} print_summary(sites, prob=prob) extra_fields = self.get_extra_fields() if 'diverging' in extra_fields: - print("Number of divergences: {}".format(jnp.sum(extra_fields['diverging']))) \ No newline at end of file + print("Number of divergences: {}".format(jnp.sum(extra_fields['diverging']))) From 2a8cc23ef16eed359ea0ba77aa17fc11177387ca Mon Sep 17 00:00:00 2001 From: Ola Date: Mon, 18 Jan 2021 22:02:48 +0100 Subject: [PATCH 48/93] Fixed minor bugs and example of hmcecs with variational proxy on logistic regression. --- examples/hmcecs/higgs_new.py | 64 +++++++++++++++++------------------- examples/logistic_hmcecs.py | 2 +- numpyro/contrib/ecs.py | 2 +- numpyro/contrib/ecs_utils.py | 2 ++ 4 files changed, 34 insertions(+), 36 deletions(-) diff --git a/examples/hmcecs/higgs_new.py b/examples/hmcecs/higgs_new.py index 8f8adfc30..188348eef 100644 --- a/examples/hmcecs/higgs_new.py +++ b/examples/hmcecs/higgs_new.py @@ -1,13 +1,12 @@ import jax.numpy as jnp -import pystan from jax import random from sklearn.datasets import load_breast_cancer import numpyro import numpyro.distributions as dist from numpyro.contrib.ecs import ECS -from numpyro.contrib.ecs_utils import difference_estimator_fn, taylor_proxy -from numpyro.infer import MCMC, NUTS +from numpyro.distributions import constraints +from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO def breast_cancer_data(): @@ -21,50 +20,47 @@ def breast_cancer_data(): def log_reg_model(features, obs): n, m = features.shape theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), .5 * jnp.ones(m))) - with numpyro.plate('N', n, subsample_size=75) as idx: + with numpyro.plate('N', n, subsample_size=75): batch_feats = numpyro.subsample(features, event_dim=1) - batch_obs = numpyro.subsample(obs, event_dim=1) - numpyro.sample('obs', dist.Bernoulli(logits=jnp.matmul(batch_feats, theta)), obs=batch_obs) + batch_obs = numpyro.subsample(obs, event_dim=0) + numpyro.sample('obs', dist.Bernoulli(logits=theta @ batch_feats.T), obs=batch_obs) -def plain_log_reg_model(features, obs): - n, m = features.shape - theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), 2 * jnp.ones(m))) - numpyro.sample('obs', dist.Bernoulli(logits=jnp.matmul(features, theta)), obs=obs) +def log_reg_guide(feature, obs): + _, m = feature.shape + mean = numpyro.param('mean', jnp.zeros(m), constraints=constraints.real) + var = numpyro.param('var', jnp.ones(m), constraints=constraints.positive) + numpyro.sample('theta', dist.continuous.Normal(mean, var)) def hmcecs_model(data, obs): - kernel = NUTS(plain_log_reg_model) - mcmc = MCMC(kernel, 500, 500) - mcmc.run(random.PRNGKey(1), data, obs) - z_ref = {k: v.mean(0) for k, v in mcmc.get_samples().items()} + optimizer = numpyro.optim.Adam(step_size=0.005) + svi = SVI(log_reg_model, log_reg_guide, optimizer, loss=Trace_ELBO()) + svi_result = svi.run(random.PRNGKey(1), 1000, data, obs) # Compute HMCECS - kernel = ECS(NUTS(log_reg_model), estimator_fn=difference_estimator_fn, proxy_gen_fn=taylor_proxy, z_ref=z_ref) - mcmc = MCMC(kernel, 500, 500) + kernel = ECS(NUTS(log_reg_model), + proxy='variational', + model_struct={'obs': ['theta']}, + ref=svi_result.params, + guide=svi.guide) + mcmc = MCMC(kernel, 1500, 8500) mcmc.run(random.PRNGKey(0), data, obs, extra_fields=("accept_prob",)) mcmc.print_summary(exclude_deterministic=False) +def plain_log_reg_model(features, obs): + n, m = features.shape + theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), .5 * jnp.ones(m))) + numpyro.sample('obs', dist.Bernoulli(logits=theta @ features.T), obs=obs) -# Stan - -def stan_model(): - model_code = """ - data { - int D; - int N; - matrix[N, D] x; - int y[N]; - } - parameters { - vector[D] beta; - } - model { - y ~ bernoulli_logit(x * beta); - } - """ - return pystan.StanModel(model_code=model_code) +def hmc(data, obs): + kernel = NUTS(log_reg_model) + mcmc = MCMC(kernel, 1500, 8500) + mcmc.run(random.PRNGKey(0), data, obs, extra_fields=("accept_prob",)) + mcmc.print_summary(exclude_deterministic=False) if __name__ == '__main__': data, obs = breast_cancer_data() + # hmcecs_model(data, obs) + hmc(data, obs) \ No newline at end of file diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index 190483371..1ffd81fa7 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -133,7 +133,7 @@ def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None, map_key, post_key = jax.random.split(map_key) z_ref, svi, svi_state = svi_map(model, map_key, feats=feats[:factor_SVI], obs=obs[:factor_SVI], num_epochs=num_epochs, batch_size=batch_size) - z_ref = svi.guide.sample_posterior(post_key, svi.get_params(svi_state), (100,)) + z_ref = svi.log_reg_guide.sample_posterior(post_key, svi.get_params(svi_state), (100,)) z_ref = {name: value.mean(0) for name, value in z_ref.items()} #highlight: AutoDiagonalNormal does not have auto_ in front of the parmeters save_obj(z_ref,"{}/MAP_Dict_Samples_Proxy_{}.pkl".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")), diff --git a/numpyro/contrib/ecs.py b/numpyro/contrib/ecs.py index cd0794594..b1422b500 100644 --- a/numpyro/contrib/ecs.py +++ b/numpyro/contrib/ecs.py @@ -142,7 +142,7 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): evidence = {name: variational / num_samples - prior / num_samples - ll.mean(1).sum() for name, ll in ll.items()} # TODO: must depend on structure! - guide_trace = trace(seed(self._guide, guide_key)).get_trace(model_args, model_kwargs) + guide_trace = trace(seed(self._guide, guide_key)).get_trace(*model_args, **model_kwargs) proxy_fn, uproxy_fn = variational_proxy(guide_trace, evidence, weights) else: raise NotImplementedError diff --git a/numpyro/contrib/ecs_utils.py b/numpyro/contrib/ecs_utils.py index c418b6636..1b853bcf4 100644 --- a/numpyro/contrib/ecs_utils.py +++ b/numpyro/contrib/ecs_utils.py @@ -56,6 +56,8 @@ def successors(self): return self._successors def postprocess_message(self, msg): + if 'name' not in msg: + return name = msg['name'] if name in self.successors: self.predecessor_sites[self.successors[name]][name] = msg.copy() From 7c41cee8ac9eeeabdcc0be88090bdd5c90236758 Mon Sep 17 00:00:00 2001 From: ola Date: Tue, 26 Jan 2021 13:33:21 +0100 Subject: [PATCH 49/93] merging --- examples/covtype.py | 2 +- examples/hmcecs/higgs.py | 142 ------------------------- examples/hmcecs/higgs_new.py | 66 ------------ examples/hmcecs/logistic_regression.py | 135 +++++++++++++++++++++++ examples/hmcecs/mnist_bnn.py | 3 +- numpyro/examples/datasets.py | 2 +- 6 files changed, 138 insertions(+), 212 deletions(-) delete mode 100644 examples/hmcecs/higgs.py delete mode 100644 examples/hmcecs/higgs_new.py create mode 100644 examples/hmcecs/logistic_regression.py diff --git a/examples/covtype.py b/examples/covtype.py index fdac66a04..a15966bd4 100644 --- a/examples/covtype.py +++ b/examples/covtype.py @@ -64,7 +64,7 @@ def main(args): parser.add_argument('--num-steps', default=10, type=int, help='number of steps (for "HMC")') parser.add_argument('--num-chains', nargs='?', default=1, type=int) parser.add_argument('--algo', default='NUTS', type=str, help='whether to run "HMC" or "NUTS"') - parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".') + parser.add_argument('--device', default='gpu', type=str, help='use "cpu" or "gpu".') args = parser.parse_args() numpyro.set_platform(args.device) diff --git a/examples/hmcecs/higgs.py b/examples/hmcecs/higgs.py deleted file mode 100644 index 394248b8b..000000000 --- a/examples/hmcecs/higgs.py +++ /dev/null @@ -1,142 +0,0 @@ -""" Logistic regression model as implemetned in https://arxiv.org/pdf/1708.00955.pdf with Higgs Dataset """ -# !/usr/bin/env python -from collections import namedtuple - -import jax -import jax.numpy as jnp -import jax.numpy as np_jax -from jax.tree_util import tree_map -from sklearn.model_selection import train_test_split - -import numpyro -import numpyro.distributions as dist -from examples.logistic_hmcecs_svi import svi_map -from numpyro import optim -from numpyro.contrib.autoguide_hmcecs import AutoDiagonalNormal -from numpyro.contrib.hmcecs import HMCECS -from numpyro.diagnostics import summary -from numpyro.examples.datasets import _load_higgs -from numpyro.infer import NUTS, MCMC -from numpyro.infer.elbo import Trace_ELBO -from numpyro.infer.svi import SVI - -numpyro.set_platform("gpu") - -DataLoaderState = namedtuple("DataLoaderState", ('iteration', 'rng_key', 'indexes', 'max_iter')) - - -def dataloader(*xs, batch_size=32, train_size=None, test_size=None, shuffle=True): - assert len(xs) > 1 - splitxs = train_test_split(*xs, train_size=train_size, test_size=test_size) - trainxs, testxs = splitxs[0::2], splitxs[1::2] - max_train_iter, max_test_iter = len(trainxs[0]) // batch_size, len(testxs[0]) // batch_size - - def make_dataset(dxs, max_iter): - def init(rng_key): - return DataLoaderState(0, rng_key, jnp.arange(len(dxs[0])), max_iter) - - def next_step(state): - - iteration = state.iteration % state.max_iter - batch = tuple(x[state.indexes[iteration * batch_size:(iteration + 1) * batch_size]] - for x in dxs) - if iteration + 1 == state.max_iter: - shuffle_rng_key, rng_key = jax.random.split(state.rng_key) - if shuffle: - indexes = jax.random.shuffle(shuffle_rng_key, state.indexes) - else: - indexes = state.indexes - return batch, DataLoaderState(state.iteration + 1, rng_key, indexes, state.max_iter) - else: - return batch, DataLoaderState(state.iteration + 1, state.rng_key, state.indexes, state.max_iter) - - return init, next_step - - return make_dataset(trainxs, max_train_iter), make_dataset(testxs, max_test_iter), testxs - - -def svi_map(model, rng_key, feats, obs, num_epochs, batch_size): - guide = AutoDiagonalNormal(model) - svi = SVI(model, guide, optim.Adam(0.0003), loss=Trace_ELBO()) - svi_rng_key, data_rng_key = jax.random.split(rng_key) - (init_train, next_train), _, _ = dataloader(feats, obs, train_size=0.9, batch_size=batch_size) - batch_fn = jax.jit(svi.update) - svi_state = None - data_state = init_train(data_rng_key) - num_batches = 0 - for _ in range(num_epochs): - for j in range(data_state.max_iter): - xs, data_state = next_train(data_state) - if svi_state is None: - svi_state = svi.init(svi_rng_key, *xs) - svi_state, _ = batch_fn(svi_state, *xs) - num_batches += 1 - return svi, svi_state - - -def infer_nuts(rng_key, features, obs, samples, warmup): - kernel = NUTS(model=logistic_regression, target_accept_prob=0.8) - mcmc = MCMC(kernel, num_warmup=warmup, num_samples=samples) - mcmc.run(rng_key, features, obs) - samples = mcmc.get_samples() - samples = tree_map(lambda x: x[None, ...], samples) - r_hat_average = np_jax.sum(summary(samples)["theta"]["r_hat"]) / len(summary(samples)["theta"]["r_hat"]) - - return mcmc.get_samples(), r_hat_average - - -def infer_hmcecs(rng_key, obs, features, m=None, g=None, n_samples=None, warmup=None, algo="NUTS", - subsample_method=None, map_method=None, proxy="taylor", estimator=None, num_epochs=None): - hmcecs_key, map_key = jax.random.split(rng_key) - n, _ = features.shape - - svi = None - if map_method == "nuts": - samples, r_hat_average = infer_nuts(map_key, features, obs, samples=10, warmup=5) - z_ref = {key: value.mean(0) for key, value in samples.items()} - elif map_method == "svi": - map_key, post_key = jax.random.split(map_key) - svi, svi_state = svi_map(logistic_regression, - map_key, - feats=features, - obs=obs, - num_epochs=num_epochs, - batch_size=256) - z_ref = svi.guide.sample_posterior(post_key, svi.get_params(svi_state), (100,)) - z_ref = {name: value.mean(0) for name, value in z_ref.items()} - - kernel = HMCECS(model=logistic_regression, z_ref=z_ref, m=m, g=g, algo=algo.upper(), - subsample_method=subsample_method, proxy=proxy, svi_fn=svi, - estimator=estimator, target_accept_prob=0.8) - - mcmc = MCMC(kernel, num_warmup=warmup, num_samples=n_samples, num_chains=1) - mcmc.run(rng_key, features, obs) - - return mcmc.get_samples() - - -def logistic_regression(features, obs): - n, m = features.shape - theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), 2 * jnp.ones(m))) - numpyro.sample('obs', dist.Bernoulli(logits=jnp.matmul(features, theta)), obs=obs) - - -def higgs_data(): - return _load_higgs() - - -if __name__ == '__main__': - rng_key = jax.random.PRNGKey(37) - obs, feats = higgs_data() - num_examples = 1000 - - est_posterior_ECS = infer_hmcecs(rng_key, obs[:num_examples], feats[:num_examples], - n_samples=10, - warmup=5, - m=30, g=5, - algo='nuts', - subsample_method="perturb", - proxy='svi', - estimator='', - map_method='svi', - num_epochs=100) diff --git a/examples/hmcecs/higgs_new.py b/examples/hmcecs/higgs_new.py deleted file mode 100644 index 188348eef..000000000 --- a/examples/hmcecs/higgs_new.py +++ /dev/null @@ -1,66 +0,0 @@ -import jax.numpy as jnp -from jax import random -from sklearn.datasets import load_breast_cancer - -import numpyro -import numpyro.distributions as dist -from numpyro.contrib.ecs import ECS -from numpyro.distributions import constraints -from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO - - -def breast_cancer_data(): - dataset = load_breast_cancer() - feats = dataset.data - feats = (feats - feats.mean(0)) / feats.std(0) - feats = jnp.hstack((feats, jnp.ones((feats.shape[0], 1)))) - return feats, dataset.target - - -def log_reg_model(features, obs): - n, m = features.shape - theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), .5 * jnp.ones(m))) - with numpyro.plate('N', n, subsample_size=75): - batch_feats = numpyro.subsample(features, event_dim=1) - batch_obs = numpyro.subsample(obs, event_dim=0) - numpyro.sample('obs', dist.Bernoulli(logits=theta @ batch_feats.T), obs=batch_obs) - - -def log_reg_guide(feature, obs): - _, m = feature.shape - mean = numpyro.param('mean', jnp.zeros(m), constraints=constraints.real) - var = numpyro.param('var', jnp.ones(m), constraints=constraints.positive) - numpyro.sample('theta', dist.continuous.Normal(mean, var)) - - -def hmcecs_model(data, obs): - optimizer = numpyro.optim.Adam(step_size=0.005) - svi = SVI(log_reg_model, log_reg_guide, optimizer, loss=Trace_ELBO()) - svi_result = svi.run(random.PRNGKey(1), 1000, data, obs) - - # Compute HMCECS - kernel = ECS(NUTS(log_reg_model), - proxy='variational', - model_struct={'obs': ['theta']}, - ref=svi_result.params, - guide=svi.guide) - mcmc = MCMC(kernel, 1500, 8500) - mcmc.run(random.PRNGKey(0), data, obs, extra_fields=("accept_prob",)) - mcmc.print_summary(exclude_deterministic=False) - -def plain_log_reg_model(features, obs): - n, m = features.shape - theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), .5 * jnp.ones(m))) - numpyro.sample('obs', dist.Bernoulli(logits=theta @ features.T), obs=obs) - -def hmc(data, obs): - kernel = NUTS(log_reg_model) - mcmc = MCMC(kernel, 1500, 8500) - mcmc.run(random.PRNGKey(0), data, obs, extra_fields=("accept_prob",)) - mcmc.print_summary(exclude_deterministic=False) - - -if __name__ == '__main__': - data, obs = breast_cancer_data() - # hmcecs_model(data, obs) - hmc(data, obs) \ No newline at end of file diff --git a/examples/hmcecs/logistic_regression.py b/examples/hmcecs/logistic_regression.py new file mode 100644 index 000000000..a889d7d73 --- /dev/null +++ b/examples/hmcecs/logistic_regression.py @@ -0,0 +1,135 @@ +import os +import pathlib +import pickle +from datetime import datetime +from time import time + +import jax.numpy as jnp +import numpy as np +from jax import random, device_get +from pandas_plink import read_plink1_bin +from sklearn.datasets import load_breast_cancer + +import numpyro +import numpyro.distributions as dist +from numpyro.contrib.ecs import ECS +from numpyro.distributions import constraints +from numpyro.examples.datasets import _load_higgs +from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, init_to_sample + +os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" + +numpyro.set_platform("gpu") + + +def summary(dataset, name, mcmc, sample_time, svi_time=0.): + n_eff_mean = np.mean([numpyro.diagnostics.effective_sample_size(device_get(v)) + for v in mcmc.get_samples(True).values()]) + pickle.dump(mcmc.get_samples(True), open(f'{dataset}/{name}_posterior_samples.pkl', 'wb')) + step_field = 'num_steps' if name == 'hmc' else 'hmc_state.num_steps' + num_step = np.sum(mcmc.get_extra_fields()[step_field]) + accpt_prob = 1. + if name == 'ecs': + accpt_prob = np.mean(mcmc.get_extra_fields()['accept_prob']) + + with open(f'{dataset}/{name}_chain_stats.txt', 'w') as f: + print('sample_time', 'svi_time', 'n_eff_mean', 'gibbs_accpt_prob', 'tot_num_steps', 'time_per_step', + 'time_per_eff', + sep=',', file=f) + print(sample_time, svi_time, n_eff_mean, accpt_prob, num_step, sample_time / num_step, sample_time / n_eff_mean, + sep=',', file=f) + + +def higgs_data(): + obs, data = _load_higgs() + return data, obs + + +def breast_cancer_data(): + dataset = load_breast_cancer() + feats = dataset.data + feats = (feats - feats.mean(0)) / feats.std(0) + feats = jnp.hstack((feats, jnp.ones((feats.shape[0], 1)))) + return feats, dataset.target + + +def copsac_data(): + data_folder = pathlib.Path('data') + bim_file = str(data_folder / 'Sim_data_3.bim') + fam_file = str(data_folder / 'Sim_data_3.fam') + bed_file = str(data_folder / 'Sim_data_3.bed') + data = read_plink1_bin(bed_file, bim_file, fam_file) + + return jnp.array(data.values), jnp.array(data['trait'].astype(int)) + + +def log_reg_model(features, obs, subsample_size): + n, m = features.shape + theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), .5 * jnp.ones(m))) + with numpyro.plate('N', n, subsample_size=subsample_size): + batch_feats = numpyro.subsample(features, event_dim=1) + batch_obs = numpyro.subsample(obs, event_dim=0) + numpyro.sample('obs', dist.Bernoulli(logits=theta @ batch_feats.T), obs=batch_obs) + + +def log_reg_guide(feature, obs, subsample_size): + _, m = feature.shape + mean = numpyro.param('mean', jnp.zeros(m), constraints=constraints.real) + # var = numpyro.param('var', jnp.ones(m), constraints=constraints.positive) + numpyro.sample('theta', dist.continuous.Normal(mean, .5)) + + +def hmcecs_model(dataset, data, obs, subsample_size): + optimizer = numpyro.optim.Adam(step_size=5e-5) + svi = SVI(log_reg_model, log_reg_guide, optimizer, loss=Trace_ELBO()) + start = time() + svi_result = svi.run(random.PRNGKey(2), 1000, data, obs, subsample_size, ) + svi_time = time() - start + + pickle.dump(svi_result.params, open(f'{dataset}/svi_params.pkl', 'wb')) + params = svi_result.params + + # Compute HMCECS + kernel = ECS(NUTS(log_reg_model), + proxy='variational', + model_struct={'obs': ['theta']}, + ref=params, + guide=log_reg_guide) + mcmc = MCMC(kernel, 10000, 10000) + start = time() + mcmc.run(random.PRNGKey(3), data, obs, subsample_size, extra_fields=("accept_prob", + "hmc_state.accept_prob", + "hmc_state.num_steps")) + print(mcmc.get_extra_fields(True)['hmc_state.accept_prob']) + summary(dataset, 'ecs', mcmc, time() - start, svi_time=svi_time) + + +def plain_log_reg_model(features, obs): + n, m = features.shape + theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), .5 * jnp.ones(m))) + numpyro.sample('obs', dist.Bernoulli(logits=theta @ features.T), obs=obs) + + +def hmc(dataset, data, obs): + kernel = NUTS(plain_log_reg_model, init_strategy=init_to_sample) + mcmc = MCMC(kernel, 100, 200) + mcmc._compile(random.PRNGKey(0), data, obs, extra_fields=("num_steps",)) + start = time() + mcmc.run(random.PRNGKey(0), data, obs, extra_fields=('num_steps',)) + summary(dataset, 'hmc', mcmc, time() - start) + + +if __name__ == '__main__': + + datasets = ('copsac',) + load_data = {'breast': breast_cancer_data, 'higgs': higgs_data, 'copsac': copsac_data} + subsample_sizes = {'breast': 75, 'higgs': 1300, 'copsac': 1000} + data, obs = breast_cancer_data() + + for dataset in datasets: + dir = f'{dataset}_{datetime.now().strftime("%Y_%m_%d_%H%M%S")}' + if not os.path.exists(dir): + os.mkdir(dir) + data, obs = load_data[dataset]() + hmcecs_model(dir, data, obs, subsample_sizes[dataset]) + hmc(dir, data, obs) diff --git a/examples/hmcecs/mnist_bnn.py b/examples/hmcecs/mnist_bnn.py index 6973218c5..4f120757a 100644 --- a/examples/hmcecs/mnist_bnn.py +++ b/examples/hmcecs/mnist_bnn.py @@ -165,12 +165,11 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.4.1') parser = argparse.ArgumentParser(description="Bayesian neural network example") parser.add_argument("-n", "--num-samples", nargs="?", default=20, type=int) parser.add_argument("--num-warmup", nargs='?', default=10, type=int) parser.add_argument("--num-chains", nargs='?', default=1, type=int) - parser.add_argument("--num-data", nargs='?', default=1000, type=int) + parser.add_argument("--num-data", nargs='?', default=10000, type=int) parser.add_argument("--num-hidden", nargs='?', default=5, type=int) parser.add_argument("--device", default='gpu', type=str, help='use "cpu" or "gpu".') args = parser.parse_args() diff --git a/numpyro/examples/datasets.py b/numpyro/examples/datasets.py index b20ff7038..28aba1e1f 100644 --- a/numpyro/examples/datasets.py +++ b/numpyro/examples/datasets.py @@ -220,7 +220,7 @@ def _load_higgs(): file_path = os.path.join(DATA_DIR, 'HIGGS.csv.gz') df = pd.read_csv(file_path, header=None) obs, feats = df.iloc[:, 0], df.iloc[:, 1:] - return obs.to_numpy(), feats.to_numpy() + return obs.to_numpy().astype(int), feats.to_numpy() def _load(dset): From dd0426cb7d67c7346e371dd96c3ffb14ac22f17e Mon Sep 17 00:00:00 2001 From: ola Date: Tue, 26 Jan 2021 14:56:26 +0100 Subject: [PATCH 50/93] Refactored taylor_estimator into taylor_proxy and a difference estimator. --- numpyro/handlers.py | 68 ++++++++++++++- numpyro/infer/hmc_gibbs.py | 166 +++++++++++++++++++++++++++++++++++-- 2 files changed, 221 insertions(+), 13 deletions(-) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 604856b8c..1bcda0625 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -1,6 +1,5 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 - """ This provides a small set of effect handlers in NumPyro that are modeled after Pyro's `poutine `_ module. @@ -76,16 +75,17 @@ -874.89813 """ -from collections import OrderedDict import warnings +from collections import OrderedDict +from functools import partial +import jax.numpy as jnp import numpy as np - from jax import lax, random -import jax.numpy as jnp import numpyro from numpyro.distributions.distribution import COERCIONS +from numpyro.infer.util import _unconstrain_reparam from numpyro.primitives import _PYRO_STACK, Messenger, apply_stack, plate from numpyro.util import not_jax_tracer @@ -390,6 +390,7 @@ class infer_config(Messenger): :param fn: a stochastic function (callable containing NumPyro primitive calls) :param config_fn: a callable taking a site and returning an infer dict """ + def __init__(self, fn=None, config_fn=None): super().__init__(fn) self.config_fn = config_fn @@ -797,3 +798,62 @@ def process_message(self, msg): msg['value'] = intervention msg['is_observed'] = True msg['stop'] = True + + +class estimate_likelihood(numpyro.primitives.Messenger): + def __init__(self, fn=None, estimator=None): + # estimate_likelihood: accept likelihood tuple (fn, value, subsample_name, subsample_dim, subsample_idx) + # and current unconstrained params + # and returns log of the bias-corrected likelihood + assert estimator is not None + super().__init__(fn) + self.estimator = estimator + self.params = None + self.likelihoods = {} + self.subsample_plates = {} + + def __enter__(self): + # trace(substitute(substitute(control_variate(model), unconstrained_reparam))) + for handler in numpyro.primitives._PYRO_STACK[::-1]: + if isinstance(handler, substitute) and isinstance(handler.substitute_fn, partial) \ + and handler.substitute_fn.func is _unconstrain_reparam: + self.params = handler.substitute_fn.args[0] + break + return super().__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + # make sure exit trackback is nice if an error happens + super().__exit__(exc_type, exc_value, traceback) + if exc_type is not None: + return + + if self.params is None: + return + + # add numpyro.factor; ideally, we will want to skip this computation when making prediction + # see: https://github.com/pyro-ppl/pyro/issues/2744 + numpyro.factor("_biased_corrected_log_likelihood", self.estimator(self.likelihoods, self.params)) + + # clean up + self.params = None + self.likelihoods = {} + self.subsample_plates = {} + + def process_message(self, msg): + if self.params is None: + return + + if msg["type"] == "sample" and msg["is_observed"]: + assert msg["name"] not in self.params + # store the likelihood for the estimator + for frame in msg["cond_indep_stack"]: + if frame.name in self.subsample_plates: + if msg["name"] in self.likelihoods: + raise RuntimeError(f"Multiple subsample plates at site {msg['name']} " + "are not allowed. Please reshape your data.") + subsample_idx = self.subsample_plates[frame.name] + self.likelihoods[msg["name"]] = (msg["fn"], msg["value"], frame.name, frame.dim, subsample_idx) + # mask the current likelihood + msg["fn"] = msg["fn"].mask(False) + elif msg["type"] == "plate" and msg["args"][0] > msg["args"][1]: + self.subsample_plates[msg["name"]] = msg["value"] diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 89f1c89d3..f81337b70 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -1,17 +1,19 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import namedtuple import copy +import warnings +from collections import defaultdict, namedtuple from functools import partial -from jax import device_put, grad, jacfwd, ops, random, value_and_grad import jax.numpy as jnp +from jax import device_put, jacfwd, jacobian, grad, hessian, ops, random, value_and_grad from jax.scipy.special import expit -from numpyro.handlers import condition, seed, substitute, trace +from numpyro.handlers import block, condition, seed, substitute, trace, estimate_likelihood from numpyro.infer.hmc import HMC from numpyro.infer.mcmc import MCMCKernel +from numpyro.infer.util import _unconstrain_reparam from numpyro.util import cond, fori_loop, identity, ravel_pytree HMCGibbsState = namedtuple("HMCGibbsState", "z, hmc_state, rng_key") @@ -247,7 +249,6 @@ def _discrete_modified_rw_proposal(rng_key, z_discrete, pe, potential_fn, idx, s def _discrete_gibbs_fn(potential_fn, support_sizes, proposal_fn): - def gibbs_fn(rng_key, gibbs_sites, hmc_sites, pe): # get support_sizes of gibbs_sites support_sizes_flat, _ = ravel_pytree({k: support_sizes[k] for k in gibbs_sites}) @@ -398,7 +399,6 @@ def potential_fn(z_gibbs, z_hmc): def _subsample_gibbs_fn(potential_fn, plate_sizes, num_blocks=1): - def gibbs_fn(rng_key, gibbs_sites, hmc_sites, pe): assert set(gibbs_sites) == set(plate_sizes) u_new = {} @@ -476,20 +476,25 @@ class HMCECS(HMCGibbs): >>> assert abs(jnp.mean(samples) - 1.) < 0.1 """ - def __init__(self, inner_kernel, *, num_blocks=1): + + def __init__(self, inner_kernel, *, estimator=None, num_blocks=1): super().__init__(inner_kernel, lambda *args: None, None) self._num_blocks = num_blocks + self._estimator = estimator def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): model_kwargs = {} if model_kwargs is None else model_kwargs.copy() rng_key, key_u = random.split(rng_key) self._prototype_trace = trace(seed(self.model, key_u)).get_trace(*model_args, **model_kwargs) - self._plate_sizes = { + self._subsample_plate_sizes = { name: site["args"] for name, site in self._prototype_trace.items() if site["type"] == "plate" and site["args"][0] > site["args"][1] # i.e. size > subsample_size } - self._gibbs_sites = list(self._plate_sizes.keys()) + self._gibbs_sites = list(self._subsample_plate_sizes.keys()) + if self._estimator is not None: + estimator = self._estimator + self.inner_kernel._model = estimate_likelihood(self.inner_kernel._model, estimator) return super().init(rng_key, num_warmup, init_params, model_args, model_kwargs) def sample(self, state, model_args, model_kwargs): @@ -505,7 +510,7 @@ def potential_fn(z_gibbs, z_hmc): model_kwargs_ = model_kwargs.copy() model_kwargs_["_gibbs_sites"] = z_gibbs - gibbs_fn = _subsample_gibbs_fn(potential_fn, self._plate_sizes, self._num_blocks) + gibbs_fn = _subsample_gibbs_fn(potential_fn, self._subsample_plate_sizes, self._num_blocks) z_gibbs, pe = gibbs_fn(rng_key=rng_gibbs, gibbs_sites=z_gibbs, hmc_sites=z_hmc, pe=state.hmc_state.potential_energy) @@ -521,3 +526,146 @@ def potential_fn(z_gibbs, z_hmc): z = {**z_gibbs, **hmc_state.z} return HMCGibbsState(z, hmc_state, rng_key) + + +def difference_estimator(rng_key, model, model_args, model_kwargs, proxy_fn): + # subsample_plate_sizes: name -> (size, subsample_size) + prototype_trace = trace(seed(model, rng_key)).get_trace(*model_args, **model_kwargs) + subsample_plate_sizes = { + name: site["args"] + for name, site in prototype_trace.items() + if site["type"] == "plate" and site["args"][0] > site["args"][1] + } + + def estimator(likelihoods, params): + subsample_log_liks = defaultdict(float) + subsample_indices = {} + for (fn, value, name, subsample_dim, subsample_idx) in likelihoods.values(): + subsample_log_liks[name] += _sum_all_except_at_dim(fn.log_prob(value), subsample_dim) + if name not in subsample_indices: + subsample_indices[name] = subsample_idx + + log_lik_sum = 0. + + proxy_value_all, proxy_value_subsample = proxy_fn(params, subsample_indices) + + for name, subsample_log_lik in subsample_log_liks.items(): # loop over all subsample sites + n, m = subsample_plate_sizes[name] + + diff = subsample_log_lik - proxy_value_subsample[name] + + unbiased_log_lik = proxy_value_all[name] + n * jnp.mean(diff) + variance = n ** 2 / m * jnp.var(diff) + log_lik_sum += unbiased_log_lik - 0.5 * variance + return log_lik_sum + + return estimator + + +def taylor_proxy(rng_key, model, model_args, model_kwargs, reference_params, using_lookup=False): + prototype_trace = trace(seed(model, rng_key)).get_trace(*model_args, **model_kwargs) + subsample_plate_sizes = { + name: site["args"] + for name, site in prototype_trace.items() + if site["type"] == "plate" and site["args"][0] > site["args"][1] # i.e. size > subsample_size + } + # subsample_plate_sizes: name -> (size, subsample_size) + ref_params_flat, unravel_fn = ravel_pytree(reference_params) + + def log_likelihood(params_flat, subsample_indices=None): + if subsample_indices is None: + subsample_indices = {k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items()} + params = unravel_fn(params_flat) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + with block(), trace() as tr, substitute(data=subsample_indices), \ + substitute(substitute_fn=partial(_unconstrain_reparam, params)): + model(*model_args, **model_kwargs) + + log_lik = defaultdict(float) + for site in tr.values(): + if site["type"] == "sample" and site["is_observed"]: + for frame in site["cond_indep_stack"]: + if frame.name in subsample_plate_sizes: + log_lik[frame.name] += _sum_all_except_at_dim( + site["fn"].log_prob(site["value"]), frame.dim) + return log_lik + + def log_likelihood_sum(params_flat, subsample_indices=None): + return {k: v.sum() for k, v in log_likelihood(params_flat, subsample_indices).items()} + + # those stats are dict keyed by subsample names + if using_lookup: + ref_log_likelihoods = log_likelihood(ref_params_flat) # n + # NB: use jacfwd (instead of jacobian/jacrev) when out_dim >> in_dim + ref_log_likelihood_grads = jacfwd(log_likelihood)(ref_params_flat) + ref_log_likelihood_hessians = jacfwd(jacfwd(log_likelihood))(ref_params_flat) # n x 55 x 55 + ref_log_likelihoods_sum = {k: v.sum(0) for k, v in ref_log_likelihoods.items()} + ref_log_likelihood_grads_sum = {k: v.sum(0) for k, v in ref_log_likelihood_grads.items()} + ref_log_likelihood_hessians_sum = {k: v.sum(0) for k, v in ref_log_likelihood_hessians.items()} # 55 x 55 + else: + ref_log_likelihoods_sum = log_likelihood_sum(ref_params_flat) + ref_log_likelihood_grads_sum = jacobian(log_likelihood_sum)(ref_params_flat) + ref_log_likelihood_hessians_sum = hessian(log_likelihood_sum)(ref_params_flat) + + def proxy_fn(params, subsample_indices): + params_flat, _ = ravel_pytree(params) + params_diff = params_flat - ref_params_flat + if using_lookup: + # NB: in GPU, indexing here is expensive, it is better to compute likelihood, grad, hessian directly + # m x 55 x 55 (m ~ sqrt(n) ~ 1000) + ref_subsample_log_lik = {k: v[subsample_indices[k]] + for k, v in ref_log_likelihoods.items()} + ref_subsample_log_lik_grad = {k: v[subsample_indices[k]] + for k, v in ref_log_likelihood_grads.items()} + ref_subsample_log_lik_hessian = {k: v[subsample_indices[k]] + for k, v in ref_log_likelihood_hessians.items()} + else: + ref_subsample_log_lik = log_likelihood_sum(ref_params_flat, subsample_indices) + ref_subsample_log_lik_grad = jacobian(log_likelihood_sum)(ref_params_flat, subsample_indices) + ref_subsample_log_lik_hessian = hessian(log_likelihood_sum)(ref_params_flat, subsample_indices) + + proxy_sum = defaultdict(float) + proxy_subsample = defaultdict(float) + for name, subsample_idx in subsample_indices.items(): + proxy_subsample[name] = ref_subsample_log_lik[name] + \ + jnp.dot(ref_subsample_log_lik_grad[name], params_diff) + \ + 0.5 * jnp.dot(jnp.dot(ref_subsample_log_lik_hessian[name], params_diff), + params_diff) + + proxy_subsample[name] = ref_log_likelihoods_sum[name] + \ + jnp.dot(ref_log_likelihood_grads_sum[name], params_diff) + \ + 0.5 * jnp.dot(jnp.dot(ref_log_likelihood_hessians_sum[name], params_diff), + params_diff) + return proxy_sum, proxy_subsample + + return proxy_fn + + +def _sum_all_except_at_dim(x, dim): + x = x.reshape((-1,) + x.shape[dim:]).sum(0) + return x.reshape(x.shape[:1] + (-1,)).sum(-1) + + +def variational_proxy(rng_key, model, model_args, model_kwargs, subsample_plate_sizes, reference_params, using_lookup=False): + pos_key, guide_key, rng_key = random.split(rng_key, 3) + num_samples = 10 # TODO: heuristic for this + guide = substitute(self._guide, self._ref) + posterior_samples = _predictive(pos_key, guide, {}, + (num_samples,), return_sites='', parallel=True, + model_args=model_args, model_kwargs=model_kwargs) + with subsample_size(self.model, plate_sizes_all): + model = subsample_size(self.model, plate_sizes_all) + ll = log_likelihood(model, posterior_samples, *model_args, **model_kwargs) + + # TODO: fix multiple likehoods + weights = {name: jnp.mean((value.T / value.sum(1).T).T, 0) for name, value in + ll.items()} # TODO: fix broadcast + prior, _ = log_density(block(model, hide_fn=lambda site: site['type'] == 'sample' and site['is_observed']), + model_args, model_kwargs, posterior_samples) + variational, _ = log_density(guide, model_args, model_kwargs, posterior_samples) + evidence = {name: variational / num_samples - prior / num_samples - ll.mean(1).sum() for name, ll in + ll.items()} # TODO: must depend on structure! + + guide_trace = trace(seed(self._guide, guide_key)).get_trace(*model_args, **model_kwargs) + proxy_fn, uproxy_fn = variational_proxy(guide_trace, evidence, weights) From 60e0912e6dea99829f497a19c6403042ad32fba6 Mon Sep 17 00:00:00 2001 From: ola Date: Wed, 27 Jan 2021 09:36:51 +0100 Subject: [PATCH 51/93] Sketched variational proxy in hmc_gibbs. --- examples/hmcecs/logistic_regression.py | 45 ++++---- examples/logistic_hmcecs.py | 2 +- numpyro/handlers.py | 58 ---------- numpyro/infer/hmc_gibbs.py | 142 +++++++++++++++++++++---- 4 files changed, 148 insertions(+), 99 deletions(-) diff --git a/examples/hmcecs/logistic_regression.py b/examples/hmcecs/logistic_regression.py index a889d7d73..ae1a48232 100644 --- a/examples/hmcecs/logistic_regression.py +++ b/examples/hmcecs/logistic_regression.py @@ -12,14 +12,15 @@ import numpyro import numpyro.distributions as dist -from numpyro.contrib.ecs import ECS from numpyro.distributions import constraints from numpyro.examples.datasets import _load_higgs from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, init_to_sample +from numpyro.infer.hmc_gibbs import HMCECS, difference_estimator, variational_proxy, taylor_proxy +from numpyro.infer.util import _predictive os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" -numpyro.set_platform("gpu") +numpyro.set_platform("cpu") def summary(dataset, name, mcmc, sample_time, svi_time=0.): @@ -63,7 +64,7 @@ def copsac_data(): return jnp.array(data.values), jnp.array(data['trait'].astype(int)) -def log_reg_model(features, obs, subsample_size): +def model(features, obs, subsample_size): n, m = features.shape theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), .5 * jnp.ones(m))) with numpyro.plate('N', n, subsample_size=subsample_size): @@ -72,35 +73,43 @@ def log_reg_model(features, obs, subsample_size): numpyro.sample('obs', dist.Bernoulli(logits=theta @ batch_feats.T), obs=batch_obs) -def log_reg_guide(feature, obs, subsample_size): +def guide(feature, obs, subsample_size): _, m = feature.shape mean = numpyro.param('mean', jnp.zeros(m), constraints=constraints.real) # var = numpyro.param('var', jnp.ones(m), constraints=constraints.positive) numpyro.sample('theta', dist.continuous.Normal(mean, .5)) -def hmcecs_model(dataset, data, obs, subsample_size): +def hmcecs_model(dataset, data, obs, subsample_size, proxy_name='variational'): + model_args, model_kwargs = (data, obs, subsample_size), {} + + svi_key, proxy_key, estimator_key, mcmc_key = random.split(random.PRNGKey(0), 4) optimizer = numpyro.optim.Adam(step_size=5e-5) - svi = SVI(log_reg_model, log_reg_guide, optimizer, loss=Trace_ELBO()) + svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) start = time() - svi_result = svi.run(random.PRNGKey(2), 1000, data, obs, subsample_size, ) + svi_result = svi.run(svi_key, 1000, *model_args) svi_time = time() - start pickle.dump(svi_result.params, open(f'{dataset}/svi_params.pkl', 'wb')) params = svi_result.params + if proxy_name == 'taylor': + proxy_key, ref_key = random.split(proxy_key) + ref_params = _predictive(ref_key, guide, {}, (1,), return_sites='', parallel=True, + model_args=model_args, model_kwargs=model_kwargs) + proxy_fn = taylor_proxy(proxy_key, model, model_args, model_kwargs, ref_params) + + else: + proxy_fn = variational_proxy(proxy_key, model, model_args, model_kwargs, guide, params) + estimator = difference_estimator(estimator_key, model, model_args, model_kwargs, proxy_fn) + # Compute HMCECS - kernel = ECS(NUTS(log_reg_model), - proxy='variational', - model_struct={'obs': ['theta']}, - ref=params, - guide=log_reg_guide) - mcmc = MCMC(kernel, 10000, 10000) + + kernel = HMCECS(NUTS(model), estimator=estimator) + mcmc = MCMC(kernel, 1000, 1000) start = time() - mcmc.run(random.PRNGKey(3), data, obs, subsample_size, extra_fields=("accept_prob", - "hmc_state.accept_prob", + mcmc.run(random.PRNGKey(3), data, obs, subsample_size, extra_fields=("hmc_state.accept_prob", "hmc_state.num_steps")) - print(mcmc.get_extra_fields(True)['hmc_state.accept_prob']) summary(dataset, 'ecs', mcmc, time() - start, svi_time=svi_time) @@ -121,15 +130,15 @@ def hmc(dataset, data, obs): if __name__ == '__main__': - datasets = ('copsac',) load_data = {'breast': breast_cancer_data, 'higgs': higgs_data, 'copsac': copsac_data} subsample_sizes = {'breast': 75, 'higgs': 1300, 'copsac': 1000} data, obs = breast_cancer_data() - for dataset in datasets: + for dataset in load_data.keys(): dir = f'{dataset}_{datetime.now().strftime("%Y_%m_%d_%H%M%S")}' if not os.path.exists(dir): os.mkdir(dir) data, obs = load_data[dataset]() hmcecs_model(dir, data, obs, subsample_sizes[dataset]) hmc(dir, data, obs) + exit() diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py index 1ffd81fa7..190483371 100644 --- a/examples/logistic_hmcecs.py +++ b/examples/logistic_hmcecs.py @@ -133,7 +133,7 @@ def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None, map_key, post_key = jax.random.split(map_key) z_ref, svi, svi_state = svi_map(model, map_key, feats=feats[:factor_SVI], obs=obs[:factor_SVI], num_epochs=num_epochs, batch_size=batch_size) - z_ref = svi.log_reg_guide.sample_posterior(post_key, svi.get_params(svi_state), (100,)) + z_ref = svi.guide.sample_posterior(post_key, svi.get_params(svi_state), (100,)) z_ref = {name: value.mean(0) for name, value in z_ref.items()} #highlight: AutoDiagonalNormal does not have auto_ in front of the parmeters save_obj(z_ref,"{}/MAP_Dict_Samples_Proxy_{}.pkl".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")), diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 1bcda0625..d69cf4f48 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -85,7 +85,6 @@ import numpyro from numpyro.distributions.distribution import COERCIONS -from numpyro.infer.util import _unconstrain_reparam from numpyro.primitives import _PYRO_STACK, Messenger, apply_stack, plate from numpyro.util import not_jax_tracer @@ -800,60 +799,3 @@ def process_message(self, msg): msg['stop'] = True -class estimate_likelihood(numpyro.primitives.Messenger): - def __init__(self, fn=None, estimator=None): - # estimate_likelihood: accept likelihood tuple (fn, value, subsample_name, subsample_dim, subsample_idx) - # and current unconstrained params - # and returns log of the bias-corrected likelihood - assert estimator is not None - super().__init__(fn) - self.estimator = estimator - self.params = None - self.likelihoods = {} - self.subsample_plates = {} - - def __enter__(self): - # trace(substitute(substitute(control_variate(model), unconstrained_reparam))) - for handler in numpyro.primitives._PYRO_STACK[::-1]: - if isinstance(handler, substitute) and isinstance(handler.substitute_fn, partial) \ - and handler.substitute_fn.func is _unconstrain_reparam: - self.params = handler.substitute_fn.args[0] - break - return super().__enter__() - - def __exit__(self, exc_type, exc_value, traceback): - # make sure exit trackback is nice if an error happens - super().__exit__(exc_type, exc_value, traceback) - if exc_type is not None: - return - - if self.params is None: - return - - # add numpyro.factor; ideally, we will want to skip this computation when making prediction - # see: https://github.com/pyro-ppl/pyro/issues/2744 - numpyro.factor("_biased_corrected_log_likelihood", self.estimator(self.likelihoods, self.params)) - - # clean up - self.params = None - self.likelihoods = {} - self.subsample_plates = {} - - def process_message(self, msg): - if self.params is None: - return - - if msg["type"] == "sample" and msg["is_observed"]: - assert msg["name"] not in self.params - # store the likelihood for the estimator - for frame in msg["cond_indep_stack"]: - if frame.name in self.subsample_plates: - if msg["name"] in self.likelihoods: - raise RuntimeError(f"Multiple subsample plates at site {msg['name']} " - "are not allowed. Please reshape your data.") - subsample_idx = self.subsample_plates[frame.name] - self.likelihoods[msg["name"]] = (msg["fn"], msg["value"], frame.name, frame.dim, subsample_idx) - # mask the current likelihood - msg["fn"] = msg["fn"].mask(False) - elif msg["type"] == "plate" and msg["args"][0] > msg["args"][1]: - self.subsample_plates[msg["name"]] = msg["value"] diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index f81337b70..ed704ee91 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -10,10 +10,11 @@ from jax import device_put, jacfwd, jacobian, grad, hessian, ops, random, value_and_grad from jax.scipy.special import expit -from numpyro.handlers import block, condition, seed, substitute, trace, estimate_likelihood +import numpyro +from numpyro.handlers import block, condition, seed, substitute, trace from numpyro.infer.hmc import HMC from numpyro.infer.mcmc import MCMCKernel -from numpyro.infer.util import _unconstrain_reparam +from numpyro.infer.util import _unconstrain_reparam, _predictive, log_density from numpyro.util import cond, fori_loop, identity, ravel_pytree HMCGibbsState = namedtuple("HMCGibbsState", "z, hmc_state, rng_key") @@ -647,25 +648,122 @@ def _sum_all_except_at_dim(x, dim): return x.reshape(x.shape[:1] + (-1,)).sum(-1) -def variational_proxy(rng_key, model, model_args, model_kwargs, subsample_plate_sizes, reference_params, using_lookup=False): +def variational_proxy(rng_key, model, model_args, model_kwargs, + guide, reference_params, + subsample_plate_sizes, num_samples=10): pos_key, guide_key, rng_key = random.split(rng_key, 3) - num_samples = 10 # TODO: heuristic for this - guide = substitute(self._guide, self._ref) - posterior_samples = _predictive(pos_key, guide, {}, - (num_samples,), return_sites='', parallel=True, + guide = substitute(guide, reference_params) + + def log_likelihood(params, subsample_indices=None): + params_flat, unravel_fn = ravel_pytree(params) + if subsample_indices is None: + subsample_indices = {k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items()} + params = unravel_fn(params_flat) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + with block(), trace() as tr, substitute(data=subsample_indices), \ + substitute(substitute_fn=partial(_unconstrain_reparam, params)): + model(*model_args, **model_kwargs) + + log_lik = defaultdict(float) + for site in tr.values(): + if site["type"] == "sample" and site["is_observed"]: + for frame in site["cond_indep_stack"]: + if frame.name in subsample_plate_sizes: + log_lik[frame.name] += _sum_all_except_at_dim( + site["fn"].log_prob(site["value"]), frame.dim) + return log_lik + + def log_prior(params): + prior_prob, _ = log_density(block(model, hide_fn=lambda site: site['type'] == 'sample' and site['is_observed']), + model_args, model_kwargs, params) + return prior_prob + + def log_posterior(params): + posterior_prob, _ = log_density(guide, model_args, model_kwargs, params) + return posterior_prob + + # TODO: get MAP from guide! + posterior_samples = _predictive(pos_key, guide, {}, (num_samples,), return_sites='', parallel=True, model_args=model_args, model_kwargs=model_kwargs) - with subsample_size(self.model, plate_sizes_all): - model = subsample_size(self.model, plate_sizes_all) - ll = log_likelihood(model, posterior_samples, *model_args, **model_kwargs) - - # TODO: fix multiple likehoods - weights = {name: jnp.mean((value.T / value.sum(1).T).T, 0) for name, value in - ll.items()} # TODO: fix broadcast - prior, _ = log_density(block(model, hide_fn=lambda site: site['type'] == 'sample' and site['is_observed']), - model_args, model_kwargs, posterior_samples) - variational, _ = log_density(guide, model_args, model_kwargs, posterior_samples) - evidence = {name: variational / num_samples - prior / num_samples - ll.mean(1).sum() for name, ll in - ll.items()} # TODO: must depend on structure! - - guide_trace = trace(seed(self._guide, guide_key)).get_trace(*model_args, **model_kwargs) - proxy_fn, uproxy_fn = variational_proxy(guide_trace, evidence, weights) + log_likelihood_ref = log_likelihood(posterior_samples) + + weights = {name: log_like.mean(1) / log_like.mean(1).sum() for name, log_like in log_likelihood_ref.items()} + + log_prior_prob = log_prior(posterior_samples) + log_posterior_prob = log_posterior(posterior_samples) + + evidence = {name: (log_posterior_prob - log_prior_prob - log_like.sum(0)).mean() # [1] - [1] - [10] + for name, log_like in log_likelihood_ref.items()} + + def proxy_fn(params, subsample_indices): + proxy_sum = defaultdict(float) + proxy_subsample = defaultdict(float) + log_prior_prob = log_prior(params) + log_posterior_prob = log_prior(params) + for name, subsample_idx in subsample_indices.items(): + proxy_sum[name] = evidence[name] + log_posterior_prob - log_prior_prob + proxy_subsample[name] = evidence[name] + \ + weights[name][subsample_idx].sum() * (log_posterior_prob - log_prior_prob) + return proxy_sum, proxy_subsample + + return proxy_fn + + +class estimate_likelihood(numpyro.primitives.Messenger): + def __init__(self, fn=None, estimator=None): + # estimate_likelihood: accept likelihood tuple (fn, value, subsample_name, subsample_dim, subsample_idx) + # and current unconstrained params + # and returns log of the bias-corrected likelihood + assert estimator is not None + super().__init__(fn) + self.estimator = estimator + self.params = None + self.likelihoods = {} + self.subsample_plates = {} + + def __enter__(self): + # trace(substitute(substitute(control_variate(model), unconstrained_reparam))) + for handler in numpyro.primitives._PYRO_STACK[::-1]: + if isinstance(handler, substitute) and isinstance(handler.substitute_fn, partial) \ + and handler.substitute_fn.func is _unconstrain_reparam: + self.params = handler.substitute_fn.args[0] + break + return super().__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + # make sure exit trackback is nice if an error happens + super().__exit__(exc_type, exc_value, traceback) + if exc_type is not None: + return + + if self.params is None: + return + + # add numpyro.factor; ideally, we will want to skip this computation when making prediction + # see: https://github.com/pyro-ppl/pyro/issues/2744 + numpyro.factor("_biased_corrected_log_likelihood", self.estimator(self.likelihoods, self.params)) + + # clean up + self.params = None + self.likelihoods = {} + self.subsample_plates = {} + + def process_message(self, msg): + if self.params is None: + return + + if msg["type"] == "sample" and msg["is_observed"]: + assert msg["name"] not in self.params + # store the likelihood for the estimator + for frame in msg["cond_indep_stack"]: + if frame.name in self.subsample_plates: + if msg["name"] in self.likelihoods: + raise RuntimeError(f"Multiple subsample plates at site {msg['name']} " + "are not allowed. Please reshape your data.") + subsample_idx = self.subsample_plates[frame.name] + self.likelihoods[msg["name"]] = (msg["fn"], msg["value"], frame.name, frame.dim, subsample_idx) + # mask the current likelihood + msg["fn"] = msg["fn"].mask(False) + elif msg["type"] == "plate" and msg["args"][0] > msg["args"][1]: + self.subsample_plates[msg["name"]] = msg["value"] From 85957dc002588239bf632b0442c93dc85c56ca08 Mon Sep 17 00:00:00 2001 From: ola Date: Wed, 27 Jan 2021 11:38:15 +0100 Subject: [PATCH 52/93] Variational proxy running. --- numpyro/infer/hmc_gibbs.py | 42 +++++++++++++++++++++++++------------- numpyro/primitives.py | 8 ++++++++ 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index ed704ee91..dfa2a64e9 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -648,12 +648,18 @@ def _sum_all_except_at_dim(x, dim): return x.reshape(x.shape[:1] + (-1,)).sum(-1) -def variational_proxy(rng_key, model, model_args, model_kwargs, - guide, reference_params, - subsample_plate_sizes, num_samples=10): +def variational_proxy(rng_key, model, model_args, model_kwargs, guide, reference_params, num_samples=10): + prototype_trace = trace(seed(model, rng_key)).get_trace(*model_args, **model_kwargs) + subsample_plate_sizes = { + name: site["args"] + for name, site in prototype_trace.items() + if site["type"] == "plate" and site["args"][0] > site["args"][1] # i.e. size > subsample_size + } + pos_key, guide_key, rng_key = random.split(rng_key, 3) guide = substitute(guide, reference_params) + # factor out? def log_likelihood(params, subsample_indices=None): params_flat, unravel_fn = ravel_pytree(params) if subsample_indices is None: @@ -674,33 +680,40 @@ def log_likelihood(params, subsample_indices=None): site["fn"].log_prob(site["value"]), frame.dim) return log_lik - def log_prior(params): - prior_prob, _ = log_density(block(model, hide_fn=lambda site: site['type'] == 'sample' and site['is_observed']), - model_args, model_kwargs, params) - return prior_prob - def log_posterior(params): - posterior_prob, _ = log_density(guide, model_args, model_kwargs, params) + with numpyro.primitives.inner_stack(): + posterior_prob, _ = log_density(guide, model_args, model_kwargs, params) return posterior_prob + def log_prior(params): + with numpyro.primitives.inner_stack(): + prior_prob, _ = log_density(block(model, hide_fn=lambda site: site['type'] == 'sample' and site['is_observed']), + model_args, model_kwargs, params) + return prior_prob + # TODO: get MAP from guide! posterior_samples = _predictive(pos_key, guide, {}, (num_samples,), return_sites='', parallel=True, model_args=model_args, model_kwargs=model_kwargs) log_likelihood_ref = log_likelihood(posterior_samples) - weights = {name: log_like.mean(1) / log_like.mean(1).sum() for name, log_like in log_likelihood_ref.items()} + posterior_samples = {**posterior_samples, **{k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items()}} + + weights = {name: log_like / log_like.sum() for name, log_like in log_likelihood_ref.items()} - log_prior_prob = log_prior(posterior_samples) - log_posterior_prob = log_posterior(posterior_samples) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=UserWarning) + log_prior_prob = log_prior(posterior_samples) + log_posterior_prob = log_posterior(posterior_samples) - evidence = {name: (log_posterior_prob - log_prior_prob - log_like.sum(0)).mean() # [1] - [1] - [10] + evidence = {name: (log_posterior_prob - log_prior_prob - log_like.sum()) / num_samples for name, log_like in log_likelihood_ref.items()} def proxy_fn(params, subsample_indices): + params = {**params, **subsample_indices} proxy_sum = defaultdict(float) proxy_subsample = defaultdict(float) log_prior_prob = log_prior(params) - log_posterior_prob = log_prior(params) + log_posterior_prob = log_posterior(params) for name, subsample_idx in subsample_indices.items(): proxy_sum[name] = evidence[name] + log_posterior_prob - log_prior_prob proxy_subsample[name] = evidence[name] + \ @@ -710,6 +723,7 @@ def proxy_fn(params, subsample_indices): return proxy_fn + class estimate_likelihood(numpyro.primitives.Messenger): def __init__(self, fn=None, estimator=None): # estimate_likelihood: accept likelihood tuple (fn, value, subsample_name, subsample_dim, subsample_idx) diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 108f7f476..c77fe150e 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -18,6 +18,14 @@ CondIndepStackFrame = namedtuple('CondIndepStackFrame', ['name', 'dim', 'size']) +@contextmanager +def inner_stack(): + global _PYRO_STACK + current_stack = _PYRO_STACK + _PYRO_STACK = [] + yield + _PYRO_STACK = current_stack + def apply_stack(msg): pointer = 0 for pointer, handler in enumerate(reversed(_PYRO_STACK)): From c01738f43e3f1fc5740c5af88d1b075d66c287c0 Mon Sep 17 00:00:00 2001 From: ola Date: Fri, 29 Jan 2021 17:25:16 +0100 Subject: [PATCH 53/93] Examples. --- examples/hmcecs/cifar10.py | 152 +++++++++++++++++++++++++ examples/hmcecs/data/data.npy | Bin 0 -> 6528 bytes examples/hmcecs/logistic_regression.py | 22 ++-- examples/hmcecs/protein.py | 139 ++++++++++++++++++++++ numpyro/infer/hmc_gibbs.py | 10 +- 5 files changed, 307 insertions(+), 16 deletions(-) create mode 100644 examples/hmcecs/cifar10.py create mode 100644 examples/hmcecs/data/data.npy create mode 100644 examples/hmcecs/protein.py diff --git a/examples/hmcecs/cifar10.py b/examples/hmcecs/cifar10.py new file mode 100644 index 000000000..5bf42edae --- /dev/null +++ b/examples/hmcecs/cifar10.py @@ -0,0 +1,152 @@ +import os +import pickle +import tarfile +from time import time +from urllib.request import urlretrieve + +import numpy as np +from flax import nn +from flax.nn.activation import selu, softmax +from jax import random, device_get + +import numpyro +import numpyro.distributions as dist +from numpyro.contrib.module import random_flax_module +from numpyro.infer import MCMC, NUTS, init_to_median + + +def cifar10(path=None): + r"""Return (train_images, train_labels, test_images, test_labels). + + Args: + path (str): Directory containing CIFAR-10. Default is + /home/USER/data/cifar10 or C:\Users\USER\data\cifar10. + Create if nonexistant. Download CIFAR-10 if missing. + + Returns: + Tuple of (train_images, train_labels, test_images, test_labels), each + a matrix. Rows are examples. Columns of images are pixel values, + with the order (red -> blue -> green). Columns of labels are a + onehot encoding of the correct class. + """ + url = 'https://www.cs.toronto.edu/~kriz/' + tar = 'cifar-10-binary.tar.gz' + files = ['cifar-10-batches-bin/data_batch_1.bin', + 'cifar-10-batches-bin/data_batch_2.bin', + 'cifar-10-batches-bin/data_batch_3.bin', + 'cifar-10-batches-bin/data_batch_4.bin', + 'cifar-10-batches-bin/data_batch_5.bin', + 'cifar-10-batches-bin/test_batch.bin'] + + if path is None: + # Set path to /home/USER/data/mnist or C:\Users\USER\data\mnist + path = os.path.join(os.path.expanduser('~'), 'data', 'cifar10') + + # Create path if it doesn't exist + os.makedirs(path, exist_ok=True) + + # Download tarfile if missing + if tar not in os.listdir(path): + urlretrieve(''.join((url, tar)), os.path.join(path, tar)) + print("Downloaded %s to %s" % (tar, path)) + + # Load data from tarfile + with tarfile.open(os.path.join(path, tar)) as tar_object: + # Each file contains 10,000 color images and 10,000 labels + fsize = 10000 * (32 * 32 * 3) + 10000 + + # There are 6 files (5 train and 1 test) + buffr = np.zeros(fsize * 6, dtype='uint8') + + # Get members of tar corresponding to data files + # -- The tar contains README's and other extraneous stuff + members = [file for file in tar_object if file.name in files] + + # Sort those members by name + # -- Ensures we load train data in the proper order + # -- Ensures that test data is the last file in the list + members.sort(key=lambda member: member.name) + + # Extract data from members + for i, member in enumerate(members): + # Get member as a file object + f = tar_object.extractfile(member) + # Read bytes from that file object into buffr + buffr[i * fsize:(i + 1) * fsize] = np.frombuffer(f.read(), 'B') + + # Parse data from buffer + # -- Examples are in chunks of 3,073 bytes + # -- First byte of each chunk is the label + # -- Next 32 * 32 * 3 = 3,072 bytes are its corresponding image + + # Labels are the first byte of every chunk + labels = buffr[::3073] + + # Pixels are everything remaining after we delete the labels + pixels = np.delete(buffr, np.arange(0, buffr.size, 3073)) + images = pixels.reshape((-1, 32, 32, 3)).astype('float32') / 255 + + # Split into train and test + train_images, test_images = images[:50000], images[50000:] + train_labels, test_labels = labels[:50000], labels[50000:] + + return train_images, train_labels, test_images, test_labels + + +def summary(dataset, name, mcmc, sample_time, svi_time=0., plates={}): + n_eff_mean = np.mean([numpyro.diagnostics.effective_sample_size(device_get(v)) + for k, v in mcmc.get_samples(True).items() if k not in plates]) + pickle.dump(mcmc.get_samples(True), open(f'{dataset}/{name}_posterior_samples.pkl', 'wb')) + step_field = 'num_steps' if name == 'hmc' else 'hmc_state.num_steps' + num_step = np.sum(mcmc.get_extra_fields()[step_field]) + accpt_prob = 1. + + with open(f'{dataset}/{name}_chain_stats.txt', 'w') as f: + print('sample_time', 'svi_time', 'n_eff_mean', 'gibbs_accpt_prob', 'tot_num_steps', 'time_per_step', + 'time_per_eff', + sep=',', file=f) + print(sample_time, svi_time, n_eff_mean, accpt_prob, num_step, sample_time / num_step, sample_time / n_eff_mean, + sep=',', file=f) + + +class Network(nn.Module): + """ Scaling Hamiltonian Monte Carlo Inference for Bayesian Neural Networks with Symmetric Splitting + Adam D. Cobb, Brian Jalaian (2020) """ + + def apply(self, x, out_channels): + c1 = selu(nn.Conv(x, features=6, kernel_size=(4, 4))) + max1 = nn.max_pool(c1, window_shape=(2, 2)) + c2 = nn.activation.selu(nn.Conv(max1, features=16, kernel_size=(4, 4))) + max2 = nn.max_pool(c2, window_shape=(2, 2)) + l1 = selu(nn.Dense(max2.reshape(x.shape[0], -1), features=400)) + l2 = selu(nn.Dense(l1, features=120)) + l3 = selu(nn.Dense(l2, features=84)) + l4 = softmax(nn.Dense(l3, features=out_channels)) + return l4 + + +def model(data, obs): + module = Network.partial(out_channels=10) + net = random_flax_module('conv_nn', module, dist.Normal(0, 1.), input_shape=data.shape) + + if obs is not None: + obs = obs[..., None] + numpyro.sample('obs', dist.Categorical(logits=net(data)), obs=obs) + + +def hmc(dataset, data, obs): + kernel = NUTS(model, init_strategy=init_to_median) + mcmc = MCMC(kernel, 100, 100) + mcmc._compile(random.PRNGKey(0), data, obs, extra_fields=("num_steps",)) + start = time() + mcmc.run(random.PRNGKey(0), data, obs, extra_fields=('num_steps',)) + summary(dataset, 'hmc', mcmc, time() - start) + + +def main(): + train_data, train_labels, test_data, test_labels = cifar10() + hmc('cifar10', train_data[:1000], train_labels[:1000]) + + +if __name__ == '__main__': + main() diff --git a/examples/hmcecs/data/data.npy b/examples/hmcecs/data/data.npy new file mode 100644 index 0000000000000000000000000000000000000000..09fd6d42cb58c97323c3de77c67a6752b5731027 GIT binary patch literal 6528 zcmbVQ`9n^Qk3LWH8v^h~}vM8vMAbQ0xIcr465sO~On^U_QeAlh3gGM{8ZGN$mNjvNN89-I=Cp>D}45?YDWGC1swa2lW~?m;+9-t*rF_hS zFe}!kYxAV=tTx#xbef7?PtUBPp|U+DaqV++im{0!uRfxbw-Y|!n{CY6r2kqPd0G7Z zuP6)iOD=lUgfi+v?hVwh0%lzB*Vbm>&&w&W%*694n}{Z~w(sEgZDkfDGfueZN@2}- z+s}JdhOA8olxx86=+HrqT8x*B>y$(B{6{Z0e&zTK4OBr%U0JIBQ&Wna-)UWs>>OO8 z%(vIG=h0q&l<+i5;^Jdq#EuL2xech=dY-Ruiz&(4T2mS1rMq5r6TD~V^Oy5fkSg3& zanO$2&v1Pcvb?nDNXL3E9@Cikbx9iyhz5@T0;v+-)f&CkTZH`;=$4AmwnW|b4O==8 zvmZ&h%YiCWU;AfnHugU$Q2}llW)XiU;Q7KG!4eebu1r2si}gnx-y{adZ+&9txM93R ztzkW?ecS&-Vo!{hS5_ehx#mGi>iTx<-|gYe5ZGBXYqvkfE4qXE{dLn^(w62(uCS&Qh#(3ujhO7 z8j#=mU#;hl7_#g67rr*2M`=Nem;dDWOZ_Q5&GUIf%t`;7RjM=0*fX8XW z0x4uhEtqEa1g{tFXr+eo6Xt%NkZO)2J9?d5YUa|LEBG} z!=S?d-XVAFA9>J$MyGn}m+iiT^*}rqNuhoF)M;<~7P0;iLQ^F)6J=?4KUZRYdX7|r z@n6z!w_UM+RMGcF)IU40zsM2uAqWvSp(~s95AbTMm_Iv15Nd3eU=s(7qd!So~vWM&4D7XxrSiHZ**)UxRO{qv!i5301E2DD4M z-91GAviK;{Nd~?zUn-_2;`#LD;96ACy|g1Rxs4qsVt3O(IE0T*ugCR;R*#j#8{$=Y z-hRv<9eQ1gEZp*4bCa<@^lyC$@=4o0HYDg3#yk3>0wRw!)y|!a`RDsQl!0MgR{OV1 zPHz`))gvFfWsgmZa6KXG_sAf5L_*I9Ev{dJf?zqEFupVGPb!zsC!-n>?_}Lp^{mHu z|JGGZzV*2F^1?MdpSNwC3Si$oW#mdO{!<*ZkjI~Xp!3Hkc0RvDpYi{y=Toz6j8_ok zEkcKn&NR7p)R1Jx1to`?Vg9Ea&yMRDAOH2uM&u`|sb2HloMOg_(%O17ty(P44&wA3 zO43MGEZFv|8Ly{Kc4%O6x7iWr$+%weCh5zd!|0Nade)NzC zDV&%Xb7#wr2Z*tGRzY$|I&E{OvJjt7v~E(s&8SE1&0$ziL|}jzof!DkYwJA3uIGu* z%HfyJn_0(rIKOE{u^4T9n0Q{SFJ<>5E;02r*MIqmiF4)be7>&feTTV~0ue7i%D9>*CAsl}}1(TnFj5?}^~_9iO9sJ7tIbrpPfKUg=IX@IF3^dG5A{jgOF8 ztN@caVSQZcYBBz;zEY?kZ=NnZhu71AZxo=NY<1{oFpf8`&b<+N3>rGk;w09)z-qe~ zb*gvTr-ztP?0&;1O3~H9v8pFK%vqZtIvUX9AcL+>Q#{V^X-1>!I%Vasiw*2~bebF_po|%*@VPI>@X$ND>m+-Kj7}?wzQ6r1v`h~N%<&aOT zwYchz&nL$DHp7SATVk9Ra`@^GGkQF0vdu$`<3TJpSHYzvJMY*=b9g7_sUh*~WA#Mf z#=DZ$V5VxByjb#t#YYZpQ$T5d#lEz7tPeqfgA}z-(p@@rRRcTD4?5lg88at)HO$5J zma@H}fWdAK!;UV%ds(FB|-ijKem**;P~(y_BNu1b-yO6i~* z@KY%)D}8zpgqUw$mbDZjw(fUv{8Y{26`m!Hvif`$u8r$o{E7!NN;tA^&N9m%u9alH6>tx`yTK4Pa$q!RP@b}kKi?ISP5 zoW=b*e}1tX+(I)V{(Xn_ARPa`0i`yMnYtww?zhDuP7aP^A6|Bz z!s(~fN`xG8s}I@q#`Tljy1wG_I^e9W&YR=&wHAV`5y*}OL<}btg z_?V>vji}7PJP^lA(8#AzX#7&odQa}Yx&ONnEh=nHyL{7}V)&;P9;8v~>NQ^8J1}2@ zpa=!LU(j#tj8UDee}Z||VpvZ*BqiVA^!O}_M&EyCMeLZ!)nm`4QrJ@J;oLNWtDpNF zG$3rr>1?;d`k+*0O|aOTx8h?PzR##{2`UJklC#t4D_+mHh?hdYbCxDnkMaFO$_N#> zv_-C|u*Ul2o5jn(dvufQ_k7GBWp_>m%4@oZ6R&XV3I9eAIMy%GAC|@XPZtEp;8B(6 zn6wz5FIb!_hfQsxEynG^dgE)fc*r~l*jpU>R)}a?*lu>C7UwU~=qy1GY&BWKR$_cK>{G$q)%%(%!#OUV%c6Xej&v^8*e zq#IixYhX=K0M|bfQOBjQn>V>?i5K>tT(gpf&*k|BeQ)6U!4FEIQJ?&OBlSBudrLnW z?Hqo3uk}k@|M}*XA~a}p`J3Dxm~Vp5;2qrU*FL7!kYwDWmlSKEI@Rmk!nvGYWJ~H% zwH5X4uWfjoy1!iw!ZUMAYKuOzd{CriGbA1gdtiML`_DIUXh6C7Pk!bm;Borxfo2fS z4*K(|BOfvI$-kNR+TenXO%>m8{~%0#L!%QD2M>4r^_Csy)9Ym5QSqXCb}~MX-*H3% zg%*Rf-`8P2_|bdSKz2+J_EvEBgYb$Bx~l@lwynZ^3S5}}#yVo@I~@g%53eFp1&ck! zs)cgyeY<)wlOMzFen;tIe#mXQ^=Qh^e%|LMWBvuB8U80!Job+Kjq8zM^A!mMH{NW! zEx~#f%(&VJ#mRajhTh@wYrH}UIy-uwajnJg8M0qrCD?rk8>zPu_vb>_v}VX#t~7V~ z2jdfN&!Zv8ZEc_EM(mH^U9<)Y2gIM6$;0^grP(qF4?S??ejtZ;{w*0~ALtc*Y+p0W zx8S~$67<8%%Uo61KmK!NGsx^}r|+M^<;&mOnYw?%bwRQ5CU(AX@K70i?DKlL^)c?f zkmRp`1v7{~<=2kClWV&g+y zpQV5?WxXeLrDFWlyAdrAXnW+*thYT*A5{ysqr zDUv~Y>&vddbUdFXvc%||>*c)c|Ld0nd?fHs^BKnzAF%$2zKLSA_2rx&-^;Nrh^uA3 zX9TndOgx403d=7wLtX2k+pjb@ex%1&DdZ`qwC2~8vHp`|2C3jg!o?Q9|8TwI8yuA& zBg6TTelZw7@rmg-FUqQh)vV$C|Ncn@SFFzD_%Gx7+ZDU&(BidUY=*|;dM=nbNC7X} z?gj;icd+{l&d=6B&-eXK8b#s$g!W65K}EWH?n?sKBfiBmCAhrL@KyeJhnV@qfz=Yo z_S(AeFvmAFBojW;9VCs#lecx+nzAy-dI+j z7W9S7@4Z`P$dBLW-G+JGe9weNWcBRs5}6OaUueHP1@Mlk3t#re^@6fp&-4qGl0)@X zIKN24G%b`~&+Bb5yb9xE-y1!Huhe}Vi}xcfMIuyEIos@c65fwk^+y3+Hv$Sr{o>E6KZFVM$~@x{YZ0998cjZ zT8yOQKCU}{6YF2-sUt>meS`H>1@6b_f(9krcDYj%#r3m-966H@kB;SkIF9wn8$MqF zzCYr%kKW*ZiT5E>1s&rr2a*Li--R11nD22lUHOU(j9(D z`diV$kvKlo9~-8>yU}Mt75Cm26dY?prEkX0b6JPKZwPj^&`{;O=e>6V)(hWxEW`IQ zr@=aVupa5x^EF@#ib{!oAspcF%2@-T7s(jEaXRaPyd`CmOB|lw~ zk8RrGxCGF+Awk#HV?M|N9*s!pO6Mo3xc*S1bM+%~ss!ba8ofNtn$x4(N)3qa9n9P(;^L9Se4m*b zQCVc0-^St>L@sCm!ZG*>)y>^sn|u{evCQDF?SpW>2v4QT;B)K#QQclRU#NClDJssp z`?A9c;}tGYG2fTo4Ga;-;QNh2CuPv5%liDS9PA&dzE+P)QU@+J(B5xp>Hjp zw+H6~wR52u^%pLURh4n|{pt)6(u#{74C{gQD^Mt;uxNG4t(yTj{=8Rw8hi_T8upFD z^<8Mf+($S5JTUUgbU8_C*2B+3hAoiiVvp$S0Wl&R?!O bzqXkjCs)suLb0FO**E)ff5Wq2?u-8c@H$K_ literal 0 HcmV?d00001 diff --git a/examples/hmcecs/logistic_regression.py b/examples/hmcecs/logistic_regression.py index ae1a48232..fd6940114 100644 --- a/examples/hmcecs/logistic_regression.py +++ b/examples/hmcecs/logistic_regression.py @@ -14,7 +14,7 @@ import numpyro.distributions as dist from numpyro.distributions import constraints from numpyro.examples.datasets import _load_higgs -from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, init_to_sample +from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, init_to_sample, init_to_median from numpyro.infer.hmc_gibbs import HMCECS, difference_estimator, variational_proxy, taylor_proxy from numpyro.infer.util import _predictive @@ -23,15 +23,13 @@ numpyro.set_platform("cpu") -def summary(dataset, name, mcmc, sample_time, svi_time=0.): +def summary(dataset, name, mcmc, sample_time, svi_time=0., plates={}): n_eff_mean = np.mean([numpyro.diagnostics.effective_sample_size(device_get(v)) - for v in mcmc.get_samples(True).values()]) + for k, v in mcmc.get_samples(True).items() if k not in plates]) pickle.dump(mcmc.get_samples(True), open(f'{dataset}/{name}_posterior_samples.pkl', 'wb')) step_field = 'num_steps' if name == 'hmc' else 'hmc_state.num_steps' num_step = np.sum(mcmc.get_extra_fields()[step_field]) accpt_prob = 1. - if name == 'ecs': - accpt_prob = np.mean(mcmc.get_extra_fields()['accept_prob']) with open(f'{dataset}/{name}_chain_stats.txt', 'w') as f: print('sample_time', 'svi_time', 'n_eff_mean', 'gibbs_accpt_prob', 'tot_num_steps', 'time_per_step', @@ -80,7 +78,7 @@ def guide(feature, obs, subsample_size): numpyro.sample('theta', dist.continuous.Normal(mean, .5)) -def hmcecs_model(dataset, data, obs, subsample_size, proxy_name='variational'): +def hmcecs_model(dataset, data, obs, subsample_size, proxy_name='taylor'): model_args, model_kwargs = (data, obs, subsample_size), {} svi_key, proxy_key, estimator_key, mcmc_key = random.split(random.PRNGKey(0), 4) @@ -110,7 +108,7 @@ def hmcecs_model(dataset, data, obs, subsample_size, proxy_name='variational'): start = time() mcmc.run(random.PRNGKey(3), data, obs, subsample_size, extra_fields=("hmc_state.accept_prob", "hmc_state.num_steps")) - summary(dataset, 'ecs', mcmc, time() - start, svi_time=svi_time) + summary(dataset, 'ecs', mcmc, time() - start, svi_time=svi_time, plates={'N': ''}) def plain_log_reg_model(features, obs): @@ -120,8 +118,8 @@ def plain_log_reg_model(features, obs): def hmc(dataset, data, obs): - kernel = NUTS(plain_log_reg_model, init_strategy=init_to_sample) - mcmc = MCMC(kernel, 100, 200) + kernel = NUTS(plain_log_reg_model,trajectory_length=1.2, init_strategy=init_to_median) + mcmc = MCMC(kernel, 100, 100) mcmc._compile(random.PRNGKey(0), data, obs, extra_fields=("num_steps",)) start = time() mcmc.run(random.PRNGKey(0), data, obs, extra_fields=('num_steps',)) @@ -130,8 +128,8 @@ def hmc(dataset, data, obs): if __name__ == '__main__': - load_data = {'breast': breast_cancer_data, 'higgs': higgs_data, 'copsac': copsac_data} - subsample_sizes = {'breast': 75, 'higgs': 1300, 'copsac': 1000} + load_data = {'higgs': higgs_data, 'breast': breast_cancer_data, 'copsac': copsac_data} + subsample_sizes = {'higgs': 1300, 'copsac': 1000, 'breast': 75, } data, obs = breast_cancer_data() for dataset in load_data.keys(): @@ -139,6 +137,6 @@ def hmc(dataset, data, obs): if not os.path.exists(dir): os.mkdir(dir) data, obs = load_data[dataset]() - hmcecs_model(dir, data, obs, subsample_sizes[dataset]) + # hmcecs_model(dir, data, obs, subsample_sizes[dataset]) hmc(dir, data, obs) exit() diff --git a/examples/hmcecs/protein.py b/examples/hmcecs/protein.py new file mode 100644 index 000000000..b3dc3af7c --- /dev/null +++ b/examples/hmcecs/protein.py @@ -0,0 +1,139 @@ +from pathlib import Path + +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from flax import nn +from flax.nn.activation import tanh +from jax import random, vmap + +import numpyro +import numpyro.distributions as dist +from numpyro import handlers +from numpyro.contrib.module import random_flax_module +from numpyro.infer import MCMC, NUTS, init_to_median + +uci_base_url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/' + + +def visualize(train_data, train_obs, test_data, predictions): + fs = 16 + + m = predictions.mean(0) + s = predictions.std(0) + # s_al = (pred_list[200:].var(0).to('cpu') + tau_out ** -1) ** 0.5 + + f, ax = plt.subplots(1, 1, figsize=(8, 4)) + + # Get upper and lower confidence bounds + lower, upper = (m - s * 2).flatten(), (m + s * 2).flatten() + # + aleotoric + # lower_al, upper_al = (m - s_al*2).flatten(), (m + s_al*2).flatten() + + # Plot training data as black stars + ax.plot(train_data, train_obs, 'k*', rasterized=True) + # Plot predictive means as blue line + ax.plot(test_data, m, 'b', rasterized=True) + # Shade between the lower and upper confidence bounds + ax.fill_between(test_data, lower, upper, alpha=0.5, rasterized=True) + # ax.fill_between(X_test.flatten().numpy(), lower_al.numpy(), upper_al.numpy(), alpha=0.2, rasterized=True) + ax.set_ylim([-2, 2]) + ax.set_xlim([-2, 2]) + plt.grid() + ax.legend(['Observed Data', 'Mean', 'Epistemic'], fontsize=fs) + ax.tick_params(axis='both', which='major', labelsize=14) + ax.tick_params(axis='both', which='minor', labelsize=14) + + bbox = {'facecolor': 'white', 'alpha': 0.8, 'pad': 1, 'boxstyle': 'round', 'edgecolor': 'black'} + + plt.tight_layout() + # plt.savefig('plots/full_hmc.pdf', rasterized=True) + + plt.show() + + +def load_agw_1d(get_feats=False): + def features(x): + return np.hstack([x[:, None] / 2.0, (x[:, None] / 2.0) ** 2]) + + data = np.load(str(Path(__file__).parent / 'hmcecs' / 'data' / 'data.npy')) + x, y = data[:, 0], data[:, 1] + y = y[:, None] + f = features(x) + + x_means, x_stds = x.mean(axis=0), x.std(axis=0) + y_means, y_stds = y.mean(axis=0), y.std(axis=0) + f_means, f_stds = f.mean(axis=0), f.std(axis=0) + + X = ((x - x_means) / x_stds).astype(np.float32) + Y = ((y - y_means) / y_stds).astype(np.float32) + F = ((f - f_means) / f_stds).astype(np.float32) + + if get_feats: + return F, Y + + return X[:, None], Y + + +def protein(): + # from hughsalimbeni/bayesian_benchmarks + # N, D, name = 45730, 9, 'protein' + url = uci_base_url + '00265/CASP.csv' + + data = pd.read_csv(url).values + return data[:, 1:], data[:, 0].reshape(-1, 1) + + +class Network(nn.Module): + def apply(self, x, out_channels): + l1 = tanh(nn.Dense(x, features=100)) + l2 = tanh(nn.Dense(l1, features=100)) + l3 = tanh(nn.Dense(l2, features=100)) + means = nn.Dense(l3, features=out_channels) + return means + + +def model(data, obs=None): + module = Network.partial(out_channels=1) + net = random_flax_module('fnn', module, dist.Normal(0, 1.), input_shape=data.shape) + + if obs is not None: + obs = obs[..., None] + + prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0)) + sigma_obs = 1.0 / jnp.sqrt(prec_obs) # prior + + numpyro.sample('obs', dist.Normal(net(data), sigma_obs), obs=obs) + + +def hmc(dataset, data, obs, warmup, num_sample): + kernel = NUTS(model, init_strategy=init_to_median) + mcmc = MCMC(kernel, warmup, num_sample) + mcmc._compile(random.PRNGKey(0), data, obs, extra_fields=("num_steps",)) + mcmc.run(random.PRNGKey(0), data, obs, extra_fields=('num_steps',)) + return mcmc.get_samples() + + +# helper function for prediction +def predict(model, rng_key, samples, *args, **kwargs): + model = handlers.substitute(handlers.seed(model, rng_key), samples) + # note that Y will be sampled in the model because we pass Y=None here + model_trace = handlers.trace(model).get_trace(*args, **kwargs) + return model_trace['obs']['value'] + + +def main(): + data, obs = load_agw_1d() + warmup = 100 + num_samples = 100 + test_data = np.linspace(-2, 2, 500).reshape(-1, 1) + samples = hmc('protein', data, obs, warmup, num_samples) + vmap_args = (samples, random.split(random.PRNGKey(1), num_samples)) + predictions = vmap(lambda samples, rng_key: predict(model, rng_key, samples, test_data))(*vmap_args) + predictions = predictions[..., 0] + visualize(data, obs, np.squeeze(test_data), predictions) + + +if __name__ == '__main__': + main() diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index dfa2a64e9..7a3aaddb9 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -570,6 +570,9 @@ def taylor_proxy(rng_key, model, model_args, model_kwargs, reference_params, usi for name, site in prototype_trace.items() if site["type"] == "plate" and site["args"][0] > site["args"][1] # i.e. size > subsample_size } + + reference_params = {k:v for k,v in reference_params.items() if k in prototype_trace} + # subsample_plate_sizes: name -> (size, subsample_size) ref_params_flat, unravel_fn = ravel_pytree(reference_params) @@ -687,11 +690,11 @@ def log_posterior(params): def log_prior(params): with numpyro.primitives.inner_stack(): - prior_prob, _ = log_density(block(model, hide_fn=lambda site: site['type'] == 'sample' and site['is_observed']), - model_args, model_kwargs, params) + prior_prob, _ = log_density( + block(model, hide_fn=lambda site: site['type'] == 'sample' and site['is_observed']), + model_args, model_kwargs, params) return prior_prob - # TODO: get MAP from guide! posterior_samples = _predictive(pos_key, guide, {}, (num_samples,), return_sites='', parallel=True, model_args=model_args, model_kwargs=model_kwargs) log_likelihood_ref = log_likelihood(posterior_samples) @@ -723,7 +726,6 @@ def proxy_fn(params, subsample_indices): return proxy_fn - class estimate_likelihood(numpyro.primitives.Messenger): def __init__(self, fn=None, estimator=None): # estimate_likelihood: accept likelihood tuple (fn, value, subsample_name, subsample_dim, subsample_idx) From 2151895c2a1413069313684bf54323392d755086 Mon Sep 17 00:00:00 2001 From: Ola Date: Fri, 29 Jan 2021 19:20:15 +0100 Subject: [PATCH 54/93] Moved estimate_likelihood --- numpyro/handlers.py | 58 ----------------------------- numpyro/infer/hmc_gibbs.py | 75 ++++++++++++++++++++++++++++++++++---- 2 files changed, 68 insertions(+), 65 deletions(-) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 1bcda0625..d69cf4f48 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -85,7 +85,6 @@ import numpyro from numpyro.distributions.distribution import COERCIONS -from numpyro.infer.util import _unconstrain_reparam from numpyro.primitives import _PYRO_STACK, Messenger, apply_stack, plate from numpyro.util import not_jax_tracer @@ -800,60 +799,3 @@ def process_message(self, msg): msg['stop'] = True -class estimate_likelihood(numpyro.primitives.Messenger): - def __init__(self, fn=None, estimator=None): - # estimate_likelihood: accept likelihood tuple (fn, value, subsample_name, subsample_dim, subsample_idx) - # and current unconstrained params - # and returns log of the bias-corrected likelihood - assert estimator is not None - super().__init__(fn) - self.estimator = estimator - self.params = None - self.likelihoods = {} - self.subsample_plates = {} - - def __enter__(self): - # trace(substitute(substitute(control_variate(model), unconstrained_reparam))) - for handler in numpyro.primitives._PYRO_STACK[::-1]: - if isinstance(handler, substitute) and isinstance(handler.substitute_fn, partial) \ - and handler.substitute_fn.func is _unconstrain_reparam: - self.params = handler.substitute_fn.args[0] - break - return super().__enter__() - - def __exit__(self, exc_type, exc_value, traceback): - # make sure exit trackback is nice if an error happens - super().__exit__(exc_type, exc_value, traceback) - if exc_type is not None: - return - - if self.params is None: - return - - # add numpyro.factor; ideally, we will want to skip this computation when making prediction - # see: https://github.com/pyro-ppl/pyro/issues/2744 - numpyro.factor("_biased_corrected_log_likelihood", self.estimator(self.likelihoods, self.params)) - - # clean up - self.params = None - self.likelihoods = {} - self.subsample_plates = {} - - def process_message(self, msg): - if self.params is None: - return - - if msg["type"] == "sample" and msg["is_observed"]: - assert msg["name"] not in self.params - # store the likelihood for the estimator - for frame in msg["cond_indep_stack"]: - if frame.name in self.subsample_plates: - if msg["name"] in self.likelihoods: - raise RuntimeError(f"Multiple subsample plates at site {msg['name']} " - "are not allowed. Please reshape your data.") - subsample_idx = self.subsample_plates[frame.name] - self.likelihoods[msg["name"]] = (msg["fn"], msg["value"], frame.name, frame.dim, subsample_idx) - # mask the current likelihood - msg["fn"] = msg["fn"].mask(False) - elif msg["type"] == "plate" and msg["args"][0] > msg["args"][1]: - self.subsample_plates[msg["name"]] = msg["value"] diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index f81337b70..95c70c8aa 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -10,10 +10,11 @@ from jax import device_put, jacfwd, jacobian, grad, hessian, ops, random, value_and_grad from jax.scipy.special import expit -from numpyro.handlers import block, condition, seed, substitute, trace, estimate_likelihood +import numpyro +from numpyro.handlers import block, condition, seed, substitute, trace, Messenger from numpyro.infer.hmc import HMC from numpyro.infer.mcmc import MCMCKernel -from numpyro.infer.util import _unconstrain_reparam +from numpyro.infer.util import _unconstrain_reparam, _predictive, log_density from numpyro.util import cond, fori_loop, identity, ravel_pytree HMCGibbsState = namedtuple("HMCGibbsState", "z, hmc_state, rng_key") @@ -564,6 +565,7 @@ def estimator(likelihoods, params): def taylor_proxy(rng_key, model, model_args, model_kwargs, reference_params, using_lookup=False): prototype_trace = trace(seed(model, rng_key)).get_trace(*model_args, **model_kwargs) + reference_params = {k: v for k, v in reference_params.items() if k in prototype_trace} subsample_plate_sizes = { name: site["args"] for name, site in prototype_trace.items() @@ -647,16 +649,16 @@ def _sum_all_except_at_dim(x, dim): return x.reshape(x.shape[:1] + (-1,)).sum(-1) -def variational_proxy(rng_key, model, model_args, model_kwargs, subsample_plate_sizes, reference_params, using_lookup=False): +def variational_proxy(rng_key, model, model_args, model_kwargs, guide, reference_params): pos_key, guide_key, rng_key = random.split(rng_key, 3) num_samples = 10 # TODO: heuristic for this - guide = substitute(self._guide, self._ref) + guide = substitute(guide, reference_params) posterior_samples = _predictive(pos_key, guide, {}, (num_samples,), return_sites='', parallel=True, model_args=model_args, model_kwargs=model_kwargs) - with subsample_size(self.model, plate_sizes_all): - model = subsample_size(self.model, plate_sizes_all) - ll = log_likelihood(model, posterior_samples, *model_args, **model_kwargs) + + model = subsample_size(self.model, plate_sizes_all) + ll = log_likelihood(model, posterior_samples, *model_args, **model_kwargs) # TODO: fix multiple likehoods weights = {name: jnp.mean((value.T / value.sum(1).T).T, 0) for name, value in @@ -669,3 +671,62 @@ def variational_proxy(rng_key, model, model_args, model_kwargs, subsample_plate_ guide_trace = trace(seed(self._guide, guide_key)).get_trace(*model_args, **model_kwargs) proxy_fn, uproxy_fn = variational_proxy(guide_trace, evidence, weights) + + +class estimate_likelihood(Messenger): + def __init__(self, fn=None, estimator=None): + # estimate_likelihood: accept likelihood tuple (fn, value, subsample_name, subsample_dim, subsample_idx) + # and current unconstrained params + # and returns log of the bias-corrected likelihood + assert estimator is not None + super().__init__(fn) + self.estimator = estimator + self.params = None + self.likelihoods = {} + self.subsample_plates = {} + + def __enter__(self): + # trace(substitute(substitute(control_variate(model), unconstrained_reparam))) + for handler in numpyro.primitives._PYRO_STACK[::-1]: + if isinstance(handler, substitute) and isinstance(handler.substitute_fn, partial) \ + and handler.substitute_fn.func is _unconstrain_reparam: + self.params = handler.substitute_fn.args[0] + break + return super().__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + # make sure exit trackback is nice if an error happens + super().__exit__(exc_type, exc_value, traceback) + if exc_type is not None: + return + + if self.params is None: + return + + # add numpyro.factor; ideally, we will want to skip this computation when making prediction + # see: https://github.com/pyro-ppl/pyro/issues/2744 + numpyro.factor("_biased_corrected_log_likelihood", self.estimator(self.likelihoods, self.params)) + + # clean up + self.params = None + self.likelihoods = {} + self.subsample_plates = {} + + def process_message(self, msg): + if self.params is None: + return + + if msg["type"] == "sample" and msg["is_observed"]: + assert msg["name"] not in self.params + # store the likelihood for the estimator + for frame in msg["cond_indep_stack"]: + if frame.name in self.subsample_plates: + if msg["name"] in self.likelihoods: + raise RuntimeError(f"Multiple subsample plates at site {msg['name']} " + "are not allowed. Please reshape your data.") + subsample_idx = self.subsample_plates[frame.name] + self.likelihoods[msg["name"]] = (msg["fn"], msg["value"], frame.name, frame.dim, subsample_idx) + # mask the current likelihood + msg["fn"] = msg["fn"].mask(False) + elif msg["type"] == "plate" and msg["args"][0] > msg["args"][1]: + self.subsample_plates[msg["name"]] = msg["value"] From e46cb40e94c83580e75a49d35bd0ea9eb8af7321 Mon Sep 17 00:00:00 2001 From: ola Date: Sun, 31 Jan 2021 16:41:52 +0100 Subject: [PATCH 55/93] Added two moons --- examples/hmcecs/lda.py | 60 ------ examples/hmcecs/logistic_regression.py | 8 +- examples/hmcecs/mnist_bnn.py | 180 ------------------ examples/hmcecs/{protein.py => regression.py} | 32 ++-- examples/hmcecs/two_moons.py | 85 +++++++++ 5 files changed, 107 insertions(+), 258 deletions(-) delete mode 100644 examples/hmcecs/lda.py delete mode 100644 examples/hmcecs/mnist_bnn.py rename examples/hmcecs/{protein.py => regression.py} (83%) create mode 100644 examples/hmcecs/two_moons.py diff --git a/examples/hmcecs/lda.py b/examples/hmcecs/lda.py deleted file mode 100644 index 25a642c4f..000000000 --- a/examples/hmcecs/lda.py +++ /dev/null @@ -1,60 +0,0 @@ -import sys - -from jax.experimental import stax -from sklearn.datasets import fetch_20newsgroups -from sklearn.feature_extraction.text import CountVectorizer - -import jax -import jax.numpy as jnp -from sklearn.utils import shuffle - -import numpyro -import numpyro.distributions as dist - -import numpy as np - -from numpyro.contrib.indexing import Vindex - - -def lda(doc_words, lengths, num_topics=20, num_words=100, num_max_elements=10, - num_hidden=100): - num_docs = doc_words.shape[0] - topic_word_probs = numpyro.sample('topic_word_probs', - dist.Dirichlet(jnp.ones((num_topics, num_words)) / num_words).to_event(1)) + 1e-7 - element_plate = numpyro.plate('words', num_max_elements, dim=-1) - with numpyro.plate('documents', num_docs, dim=-2): - document_topic_probs = numpyro.sample('topic_probs', dist.Dirichlet(jnp.ones(num_topics) / num_topics)) - with element_plate: - word_topic = numpyro.sample('word_topic', dist.Categorical(document_topic_probs)) - numpyro.sample('word', dist.Categorical(Vindex(topic_word_probs)[word_topic]), obs=doc_words) - - -def lda_guide(doc_words, lengths, num_topics=20, num_words=100, num_max_elements=10, - num_hidden=100): - num_docs = doc_words.shape[0] - topic_word_probs_val = numpyro.param('topic_word_probs_val', jnp.ones((num_topics, num_words)), - constraint=dist.constraints.simplex) - _topic_word_probs = numpyro.sample('topic_word_probs', dist.Delta(topic_word_probs_val).to_event(1)) - amortize_nn = numpyro.module('amortize_nn', stax.serial( - stax.Dense(num_hidden), - stax.Relu, - stax.Dense(num_topics), - stax.Softmax - ), (num_docs, num_max_elements)) - document_topic_probs_vals = amortize_nn(doc_words)[..., None, :] + 1e-7 - _document_topic_probs = numpyro.sample('topic_probs', dist.Delta(document_topic_probs_vals)) - - -def main(_argv): - newsgroups = fetch_20newsgroups()['data'] - num_words = 300 - count_vectorizer = CountVectorizer(max_df=.95, min_df=.01, - token_pattern=r'(?u)\b[^\d\W]\w+\b', - max_features=num_words, - stop_words='english') - newsgroups_docs = count_vectorizer.fit_transform(newsgroups) - rng_key = jax.random.PRNGKey(37) - - -if __name__ == '__main__': - main(sys.argv) diff --git a/examples/hmcecs/logistic_regression.py b/examples/hmcecs/logistic_regression.py index fd6940114..2e20580f4 100644 --- a/examples/hmcecs/logistic_regression.py +++ b/examples/hmcecs/logistic_regression.py @@ -14,13 +14,13 @@ import numpyro.distributions as dist from numpyro.distributions import constraints from numpyro.examples.datasets import _load_higgs -from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, init_to_sample, init_to_median +from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, init_to_median from numpyro.infer.hmc_gibbs import HMCECS, difference_estimator, variational_proxy, taylor_proxy from numpyro.infer.util import _predictive os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" -numpyro.set_platform("cpu") +numpyro.set_platform("gpu") def summary(dataset, name, mcmc, sample_time, svi_time=0., plates={}): @@ -78,7 +78,7 @@ def guide(feature, obs, subsample_size): numpyro.sample('theta', dist.continuous.Normal(mean, .5)) -def hmcecs_model(dataset, data, obs, subsample_size, proxy_name='taylor'): +def hmcecs_model( dataset, data, obs, subsample_size, proxy_name='taylor'): model_args, model_kwargs = (data, obs, subsample_size), {} svi_key, proxy_key, estimator_key, mcmc_key = random.split(random.PRNGKey(0), 4) @@ -118,7 +118,7 @@ def plain_log_reg_model(features, obs): def hmc(dataset, data, obs): - kernel = NUTS(plain_log_reg_model,trajectory_length=1.2, init_strategy=init_to_median) + kernel = NUTS(plain_log_reg_model, trajectory_length=1.2, init_strategy=init_to_median) mcmc = MCMC(kernel, 100, 100) mcmc._compile(random.PRNGKey(0), data, obs, extra_fields=("num_steps",)) start = time() diff --git a/examples/hmcecs/mnist_bnn.py b/examples/hmcecs/mnist_bnn.py deleted file mode 100644 index 4f120757a..000000000 --- a/examples/hmcecs/mnist_bnn.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright Contributors to the Pyro project. -# SPDX-License-Identifier: Apache-2.0 - -""" -Example: Bayesian Neural Network -================================ - -We demonstrate how to use NUTS to do inference on a simple (small) -Bayesian neural network with two hidden layers. -""" - -import argparse -import time - -import jax.numpy as jnp -import jax.random as random -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -from flax import nn -from jax import vmap - -import numpyro -import numpyro.distributions as dist -from numpyro import handlers -from numpyro.contrib.module import random_flax_module -from numpyro.examples.datasets import load_dataset, MNIST -from numpyro.infer import MCMC, NUTS - -matplotlib.use('Agg') # noqa: E402 - - -class Network(nn.Module): - def apply(self, x, hid_channels, out_channels): - l1 = nn.relu(nn.Dense(x, features=hid_channels)) - l2 = nn.relu(nn.Dense(l1, features=hid_channels)) - logits = nn.Dense(l2, features=out_channels) - return logits - - -def mnist_model(features, hid_channels, obs=None): - module = Network.partial(hid_channels=hid_channels, out_channels=10) - net = random_flax_module('snn', module, dist.Normal(0, 1.), input_shape=features.shape) - if obs is not None: - obs = obs[..., None] - numpyro.sample('obs', dist.Categorical(logits=net(features)), obs=obs) - - -def mnist_data(split='train'): - mnist_init, mnist_batch = load_dataset(MNIST, split=split) - _, idxs = mnist_init() - X, Y = mnist_batch(0, idxs) - _, m, _ = X.shape - X = X.reshape(-1, m ** 2) - return X, Y - - -def mnist_main(args): - hid_channels = 32 - X, Y = mnist_data() - rng_key, rng_key_predict = random.split(random.PRNGKey(37)) - samples = run_inference(mnist_model, args, rng_key, X[:args.num_data], hid_channels, Y[:args.num_data]) - - # predict Y_test at inputs X_test - vmap_args = (samples, random.split(rng_key_predict, args.num_samples * args.num_chains)) - X, Y = mnist_data('test') - predictions = vmap(lambda samples, rng_key: predict(mnist_model, rng_key, samples, X[:100], hid_channels))( - *vmap_args) - predictions = predictions[..., 0] - - -class RegNetwork(nn.Module): - def apply(self, x, hid_channels, out_channels): - l1 = nn.tanh(nn.Dense(x, features=hid_channels)) - l2 = nn.tahn(nn.Dense(l1, features=hid_channels)) - mean = nn.Dense(l2, features=out_channels) - return mean - - -def reg_model(features, obs, hid_channels): - in_channels, out_channels = features.shape[1], 1 - module = Network.partial(hid_channels=hid_channels, out_channels=out_channels) - - net = random_flax_module('snn', module, dist.Normal(0, 1.), input_shape=()) - mean = net(features) - - # we put a prior on the observation noise - prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0)) - sigma_obs = 1.0 / jnp.sqrt(prec_obs) # prior - - numpyro.sample("Y", dist.Normal(mean, sigma_obs), obs=obs[..., None]) - - -# helper function for HMC inference -def run_inference(model, args, rng_key, X, Y, D_H): - start = time.time() - kernel = NUTS(model) - mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains) - mcmc.run(rng_key, X, Y, D_H) - mcmc.print_summary() - print('\nMCMC elapsed time:', time.time() - start) - return mcmc.get_samples() - - -# helper function for prediction -def predict(model, rng_key, samples, *args, **kwargs): - model = handlers.substitute(handlers.seed(model, rng_key), samples) - # note that Y will be sampled in the model because we pass Y=None here - model_trace = handlers.trace(model).get_trace(*args, **kwargs) - return model_trace['obs']['value'] - - -# create artificial regression dataset -def get_data(N=50, D_X=3, sigma_obs=0.05, N_test=500): - D_Y = 1 # create 1d outputs - np.random.seed(0) - X = jnp.linspace(-1, 1, N) - X = jnp.power(X[:, np.newaxis], jnp.arange(D_X)) - W = 0.5 * np.random.randn(D_X) - Y = jnp.dot(X, W) + 0.5 * jnp.power(0.5 + X[:, 1], 2.0) * jnp.sin(4.0 * X[:, 1]) - Y += sigma_obs * np.random.randn(N) - Y = Y[:, np.newaxis] - Y -= jnp.mean(Y) - Y /= jnp.std(Y) - - assert X.shape == (N, D_X) - assert Y.shape == (N, D_Y) - - X_test = jnp.linspace(-1.3, 1.3, N_test) - X_test = jnp.power(X_test[:, np.newaxis], jnp.arange(D_X)) - - return X, Y, X_test - - -def main(args): - N, D_X, D_H = args.num_data, 3, args.num_hidden - X, Y, X_test = get_data(N=N, D_X=D_X) - - # do inference - rng_key, rng_key_predict = random.split(random.PRNGKey(0)) - samples = run_inference(reg_model, args, rng_key, X, Y, D_H) - - # predict Y_test at inputs X_test - vmap_args = (samples, random.split(rng_key_predict, args.num_samples * args.num_chains)) - predictions = vmap(lambda samples, rng_key: predict(reg_model, rng_key, samples, X_test, D_H))(*vmap_args) - predictions = predictions[..., 0] - - # compute mean prediction and confidence interval around median - mean_prediction = jnp.mean(predictions, axis=0) - percentiles = np.percentile(predictions, [5.0, 95.0], axis=0) - - # make plots - fig, ax = plt.subplots(1, 1) - - # plot training data - ax.plot(X[:, 1], Y[:, 0], 'kx') - # plot 90% confidence level of predictions - ax.fill_between(X_test[:, 1], percentiles[0, :], percentiles[1, :], color='lightblue') - # plot mean prediction - ax.plot(X_test[:, 1], mean_prediction, 'blue', ls='solid', lw=2.0) - ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI") - - plt.savefig('bnn_plot.pdf') - plt.tight_layout() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Bayesian neural network example") - parser.add_argument("-n", "--num-samples", nargs="?", default=20, type=int) - parser.add_argument("--num-warmup", nargs='?', default=10, type=int) - parser.add_argument("--num-chains", nargs='?', default=1, type=int) - parser.add_argument("--num-data", nargs='?', default=10000, type=int) - parser.add_argument("--num-hidden", nargs='?', default=5, type=int) - parser.add_argument("--device", default='gpu', type=str, help='use "cpu" or "gpu".') - args = parser.parse_args() - - numpyro.set_platform(args.device) - numpyro.set_host_device_count(args.num_chains) - - mnist_main(args) diff --git a/examples/hmcecs/protein.py b/examples/hmcecs/regression.py similarity index 83% rename from examples/hmcecs/protein.py rename to examples/hmcecs/regression.py index b3dc3af7c..b9c1770f8 100644 --- a/examples/hmcecs/protein.py +++ b/examples/hmcecs/regression.py @@ -5,14 +5,14 @@ import numpy as np import pandas as pd from flax import nn -from flax.nn.activation import tanh +from flax.nn.activation import relu, tanh from jax import random, vmap import numpyro import numpyro.distributions as dist from numpyro import handlers from numpyro.contrib.module import random_flax_module -from numpyro.infer import MCMC, NUTS, init_to_median +from numpyro.infer import MCMC, NUTS, init_to_sample uci_base_url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/' @@ -57,7 +57,7 @@ def load_agw_1d(get_feats=False): def features(x): return np.hstack([x[:, None] / 2.0, (x[:, None] / 2.0) ** 2]) - data = np.load(str(Path(__file__).parent / 'hmcecs' / 'data' / 'data.npy')) + data = np.load(str(Path(__file__).parent / 'data' / 'data.npy')) x, y = data[:, 0], data[:, 1] y = y[:, None] f = features(x) @@ -87,31 +87,35 @@ def protein(): class Network(nn.Module): def apply(self, x, out_channels): - l1 = tanh(nn.Dense(x, features=100)) - l2 = tanh(nn.Dense(l1, features=100)) - l3 = tanh(nn.Dense(l2, features=100)) - means = nn.Dense(l3, features=out_channels) + l1 = relu(nn.Dense(x, features=100)) + l2 = relu(nn.Dense(l1, features=100)) + means = nn.Dense(l2, features=out_channels) return means +def nonlin(x): + return tanh(x) + + def model(data, obs=None): module = Network.partial(out_channels=1) - net = random_flax_module('fnn', module, dist.Normal(0, 1.), input_shape=data.shape) + + net = random_flax_module('fnn', module, dist.Normal(0, 2.), input_shape=data.shape[1]) if obs is not None: obs = obs[..., None] - prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0)) + prec_obs = numpyro.sample("prec_obs", dist.Normal(110.4, .1)) sigma_obs = 1.0 / jnp.sqrt(prec_obs) # prior numpyro.sample('obs', dist.Normal(net(data), sigma_obs), obs=obs) def hmc(dataset, data, obs, warmup, num_sample): - kernel = NUTS(model, init_strategy=init_to_median) + kernel = NUTS(model, max_tree_depth=4, step_size=.0005, init_strategy=init_to_sample) mcmc = MCMC(kernel, warmup, num_sample) - mcmc._compile(random.PRNGKey(0), data, obs, extra_fields=("num_steps",)) - mcmc.run(random.PRNGKey(0), data, obs, extra_fields=('num_steps',)) + mcmc.run(random.PRNGKey(37), data, obs, extra_fields=('num_steps',)) + print(mcmc.print_summary()) return mcmc.get_samples() @@ -125,8 +129,8 @@ def predict(model, rng_key, samples, *args, **kwargs): def main(): data, obs = load_agw_1d() - warmup = 100 - num_samples = 100 + warmup = 20 + num_samples = 10 test_data = np.linspace(-2, 2, 500).reshape(-1, 1) samples = hmc('protein', data, obs, warmup, num_samples) vmap_args = (samples, random.split(random.PRNGKey(1), num_samples)) diff --git a/examples/hmcecs/two_moons.py b/examples/hmcecs/two_moons.py new file mode 100644 index 000000000..0bca42f97 --- /dev/null +++ b/examples/hmcecs/two_moons.py @@ -0,0 +1,85 @@ +import argparse +import os + +from matplotlib.gridspec import GridSpec +import matplotlib.pyplot as plt +import seaborn as sns + +import jax +from jax import random +import jax.numpy as jnp +from jax.scipy.special import logsumexp + +import numpyro +from numpyro import optim +from numpyro.diagnostics import print_summary +import numpyro.distributions as dist +from numpyro.distributions import constraints +from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO +from numpyro.infer.autoguide import AutoBNAFNormal +from numpyro.infer.reparam import NeuTraReparam + + +class DualMoonDistribution(dist.Distribution): + support = constraints.real_vector + + def __init__(self): + super(DualMoonDistribution, self).__init__(event_shape=(2,)) + + def sample(self, key, sample_shape=()): + # it is enough to return an arbitrary sample with correct shape + return jnp.zeros(sample_shape + self.event_shape) + + def log_prob(self, x): + term1 = 0.5 * ((jnp.linalg.norm(x, axis=-1) - 2) / 0.4) ** 2 + term2 = -0.5 * ((x[..., :1] + jnp.array([-2., 2.])) / 0.6) ** 2 + pe = term1 - logsumexp(term2, axis=-1) + return -pe + + +def dual_moon_model(): + numpyro.sample('x', DualMoonDistribution()) + + +def guide(): + var = numpyro.param('var', jnp.eye(2, dtype=jnp.float32), constraints=constraints.corr_matrix) + mean = numpyro.param('mean', jnp.zeros(2, dtype=jnp.float32), constraints=constraints.real_vector) + numpyro.sample('x', dist.MultivariateNormal(mean, var)) + + +def visualize(samples): + print(samples.shape) + print(samples) + sns.kdeplot(x=samples[:, 0], y=samples[:, 1]) + plt.show() + + +def two_moons(rng_key, noise, shape): + def make_circle(data, radius, center): + return jnp.sqrt(radius ** 2 - (data - center) ** 2) + + # TODO: finish compute density + + noise_key, uni_key = random.split(rng_key) + uni_samples = jax.random.uniform(uni_key, shape) + noise = noise * jax.random.normal(noise_key, shape) + upper = uni_samples[:shape[0] // 2] - .25 + upper_noise = noise[:shape[0] // 2] + lower_noise = noise[shape[0] // 2:] + lower = uni_samples[shape[0] // 2:] + .25 + upper = jnp.vstack((upper, -make_circle(upper, .5, .25) + .1)).T + lower = jnp.vstack((lower, make_circle(lower, .5, .75) - .1)).T + + plt.scatter(upper[:, 0], upper[:, 1]) + plt.scatter(lower[:, 0], lower[:, 1]) + plt.gca().set_aspect('equal', adjustable='box') + plt.show() + + +if __name__ == '__main__': + sim_key, guide_key, mcmc_key = random.split(random.PRNGKey(0), 3) + two_moons(sim_key, noise=.05, shape=(1000,)) + dm = DualMoonDistribution() + samples = dm.sample(sim_key, (10000,)) + + # visualize(samples) From 41bda0c8e44fd5fd1da2b7757cb3385ef370197c Mon Sep 17 00:00:00 2001 From: Du Phan Date: Mon, 1 Feb 2021 15:00:31 -0600 Subject: [PATCH 56/93] add gibbs_state and fix bugs --- numpyro/infer/hmc_gibbs.py | 234 ++++++++++++++++++++++++------------- numpyro/infer/util.py | 2 +- 2 files changed, 157 insertions(+), 79 deletions(-) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 7a3aaddb9..7bca63edd 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -7,7 +7,7 @@ from functools import partial import jax.numpy as jnp -from jax import device_put, jacfwd, jacobian, grad, hessian, ops, random, value_and_grad +from jax import device_put, jacfwd, jacobian, grad, hessian, lax, ops, random, value_and_grad from jax.scipy.special import expit import numpyro @@ -399,30 +399,41 @@ def potential_fn(z_gibbs, z_hmc): return HMCGibbsState(z, hmc_state, rng_key) -def _subsample_gibbs_fn(potential_fn, plate_sizes, num_blocks=1): - def gibbs_fn(rng_key, gibbs_sites, hmc_sites, pe): - assert set(gibbs_sites) == set(plate_sizes) - u_new = {} - for name in gibbs_sites: - size, subsample_size = plate_sizes[name] - rng_key, subkey, block_key = random.split(rng_key, 3) - block_size = subsample_size // num_blocks +def _block_update(plate_sizes, num_blocks, rng_key, gibbs_sites, gibbs_state): + u_new = {} + for name, subsample_idx in gibbs_sites.items(): + size, subsample_size = plate_sizes[name] + rng_key, subkey, block_key = random.split(rng_key, 3) + block_size = (subsample_size - 1) // num_blocks + 1 + pad = block_size - (subsample_size - 1) % block_size - 1 - chosen_block = random.randint(block_key, shape=(), minval=0, maxval=num_blocks) - new_idx = random.randint(subkey, minval=0, maxval=size, shape=(subsample_size,)) - block_mask = jnp.arange(subsample_size) // block_size == chosen_block + # subsample_size = 7, num_blocks=3, block_size=3: 3 + 3 + 1 + chosen_block = random.randint(block_key, shape=(), minval=0, maxval=num_blocks) + new_idx = random.randint(subkey, minval=0, maxval=size, shape=(block_size,)) + subsample_idx_padded = jnp.pad(subsample_idx, (0, pad)) # size = 9 + start = chosen_block * block_size + subsample_idx_padded = lax.dynamic_update_slice_in_dim( + subsample_idx_padded, new_idx, start, 0) - u_new[name] = jnp.where(block_mask, new_idx, gibbs_sites[name]) + u_new[name] = subsample_idx_padded[:subsample_size] # size = 7 + return u_new, gibbs_state - # given a fixed hmc_sites, pe_new - pe_curr = loglik_new - loglik_curr - pe_new = potential_fn(u_new, hmc_sites) - accept_prob = jnp.clip(jnp.exp(pe - pe_new), a_max=1.0) - gibbs_sites, pe = cond(random.bernoulli(rng_key, accept_prob), - (u_new, pe_new), identity, - (gibbs_sites, pe), identity) - return gibbs_sites, pe - return gibbs_fn +HMCECSState = namedtuple("HMCECSState", "z, hmc_state, rng_key, gibbs_state, accept_prob") +DiffEstState = namedtuple("DiffEstState", "ref_subsample_log_liks, " + "ref_subsample_log_lik_grads, ref_subsample_log_lik_hessians") # TODO: rename to shorter names? +# TODO: add Variational State which caches the current subsample weights +BlockPoissonEstState = namedtuple("BlockPoissonEstState", "block_rng_keys, sign") + + +def _wrap_gibbs_state(model): + def wrapped_fn(*args, **kwargs): + # this is to let estimate_likelihood handler knows what is the current gibbs_state + msg = {"type": "_gibbs_state", "value": kwargs.pop("_gibbs_state", ())} + numpyro.primitives.apply_stack(msg) + return model(*args, **kwargs) + + return wrapped_fn class HMCECS(HMCGibbs): @@ -478,11 +489,22 @@ class HMCECS(HMCGibbs): """ - def __init__(self, inner_kernel, *, estimator=None, num_blocks=1): + def __init__(self, inner_kernel, *, num_blocks=1, estimator=None): # TODO: estimator -> proxy super().__init__(inner_kernel, lambda *args: None, None) self._num_blocks = num_blocks self._estimator = estimator + def postprocess_fn(self, args, kwargs): + def fn(z): + model_kwargs = {} if kwargs is None else kwargs.copy() + hmc_sites = {k: v for k, v in z.items() if k not in self._gibbs_sites} + gibbs_sites = {k: v for k, v in z.items() if k in self._gibbs_sites} + model_kwargs["_gibbs_sites"] = gibbs_sites + hmc_sites = self.inner_kernel.postprocess_fn(args, model_kwargs)(hmc_sites) + return hmc_sites + + return fn + def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): model_kwargs = {} if model_kwargs is None else model_kwargs.copy() rng_key, key_u = random.split(rng_key) @@ -494,12 +516,26 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): } self._gibbs_sites = list(self._subsample_plate_sizes.keys()) if self._estimator is not None: + # TODO: expose gibbs_init, gibbs_update where gibbs_init is used to create an initial + # gibbs_state, and gibbs_update is used to give a new proposal given the current + # gibbs_sites+gibbs_state given a new rng_key estimator = self._estimator self.inner_kernel._model = estimate_likelihood(self.inner_kernel._model, estimator) - return super().init(rng_key, num_warmup, init_params, model_args, model_kwargs) + + z_gibbs = {name: site["value"] for name, site in self._prototype_trace.items() + if name in self._gibbs_sites} + rng_key, rng_state = random.split(rng_key) + gibbs_state = gibbs_init(rng_state, z_gibbs) + else: + self._gibbs_update = partial(_block_update, self._subsample_plate_sizes, self._num_blocks) + gibbs_state = () + + model_kwargs["_gibbs_state"] = gibbs_state + state = super().init(rng_key, num_warmup, init_params, model_args, model_kwargs) + return HMCECSState(state.z, state.hmc_state, state.rng_key, gibbs_state, jnp.array(0.)) def sample(self, state, model_args, model_kwargs): - model_kwargs = {} if model_kwargs is None else model_kwargs + model_kwargs = {} if model_kwargs is None else model_kwargs.copy() rng_key, rng_gibbs = random.split(state.rng_key) def potential_fn(z_gibbs, z_hmc): @@ -507,26 +543,31 @@ def potential_fn(z_gibbs, z_hmc): *model_args, _gibbs_sites=z_gibbs, **model_kwargs)(z_hmc) z_gibbs = {k: v for k, v in state.z.items() if k not in state.hmc_state.z} - z_hmc = {k: v for k, v in state.z.items() if k in state.hmc_state.z} - model_kwargs_ = model_kwargs.copy() - model_kwargs_["_gibbs_sites"] = z_gibbs + z_gibbs_new, gibbs_state_new = self._gibbs_update(rng_key, z_gibbs, state.gibbs_state) - gibbs_fn = _subsample_gibbs_fn(potential_fn, self._subsample_plate_sizes, self._num_blocks) - z_gibbs, pe = gibbs_fn(rng_key=rng_gibbs, gibbs_sites=z_gibbs, hmc_sites=z_hmc, - pe=state.hmc_state.potential_energy) + # given a fixed hmc_sites, pe_new - pe_curr = loglik_new - loglik_curr + pe = state.hmc_state.potential_energy + pe_new = potential_fn(z_gibbs_new, gibbs_state_new, state.hmc_state.z) + accept_prob = jnp.clip(jnp.exp(pe - pe_new), a_max=1.0) + z_gibbs, gibbs_state, pe = cond(random.bernoulli(rng_key, accept_prob), + (z_gibbs_new, gibbs_state_new, pe_new), identity, + (z_gibbs, state.gibbs_state, pe), identity) + # TODO (very low priority): move this to the above cond, only compute grad when accepting if self.inner_kernel._forward_mode_differentiation: z_grad = jacfwd(partial(potential_fn, z_gibbs))(state.hmc_state.z) else: z_grad = grad(partial(potential_fn, z_gibbs))(state.hmc_state.z) hmc_state = state.hmc_state._replace(z_grad=z_grad, potential_energy=pe) - model_kwargs_["_gibbs_sites"] = z_gibbs - hmc_state = self.inner_kernel.sample(hmc_state, model_args, model_kwargs_) + model_kwargs["_gibbs_sites"] = z_gibbs + model_kwargs["_gibbs_state"] = gibbs_state + hmc_state = self.inner_kernel.sample(hmc_state, model_args, model_kwargs) z = {**z_gibbs, **hmc_state.z} - - return HMCGibbsState(z, hmc_state, rng_key) + # TODO: post update gibbs_state to update sign in Block Poisson estimator + # extra_fields=('gibbs_state.sign',) + return HMCECSState(z, hmc_state, rng_key, gibbs_state, accept_prob) def difference_estimator(rng_key, model, model_args, model_kwargs, proxy_fn): @@ -538,7 +579,7 @@ def difference_estimator(rng_key, model, model_args, model_kwargs, proxy_fn): if site["type"] == "plate" and site["args"][0] > site["args"][1] } - def estimator(likelihoods, params): + def estimator(likelihoods, params, gibbs_state): subsample_log_liks = defaultdict(float) subsample_indices = {} for (fn, value, name, subsample_dim, subsample_idx) in likelihoods.values(): @@ -548,7 +589,7 @@ def estimator(likelihoods, params): log_lik_sum = 0. - proxy_value_all, proxy_value_subsample = proxy_fn(params, subsample_indices) + proxy_value_all, proxy_value_subsample = proxy_fn(params, subsample_indices, gibbs_state) for name, subsample_log_lik in subsample_log_liks.items(): # loop over all subsample sites n, m = subsample_plate_sizes[name] @@ -563,7 +604,7 @@ def estimator(likelihoods, params): return estimator -def taylor_proxy(rng_key, model, model_args, model_kwargs, reference_params, using_lookup=False): +def taylor_proxy(rng_key, model, model_args, model_kwargs, reference_params, using_lookup=False, num_blocks=1): prototype_trace = trace(seed(model, rng_key)).get_trace(*model_args, **model_kwargs) subsample_plate_sizes = { name: site["args"] @@ -571,7 +612,7 @@ def taylor_proxy(rng_key, model, model_args, model_kwargs, reference_params, usi if site["type"] == "plate" and site["args"][0] > site["args"][1] # i.e. size > subsample_size } - reference_params = {k:v for k,v in reference_params.items() if k in prototype_trace} + reference_params = {k:v for k,v in reference_params.items() if k in prototype_trace} # subsample_plate_sizes: name -> (size, subsample_size) ref_params_flat, unravel_fn = ravel_pytree(reference_params) @@ -599,51 +640,82 @@ def log_likelihood_sum(params_flat, subsample_indices=None): return {k: v.sum() for k, v in log_likelihood(params_flat, subsample_indices).items()} # those stats are dict keyed by subsample names - if using_lookup: - ref_log_likelihoods = log_likelihood(ref_params_flat) # n - # NB: use jacfwd (instead of jacobian/jacrev) when out_dim >> in_dim - ref_log_likelihood_grads = jacfwd(log_likelihood)(ref_params_flat) - ref_log_likelihood_hessians = jacfwd(jacfwd(log_likelihood))(ref_params_flat) # n x 55 x 55 - ref_log_likelihoods_sum = {k: v.sum(0) for k, v in ref_log_likelihoods.items()} - ref_log_likelihood_grads_sum = {k: v.sum(0) for k, v in ref_log_likelihood_grads.items()} - ref_log_likelihood_hessians_sum = {k: v.sum(0) for k, v in ref_log_likelihood_hessians.items()} # 55 x 55 - else: - ref_log_likelihoods_sum = log_likelihood_sum(ref_params_flat) - ref_log_likelihood_grads_sum = jacobian(log_likelihood_sum)(ref_params_flat) - ref_log_likelihood_hessians_sum = hessian(log_likelihood_sum)(ref_params_flat) + ref_log_likelihoods_sum = log_likelihood_sum(ref_params_flat) + ref_log_likelihood_grads_sum = jacobian(log_likelihood_sum)(ref_params_flat) + ref_log_likelihood_hessians_sum = hessian(log_likelihood_sum)(ref_params_flat) - def proxy_fn(params, subsample_indices): + def gibbs_init(rng_key, gibbs_sites): + ref_subsample_log_liks = log_likelihood(ref_params_flat, gibbs_sites) + ref_subsample_log_lik_grads = jacfwd(log_likelihood)(ref_params_flat, gibbs_sites) + ref_subsample_log_lik_hessians = jacfwd(jacfwd(log_likelihood))(ref_params_flat, gibbs_sites) + return DiffEstState(ref_subsample_log_liks, ref_subsample_log_lik_grads, ref_subsample_log_lik_hessians) + + def gibbs_update(rng_key, gibbs_sites, gibbs_state): + u_new = {} + pads = {} + new_idxs = {} + starts = {} + new_states = defaultdict(dict) + for name, subsample_idx in gibbs_sites.items(): + size, subsample_size = subsample_plate_sizes[name] + rng_key, subkey, block_key = random.split(rng_key, 3) + block_size = (subsample_size - 1) // num_blocks + 1 + pad = block_size - (subsample_size - 1) % block_size - 1 + + chosen_block = random.randint(block_key, shape=(), minval=0, maxval=num_blocks) + new_idx = random.randint(subkey, minval=0, maxval=size, shape=(block_size,)) + subsample_idx_padded = jnp.pad(subsample_idx, (0, pad)) + start = chosen_block * block_size + subsample_idx_padded = lax.dynamic_update_slice_in_dim( + subsample_idx_padded, new_idx, start, 0) + + u_new[name] = subsample_idx_padded[:subsample_size] + pads[name] = pad + new_idxs[name] = new_idx + starts[name] = start + + ref_subsample_log_liks = log_likelihood(ref_params_flat, new_idxs) + ref_subsample_log_lik_grads = jacfwd(log_likelihood)(ref_params_flat, new_idxs) + ref_subsample_log_lik_hessians = jacfwd(jacfwd(log_likelihood))(ref_params_flat, new_idxs) + for stat, new_block_values, last_values in zip( + ["log_liks", "grads", "hessians"], + [ref_subsample_log_liks, + ref_subsample_log_lik_grads, + ref_subsample_log_lik_hessians], + [gibbs_state.ref_subsample_log_liks, + gibbs_state.ref_subsample_log_lik_grads, + gibbs_state.ref_subsample_log_lik_hessians]): + for name, subsample_idx in gibbs_sites.items(): + size, subsample_size = subsample_plate_sizes[name] + pad, new_idx, start = pads[name], new_idxs[name], starts[name] + new_value = jnp.pad(last_values[name], [(0, pad)] + [(0, 0)] * (jnp.ndim(last_values[name]) - 1)) + new_value = lax.dynamic_update_slice_in_dim( + new_value, new_block_values[name], start, 0) + new_states[stat][name] = new_value[:subsample_size] + gibbs_state = DiffEstState(new_states["log_liks"], new_states["grads"], new_states["hessians"]) + return u_new, gibbs_state + + def proxy_fn(params, subsample_indices, gibbs_state): params_flat, _ = ravel_pytree(params) params_diff = params_flat - ref_params_flat - if using_lookup: - # NB: in GPU, indexing here is expensive, it is better to compute likelihood, grad, hessian directly - # m x 55 x 55 (m ~ sqrt(n) ~ 1000) - ref_subsample_log_lik = {k: v[subsample_indices[k]] - for k, v in ref_log_likelihoods.items()} - ref_subsample_log_lik_grad = {k: v[subsample_indices[k]] - for k, v in ref_log_likelihood_grads.items()} - ref_subsample_log_lik_hessian = {k: v[subsample_indices[k]] - for k, v in ref_log_likelihood_hessians.items()} - else: - ref_subsample_log_lik = log_likelihood_sum(ref_params_flat, subsample_indices) - ref_subsample_log_lik_grad = jacobian(log_likelihood_sum)(ref_params_flat, subsample_indices) - ref_subsample_log_lik_hessian = hessian(log_likelihood_sum)(ref_params_flat, subsample_indices) + + ref_subsample_log_liks = gibbs_state.ref_subsample_log_liks + ref_subsample_log_lik_grads = gibbs_state.ref_subsample_log_lik_grads + ref_subsample_log_lik_hessians = gibbs_state.ref_subsample_log_lik_hessians proxy_sum = defaultdict(float) proxy_subsample = defaultdict(float) for name, subsample_idx in subsample_indices.items(): - proxy_subsample[name] = ref_subsample_log_lik[name] + \ - jnp.dot(ref_subsample_log_lik_grad[name], params_diff) + \ - 0.5 * jnp.dot(jnp.dot(ref_subsample_log_lik_hessian[name], params_diff), - params_diff) - - proxy_subsample[name] = ref_log_likelihoods_sum[name] + \ - jnp.dot(ref_log_likelihood_grads_sum[name], params_diff) + \ - 0.5 * jnp.dot(jnp.dot(ref_log_likelihood_hessians_sum[name], params_diff), - params_diff) + proxy_subsample[name] = ref_subsample_log_liks[name] + \ + jnp.dot(ref_subsample_log_lik_grads[name], params_diff) + \ + 0.5 * jnp.dot(jnp.dot(ref_subsample_log_lik_hessians[name], params_diff), params_diff) + + proxy_sum[name] = ref_log_likelihoods_sum[name] + \ + jnp.dot(ref_log_likelihood_grads_sum[name], params_diff) + \ + 0.5 * jnp.dot(jnp.dot(ref_log_likelihood_hessians_sum[name], params_diff), params_diff) return proxy_sum, proxy_subsample - return proxy_fn + return proxy_fn, gibbs_init, gibbs_update def _sum_all_except_at_dim(x, dim): @@ -728,7 +800,7 @@ def proxy_fn(params, subsample_indices): class estimate_likelihood(numpyro.primitives.Messenger): def __init__(self, fn=None, estimator=None): - # estimate_likelihood: accept likelihood tuple (fn, value, subsample_name, subsample_dim, subsample_idx) + # estimate_likelihood: accept likelihood tuple (fn, value, subsample_name, subsample_dim) # and current unconstrained params # and returns log of the bias-corrected likelihood assert estimator is not None @@ -737,6 +809,7 @@ def __init__(self, fn=None, estimator=None): self.params = None self.likelihoods = {} self.subsample_plates = {} + self.gibbs_state = None def __enter__(self): # trace(substitute(substitute(control_variate(model), unconstrained_reparam))) @@ -758,17 +831,23 @@ def __exit__(self, exc_type, exc_value, traceback): # add numpyro.factor; ideally, we will want to skip this computation when making prediction # see: https://github.com/pyro-ppl/pyro/issues/2744 - numpyro.factor("_biased_corrected_log_likelihood", self.estimator(self.likelihoods, self.params)) + numpyro.factor("_biased_corrected_log_likelihood", + self.estimator(self.likelihoods, self.params, self.gibbs_state)) # clean up self.params = None self.likelihoods = {} self.subsample_plates = {} + self.gibbs_state = None def process_message(self, msg): if self.params is None: return + if msg["type"] == "_gibbs_state": + self.gibbs_state = msg["value"] + return + if msg["type"] == "sample" and msg["is_observed"]: assert msg["name"] not in self.params # store the likelihood for the estimator @@ -777,8 +856,7 @@ def process_message(self, msg): if msg["name"] in self.likelihoods: raise RuntimeError(f"Multiple subsample plates at site {msg['name']} " "are not allowed. Please reshape your data.") - subsample_idx = self.subsample_plates[frame.name] - self.likelihoods[msg["name"]] = (msg["fn"], msg["value"], frame.name, frame.dim, subsample_idx) + self.likelihoods[msg["name"]] = (msg["fn"], msg["value"], frame.name, frame.dim) # mask the current likelihood msg["fn"] = msg["fn"].mask(False) elif msg["type"] == "plate" and msg["args"][0] > msg["args"][1]: diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index b38b19bed..5fc3cd24b 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -432,7 +432,7 @@ def initialize_model(rng_key, model, # substitute param sites from model_trace to model so # we don't need to generate again parameters of `numpyro.module` model = substitute(model, data={k: site["value"] for k, site in model_trace.items() - if site["type"] in ["param", "plate"]}) + if site["type"] in ["param"]}) constrained_values = {k: v['value'] for k, v in model_trace.items() if v['type'] == 'sample' and not v['is_observed'] and not v['fn'].is_discrete} From ab7888e166993cb05ba3cbd62e57130d48f67144 Mon Sep 17 00:00:00 2001 From: ola Date: Mon, 1 Feb 2021 23:45:07 +0100 Subject: [PATCH 57/93] Integrated taylor proxy and updated API. --- examples/hmcecs/logistic_regression.py | 6 +- numpyro/infer/hmc_gibbs.py | 413 +++++++++++++------------ 2 files changed, 214 insertions(+), 205 deletions(-) diff --git a/examples/hmcecs/logistic_regression.py b/examples/hmcecs/logistic_regression.py index 2e20580f4..edafca859 100644 --- a/examples/hmcecs/logistic_regression.py +++ b/examples/hmcecs/logistic_regression.py @@ -15,7 +15,7 @@ from numpyro.distributions import constraints from numpyro.examples.datasets import _load_higgs from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, init_to_median -from numpyro.infer.hmc_gibbs import HMCECS, difference_estimator, variational_proxy, taylor_proxy +from numpyro.infer.hmc_gibbs import HMCECS, perturbed_method, variational_proxy, taylor_proxy from numpyro.infer.util import _predictive os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" @@ -99,11 +99,11 @@ def hmcecs_model( dataset, data, obs, subsample_size, proxy_name='taylor'): else: proxy_fn = variational_proxy(proxy_key, model, model_args, model_kwargs, guide, params) - estimator = difference_estimator(estimator_key, model, model_args, model_kwargs, proxy_fn) + estimator = perturbed_method(estimator_key, model, model_args, model_kwargs, proxy_fn) # Compute HMCECS - kernel = HMCECS(NUTS(model), estimator=estimator) + kernel = HMCECS(NUTS(model), proxy=estimator) mcmc = MCMC(kernel, 1000, 1000) start = time() mcmc.run(random.PRNGKey(3), data, obs, subsample_size, extra_fields=("hmc_state.accept_prob", diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 7bca63edd..34f574ed8 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -420,9 +420,9 @@ def _block_update(plate_sizes, num_blocks, rng_key, gibbs_sites, gibbs_state): HMCECSState = namedtuple("HMCECSState", "z, hmc_state, rng_key, gibbs_state, accept_prob") -DiffEstState = namedtuple("DiffEstState", "ref_subsample_log_liks, " - "ref_subsample_log_lik_grads, ref_subsample_log_lik_hessians") # TODO: rename to shorter names? -# TODO: add Variational State which caches the current subsample weights +TaylorProxyState = namedtuple("TaylorProxyState", "ref_subsample_log_liks, " + "ref_subsample_log_lik_grads, ref_subsample_log_lik_hessians") # TODO: rename to shorter names? +VariationalProxyState = namedtuple('VariationalProxyState', 'subsample_weights') BlockPoissonEstState = namedtuple("BlockPoissonEstState", "block_rng_keys, sign") @@ -489,10 +489,12 @@ class HMCECS(HMCGibbs): """ - def __init__(self, inner_kernel, *, num_blocks=1, estimator=None): # TODO: estimator -> proxy + def __init__(self, inner_kernel, *, num_blocks=1, proxy=None, method='perturbed'): super().__init__(inner_kernel, lambda *args: None, None) + assert method in ['perturbed'] self._num_blocks = num_blocks - self._estimator = estimator + self._proxy = proxy + self._method = method def postprocess_fn(self, args, kwargs): def fn(z): @@ -515,15 +517,17 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): if site["type"] == "plate" and site["args"][0] > site["args"][1] # i.e. size > subsample_size } self._gibbs_sites = list(self._subsample_plate_sizes.keys()) - if self._estimator is not None: - # TODO: expose gibbs_init, gibbs_update where gibbs_init is used to create an initial - # gibbs_state, and gibbs_update is used to give a new proposal given the current - # gibbs_sites+gibbs_state given a new rng_key - estimator = self._estimator - self.inner_kernel._model = estimate_likelihood(self.inner_kernel._model, estimator) - - z_gibbs = {name: site["value"] for name, site in self._prototype_trace.items() - if name in self._gibbs_sites} + if self._proxy is not None: + rng_key, proxy_key, method_key = random.split(rng_key, 3) + proxy_fn, gibbs_init, gibbs_update = self._proxy(rng_key, + self.model, + model_args, + model_kwargs, + num_blocks=self._num_blocks) + method = perturbed_method(method_key, self.model, model_args, model_kwargs, proxy_fn) + self.inner_kernel._model = estimate_likelihood(self.inner_kernel._model, method) + + z_gibbs = {name: site["value"] for name, site in self._prototype_trace.items() if name in self._gibbs_sites} rng_key, rng_state = random.split(rng_key) gibbs_state = gibbs_init(rng_state, z_gibbs) else: @@ -570,7 +574,7 @@ def potential_fn(z_gibbs, z_hmc): return HMCECSState(z, hmc_state, rng_key, gibbs_state, accept_prob) -def difference_estimator(rng_key, model, model_args, model_kwargs, proxy_fn): +def perturbed_method(rng_key, model, model_args, model_kwargs, proxy_fn): # subsample_plate_sizes: name -> (size, subsample_size) prototype_trace = trace(seed(model, rng_key)).get_trace(*model_args, **model_kwargs) subsample_plate_sizes = { @@ -581,15 +585,12 @@ def difference_estimator(rng_key, model, model_args, model_kwargs, proxy_fn): def estimator(likelihoods, params, gibbs_state): subsample_log_liks = defaultdict(float) - subsample_indices = {} for (fn, value, name, subsample_dim, subsample_idx) in likelihoods.values(): subsample_log_liks[name] += _sum_all_except_at_dim(fn.log_prob(value), subsample_dim) - if name not in subsample_indices: - subsample_indices[name] = subsample_idx log_lik_sum = 0. - proxy_value_all, proxy_value_subsample = proxy_fn(params, subsample_indices, gibbs_state) + proxy_value_all, proxy_value_subsample = proxy_fn(params, subsample_log_liks.keys(), gibbs_state) for name, subsample_log_lik in subsample_log_liks.items(): # loop over all subsample sites n, m = subsample_plate_sizes[name] @@ -604,118 +605,123 @@ def estimator(likelihoods, params, gibbs_state): return estimator -def taylor_proxy(rng_key, model, model_args, model_kwargs, reference_params, using_lookup=False, num_blocks=1): - prototype_trace = trace(seed(model, rng_key)).get_trace(*model_args, **model_kwargs) - subsample_plate_sizes = { - name: site["args"] - for name, site in prototype_trace.items() - if site["type"] == "plate" and site["args"][0] > site["args"][1] # i.e. size > subsample_size - } - - reference_params = {k:v for k,v in reference_params.items() if k in prototype_trace} - - # subsample_plate_sizes: name -> (size, subsample_size) - ref_params_flat, unravel_fn = ravel_pytree(reference_params) +def taylor_proxy(reference_params): + def construct_proxy_fn(rng_key, model, model_args, model_kwargs, num_blocks=1): + prototype_trace = trace(seed(model, rng_key)).get_trace(*model_args, **model_kwargs) + subsample_plate_sizes = { + name: site["args"] + for name, site in prototype_trace.items() + if site["type"] == "plate" and site["args"][0] > site["args"][1] # i.e. size > subsample_size + } - def log_likelihood(params_flat, subsample_indices=None): - if subsample_indices is None: - subsample_indices = {k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items()} - params = unravel_fn(params_flat) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - with block(), trace() as tr, substitute(data=subsample_indices), \ - substitute(substitute_fn=partial(_unconstrain_reparam, params)): - model(*model_args, **model_kwargs) - - log_lik = defaultdict(float) - for site in tr.values(): - if site["type"] == "sample" and site["is_observed"]: - for frame in site["cond_indep_stack"]: - if frame.name in subsample_plate_sizes: - log_lik[frame.name] += _sum_all_except_at_dim( - site["fn"].log_prob(site["value"]), frame.dim) - return log_lik - - def log_likelihood_sum(params_flat, subsample_indices=None): - return {k: v.sum() for k, v in log_likelihood(params_flat, subsample_indices).items()} - - # those stats are dict keyed by subsample names - ref_log_likelihoods_sum = log_likelihood_sum(ref_params_flat) - ref_log_likelihood_grads_sum = jacobian(log_likelihood_sum)(ref_params_flat) - ref_log_likelihood_hessians_sum = hessian(log_likelihood_sum)(ref_params_flat) - - def gibbs_init(rng_key, gibbs_sites): - ref_subsample_log_liks = log_likelihood(ref_params_flat, gibbs_sites) - ref_subsample_log_lik_grads = jacfwd(log_likelihood)(ref_params_flat, gibbs_sites) - ref_subsample_log_lik_hessians = jacfwd(jacfwd(log_likelihood))(ref_params_flat, gibbs_sites) - return DiffEstState(ref_subsample_log_liks, ref_subsample_log_lik_grads, ref_subsample_log_lik_hessians) - - def gibbs_update(rng_key, gibbs_sites, gibbs_state): - u_new = {} - pads = {} - new_idxs = {} - starts = {} - new_states = defaultdict(dict) - for name, subsample_idx in gibbs_sites.items(): - size, subsample_size = subsample_plate_sizes[name] - rng_key, subkey, block_key = random.split(rng_key, 3) - block_size = (subsample_size - 1) // num_blocks + 1 - pad = block_size - (subsample_size - 1) % block_size - 1 - - chosen_block = random.randint(block_key, shape=(), minval=0, maxval=num_blocks) - new_idx = random.randint(subkey, minval=0, maxval=size, shape=(block_size,)) - subsample_idx_padded = jnp.pad(subsample_idx, (0, pad)) - start = chosen_block * block_size - subsample_idx_padded = lax.dynamic_update_slice_in_dim( - subsample_idx_padded, new_idx, start, 0) - - u_new[name] = subsample_idx_padded[:subsample_size] - pads[name] = pad - new_idxs[name] = new_idx - starts[name] = start - - ref_subsample_log_liks = log_likelihood(ref_params_flat, new_idxs) - ref_subsample_log_lik_grads = jacfwd(log_likelihood)(ref_params_flat, new_idxs) - ref_subsample_log_lik_hessians = jacfwd(jacfwd(log_likelihood))(ref_params_flat, new_idxs) - for stat, new_block_values, last_values in zip( - ["log_liks", "grads", "hessians"], - [ref_subsample_log_liks, - ref_subsample_log_lik_grads, - ref_subsample_log_lik_hessians], - [gibbs_state.ref_subsample_log_liks, - gibbs_state.ref_subsample_log_lik_grads, - gibbs_state.ref_subsample_log_lik_hessians]): + # TODO: map reference params to unconstraint_params + + # subsample_plate_sizes: name -> (size, subsample_size) + ref_params_flat, unravel_fn = ravel_pytree(reference_params) + + def log_likelihood(params_flat, subsample_indices=None): + if subsample_indices is None: + subsample_indices = {k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items()} + params = unravel_fn(params_flat) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + with block(), trace() as tr, substitute(data=subsample_indices), \ + substitute(substitute_fn=partial(_unconstrain_reparam, params)): + model(*model_args, **model_kwargs) + + log_lik = defaultdict(float) + for site in tr.values(): + if site["type"] == "sample" and site["is_observed"]: + for frame in site["cond_indep_stack"]: + if frame.name in subsample_plate_sizes: + log_lik[frame.name] += _sum_all_except_at_dim( + site["fn"].log_prob(site["value"]), frame.dim) + return log_lik + + def log_likelihood_sum(params_flat, subsample_indices=None): + return {k: v.sum() for k, v in log_likelihood(params_flat, subsample_indices).items()} + + # those stats are dict keyed by subsample names + ref_log_likelihoods_sum = log_likelihood_sum(ref_params_flat) + ref_log_likelihood_grads_sum = jacobian(log_likelihood_sum)(ref_params_flat) + ref_log_likelihood_hessians_sum = hessian(log_likelihood_sum)(ref_params_flat) + + def gibbs_init(rng_key, gibbs_sites): + ref_subsample_log_liks = log_likelihood(ref_params_flat, gibbs_sites) + ref_subsample_log_lik_grads = jacfwd(log_likelihood)(ref_params_flat, gibbs_sites) + ref_subsample_log_lik_hessians = jacfwd(jacfwd(log_likelihood))(ref_params_flat, gibbs_sites) + return TaylorProxyState(ref_subsample_log_liks, ref_subsample_log_lik_grads, ref_subsample_log_lik_hessians) + + def gibbs_update(rng_key, gibbs_sites, gibbs_state): + u_new = {} + pads = {} + new_idxs = {} + starts = {} + new_states = defaultdict(dict) for name, subsample_idx in gibbs_sites.items(): size, subsample_size = subsample_plate_sizes[name] - pad, new_idx, start = pads[name], new_idxs[name], starts[name] - new_value = jnp.pad(last_values[name], [(0, pad)] + [(0, 0)] * (jnp.ndim(last_values[name]) - 1)) - new_value = lax.dynamic_update_slice_in_dim( - new_value, new_block_values[name], start, 0) - new_states[stat][name] = new_value[:subsample_size] - gibbs_state = DiffEstState(new_states["log_liks"], new_states["grads"], new_states["hessians"]) - return u_new, gibbs_state - - def proxy_fn(params, subsample_indices, gibbs_state): - params_flat, _ = ravel_pytree(params) - params_diff = params_flat - ref_params_flat - - ref_subsample_log_liks = gibbs_state.ref_subsample_log_liks - ref_subsample_log_lik_grads = gibbs_state.ref_subsample_log_lik_grads - ref_subsample_log_lik_hessians = gibbs_state.ref_subsample_log_lik_hessians - - proxy_sum = defaultdict(float) - proxy_subsample = defaultdict(float) - for name, subsample_idx in subsample_indices.items(): - proxy_subsample[name] = ref_subsample_log_liks[name] + \ - jnp.dot(ref_subsample_log_lik_grads[name], params_diff) + \ - 0.5 * jnp.dot(jnp.dot(ref_subsample_log_lik_hessians[name], params_diff), params_diff) - - proxy_sum[name] = ref_log_likelihoods_sum[name] + \ - jnp.dot(ref_log_likelihood_grads_sum[name], params_diff) + \ - 0.5 * jnp.dot(jnp.dot(ref_log_likelihood_hessians_sum[name], params_diff), params_diff) - return proxy_sum, proxy_subsample - - return proxy_fn, gibbs_init, gibbs_update + rng_key, subkey, block_key = random.split(rng_key, 3) + block_size = (subsample_size - 1) // num_blocks + 1 + pad = block_size - (subsample_size - 1) % block_size - 1 + + chosen_block = random.randint(block_key, shape=(), minval=0, maxval=num_blocks) + new_idx = random.randint(subkey, minval=0, maxval=size, shape=(block_size,)) + subsample_idx_padded = jnp.pad(subsample_idx, (0, pad)) + start = chosen_block * block_size + subsample_idx_padded = lax.dynamic_update_slice_in_dim( + subsample_idx_padded, new_idx, start, 0) + + u_new[name] = subsample_idx_padded[:subsample_size] + pads[name] = pad + new_idxs[name] = new_idx + starts[name] = start + + ref_subsample_log_liks = log_likelihood(ref_params_flat, new_idxs) + ref_subsample_log_lik_grads = jacfwd(log_likelihood)(ref_params_flat, new_idxs) + ref_subsample_log_lik_hessians = jacfwd(jacfwd(log_likelihood))(ref_params_flat, new_idxs) + for stat, new_block_values, last_values in zip( + ["log_liks", "grads", "hessians"], + [ref_subsample_log_liks, + ref_subsample_log_lik_grads, + ref_subsample_log_lik_hessians], + [gibbs_state.ref_subsample_log_liks, + gibbs_state.ref_subsample_log_lik_grads, + gibbs_state.ref_subsample_log_lik_hessians]): + for name, subsample_idx in gibbs_sites.items(): + size, subsample_size = subsample_plate_sizes[name] + pad, new_idx, start = pads[name], new_idxs[name], starts[name] + new_value = jnp.pad(last_values[name], [(0, pad)] + [(0, 0)] * (jnp.ndim(last_values[name]) - 1)) + new_value = lax.dynamic_update_slice_in_dim( + new_value, new_block_values[name], start, 0) + new_states[stat][name] = new_value[:subsample_size] + gibbs_state = TaylorProxyState(new_states["log_liks"], new_states["grads"], new_states["hessians"]) + return u_new, gibbs_state + + def proxy_fn(params, subsample_lik_sites, gibbs_state): + params_flat, _ = ravel_pytree(params) + params_diff = params_flat - ref_params_flat + + ref_subsample_log_liks = gibbs_state.ref_subsample_log_liks + ref_subsample_log_lik_grads = gibbs_state.ref_subsample_log_lik_grads + ref_subsample_log_lik_hessians = gibbs_state.ref_subsample_log_lik_hessians + + proxy_sum = defaultdict(float) + proxy_subsample = defaultdict(float) + for name in subsample_lik_sites: + proxy_subsample[name] = ref_subsample_log_liks[name] + \ + jnp.dot(ref_subsample_log_lik_grads[name], params_diff) + \ + 0.5 * jnp.dot(jnp.dot(ref_subsample_log_lik_hessians[name], params_diff), + params_diff) + + proxy_sum[name] = ref_log_likelihoods_sum[name] + \ + jnp.dot(ref_log_likelihood_grads_sum[name], params_diff) + \ + 0.5 * jnp.dot(jnp.dot(ref_log_likelihood_hessians_sum[name], params_diff), + params_diff) + return proxy_sum, proxy_subsample + + return proxy_fn, gibbs_init, gibbs_update + + return construct_proxy_fn def _sum_all_except_at_dim(x, dim): @@ -723,89 +729,92 @@ def _sum_all_except_at_dim(x, dim): return x.reshape(x.shape[:1] + (-1,)).sum(-1) -def variational_proxy(rng_key, model, model_args, model_kwargs, guide, reference_params, num_samples=10): - prototype_trace = trace(seed(model, rng_key)).get_trace(*model_args, **model_kwargs) - subsample_plate_sizes = { - name: site["args"] - for name, site in prototype_trace.items() - if site["type"] == "plate" and site["args"][0] > site["args"][1] # i.e. size > subsample_size - } +def variational_proxy(guide, guide_params, num_samples=10): + def construct_proxy_fn(rng_key, model, model_args, model_kwargs, num_blocks=1): + prototype_trace = trace(seed(model, rng_key)).get_trace(*model_args, **model_kwargs) + subsample_plate_sizes = { + name: site["args"] + for name, site in prototype_trace.items() + if site["type"] == "plate" and site["args"][0] > site["args"][1] # i.e. size > subsample_size + } - pos_key, guide_key, rng_key = random.split(rng_key, 3) - guide = substitute(guide, reference_params) + pos_key, guide_key, rng_key = random.split(rng_key, 3) + guide_with_params = substitute(guide, guide_params) + + # factor out? + def log_likelihood(params, subsample_indices=None): + params_flat, unravel_fn = ravel_pytree(params) + if subsample_indices is None: + subsample_indices = {k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items()} + params = unravel_fn(params_flat) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + with block(), trace() as tr, substitute(data=subsample_indices), \ + substitute(substitute_fn=partial(_unconstrain_reparam, params)): + model(*model_args, **model_kwargs) + + log_lik = defaultdict(float) + for site in tr.values(): + if site["type"] == "sample" and site["is_observed"]: + for frame in site["cond_indep_stack"]: + if frame.name in subsample_plate_sizes: + log_lik[frame.name] += _sum_all_except_at_dim( + site["fn"].log_prob(site["value"]), frame.dim) + return log_lik + + def log_posterior(params): + with numpyro.primitives.inner_stack(): + posterior_prob, _ = log_density(guide_with_params, model_args, model_kwargs, params) + return posterior_prob + + def log_prior(params): + with numpyro.primitives.inner_stack(): + prior_prob, _ = log_density( + block(model, hide_fn=lambda site: site['type'] == 'sample' and site['is_observed']), + model_args, model_kwargs, params) + return prior_prob + + posterior_samples = _predictive(pos_key, guide_with_params, {}, (num_samples,), return_sites='', parallel=True, + model_args=model_args, model_kwargs=model_kwargs) + log_likelihood_ref = log_likelihood(posterior_samples) + + posterior_samples = {**posterior_samples, **{k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items()}} + + weights = {name: log_like / log_like.sum() for name, log_like in log_likelihood_ref.items()} - # factor out? - def log_likelihood(params, subsample_indices=None): - params_flat, unravel_fn = ravel_pytree(params) - if subsample_indices is None: - subsample_indices = {k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items()} - params = unravel_fn(params_flat) with warnings.catch_warnings(): - warnings.simplefilter("ignore") - with block(), trace() as tr, substitute(data=subsample_indices), \ - substitute(substitute_fn=partial(_unconstrain_reparam, params)): - model(*model_args, **model_kwargs) - - log_lik = defaultdict(float) - for site in tr.values(): - if site["type"] == "sample" and site["is_observed"]: - for frame in site["cond_indep_stack"]: - if frame.name in subsample_plate_sizes: - log_lik[frame.name] += _sum_all_except_at_dim( - site["fn"].log_prob(site["value"]), frame.dim) - return log_lik - - def log_posterior(params): - with numpyro.primitives.inner_stack(): - posterior_prob, _ = log_density(guide, model_args, model_kwargs, params) - return posterior_prob - - def log_prior(params): - with numpyro.primitives.inner_stack(): - prior_prob, _ = log_density( - block(model, hide_fn=lambda site: site['type'] == 'sample' and site['is_observed']), - model_args, model_kwargs, params) - return prior_prob - - posterior_samples = _predictive(pos_key, guide, {}, (num_samples,), return_sites='', parallel=True, - model_args=model_args, model_kwargs=model_kwargs) - log_likelihood_ref = log_likelihood(posterior_samples) - - posterior_samples = {**posterior_samples, **{k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items()}} - - weights = {name: log_like / log_like.sum() for name, log_like in log_likelihood_ref.items()} - - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=UserWarning) - log_prior_prob = log_prior(posterior_samples) - log_posterior_prob = log_posterior(posterior_samples) - - evidence = {name: (log_posterior_prob - log_prior_prob - log_like.sum()) / num_samples - for name, log_like in log_likelihood_ref.items()} - - def proxy_fn(params, subsample_indices): - params = {**params, **subsample_indices} - proxy_sum = defaultdict(float) - proxy_subsample = defaultdict(float) - log_prior_prob = log_prior(params) - log_posterior_prob = log_posterior(params) - for name, subsample_idx in subsample_indices.items(): - proxy_sum[name] = evidence[name] + log_posterior_prob - log_prior_prob - proxy_subsample[name] = evidence[name] + \ - weights[name][subsample_idx].sum() * (log_posterior_prob - log_prior_prob) - return proxy_sum, proxy_subsample - - return proxy_fn + warnings.filterwarnings('ignore', category=UserWarning) + log_prior_prob = log_prior(posterior_samples) + log_posterior_prob = log_posterior(posterior_samples) + + evidence = {name: (log_posterior_prob - log_prior_prob - log_like.sum()) / num_samples + for name, log_like in log_likelihood_ref.items()} + + def proxy_fn(params, subsample_indices, gibbs_state): + params = {**params, **subsample_indices} + proxy_sum = defaultdict(float) + proxy_subsample = defaultdict(float) + log_prior_prob = log_prior(params) + log_posterior_prob = log_posterior(params) + for name, subsample_idx in subsample_indices.items(): + proxy_sum[name] = evidence[name] + log_posterior_prob - log_prior_prob + proxy_subsample[name] = evidence[name] + \ + weights[name][subsample_idx].sum() * (log_posterior_prob - log_prior_prob) + return proxy_sum, proxy_subsample + + return proxy_fn + + return construct_proxy_fn class estimate_likelihood(numpyro.primitives.Messenger): - def __init__(self, fn=None, estimator=None): + def __init__(self, fn=None, method=None): # estimate_likelihood: accept likelihood tuple (fn, value, subsample_name, subsample_dim) # and current unconstrained params # and returns log of the bias-corrected likelihood - assert estimator is not None + assert method is not None super().__init__(fn) - self.estimator = estimator + self.method = method self.params = None self.likelihoods = {} self.subsample_plates = {} @@ -832,7 +841,7 @@ def __exit__(self, exc_type, exc_value, traceback): # add numpyro.factor; ideally, we will want to skip this computation when making prediction # see: https://github.com/pyro-ppl/pyro/issues/2744 numpyro.factor("_biased_corrected_log_likelihood", - self.estimator(self.likelihoods, self.params, self.gibbs_state)) + self.method(self.likelihoods, self.params, self.gibbs_state)) # clean up self.params = None From a9d2c0eca97b00019ac5e4feecc30e4eafc4f057 Mon Sep 17 00:00:00 2001 From: ola Date: Tue, 2 Feb 2021 00:31:22 +0100 Subject: [PATCH 58/93] Bugs fixed and taylor working! --- examples/hmcecs/logistic_regression.py | 17 +++++++------- numpyro/infer/hmc_gibbs.py | 31 +++++++++++++++----------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/examples/hmcecs/logistic_regression.py b/examples/hmcecs/logistic_regression.py index edafca859..5c18b8f79 100644 --- a/examples/hmcecs/logistic_regression.py +++ b/examples/hmcecs/logistic_regression.py @@ -78,7 +78,7 @@ def guide(feature, obs, subsample_size): numpyro.sample('theta', dist.continuous.Normal(mean, .5)) -def hmcecs_model( dataset, data, obs, subsample_size, proxy_name='taylor'): +def hmcecs_model(dataset, data, obs, subsample_size, proxy_name='taylor'): model_args, model_kwargs = (data, obs, subsample_size), {} svi_key, proxy_key, estimator_key, mcmc_key = random.split(random.PRNGKey(0), 4) @@ -95,15 +95,14 @@ def hmcecs_model( dataset, data, obs, subsample_size, proxy_name='taylor'): proxy_key, ref_key = random.split(proxy_key) ref_params = _predictive(ref_key, guide, {}, (1,), return_sites='', parallel=True, model_args=model_args, model_kwargs=model_kwargs) - proxy_fn = taylor_proxy(proxy_key, model, model_args, model_kwargs, ref_params) + ref_params.pop('mean') + proxy_fn = taylor_proxy(ref_params) else: - proxy_fn = variational_proxy(proxy_key, model, model_args, model_kwargs, guide, params) - estimator = perturbed_method(estimator_key, model, model_args, model_kwargs, proxy_fn) + proxy_fn = variational_proxy(guide, params) # Compute HMCECS - - kernel = HMCECS(NUTS(model), proxy=estimator) + kernel = HMCECS(NUTS(model), proxy=proxy_fn) mcmc = MCMC(kernel, 1000, 1000) start = time() mcmc.run(random.PRNGKey(3), data, obs, subsample_size, extra_fields=("hmc_state.accept_prob", @@ -128,7 +127,7 @@ def hmc(dataset, data, obs): if __name__ == '__main__': - load_data = {'higgs': higgs_data, 'breast': breast_cancer_data, 'copsac': copsac_data} + load_data = {'breast': breast_cancer_data, 'higgs': higgs_data, 'copsac': copsac_data} subsample_sizes = {'higgs': 1300, 'copsac': 1000, 'breast': 75, } data, obs = breast_cancer_data() @@ -137,6 +136,6 @@ def hmc(dataset, data, obs): if not os.path.exists(dir): os.mkdir(dir) data, obs = load_data[dataset]() - # hmcecs_model(dir, data, obs, subsample_sizes[dataset]) - hmc(dir, data, obs) + hmcecs_model(dir, data, obs, subsample_sizes[dataset]) + # hmc(dir, data, obs) exit() diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 34f574ed8..b39fba6c6 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -491,7 +491,9 @@ class HMCECS(HMCGibbs): def __init__(self, inner_kernel, *, num_blocks=1, proxy=None, method='perturbed'): super().__init__(inner_kernel, lambda *args: None, None) - assert method in ['perturbed'] + + assert method in {'perturbed'} + self.inner_kernel._model = _wrap_gibbs_state(self.inner_kernel._model) self._num_blocks = num_blocks self._proxy = proxy self._method = method @@ -519,11 +521,11 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): self._gibbs_sites = list(self._subsample_plate_sizes.keys()) if self._proxy is not None: rng_key, proxy_key, method_key = random.split(rng_key, 3) - proxy_fn, gibbs_init, gibbs_update = self._proxy(rng_key, - self.model, - model_args, - model_kwargs, - num_blocks=self._num_blocks) + proxy_fn, gibbs_init, self._gibbs_update = self._proxy(rng_key, + self.model, + model_args, + model_kwargs, + num_blocks=self._num_blocks) method = perturbed_method(method_key, self.model, model_args, model_kwargs, proxy_fn) self.inner_kernel._model = estimate_likelihood(self.inner_kernel._model, method) @@ -542,9 +544,9 @@ def sample(self, state, model_args, model_kwargs): model_kwargs = {} if model_kwargs is None else model_kwargs.copy() rng_key, rng_gibbs = random.split(state.rng_key) - def potential_fn(z_gibbs, z_hmc): + def potential_fn(z_gibbs, gibbs_state, z_hmc): return self.inner_kernel._potential_fn_gen( - *model_args, _gibbs_sites=z_gibbs, **model_kwargs)(z_hmc) + *model_args, _gibbs_sites=z_gibbs, _gibbs_state=gibbs_state, **model_kwargs)(z_hmc) z_gibbs = {k: v for k, v in state.z.items() if k not in state.hmc_state.z} z_gibbs_new, gibbs_state_new = self._gibbs_update(rng_key, z_gibbs, state.gibbs_state) @@ -559,9 +561,9 @@ def potential_fn(z_gibbs, z_hmc): # TODO (very low priority): move this to the above cond, only compute grad when accepting if self.inner_kernel._forward_mode_differentiation: - z_grad = jacfwd(partial(potential_fn, z_gibbs))(state.hmc_state.z) + z_grad = jacfwd(partial(potential_fn, z_gibbs, gibbs_state))(state.hmc_state.z) else: - z_grad = grad(partial(potential_fn, z_gibbs))(state.hmc_state.z) + z_grad = grad(partial(potential_fn, z_gibbs, gibbs_state))(state.hmc_state.z) hmc_state = state.hmc_state._replace(z_grad=z_grad, potential_energy=pe) model_kwargs["_gibbs_sites"] = z_gibbs @@ -585,7 +587,7 @@ def perturbed_method(rng_key, model, model_args, model_kwargs, proxy_fn): def estimator(likelihoods, params, gibbs_state): subsample_log_liks = defaultdict(float) - for (fn, value, name, subsample_dim, subsample_idx) in likelihoods.values(): + for (fn, value, name, subsample_dim) in likelihoods.values(): subsample_log_liks[name] += _sum_all_except_at_dim(fn.log_prob(value), subsample_dim) log_lik_sum = 0. @@ -629,13 +631,16 @@ def log_likelihood(params_flat, subsample_indices=None): substitute(substitute_fn=partial(_unconstrain_reparam, params)): model(*model_args, **model_kwargs) - log_lik = defaultdict(float) + log_lik = {} for site in tr.values(): if site["type"] == "sample" and site["is_observed"]: for frame in site["cond_indep_stack"]: - if frame.name in subsample_plate_sizes: + if frame.name in log_lik: log_lik[frame.name] += _sum_all_except_at_dim( site["fn"].log_prob(site["value"]), frame.dim) + else: + log_lik[frame.name] = _sum_all_except_at_dim( + site["fn"].log_prob(site["value"]), frame.dim) return log_lik def log_likelihood_sum(params_flat, subsample_indices=None): From e4bf26338a854c81e90997edd868c629b8ebef5d Mon Sep 17 00:00:00 2001 From: ola Date: Tue, 2 Feb 2021 01:13:42 +0100 Subject: [PATCH 59/93] Updated variational proxy to new API. --- examples/hmcecs/logistic_regression.py | 2 +- numpyro/infer/hmc_gibbs.py | 88 +++++++++++++++++--------- 2 files changed, 58 insertions(+), 32 deletions(-) diff --git a/examples/hmcecs/logistic_regression.py b/examples/hmcecs/logistic_regression.py index 5c18b8f79..e30ffaef4 100644 --- a/examples/hmcecs/logistic_regression.py +++ b/examples/hmcecs/logistic_regression.py @@ -78,7 +78,7 @@ def guide(feature, obs, subsample_size): numpyro.sample('theta', dist.continuous.Normal(mean, .5)) -def hmcecs_model(dataset, data, obs, subsample_size, proxy_name='taylor'): +def hmcecs_model(dataset, data, obs, subsample_size, proxy_name='variational'): model_args, model_kwargs = (data, obs, subsample_size), {} svi_key, proxy_key, estimator_key, mcmc_key = random.split(random.PRNGKey(0), 4) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index b39fba6c6..dc7371b8f 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -407,18 +407,43 @@ def _block_update(plate_sizes, num_blocks, rng_key, gibbs_sites, gibbs_state): block_size = (subsample_size - 1) // num_blocks + 1 pad = block_size - (subsample_size - 1) % block_size - 1 - # subsample_size = 7, num_blocks=3, block_size=3: 3 + 3 + 1 chosen_block = random.randint(block_key, shape=(), minval=0, maxval=num_blocks) new_idx = random.randint(subkey, minval=0, maxval=size, shape=(block_size,)) - subsample_idx_padded = jnp.pad(subsample_idx, (0, pad)) # size = 9 + subsample_idx_padded = jnp.pad(subsample_idx, (0, pad)) start = chosen_block * block_size subsample_idx_padded = lax.dynamic_update_slice_in_dim( subsample_idx_padded, new_idx, start, 0) - u_new[name] = subsample_idx_padded[:subsample_size] # size = 7 + u_new[name] = subsample_idx_padded[:subsample_size] return u_new, gibbs_state +def _block_update_proxy(num_blocks, rng_key, gibbs_sites, subsample_plate_sizes): + u_new = {} + pads = {} + new_idxs = {} + starts = {} + for name, subsample_idx in gibbs_sites.items(): + # TODO: merge with _block_update + size, subsample_size = subsample_plate_sizes[name] + rng_key, subkey, block_key = random.split(rng_key, 3) + block_size = (subsample_size - 1) // num_blocks + 1 + pad = block_size - (subsample_size - 1) % block_size - 1 + + chosen_block = random.randint(block_key, shape=(), minval=0, maxval=num_blocks) + new_idx = random.randint(subkey, minval=0, maxval=size, shape=(block_size,)) + subsample_idx_padded = jnp.pad(subsample_idx, (0, pad)) + start = chosen_block * block_size + subsample_idx_padded = lax.dynamic_update_slice_in_dim( + subsample_idx_padded, new_idx, start, 0) + + u_new[name] = subsample_idx_padded[:subsample_size] + pads[name] = pad + new_idxs[name] = new_idx + starts[name] = start + return u_new, pads, new_idxs, starts + + HMCECSState = namedtuple("HMCECSState", "z, hmc_state, rng_key, gibbs_state, accept_prob") TaylorProxyState = namedtuple("TaylorProxyState", "ref_subsample_log_liks, " "ref_subsample_log_lik_grads, ref_subsample_log_lik_hessians") # TODO: rename to shorter names? @@ -658,29 +683,9 @@ def gibbs_init(rng_key, gibbs_sites): return TaylorProxyState(ref_subsample_log_liks, ref_subsample_log_lik_grads, ref_subsample_log_lik_hessians) def gibbs_update(rng_key, gibbs_sites, gibbs_state): - u_new = {} - pads = {} - new_idxs = {} - starts = {} - new_states = defaultdict(dict) - for name, subsample_idx in gibbs_sites.items(): - size, subsample_size = subsample_plate_sizes[name] - rng_key, subkey, block_key = random.split(rng_key, 3) - block_size = (subsample_size - 1) // num_blocks + 1 - pad = block_size - (subsample_size - 1) % block_size - 1 - - chosen_block = random.randint(block_key, shape=(), minval=0, maxval=num_blocks) - new_idx = random.randint(subkey, minval=0, maxval=size, shape=(block_size,)) - subsample_idx_padded = jnp.pad(subsample_idx, (0, pad)) - start = chosen_block * block_size - subsample_idx_padded = lax.dynamic_update_slice_in_dim( - subsample_idx_padded, new_idx, start, 0) - - u_new[name] = subsample_idx_padded[:subsample_size] - pads[name] = pad - new_idxs[name] = new_idx - starts[name] = start + u_new, pads, new_idxs, starts = _block_update_proxy(num_blocks, rng_key, gibbs_sites, subsample_plate_sizes) + new_states = defaultdict(dict) ref_subsample_log_liks = log_likelihood(ref_params_flat, new_idxs) ref_subsample_log_lik_grads = jacfwd(log_likelihood)(ref_params_flat, new_idxs) ref_subsample_log_lik_hessians = jacfwd(jacfwd(log_likelihood))(ref_params_flat, new_idxs) @@ -795,19 +800,40 @@ def log_prior(params): evidence = {name: (log_posterior_prob - log_prior_prob - log_like.sum()) / num_samples for name, log_like in log_likelihood_ref.items()} - def proxy_fn(params, subsample_indices, gibbs_state): - params = {**params, **subsample_indices} + def gibbs_init(rng_key, gibbs_sites): + return VariationalProxyState({name: weights[subsample] for name, subsample in gibbs_sites.items()}) + + def gibbs_update(rng_key, gibbs_sites, gibbs_state): + u_new, pads, new_idxs, starts = _block_update_proxy(num_blocks, rng_key, gibbs_sites, subsample_plate_sizes) + + new_subsample_weights = {} + for name, subsample_weights in gibbs_sites.subsample_weights.items(): + size, subsample_size = subsample_plate_sizes[name] # TODO: fix doublication! + pad, new_idx, start = pads[name], new_idxs[name], starts[name] + new_value = jnp.pad(subsample_weights[name], + [(0, pad)] + [(0, 0)] * (jnp.ndim(subsample_weights[name]) - 1)) + new_value = lax.dynamic_update_slice_in_dim(new_value, weights[name][new_idx], start, 0) + new_subsample_weights[name] = new_value[:subsample_size] + gibbs_state = VariationalProxyState(new_subsample_weights) + return u_new, gibbs_state + + dummy_subsample = {k: jnp.arange(v[1]) for k, v in subsample_plate_sizes.items()} + + def proxy_fn(params, subsample_lik_sites, gibbs_state): + + params = {**params, **dummy_subsample} proxy_sum = defaultdict(float) proxy_subsample = defaultdict(float) log_prior_prob = log_prior(params) log_posterior_prob = log_posterior(params) - for name, subsample_idx in subsample_indices.items(): + + for name in subsample_lik_sites: proxy_sum[name] = evidence[name] + log_posterior_prob - log_prior_prob - proxy_subsample[name] = evidence[name] + \ - weights[name][subsample_idx].sum() * (log_posterior_prob - log_prior_prob) + proxy_subsample[name] = evidence[name] + gibbs_state.subsample_weights.sum() * ( + log_posterior_prob - log_prior_prob) return proxy_sum, proxy_subsample - return proxy_fn + return proxy_fn, gibbs_init, gibbs_update return construct_proxy_fn From 508e96a73fb147169c87a8f15211ca6b1c485e71 Mon Sep 17 00:00:00 2001 From: ola Date: Tue, 2 Feb 2021 01:34:59 +0100 Subject: [PATCH 60/93] Variational proxy running on breast cancer! --- examples/hmcecs/logistic_regression.py | 2 +- numpyro/infer/hmc_gibbs.py | 17 ++++++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/examples/hmcecs/logistic_regression.py b/examples/hmcecs/logistic_regression.py index e30ffaef4..fb114d6ca 100644 --- a/examples/hmcecs/logistic_regression.py +++ b/examples/hmcecs/logistic_regression.py @@ -78,7 +78,7 @@ def guide(feature, obs, subsample_size): numpyro.sample('theta', dist.continuous.Normal(mean, .5)) -def hmcecs_model(dataset, data, obs, subsample_size, proxy_name='variational'): +def hmcecs_model(dataset, data, obs, subsample_size, proxy_name='vari'): model_args, model_kwargs = (data, obs, subsample_size), {} svi_key, proxy_key, estimator_key, mcmc_key = random.split(random.PRNGKey(0), 4) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index dc7371b8f..c0dc5c888 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -700,6 +700,7 @@ def gibbs_update(rng_key, gibbs_sites, gibbs_state): for name, subsample_idx in gibbs_sites.items(): size, subsample_size = subsample_plate_sizes[name] pad, new_idx, start = pads[name], new_idxs[name], starts[name] + print(last_values, type(last_values)) new_value = jnp.pad(last_values[name], [(0, pad)] + [(0, 0)] * (jnp.ndim(last_values[name]) - 1)) new_value = lax.dynamic_update_slice_in_dim( new_value, new_block_values[name], start, 0) @@ -774,7 +775,8 @@ def log_likelihood(params, subsample_indices=None): def log_posterior(params): with numpyro.primitives.inner_stack(): - posterior_prob, _ = log_density(guide_with_params, model_args, model_kwargs, params) + guide_kwargs = {k: v for k, v in model_kwargs.items() if k != '_gibbs_state'} + posterior_prob, _ = log_density(guide_with_params, model_args, guide_kwargs, params) return posterior_prob def log_prior(params): @@ -801,17 +803,18 @@ def log_prior(params): for name, log_like in log_likelihood_ref.items()} def gibbs_init(rng_key, gibbs_sites): - return VariationalProxyState({name: weights[subsample] for name, subsample in gibbs_sites.items()}) + return VariationalProxyState( + {name: weights[name][subsample_idx] for name, subsample_idx in gibbs_sites.items()}) def gibbs_update(rng_key, gibbs_sites, gibbs_state): u_new, pads, new_idxs, starts = _block_update_proxy(num_blocks, rng_key, gibbs_sites, subsample_plate_sizes) new_subsample_weights = {} - for name, subsample_weights in gibbs_sites.subsample_weights.items(): - size, subsample_size = subsample_plate_sizes[name] # TODO: fix doublication! + for name, subsample_weights in gibbs_state.subsample_weights.items(): + size, subsample_size = subsample_plate_sizes[name] # TODO: fix duplication! pad, new_idx, start = pads[name], new_idxs[name], starts[name] - new_value = jnp.pad(subsample_weights[name], - [(0, pad)] + [(0, 0)] * (jnp.ndim(subsample_weights[name]) - 1)) + new_value = jnp.pad(subsample_weights, + [(0, pad)] + [(0, 0)] * (jnp.ndim(subsample_weights) - 1)) new_value = lax.dynamic_update_slice_in_dim(new_value, weights[name][new_idx], start, 0) new_subsample_weights[name] = new_value[:subsample_size] gibbs_state = VariationalProxyState(new_subsample_weights) @@ -829,7 +832,7 @@ def proxy_fn(params, subsample_lik_sites, gibbs_state): for name in subsample_lik_sites: proxy_sum[name] = evidence[name] + log_posterior_prob - log_prior_prob - proxy_subsample[name] = evidence[name] + gibbs_state.subsample_weights.sum() * ( + proxy_subsample[name] = evidence[name] + gibbs_state.subsample_weights[name].sum() * ( log_posterior_prob - log_prior_prob) return proxy_sum, proxy_subsample From 37619058af23a798bcf6300345b4b93cc7a2f24f Mon Sep 17 00:00:00 2001 From: ola Date: Tue, 2 Feb 2021 10:34:17 +0100 Subject: [PATCH 61/93] Working regression --- examples/hmcecs/regression.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/examples/hmcecs/regression.py b/examples/hmcecs/regression.py index b9c1770f8..27ff82edb 100644 --- a/examples/hmcecs/regression.py +++ b/examples/hmcecs/regression.py @@ -87,8 +87,8 @@ def protein(): class Network(nn.Module): def apply(self, x, out_channels): - l1 = relu(nn.Dense(x, features=100)) - l2 = relu(nn.Dense(l1, features=100)) + l1 = tanh(nn.Dense(x, features=100)) + l2 = tanh(nn.Dense(l1, features=100)) means = nn.Dense(l2, features=out_channels) return means @@ -100,19 +100,16 @@ def nonlin(x): def model(data, obs=None): module = Network.partial(out_channels=1) - net = random_flax_module('fnn', module, dist.Normal(0, 2.), input_shape=data.shape[1]) + net = random_flax_module('fnn', module, dist.Normal(0, 1.), input_shape=data.shape[1]) - if obs is not None: - obs = obs[..., None] - - prec_obs = numpyro.sample("prec_obs", dist.Normal(110.4, .1)) + prec_obs = numpyro.sample("prec_obs", dist.LogNormal(jnp.log(110.4), .0001)) sigma_obs = 1.0 / jnp.sqrt(prec_obs) # prior - numpyro.sample('obs', dist.Normal(net(data), sigma_obs), obs=obs) + numpyro.sample('obs', dist.Normal(net(data), 1 / jnp.sqrt(110.4)), obs=obs) def hmc(dataset, data, obs, warmup, num_sample): - kernel = NUTS(model, max_tree_depth=4, step_size=.0005, init_strategy=init_to_sample) + kernel = NUTS(model, max_tree_depth=5, step_size=.0005, init_strategy=init_to_sample) mcmc = MCMC(kernel, warmup, num_sample) mcmc.run(random.PRNGKey(37), data, obs, extra_fields=('num_steps',)) print(mcmc.print_summary()) @@ -129,8 +126,8 @@ def predict(model, rng_key, samples, *args, **kwargs): def main(): data, obs = load_agw_1d() - warmup = 20 - num_samples = 10 + warmup = 200 + num_samples = 1000 test_data = np.linspace(-2, 2, 500).reshape(-1, 1) samples = hmc('protein', data, obs, warmup, num_samples) vmap_args = (samples, random.split(random.PRNGKey(1), num_samples)) From a6950463495eb3d16598440047129571e717f0cf Mon Sep 17 00:00:00 2001 From: ola Date: Tue, 2 Feb 2021 16:07:56 +0100 Subject: [PATCH 62/93] Fixed problems in variational; todo rethink dummy_sample ([] doesn't work). Rethink post(theta), prior(theta). --- examples/hmcecs/logistic_regression.py | 62 +++++++++++++++----------- examples/hmcecs/regression.py | 26 +++++------ numpyro/infer/hmc_gibbs.py | 15 +++---- numpyro/primitives.py | 7 +-- 4 files changed, 59 insertions(+), 51 deletions(-) diff --git a/examples/hmcecs/logistic_regression.py b/examples/hmcecs/logistic_regression.py index fb114d6ca..459c44fa6 100644 --- a/examples/hmcecs/logistic_regression.py +++ b/examples/hmcecs/logistic_regression.py @@ -14,27 +14,24 @@ import numpyro.distributions as dist from numpyro.distributions import constraints from numpyro.examples.datasets import _load_higgs -from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, init_to_median +from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, init_to_median, init_to_value, HMC from numpyro.infer.hmc_gibbs import HMCECS, perturbed_method, variational_proxy, taylor_proxy from numpyro.infer.util import _predictive os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" -numpyro.set_platform("gpu") - def summary(dataset, name, mcmc, sample_time, svi_time=0., plates={}): n_eff_mean = np.mean([numpyro.diagnostics.effective_sample_size(device_get(v)) for k, v in mcmc.get_samples(True).items() if k not in plates]) pickle.dump(mcmc.get_samples(True), open(f'{dataset}/{name}_posterior_samples.pkl', 'wb')) - step_field = 'num_steps' if name == 'hmc' else 'hmc_state.num_steps' + step_field = 'num_steps' if name in ['hmc', 'nuts'] else 'hmc_state.num_steps' num_step = np.sum(mcmc.get_extra_fields()[step_field]) - accpt_prob = 1. + accpt_prob = np.mean(mcmc.get_extra_fields()['accept_prob']) if 'ecs' in name else 1. with open(f'{dataset}/{name}_chain_stats.txt', 'w') as f: print('sample_time', 'svi_time', 'n_eff_mean', 'gibbs_accpt_prob', 'tot_num_steps', 'time_per_step', - 'time_per_eff', - sep=',', file=f) + 'time_per_eff', sep=',', file=f) print(sample_time, svi_time, n_eff_mean, accpt_prob, num_step, sample_time / num_step, sample_time / n_eff_mean, sep=',', file=f) @@ -91,11 +88,11 @@ def hmcecs_model(dataset, data, obs, subsample_size, proxy_name='vari'): pickle.dump(svi_result.params, open(f'{dataset}/svi_params.pkl', 'wb')) params = svi_result.params + proxy_key, ref_key = random.split(proxy_key) + ref_params = _predictive(ref_key, guide, {}, (1,), return_sites='', parallel=True, + model_args=model_args, model_kwargs=model_kwargs) + ref_params.pop('mean') if proxy_name == 'taylor': - proxy_key, ref_key = random.split(proxy_key) - ref_params = _predictive(ref_key, guide, {}, (1,), return_sites='', parallel=True, - model_args=model_args, model_kwargs=model_kwargs) - ref_params.pop('mean') proxy_fn = taylor_proxy(ref_params) else: @@ -105,9 +102,10 @@ def hmcecs_model(dataset, data, obs, subsample_size, proxy_name='vari'): kernel = HMCECS(NUTS(model), proxy=proxy_fn) mcmc = MCMC(kernel, 1000, 1000) start = time() - mcmc.run(random.PRNGKey(3), data, obs, subsample_size, extra_fields=("hmc_state.accept_prob", + mcmc.run(random.PRNGKey(3), data, obs, subsample_size, extra_fields=("accept_prob", "hmc_state.num_steps")) - summary(dataset, 'ecs', mcmc, time() - start, svi_time=svi_time, plates={'N': ''}) + summary(dataset, f'ecs_{proxy_name}', mcmc, time() - start, svi_time=svi_time, plates={'N': ''}) + return ref_params def plain_log_reg_model(features, obs): @@ -116,9 +114,18 @@ def plain_log_reg_model(features, obs): numpyro.sample('obs', dist.Bernoulli(logits=theta @ features.T), obs=obs) -def hmc(dataset, data, obs): - kernel = NUTS(plain_log_reg_model, trajectory_length=1.2, init_strategy=init_to_median) - mcmc = MCMC(kernel, 100, 100) +def nuts(dataset, data, obs, ref_param): + kernel = NUTS(plain_log_reg_model, trajectory_length=1.2, init_strategy=init_to_value(values=ref_param)) + mcmc = MCMC(kernel, 1000, 1000) + mcmc._compile(random.PRNGKey(0), data, obs, extra_fields=("num_steps",)) + start = time() + mcmc.run(random.PRNGKey(0), data, obs, extra_fields=('num_steps',)) + summary(dataset, 'nuts', mcmc, time() - start) + + +def hmc(dataset, data, obs, ref_param): + kernel = HMC(plain_log_reg_model, trajectory_length=1.2, init_strategy=init_to_value(values=ref_param)) + mcmc = MCMC(kernel, 1000, 1000) mcmc._compile(random.PRNGKey(0), data, obs, extra_fields=("num_steps",)) start = time() mcmc.run(random.PRNGKey(0), data, obs, extra_fields=('num_steps',)) @@ -127,15 +134,18 @@ def hmc(dataset, data, obs): if __name__ == '__main__': - load_data = {'breast': breast_cancer_data, 'higgs': higgs_data, 'copsac': copsac_data} - subsample_sizes = {'higgs': 1300, 'copsac': 1000, 'breast': 75, } + load_data = {'higgs': higgs_data} # 'breast': breast_cancer_data , 'copsac': copsac_data} + subsample_sizes = {'higgs': 1300, 'breast': 75, } # 'copsac': 1000, data, obs = breast_cancer_data() - for dataset in load_data.keys(): - dir = f'{dataset}_{datetime.now().strftime("%Y_%m_%d_%H%M%S")}' - if not os.path.exists(dir): - os.mkdir(dir) - data, obs = load_data[dataset]() - hmcecs_model(dir, data, obs, subsample_sizes[dataset]) - # hmc(dir, data, obs) - exit() + for platform in ['gpu', 'cpu']: + numpyro.set_platform(platform) + for dataset in load_data.keys(): + dir = f'{platform}_{dataset}_{datetime.now().strftime("%Y_%m_%d_%H%M%S")}' + if not os.path.exists(dir): + os.mkdir(dir) + data, obs = load_data[dataset]() + ref_param = hmcecs_model(dir, data, obs, subsample_sizes[dataset], proxy_name='taylor') + ref_param = hmcecs_model(dir, data, obs, subsample_sizes[dataset], proxy_name='variational') + hmc(dir, data, obs, ref_param) + nuts(dir, data, obs, ref_param) diff --git a/examples/hmcecs/regression.py b/examples/hmcecs/regression.py index 27ff82edb..41250ea55 100644 --- a/examples/hmcecs/regression.py +++ b/examples/hmcecs/regression.py @@ -16,32 +16,30 @@ uci_base_url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/' +numpyro.set_platform("gpu") + def visualize(train_data, train_obs, test_data, predictions): - fs = 16 + fs = 14 m = predictions.mean(0) - s = predictions.std(0) - # s_al = (pred_list[200:].var(0).to('cpu') + tau_out ** -1) ** 0.5 + percentiles = np.percentile(predictions, [2.5, 97.5], axis=0) f, ax = plt.subplots(1, 1, figsize=(8, 4)) # Get upper and lower confidence bounds - lower, upper = (m - s * 2).flatten(), (m + s * 2).flatten() - # + aleotoric - # lower_al, upper_al = (m - s_al*2).flatten(), (m + s_al*2).flatten() + lower, upper = (percentiles[0, :]).flatten(), (percentiles[1, :]).flatten() # Plot training data as black stars - ax.plot(train_data, train_obs, 'k*', rasterized=True) + ax.plot(train_data, train_obs, 'x', marker='x', color='forestgreen', rasterized=True, label='Observed Data') # Plot predictive means as blue line - ax.plot(test_data, m, 'b', rasterized=True) + ax.plot(test_data, m, 'b', rasterized=True, label="Mean Prediction") # Shade between the lower and upper confidence bounds - ax.fill_between(test_data, lower, upper, alpha=0.5, rasterized=True) - # ax.fill_between(X_test.flatten().numpy(), lower_al.numpy(), upper_al.numpy(), alpha=0.2, rasterized=True) - ax.set_ylim([-2, 2]) + ax.fill_between(test_data, lower, upper, alpha=0.5, rasterized=True, label='95% C.I.') + ax.set_ylim([-2.5, 2.5]) ax.set_xlim([-2, 2]) plt.grid() - ax.legend(['Observed Data', 'Mean', 'Epistemic'], fontsize=fs) + ax.legend(fontsize=fs) ax.tick_params(axis='both', which='major', labelsize=14) ax.tick_params(axis='both', which='minor', labelsize=14) @@ -105,14 +103,14 @@ def model(data, obs=None): prec_obs = numpyro.sample("prec_obs", dist.LogNormal(jnp.log(110.4), .0001)) sigma_obs = 1.0 / jnp.sqrt(prec_obs) # prior - numpyro.sample('obs', dist.Normal(net(data), 1 / jnp.sqrt(110.4)), obs=obs) + numpyro.sample('obs', dist.Normal(net(data), sigma_obs), obs=obs) def hmc(dataset, data, obs, warmup, num_sample): kernel = NUTS(model, max_tree_depth=5, step_size=.0005, init_strategy=init_to_sample) mcmc = MCMC(kernel, warmup, num_sample) mcmc.run(random.PRNGKey(37), data, obs, extra_fields=('num_steps',)) - print(mcmc.print_summary()) + mcmc.print_summary() return mcmc.get_samples() diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index c0dc5c888..da54849d1 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -700,7 +700,6 @@ def gibbs_update(rng_key, gibbs_sites, gibbs_state): for name, subsample_idx in gibbs_sites.items(): size, subsample_size = subsample_plate_sizes[name] pad, new_idx, start = pads[name], new_idxs[name], starts[name] - print(last_values, type(last_values)) new_value = jnp.pad(last_values[name], [(0, pad)] + [(0, 0)] * (jnp.ndim(last_values[name]) - 1)) new_value = lax.dynamic_update_slice_in_dim( new_value, new_block_values[name], start, 0) @@ -774,16 +773,16 @@ def log_likelihood(params, subsample_indices=None): return log_lik def log_posterior(params): - with numpyro.primitives.inner_stack(): + with block(): guide_kwargs = {k: v for k, v in model_kwargs.items() if k != '_gibbs_state'} posterior_prob, _ = log_density(guide_with_params, model_args, guide_kwargs, params) return posterior_prob def log_prior(params): - with numpyro.primitives.inner_stack(): + with block(): prior_prob, _ = log_density( - block(model, hide_fn=lambda site: site['type'] == 'sample' and site['is_observed']), - model_args, model_kwargs, params) + block(model, hide_fn=lambda site: site['type'] == 'sample' and site['is_observed']), model_args, + model_kwargs, params) return prior_prob posterior_samples = _predictive(pos_key, guide_with_params, {}, (num_samples,), return_sites='', parallel=True, @@ -831,9 +830,9 @@ def proxy_fn(params, subsample_lik_sites, gibbs_state): log_posterior_prob = log_posterior(params) for name in subsample_lik_sites: - proxy_sum[name] = evidence[name] + log_posterior_prob - log_prior_prob - proxy_subsample[name] = evidence[name] + gibbs_state.subsample_weights[name].sum() * ( - log_posterior_prob - log_prior_prob) + proxy_sum[name] = log_posterior_prob - log_prior_prob - evidence[name] + proxy_subsample[name] = gibbs_state.subsample_weights[name].sum() * ( + log_posterior_prob - log_prior_prob) - evidence[name] return proxy_sum, proxy_subsample return proxy_fn, gibbs_init, gibbs_update diff --git a/numpyro/primitives.py b/numpyro/primitives.py index c77fe150e..2104b73dc 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -26,6 +26,7 @@ def inner_stack(): yield _PYRO_STACK = current_stack + def apply_stack(msg): pointer = 0 for pointer, handler in enumerate(reversed(_PYRO_STACK)): @@ -251,7 +252,7 @@ def body_fn(val, idx): i_p1 = size - idx i = i_p1 - 1 j = random.randint(rng_keys[idx], (), 0, i_p1) - val = ops.index_update(val, ops.index[[i, j], ], val[ops.index[[j, i], ]]) + val = ops.index_update(val, ops.index[[i, j],], val[ops.index[[j, i],]]) return val, None val, _ = lax.scan(body_fn, jnp.arange(size), jnp.arange(subsample_size)) @@ -316,7 +317,7 @@ def _subsample(name, size, subsample_size, dim): if subsample_size is not None and subsample_size != subsample.shape[0]: warnings.warn("subsample_size does not match len(subsample), {} vs {}.".format( subsample_size, len(subsample)) + - " Did you accidentally use different subsample_size in the model and guide?") + " Did you accidentally use different subsample_size in the model and guide?") cond_indep_stack = msg['cond_indep_stack'] occupied_dims = {f.dim for f in cond_indep_stack} if dim is None: @@ -382,7 +383,7 @@ def postprocess_message(self, msg): raise ValueError( "Inside numpyro.plate({}, {}, dim={}) invalid shape of {}: {}" .format(self.name, self.size, self.dim, statement, shape)) - if self.subsample_size < self.size: + elif self.subsample_size < self.size: value = msg["value"] new_value = jnp.take(value, self._indices, dim) msg["value"] = new_value From 896cd197035052e53a6c3ea6ca8f68b4776e1513 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Tue, 2 Feb 2021 11:52:02 -0600 Subject: [PATCH 63/93] add covtype example --- examples/covtype.py | 90 ++++++++++++++++++++++++++++++++------ numpyro/infer/hmc_gibbs.py | 23 +++++----- 2 files changed, 88 insertions(+), 25 deletions(-) diff --git a/examples/covtype.py b/examples/covtype.py index 0778bb7fe..b7c5a36f3 100644 --- a/examples/covtype.py +++ b/examples/covtype.py @@ -4,13 +4,18 @@ import argparse import time +import matplotlib.pyplot as plt + from jax import random import jax.numpy as jnp import numpyro import numpyro.distributions as dist from numpyro.examples.datasets import COVTYPE, load_dataset -from numpyro.infer import MCMC, NUTS +from numpyro.infer import HMC, HMCECS, MCMC, NUTS, SVI, Trace_ELBO, init_to_value +from numpyro.infer.autoguide import AutoBNAFNormal +from numpyro.infer.hmc_gibbs import taylor_proxy +from numpyro.infer.reparam import NeuTraReparam def _load_dataset(): @@ -30,25 +35,76 @@ def _load_dataset(): print("Data shape:", features.shape) print("Label distribution: {} has label 1, {} has label 0" .format(labels.sum(), N - labels.sum())) - return features, labels + return features[::5], labels[::5] -def model(data, labels): +def model(data, labels, subsample_size=None): dim = data.shape[1] coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim))) - logits = jnp.dot(data, coefs) - return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) + with numpyro.plate("N", data.shape[0], subsample_size=subsample_size) as idx: + logits = jnp.dot(data[idx], coefs) + return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels[idx]) def benchmark_hmc(args, features, labels): - step_size = jnp.sqrt(0.5 / features.shape[0]) - trajectory_length = step_size * args.num_steps rng_key = random.PRNGKey(1) start = time.time() - kernel = NUTS(model, trajectory_length=trajectory_length) - mcmc = MCMC(kernel, 0, args.num_samples) - mcmc.run(rng_key, features, labels) - mcmc.print_summary() + if args.algo == "HMC": + step_size = jnp.sqrt(0.5 / features.shape[0]) + trajectory_length = step_size * args.num_steps + kernel = HMC(model, step_size=step_size, trajectory_length=trajectory_length, adapt_step_size=False, + dense_mass=args.dense_mass) + subsample_size = None + elif args.algo == "NUTS": + kernel = NUTS(model, dense_mass=args.dense_mass) + subsample_size = None + elif args.algo == "HMCECS": + # a MAP estimate at the following source + # https://github.com/google/edward2/blob/master/examples/no_u_turn_sampler/logistic_regression.py#L117 + ref_params = {"coefs": jnp.array([ + +2.03420663e+00, -3.53567265e-02, -1.49223924e-01, -3.07049364e-01, + -1.00028366e-01, -1.46827862e-01, -1.64167881e-01, -4.20344204e-01, + +9.47479829e-02, -1.12681836e-02, +2.64442056e-01, -1.22087866e-01, + -6.00568838e-02, -3.79419506e-01, -1.06668741e-01, -2.97053963e-01, + -2.05253899e-01, -4.69537191e-02, -2.78072730e-02, -1.43250525e-01, + -6.77954629e-02, -4.34899796e-03, +5.90927452e-02, +7.23133609e-02, + +1.38526391e-02, -1.24497898e-01, -1.50733739e-02, -2.68872194e-02, + -1.80925727e-02, +3.47936489e-02, +4.03552800e-02, -9.98773426e-03, + +6.20188080e-02, +1.15002751e-01, +1.32145107e-01, +2.69109547e-01, + +2.45785132e-01, +1.19035013e-01, -2.59744357e-02, +9.94279515e-04, + +3.39266285e-02, -1.44057125e-02, -6.95222765e-02, -7.52013028e-02, + +1.21171586e-01, +2.29205526e-02, +1.47308692e-01, -8.34354162e-02, + -9.34122875e-02, -2.97472421e-02, -3.03937674e-01, -1.70958012e-01, + -1.59496680e-01, -1.88516974e-01, -1.20889175e+00])} + subsample_size = 1000 + inner_kernel = NUTS(model, init_strategy=init_to_value(values=ref_params), + dense_mass=args.dense_mass) + # note: if num_blocks=100, we'll update 10 index at each MCMC step + # so it took 50000 MCMC steps to iterative the whole dataset + kernel = HMCECS(inner_kernel, num_blocks=100, proxy=taylor_proxy(ref_params)) + elif args.algo == "FlowHMCECS": + subsample_size = 1000 + guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8]) + svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO()) + params, losses = svi.run(random.PRNGKey(2), 2000, features, labels) + plt.plot(losses) + + neutra = NeuTraReparam(guide, params) + neutra_model = neutra.reparam(model) + neutra_ref_params = {"auto_shared_latent": jnp.zeros(55)} + # no need to adapt mass matrix if the flow does a good job + inner_kernel = NUTS(neutra_model, init_strategy=init_to_value(values=neutra_ref_params), + adapt_mass_matrix=False) + kernel = HMCECS(inner_kernel, num_blocks=100, proxy=taylor_proxy(neutra_ref_params)) + else: + raise ValueError("Invalid algorithm, either 'HMC', 'NUTS', or 'HMCECS'.") + mcmc = MCMC(kernel, args.num_warmup, args.num_samples) + from numpyro.util import control_flow_prims_disabled + from jax import disable_jit + with disable_jit(), control_flow_prims_disabled(): + mcmc.run(rng_key, features, labels, subsample_size, extra_fields=("accept_prob",)) + print("Mean accept prob:", jnp.mean(mcmc.get_extra_fields()["accept_prob"])) + mcmc.print_summary(exclude_deterministic=False) print('\nMCMC elapsed time:', time.time() - start) @@ -60,14 +116,20 @@ def main(args): if __name__ == '__main__': assert numpyro.__version__.startswith('0.5.0') parser = argparse.ArgumentParser(description="parse args") - parser.add_argument('-n', '--num-samples', default=100, type=int, help='number of samples') + parser.add_argument('-n', '--num-samples', default=1000, type=int, help='number of samples') + parser.add_argument('--num-warmup', default=1000, type=int, help='number of warmup steps') parser.add_argument('--num-steps', default=10, type=int, help='number of steps (for "HMC")') parser.add_argument('--num-chains', nargs='?', default=1, type=int) - parser.add_argument('--algo', default='NUTS', type=str, help='whether to run "HMC" or "NUTS"') - parser.add_argument('--device', default='gpu', type=str, help='use "cpu" or "gpu".') + parser.add_argument('--algo', default='HMCECS', type=str, + help='whether to run "HMC", "NUTS", "HMCECS", or "FlowHMCECS"') + parser.add_argument('--dense-mass', action="store_true") + parser.add_argument('--x64', action="store_true") + parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".') args = parser.parse_args() numpyro.set_platform(args.device) numpyro.set_host_device_count(args.num_chains) + if args.x64: + numpyro.enable_x64() main(args) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index da54849d1..1ec8c51a4 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -445,8 +445,9 @@ def _block_update_proxy(num_blocks, rng_key, gibbs_sites, subsample_plate_sizes) HMCECSState = namedtuple("HMCECSState", "z, hmc_state, rng_key, gibbs_state, accept_prob") +# TODO: rename to shorter names? TaylorProxyState = namedtuple("TaylorProxyState", "ref_subsample_log_liks, " - "ref_subsample_log_lik_grads, ref_subsample_log_lik_hessians") # TODO: rename to shorter names? + "ref_subsample_log_lik_grads, ref_subsample_log_lik_hessians") VariationalProxyState = namedtuple('VariationalProxyState', 'subsample_weights') BlockPoissonEstState = namedtuple("BlockPoissonEstState", "block_rng_keys, sign") @@ -699,7 +700,7 @@ def gibbs_update(rng_key, gibbs_sites, gibbs_state): gibbs_state.ref_subsample_log_lik_hessians]): for name, subsample_idx in gibbs_sites.items(): size, subsample_size = subsample_plate_sizes[name] - pad, new_idx, start = pads[name], new_idxs[name], starts[name] + pad, start = pads[name], starts[name] new_value = jnp.pad(last_values[name], [(0, pad)] + [(0, 0)] * (jnp.ndim(last_values[name]) - 1)) new_value = lax.dynamic_update_slice_in_dim( new_value, new_block_values[name], start, 0) @@ -718,15 +719,15 @@ def proxy_fn(params, subsample_lik_sites, gibbs_state): proxy_sum = defaultdict(float) proxy_subsample = defaultdict(float) for name in subsample_lik_sites: - proxy_subsample[name] = ref_subsample_log_liks[name] + \ - jnp.dot(ref_subsample_log_lik_grads[name], params_diff) + \ - 0.5 * jnp.dot(jnp.dot(ref_subsample_log_lik_hessians[name], params_diff), - params_diff) - - proxy_sum[name] = ref_log_likelihoods_sum[name] + \ - jnp.dot(ref_log_likelihood_grads_sum[name], params_diff) + \ - 0.5 * jnp.dot(jnp.dot(ref_log_likelihood_hessians_sum[name], params_diff), - params_diff) + proxy_subsample[name] = (ref_subsample_log_liks[name] + + jnp.dot(ref_subsample_log_lik_grads[name], params_diff) + + 0.5 * jnp.dot(jnp.dot(ref_subsample_log_lik_hessians[name], params_diff), + params_diff)) + + proxy_sum[name] = (ref_log_likelihoods_sum[name] + + jnp.dot(ref_log_likelihood_grads_sum[name], params_diff) + + 0.5 * jnp.dot(jnp.dot(ref_log_likelihood_hessians_sum[name], params_diff), + params_diff)) return proxy_sum, proxy_subsample return proxy_fn, gibbs_init, gibbs_update From f5e8894b9d0d6e86f28eb336c0fb7be626e5d33e Mon Sep 17 00:00:00 2001 From: Du Phan Date: Tue, 2 Feb 2021 16:48:00 -0600 Subject: [PATCH 64/93] fix some bugs to substitute empty subsample indices and add some FIXME --- examples/covtype.py | 55 +++++++++++++++----------- examples/hmcecs/logistic_regression.py | 4 +- numpyro/infer/autoguide.py | 15 +++---- numpyro/infer/hmc_gibbs.py | 48 +++++++++++----------- numpyro/infer/util.py | 3 +- numpyro/primitives.py | 6 +-- 6 files changed, 69 insertions(+), 62 deletions(-) diff --git a/examples/covtype.py b/examples/covtype.py index b7c5a36f3..c89db50cc 100644 --- a/examples/covtype.py +++ b/examples/covtype.py @@ -13,8 +13,8 @@ import numpyro.distributions as dist from numpyro.examples.datasets import COVTYPE, load_dataset from numpyro.infer import HMC, HMCECS, MCMC, NUTS, SVI, Trace_ELBO, init_to_value -from numpyro.infer.autoguide import AutoBNAFNormal -from numpyro.infer.hmc_gibbs import taylor_proxy +from numpyro.infer.autoguide import AutoBNAFNormal, AutoNormal +from numpyro.infer.hmc_gibbs import taylor_proxy, variational_proxy from numpyro.infer.reparam import NeuTraReparam @@ -49,6 +49,23 @@ def model(data, labels, subsample_size=None): def benchmark_hmc(args, features, labels): rng_key = random.PRNGKey(1) start = time.time() + # a MAP estimate at the following source + # https://github.com/google/edward2/blob/master/examples/no_u_turn_sampler/logistic_regression.py#L117 + ref_params = {"coefs": jnp.array([ + +2.03420663e+00, -3.53567265e-02, -1.49223924e-01, -3.07049364e-01, + -1.00028366e-01, -1.46827862e-01, -1.64167881e-01, -4.20344204e-01, + +9.47479829e-02, -1.12681836e-02, +2.64442056e-01, -1.22087866e-01, + -6.00568838e-02, -3.79419506e-01, -1.06668741e-01, -2.97053963e-01, + -2.05253899e-01, -4.69537191e-02, -2.78072730e-02, -1.43250525e-01, + -6.77954629e-02, -4.34899796e-03, +5.90927452e-02, +7.23133609e-02, + +1.38526391e-02, -1.24497898e-01, -1.50733739e-02, -2.68872194e-02, + -1.80925727e-02, +3.47936489e-02, +4.03552800e-02, -9.98773426e-03, + +6.20188080e-02, +1.15002751e-01, +1.32145107e-01, +2.69109547e-01, + +2.45785132e-01, +1.19035013e-01, -2.59744357e-02, +9.94279515e-04, + +3.39266285e-02, -1.44057125e-02, -6.95222765e-02, -7.52013028e-02, + +1.21171586e-01, +2.29205526e-02, +1.47308692e-01, -8.34354162e-02, + -9.34122875e-02, -2.97472421e-02, -3.03937674e-01, -1.70958012e-01, + -1.59496680e-01, -1.88516974e-01, -1.20889175e+00])} if args.algo == "HMC": step_size = jnp.sqrt(0.5 / features.shape[0]) trajectory_length = step_size * args.num_steps @@ -59,23 +76,6 @@ def benchmark_hmc(args, features, labels): kernel = NUTS(model, dense_mass=args.dense_mass) subsample_size = None elif args.algo == "HMCECS": - # a MAP estimate at the following source - # https://github.com/google/edward2/blob/master/examples/no_u_turn_sampler/logistic_regression.py#L117 - ref_params = {"coefs": jnp.array([ - +2.03420663e+00, -3.53567265e-02, -1.49223924e-01, -3.07049364e-01, - -1.00028366e-01, -1.46827862e-01, -1.64167881e-01, -4.20344204e-01, - +9.47479829e-02, -1.12681836e-02, +2.64442056e-01, -1.22087866e-01, - -6.00568838e-02, -3.79419506e-01, -1.06668741e-01, -2.97053963e-01, - -2.05253899e-01, -4.69537191e-02, -2.78072730e-02, -1.43250525e-01, - -6.77954629e-02, -4.34899796e-03, +5.90927452e-02, +7.23133609e-02, - +1.38526391e-02, -1.24497898e-01, -1.50733739e-02, -2.68872194e-02, - -1.80925727e-02, +3.47936489e-02, +4.03552800e-02, -9.98773426e-03, - +6.20188080e-02, +1.15002751e-01, +1.32145107e-01, +2.69109547e-01, - +2.45785132e-01, +1.19035013e-01, -2.59744357e-02, +9.94279515e-04, - +3.39266285e-02, -1.44057125e-02, -6.95222765e-02, -7.52013028e-02, - +1.21171586e-01, +2.29205526e-02, +1.47308692e-01, -8.34354162e-02, - -9.34122875e-02, -2.97472421e-02, -3.03937674e-01, -1.70958012e-01, - -1.59496680e-01, -1.88516974e-01, -1.20889175e+00])} subsample_size = 1000 inner_kernel = NUTS(model, init_strategy=init_to_value(values=ref_params), dense_mass=args.dense_mass) @@ -88,6 +88,7 @@ def benchmark_hmc(args, features, labels): svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO()) params, losses = svi.run(random.PRNGKey(2), 2000, features, labels) plt.plot(losses) + plt.show() neutra = NeuTraReparam(guide, params) neutra_model = neutra.reparam(model) @@ -96,13 +97,21 @@ def benchmark_hmc(args, features, labels): inner_kernel = NUTS(neutra_model, init_strategy=init_to_value(values=neutra_ref_params), adapt_mass_matrix=False) kernel = HMCECS(inner_kernel, num_blocks=100, proxy=taylor_proxy(neutra_ref_params)) + elif args.algo == "HMCVECS": + subsample_size = 1000 + guide = AutoNormal(model) + svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO()) + params, losses = svi.run(random.PRNGKey(2), 2000, features, labels) + plt.plot(losses) + plt.show() + + inner_kernel = NUTS(model, init_strategy=init_to_value(values=ref_params), + dense_mass=args.dense_mass) + kernel = HMCECS(inner_kernel, num_blocks=100, proxy=variational_proxy(guide, params, num_samples=100)) else: raise ValueError("Invalid algorithm, either 'HMC', 'NUTS', or 'HMCECS'.") mcmc = MCMC(kernel, args.num_warmup, args.num_samples) - from numpyro.util import control_flow_prims_disabled - from jax import disable_jit - with disable_jit(), control_flow_prims_disabled(): - mcmc.run(rng_key, features, labels, subsample_size, extra_fields=("accept_prob",)) + mcmc.run(rng_key, features, labels, subsample_size, extra_fields=("accept_prob",)) print("Mean accept prob:", jnp.mean(mcmc.get_extra_fields()["accept_prob"])) mcmc.print_summary(exclude_deterministic=False) print('\nMCMC elapsed time:', time.time() - start) diff --git a/examples/hmcecs/logistic_regression.py b/examples/hmcecs/logistic_regression.py index 459c44fa6..0784d6dff 100644 --- a/examples/hmcecs/logistic_regression.py +++ b/examples/hmcecs/logistic_regression.py @@ -15,7 +15,7 @@ from numpyro.distributions import constraints from numpyro.examples.datasets import _load_higgs from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, init_to_median, init_to_value, HMC -from numpyro.infer.hmc_gibbs import HMCECS, perturbed_method, variational_proxy, taylor_proxy +from numpyro.infer.hmc_gibbs import HMCECS, variational_proxy, taylor_proxy from numpyro.infer.util import _predictive os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" @@ -89,6 +89,7 @@ def hmcecs_model(dataset, data, obs, subsample_size, proxy_name='vari'): params = svi_result.params proxy_key, ref_key = random.split(proxy_key) + # FIXME should we substitute params to here; or even better using the optimized mean for taylor proxy? ref_params = _predictive(ref_key, guide, {}, (1,), return_sites='', parallel=True, model_args=model_args, model_kwargs=model_kwargs) ref_params.pop('mean') @@ -138,6 +139,7 @@ def hmc(dataset, data, obs, ref_param): subsample_sizes = {'higgs': 1300, 'breast': 75, } # 'copsac': 1000, data, obs = breast_cancer_data() + # FIXME: can we change platform in a JAX program? for platform in ['gpu', 'cpu']: numpyro.set_platform(platform) for dataset in load_data.keys(): diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index c42bb3098..bd4966bdd 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -214,19 +214,14 @@ def __call__(self, *args, **kwargs): event_dim=event_dim) site_fn = dist.Normal(site_loc, site_scale).to_event(event_dim) - if site["fn"].support in [constraints.real, constraints.real_vector]: + if site["fn"].support is constraints.real \ + or (isinstance(site["fn"].support, constraints.independent) and + site["fn"].support is constraints.real): result[name] = numpyro.sample(name, site_fn) else: - unconstrained_value = numpyro.sample("{}_unconstrained".format(name), site_fn, - infer={"is_auxiliary": True}) - transform = biject_to(site['fn'].support) - value = transform(unconstrained_value) - log_density = - transform.log_abs_det_jacobian(unconstrained_value, value) - log_density = sum_rightmost(log_density, - jnp.ndim(log_density) - jnp.ndim(value) + site["fn"].event_dim) - delta_dist = dist.Delta(value, log_density=log_density, event_dim=site["fn"].event_dim) - result[name] = numpyro.sample(name, delta_dist) + guide_dist = dist.TransformedDistribution(site_fn, transform) + result[name] = numpyro.sample(name, guide_dist) return result diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 1ec8c51a4..314b7d7a9 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -7,7 +7,7 @@ from functools import partial import jax.numpy as jnp -from jax import device_put, jacfwd, jacobian, grad, hessian, lax, ops, random, value_and_grad +from jax import device_put, jacfwd, jacobian, grad, hessian, lax, ops, random, value_and_grad, vmap from jax.scipy.special import expit import numpyro @@ -742,6 +742,8 @@ def _sum_all_except_at_dim(x, dim): def variational_proxy(guide, guide_params, num_samples=10): def construct_proxy_fn(rng_key, model, model_args, model_kwargs, num_blocks=1): + # TODO: assert that there is no auxiliary latent variable in the guide + model_kwargs = model_kwargs.copy() prototype_trace = trace(seed(model, rng_key)).get_trace(*model_args, **model_kwargs) subsample_plate_sizes = { name: site["args"] @@ -775,30 +777,28 @@ def log_likelihood(params, subsample_indices=None): def log_posterior(params): with block(): - guide_kwargs = {k: v for k, v in model_kwargs.items() if k != '_gibbs_state'} - posterior_prob, _ = log_density(guide_with_params, model_args, guide_kwargs, params) + posterior_prob, _ = log_density(guide_with_params, model_args, model_kwargs, params) return posterior_prob def log_prior(params): - with block(): - prior_prob, _ = log_density( - block(model, hide_fn=lambda site: site['type'] == 'sample' and site['is_observed']), model_args, - model_kwargs, params) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=UserWarning) + dummy_subsample = {k: jnp.array([], dtype=jnp.int32) for k in subsample_plate_sizes} + with block(), substitute(data=dummy_subsample): + prior_prob, _ = log_density(model, model_args, model_kwargs, params) return prior_prob - posterior_samples = _predictive(pos_key, guide_with_params, {}, (num_samples,), return_sites='', parallel=True, - model_args=model_args, model_kwargs=model_kwargs) - log_likelihood_ref = log_likelihood(posterior_samples) - - posterior_samples = {**posterior_samples, **{k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items()}} + return_sites = [k for k, site in prototype_trace.items() + if site["type"] == "sample" and not site["is_observed"]] + posterior_samples = _predictive(pos_key, guide_with_params, {}, (num_samples,), return_sites=return_sites, + parallel=True, model_args=model_args, model_kwargs=model_kwargs) + log_likelihood_ref = vmap(log_likelihood)(posterior_samples) + log_likelihood_ref = {k: v.sum(0) for k, v in log_likelihood_ref.items()} weights = {name: log_like / log_like.sum() for name, log_like in log_likelihood_ref.items()} - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=UserWarning) - log_prior_prob = log_prior(posterior_samples) - log_posterior_prob = log_posterior(posterior_samples) - + log_prior_prob = vmap(log_prior)(posterior_samples).sum() + log_posterior_prob = vmap(log_posterior)(posterior_samples).sum() evidence = {name: (log_posterior_prob - log_prior_prob - log_like.sum()) / num_samples for name, log_like in log_likelihood_ref.items()} @@ -820,20 +820,20 @@ def gibbs_update(rng_key, gibbs_sites, gibbs_state): gibbs_state = VariationalProxyState(new_subsample_weights) return u_new, gibbs_state - dummy_subsample = {k: jnp.arange(v[1]) for k, v in subsample_plate_sizes.items()} - def proxy_fn(params, subsample_lik_sites, gibbs_state): - params = {**params, **dummy_subsample} - proxy_sum = defaultdict(float) - proxy_subsample = defaultdict(float) + proxy_sum = {} + proxy_subsample = {} + # TODO: convert params to constrained space log_prior_prob = log_prior(params) log_posterior_prob = log_posterior(params) for name in subsample_lik_sites: proxy_sum[name] = log_posterior_prob - log_prior_prob - evidence[name] - proxy_subsample[name] = gibbs_state.subsample_weights[name].sum() * ( - log_posterior_prob - log_prior_prob) - evidence[name] + # FIXME: what is a correct formula? + # proxy_subsample[name] = gibbs_state.subsample_weights[name] * ( + # log_posterior_prob - log_prior_prob) - evidence[name] + proxy_subsample[name] = gibbs_state.subsample_weights[name] * proxy_sum[name] return proxy_sum, proxy_subsample return proxy_fn, gibbs_init, gibbs_update diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 5fc3cd24b..5da8883a7 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -460,7 +460,8 @@ def initialize_model(rng_key, model, init_strategy = _init_to_unconstrained_value(values=unconstrained_values) prototype_params = transform_fn(inv_transforms, constrained_values, invert=True) (init_params, pe, grad), is_valid = find_valid_initial_params( - rng_key, model, + rng_key, substitute(model, data={k: site["value"] for k, site in model_trace.items() + if site["type"] in ["plate"]}), init_strategy=init_strategy, enum=has_enumerate_support, model_args=model_args, diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 2104b73dc..92c7f6135 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -252,7 +252,7 @@ def body_fn(val, idx): i_p1 = size - idx i = i_p1 - 1 j = random.randint(rng_keys[idx], (), 0, i_p1) - val = ops.index_update(val, ops.index[[i, j],], val[ops.index[[j, i],]]) + val = ops.index_update(val, ops.index[[i, j], ], val[ops.index[[j, i], ]]) return val, None val, _ = lax.scan(body_fn, jnp.arange(size), jnp.arange(subsample_size)) @@ -365,7 +365,7 @@ def process_message(self, msg): msg['fn'] = msg['fn'].expand(batch_shape) if self.size != self.subsample_size: scale = 1. if msg['scale'] is None else msg['scale'] - msg['scale'] = scale * self.size / self.subsample_size + msg['scale'] = scale * (self.size / self.subsample_size if self.subsample_size else 1) def postprocess_message(self, msg): if msg["type"] in ("subsample", "param") and self.dim is not None: @@ -382,7 +382,7 @@ def postprocess_message(self, msg): statement = "numpyro.subsample(..., event_dim={})".format(event_dim) raise ValueError( "Inside numpyro.plate({}, {}, dim={}) invalid shape of {}: {}" - .format(self.name, self.size, self.dim, statement, shape)) + .format(self.name, self.size, self.dim, statement, shape)) elif self.subsample_size < self.size: value = msg["value"] new_value = jnp.take(value, self._indices, dim) From b57122872ffd7d3572393156b76259a73252e4cf Mon Sep 17 00:00:00 2001 From: ola Date: Wed, 3 Feb 2021 11:34:37 +0100 Subject: [PATCH 65/93] FIXED ELBO computation and changed the weight scheme in variational proxy to w_i = softmax(E_{z~Q}[l(x_i,z)]). --- examples/covtype.py | 2 +- examples/hmcecs/logistic_regression.py | 28 ++++++++++++++------------ numpyro/infer/hmc_gibbs.py | 27 ++++++++++++++----------- 3 files changed, 31 insertions(+), 26 deletions(-) diff --git a/examples/covtype.py b/examples/covtype.py index c89db50cc..38a8a9abe 100644 --- a/examples/covtype.py +++ b/examples/covtype.py @@ -107,7 +107,7 @@ def benchmark_hmc(args, features, labels): inner_kernel = NUTS(model, init_strategy=init_to_value(values=ref_params), dense_mass=args.dense_mass) - kernel = HMCECS(inner_kernel, num_blocks=100, proxy=variational_proxy(guide, params, num_samples=100)) + kernel = HMCECS(inner_kernel, num_blocks=100, proxy=variational_proxy(guide, params, num_particles=100)) else: raise ValueError("Invalid algorithm, either 'HMC', 'NUTS', or 'HMCECS'.") mcmc = MCMC(kernel, args.num_warmup, args.num_samples) diff --git a/examples/hmcecs/logistic_regression.py b/examples/hmcecs/logistic_regression.py index 0784d6dff..201980d57 100644 --- a/examples/hmcecs/logistic_regression.py +++ b/examples/hmcecs/logistic_regression.py @@ -20,6 +20,9 @@ os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" +platform = 'gpu' +numpyro.set_platform(platform) + def summary(dataset, name, mcmc, sample_time, svi_time=0., plates={}): n_eff_mean = np.mean([numpyro.diagnostics.effective_sample_size(device_get(v)) @@ -82,7 +85,7 @@ def hmcecs_model(dataset, data, obs, subsample_size, proxy_name='vari'): optimizer = numpyro.optim.Adam(step_size=5e-5) svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) start = time() - svi_result = svi.run(svi_key, 1000, *model_args) + svi_result = svi.run(svi_key, 10000, *model_args) svi_time = time() - start pickle.dump(svi_result.params, open(f'{dataset}/svi_params.pkl', 'wb')) @@ -135,19 +138,18 @@ def hmc(dataset, data, obs, ref_param): if __name__ == '__main__': - load_data = {'higgs': higgs_data} # 'breast': breast_cancer_data , 'copsac': copsac_data} + load_data = { 'breast': breast_cancer_data} #,'higgs': higgs_data} , 'copsac': copsac_data} subsample_sizes = {'higgs': 1300, 'breast': 75, } # 'copsac': 1000, data, obs = breast_cancer_data() # FIXME: can we change platform in a JAX program? - for platform in ['gpu', 'cpu']: - numpyro.set_platform(platform) - for dataset in load_data.keys(): - dir = f'{platform}_{dataset}_{datetime.now().strftime("%Y_%m_%d_%H%M%S")}' - if not os.path.exists(dir): - os.mkdir(dir) - data, obs = load_data[dataset]() - ref_param = hmcecs_model(dir, data, obs, subsample_sizes[dataset], proxy_name='taylor') - ref_param = hmcecs_model(dir, data, obs, subsample_sizes[dataset], proxy_name='variational') - hmc(dir, data, obs, ref_param) - nuts(dir, data, obs, ref_param) + for dataset in load_data.keys(): + dir = f'{platform}_{dataset}_{datetime.now().strftime("%Y_%m_%d_%H%M%S")}' + if not os.path.exists(dir): + os.mkdir(dir) + data, obs = load_data[dataset]() + ref_param = hmcecs_model(dir, data, obs, subsample_sizes[dataset], proxy_name='variational') + # ref_param = hmcecs_model(dir, data, obs, subsample_sizes[dataset], proxy_name='taylor') + # hmc(dir, data, obs, ref_param) + # nuts(dir, data, obs, ref_param) + exit() \ No newline at end of file diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 314b7d7a9..b696a065d 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -5,6 +5,7 @@ import warnings from collections import defaultdict, namedtuple from functools import partial +import jax import jax.numpy as jnp from jax import device_put, jacfwd, jacobian, grad, hessian, lax, ops, random, value_and_grad, vmap @@ -740,7 +741,7 @@ def _sum_all_except_at_dim(x, dim): return x.reshape(x.shape[:1] + (-1,)).sum(-1) -def variational_proxy(guide, guide_params, num_samples=10): +def variational_proxy(guide, guide_params, num_particles=10): def construct_proxy_fn(rng_key, model, model_args, model_kwargs, num_blocks=1): # TODO: assert that there is no auxiliary latent variable in the guide model_kwargs = model_kwargs.copy() @@ -790,17 +791,21 @@ def log_prior(params): return_sites = [k for k, site in prototype_trace.items() if site["type"] == "sample" and not site["is_observed"]] - posterior_samples = _predictive(pos_key, guide_with_params, {}, (num_samples,), return_sites=return_sites, + posterior_samples = _predictive(pos_key, guide_with_params, {}, (num_particles,), return_sites=return_sites, parallel=True, model_args=model_args, model_kwargs=model_kwargs) log_likelihood_ref = vmap(log_likelihood)(posterior_samples) - log_likelihood_ref = {k: v.sum(0) for k, v in log_likelihood_ref.items()} - weights = {name: log_like / log_like.sum() for name, log_like in log_likelihood_ref.items()} + log_prior_prob = vmap(log_prior)(posterior_samples) + log_posterior_prob = vmap(log_posterior)(posterior_samples) - log_prior_prob = vmap(log_prior)(posterior_samples).sum() - log_posterior_prob = vmap(log_posterior)(posterior_samples).sum() - evidence = {name: (log_posterior_prob - log_prior_prob - log_like.sum()) / num_samples - for name, log_like in log_likelihood_ref.items()} + # softmax(E_{z~Q}[l(x_i,z)]) + weights = {name: jax.nn.softmax(jnp.exp(log_posterior_prob) @ log_like / num_particles) for name, log_like in + log_likelihood_ref.items()} + + # ELBO = exp(log(Q(z)) @ (log(L(z)) + log(pi(z)) - log(Q(z))) + elbo = { + name: jnp.exp(log_posterior_prob) @ (log_prior_prob + log_like.sum(1) - log_posterior_prob) / num_particles + for name, log_like in log_likelihood_ref.items()} def gibbs_init(rng_key, gibbs_sites): return VariationalProxyState( @@ -829,10 +834,8 @@ def proxy_fn(params, subsample_lik_sites, gibbs_state): log_posterior_prob = log_posterior(params) for name in subsample_lik_sites: - proxy_sum[name] = log_posterior_prob - log_prior_prob - evidence[name] - # FIXME: what is a correct formula? - # proxy_subsample[name] = gibbs_state.subsample_weights[name] * ( - # log_posterior_prob - log_prior_prob) - evidence[name] + proxy_sum[name] = log_posterior_prob - log_prior_prob - elbo[name] + # w_i = exp(E_{z~Q}[l(w_i, z)]) / sum_j^n exp(E_{z~Q}[l(w_j, z)]) proxy_subsample[name] = gibbs_state.subsample_weights[name] * proxy_sum[name] return proxy_sum, proxy_subsample From 7d9cd118f7c4ecaeda542ccc0ae22a93dbd40c68 Mon Sep 17 00:00:00 2001 From: ola Date: Wed, 3 Feb 2021 11:48:13 +0100 Subject: [PATCH 66/93] fixed proxy_sum and added equations. --- numpyro/infer/hmc_gibbs.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index b696a065d..117f236d1 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -834,7 +834,10 @@ def proxy_fn(params, subsample_lik_sites, gibbs_state): log_posterior_prob = log_posterior(params) for name in subsample_lik_sites: - proxy_sum[name] = log_posterior_prob - log_prior_prob - elbo[name] + # Q(z) = L(z)pi(z)/p(x) => L(z) = p(x)/Q(z)pi(z) >= exp(elbo)/Q(z)pi(z) => + # log(L(z)) = elbo - Q(z) - pi(z) + proxy_sum[name] = elbo[name] - log_posterior_prob - log_prior_prob + # w_i = exp(E_{z~Q}[l(w_i, z)]) / sum_j^n exp(E_{z~Q}[l(w_j, z)]) proxy_subsample[name] = gibbs_state.subsample_weights[name] * proxy_sum[name] return proxy_sum, proxy_subsample From 5dbac85f66d8fd29b25f8a5465dac0e6ca48150b Mon Sep 17 00:00:00 2001 From: ola Date: Wed, 3 Feb 2021 16:59:47 +0100 Subject: [PATCH 67/93] VECS working with AutoNormal on BreastCancer. --- examples/covtype.py | 9 +- examples/hmcecs/covtype.py | 73 ------------- examples/hmcecs/logistic_regression.py | 21 ++-- examples/hmcecs/regression.py | 142 ++++++++++++++++++------- numpyro/infer/hmc.py | 2 +- numpyro/infer/hmc_gibbs.py | 13 ++- 6 files changed, 130 insertions(+), 130 deletions(-) delete mode 100644 examples/hmcecs/covtype.py diff --git a/examples/covtype.py b/examples/covtype.py index 38a8a9abe..3b65ffc87 100644 --- a/examples/covtype.py +++ b/examples/covtype.py @@ -11,9 +11,10 @@ import numpyro import numpyro.distributions as dist +from numpyro.distributions import constraints from numpyro.examples.datasets import COVTYPE, load_dataset from numpyro.infer import HMC, HMCECS, MCMC, NUTS, SVI, Trace_ELBO, init_to_value -from numpyro.infer.autoguide import AutoBNAFNormal, AutoNormal +from numpyro.infer.autoguide import AutoBNAFNormal, AutoNormal, AutoDiagonalNormal from numpyro.infer.hmc_gibbs import taylor_proxy, variational_proxy from numpyro.infer.reparam import NeuTraReparam @@ -129,11 +130,11 @@ def main(args): parser.add_argument('--num-warmup', default=1000, type=int, help='number of warmup steps') parser.add_argument('--num-steps', default=10, type=int, help='number of steps (for "HMC")') parser.add_argument('--num-chains', nargs='?', default=1, type=int) - parser.add_argument('--algo', default='HMCECS', type=str, - help='whether to run "HMC", "NUTS", "HMCECS", or "FlowHMCECS"') + parser.add_argument('--algo', default='HMCVECS', type=str, + help='whether to run "HMCECS", "NUTS", "HMCECS", or "FlowHMCECS"') parser.add_argument('--dense-mass', action="store_true") parser.add_argument('--x64', action="store_true") - parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".') + parser.add_argument('--device', default='gpu', type=str, help='use "cpu" or "gpu".') args = parser.parse_args() numpyro.set_platform(args.device) diff --git a/examples/hmcecs/covtype.py b/examples/hmcecs/covtype.py deleted file mode 100644 index fdac66a04..000000000 --- a/examples/hmcecs/covtype.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright Contributors to the Pyro project. -# SPDX-License-Identifier: Apache-2.0 - -import argparse -import time - -from jax import random -import jax.numpy as jnp - -import numpyro -import numpyro.distributions as dist -from numpyro.examples.datasets import COVTYPE, load_dataset -from numpyro.infer import MCMC, NUTS - - -def _load_dataset(): - _, fetch = load_dataset(COVTYPE, shuffle=False) - features, labels = fetch() - - # normalize features and add intercept - features = (features - features.mean(0)) / features.std(0) - features = jnp.hstack([features, jnp.ones((features.shape[0], 1))]) - - # make binary feature - _, counts = jnp.unique(labels, return_counts=True) - specific_category = jnp.argmax(counts) - labels = (labels == specific_category) - - N, dim = features.shape - print("Data shape:", features.shape) - print("Label distribution: {} has label 1, {} has label 0" - .format(labels.sum(), N - labels.sum())) - return features, labels - - -def model(data, labels): - dim = data.shape[1] - coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim))) - logits = jnp.dot(data, coefs) - return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) - - -def benchmark_hmc(args, features, labels): - step_size = jnp.sqrt(0.5 / features.shape[0]) - trajectory_length = step_size * args.num_steps - rng_key = random.PRNGKey(1) - start = time.time() - kernel = NUTS(model, trajectory_length=trajectory_length) - mcmc = MCMC(kernel, 0, args.num_samples) - mcmc.run(rng_key, features, labels) - mcmc.print_summary() - print('\nMCMC elapsed time:', time.time() - start) - - -def main(args): - features, labels = _load_dataset() - benchmark_hmc(args, features, labels) - - -if __name__ == '__main__': - assert numpyro.__version__.startswith('0.4.1') - parser = argparse.ArgumentParser(description="parse args") - parser.add_argument('-n', '--num-samples', default=100, type=int, help='number of samples') - parser.add_argument('--num-steps', default=10, type=int, help='number of steps (for "HMC")') - parser.add_argument('--num-chains', nargs='?', default=1, type=int) - parser.add_argument('--algo', default='NUTS', type=str, help='whether to run "HMC" or "NUTS"') - parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".') - args = parser.parse_args() - - numpyro.set_platform(args.device) - numpyro.set_host_device_count(args.num_chains) - - main(args) diff --git a/examples/hmcecs/logistic_regression.py b/examples/hmcecs/logistic_regression.py index 201980d57..ad754d9a3 100644 --- a/examples/hmcecs/logistic_regression.py +++ b/examples/hmcecs/logistic_regression.py @@ -14,9 +14,10 @@ import numpyro.distributions as dist from numpyro.distributions import constraints from numpyro.examples.datasets import _load_higgs -from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, init_to_median, init_to_value, HMC +from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, init_to_median, init_to_value, HMC, autoguide from numpyro.infer.hmc_gibbs import HMCECS, variational_proxy, taylor_proxy from numpyro.infer.util import _predictive +import matplotlib.pyplot as plt os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" @@ -83,19 +84,24 @@ def hmcecs_model(dataset, data, obs, subsample_size, proxy_name='vari'): svi_key, proxy_key, estimator_key, mcmc_key = random.split(random.PRNGKey(0), 4) optimizer = numpyro.optim.Adam(step_size=5e-5) + guide = autoguide.AutoNormal(model) svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) start = time() - svi_result = svi.run(svi_key, 10000, *model_args) + params, losses = svi.run(svi_key, 10000, *model_args) svi_time = time() - start + plt.plot(losses) + plt.show() - pickle.dump(svi_result.params, open(f'{dataset}/svi_params.pkl', 'wb')) - params = svi_result.params + pickle.dump(params, open(f'{dataset}/svi_params.pkl', 'wb')) + params = params proxy_key, ref_key = random.split(proxy_key) # FIXME should we substitute params to here; or even better using the optimized mean for taylor proxy? ref_params = _predictive(ref_key, guide, {}, (1,), return_sites='', parallel=True, model_args=model_args, model_kwargs=model_kwargs) - ref_params.pop('mean') + + ref_params = {k: v for k, v in ref_params.items() if k in ['theta']} + if proxy_name == 'taylor': proxy_fn = taylor_proxy(ref_params) @@ -108,6 +114,7 @@ def hmcecs_model(dataset, data, obs, subsample_size, proxy_name='vari'): start = time() mcmc.run(random.PRNGKey(3), data, obs, subsample_size, extra_fields=("accept_prob", "hmc_state.num_steps")) + mcmc.print_summary() summary(dataset, f'ecs_{proxy_name}', mcmc, time() - start, svi_time=svi_time, plates={'N': ''}) return ref_params @@ -138,7 +145,7 @@ def hmc(dataset, data, obs, ref_param): if __name__ == '__main__': - load_data = { 'breast': breast_cancer_data} #,'higgs': higgs_data} , 'copsac': copsac_data} + load_data = {'breast': breast_cancer_data} # ,'higgs': higgs_data} , 'copsac': copsac_data} subsample_sizes = {'higgs': 1300, 'breast': 75, } # 'copsac': 1000, data, obs = breast_cancer_data() @@ -152,4 +159,4 @@ def hmc(dataset, data, obs, ref_param): # ref_param = hmcecs_model(dir, data, obs, subsample_sizes[dataset], proxy_name='taylor') # hmc(dir, data, obs, ref_param) # nuts(dir, data, obs, ref_param) - exit() \ No newline at end of file + exit() diff --git a/examples/hmcecs/regression.py b/examples/hmcecs/regression.py index 41250ea55..eaf988992 100644 --- a/examples/hmcecs/regression.py +++ b/examples/hmcecs/regression.py @@ -1,3 +1,4 @@ +import argparse from pathlib import Path import jax.numpy as jnp @@ -12,14 +13,41 @@ import numpyro.distributions as dist from numpyro import handlers from numpyro.contrib.module import random_flax_module -from numpyro.infer import MCMC, NUTS, init_to_sample + +from numpyro.infer import MCMC, NUTS, init_to_sample, HMC +import time + +import matplotlib.pyplot as plt + +from jax import random +import jax.numpy as jnp + +import numpyro +import numpyro.distributions as dist +from numpyro.distributions import constraints +from numpyro.examples.datasets import COVTYPE, load_dataset +from numpyro.infer import HMC, HMCECS, MCMC, NUTS, SVI, Trace_ELBO, init_to_value +from numpyro.infer.autoguide import AutoBNAFNormal, AutoNormal +from numpyro.infer.hmc_gibbs import taylor_proxy, variational_proxy +from numpyro.infer.reparam import NeuTraReparam uci_base_url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/' -numpyro.set_platform("gpu") +numpyro.set_platform("cpu") + +def visualize(alg, train_data, train_obs, samples, num_samples): + # helper function for prediction + def predict(model, rng_key, samples, *args, **kwargs): + model = handlers.substitute(handlers.seed(model, rng_key), samples) + # note that Y will be sampled in the model because we pass Y=None here + model_trace = handlers.trace(model).get_trace(*args, **kwargs) + return model_trace['obs']['value'] -def visualize(train_data, train_obs, test_data, predictions): + test_data = np.linspace(-2, 2, 500).reshape(-1, 1) + vmap_args = (samples, random.split(random.PRNGKey(1), num_samples)) + predictions = vmap(lambda samples, rng_key: predict(model, rng_key, samples, test_data))(*vmap_args) + predictions = predictions[..., 0] fs = 14 m = predictions.mean(0) @@ -43,10 +71,8 @@ def visualize(train_data, train_obs, test_data, predictions): ax.tick_params(axis='both', which='major', labelsize=14) ax.tick_params(axis='both', which='minor', labelsize=14) - bbox = {'facecolor': 'white', 'alpha': 0.8, 'pad': 1, 'boxstyle': 'round', 'edgecolor': 'black'} - plt.tight_layout() - # plt.savefig('plots/full_hmc.pdf', rasterized=True) + plt.savefig(f'plots/regression_{alg}.pdf', rasterized=True) plt.show() @@ -74,15 +100,6 @@ def features(x): return X[:, None], Y -def protein(): - # from hughsalimbeni/bayesian_benchmarks - # N, D, name = 45730, 9, 'protein' - url = uci_base_url + '00265/CASP.csv' - - data = pd.read_csv(url).values - return data[:, 1:], data[:, 0].reshape(-1, 1) - - class Network(nn.Module): def apply(self, x, out_channels): l1 = tanh(nn.Dense(x, features=100)) @@ -95,44 +112,87 @@ def nonlin(x): return tanh(x) -def model(data, obs=None): +def model(data, obs=None, subsample_size=None): module = Network.partial(out_channels=1) - net = random_flax_module('fnn', module, dist.Normal(0, 1.), input_shape=data.shape[1]) prec_obs = numpyro.sample("prec_obs", dist.LogNormal(jnp.log(110.4), .0001)) sigma_obs = 1.0 / jnp.sqrt(prec_obs) # prior - numpyro.sample('obs', dist.Normal(net(data), sigma_obs), obs=obs) + with numpyro.plate('N', data.shape[0], subsample_size=subsample_size) as idx: + numpyro.sample('obs', dist.Normal(net(data[idx]), sigma_obs), obs=obs[idx]) -def hmc(dataset, data, obs, warmup, num_sample): - kernel = NUTS(model, max_tree_depth=5, step_size=.0005, init_strategy=init_to_sample) - mcmc = MCMC(kernel, warmup, num_sample) - mcmc.run(random.PRNGKey(37), data, obs, extra_fields=('num_steps',)) - mcmc.print_summary() +def benchmark_hmc(args, features, labels): + features = jnp.array(features) + labels = jnp.array(labels) + start = time.time() + rng_key, ref_key = random.split(random.PRNGKey(1)) + subsample_size = 40 + guide = AutoNormal(model) + svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO()) + params, losses = svi.run(random.PRNGKey(2), 2000, features, labels, subsample_size) + plt.plot(losses) + plt.show() + ref_params = svi.guide.sample_posterior(ref_key, params, (1,)) + print(ref_params) + if args.alg == "HMC": + step_size = jnp.sqrt(0.5 / features.shape[0]) + trajectory_length = step_size * args.num_steps + kernel = HMC(model, step_size=step_size, trajectory_length=trajectory_length, adapt_step_size=False, + dense_mass=args.dense_mass) + subsample_size = None + elif args.alg == "NUTS": + kernel = NUTS(model, dense_mass=args.dense_mass) + subsample_size = None + elif args.alg == "HMCECS": + subsample_size = 40 + inner_kernel = NUTS(model, init_strategy=init_to_value(values=ref_params), + dense_mass=args.dense_mass) + kernel = HMCECS(inner_kernel, num_blocks=100, proxy=taylor_proxy(ref_params)) + elif args.alg == 'HMCVECS': + subsample_size = 40 + guide = AutoNormal(model) + svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO()) + params, losses = svi.run(random.PRNGKey(2), 2000, features, labels, subsample_size) + plt.plot(losses) + plt.show() + + inner_kernel = NUTS(model, init_strategy=init_to_value(values=ref_params), + dense_mass=args.dense_mass) + kernel = HMCECS(inner_kernel, num_blocks=100, proxy=variational_proxy(guide, params, num_particles=100)) + else: + raise ValueError('Alg not in HMC, NUTS, HMCECS, or HMCVECS.') + mcmc = MCMC(kernel, args.num_warmup, args.num_samples) + mcmc.run(rng_key, features, labels, subsample_size, extra_fields=("accept_prob",)) + print("Mean accept prob:", jnp.mean(mcmc.get_extra_fields()["accept_prob"])) + mcmc.print_summary(exclude_deterministic=False) + print('\nMCMC elapsed time:', time.time() - start) return mcmc.get_samples() -# helper function for prediction -def predict(model, rng_key, samples, *args, **kwargs): - model = handlers.substitute(handlers.seed(model, rng_key), samples) - # note that Y will be sampled in the model because we pass Y=None here - model_trace = handlers.trace(model).get_trace(*args, **kwargs) - return model_trace['obs']['value'] - - -def main(): +def main(args): data, obs = load_agw_1d() - warmup = 200 - num_samples = 1000 - test_data = np.linspace(-2, 2, 500).reshape(-1, 1) - samples = hmc('protein', data, obs, warmup, num_samples) - vmap_args = (samples, random.split(random.PRNGKey(1), num_samples)) - predictions = vmap(lambda samples, rng_key: predict(model, rng_key, samples, test_data))(*vmap_args) - predictions = predictions[..., 0] - visualize(data, obs, np.squeeze(test_data), predictions) + samples = benchmark_hmc(args, data, obs) + visualize(args.alg, data, obs, samples, args.num_samples) if __name__ == '__main__': - main() + parser = argparse.ArgumentParser(description="parse args") + parser.add_argument('-n', '--num-samples', default=1000, type=int, help='number of samples') + parser.add_argument('--num-warmup', default=200, type=int, help='number of warmup steps') + parser.add_argument('--num-steps', default=10, type=int, help='number of steps (for "HMC")') + parser.add_argument('--num-chains', nargs='?', default=1, type=int) + parser.add_argument('--alg', default='NUTS', type=str, + help='whether to run "HMCVECS", "HMC", "NUTS", or "HMCECS"') + parser.add_argument('--dense-mass', action="store_true") + parser.add_argument('--x64', action="store_true") + parser.add_argument('--device', default='gpu', type=str, help='use "cpu" or "gpu".') + args = parser.parse_args() + + numpyro.set_platform(args.device) + numpyro.set_host_device_count(args.num_chains) + if args.x64: + numpyro.enable_x64() + + main(args) diff --git a/numpyro/infer/hmc.py b/numpyro/infer/hmc.py index 8e9d9302b..3f0541532 100644 --- a/numpyro/infer/hmc.py +++ b/numpyro/infer/hmc.py @@ -120,7 +120,7 @@ def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='NUTS'): >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist - >>> from numpyro.infer.hmc import hmc + >>> from numpyro.infer.benchmark_hmc import hmc >>> from numpyro.infer.util import initialize_model >>> from numpyro.util import fori_collect diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 117f236d1..73165c84a 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -777,8 +777,11 @@ def log_likelihood(params, subsample_indices=None): return log_lik def log_posterior(params): - with block(): - posterior_prob, _ = log_density(guide_with_params, model_args, model_kwargs, params) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=UserWarning) + dummy_subsample = {k: jnp.array([], dtype=jnp.int32) for k in subsample_plate_sizes} + with block(), substitute(data=dummy_subsample): + posterior_prob, _ = log_density(guide_with_params, model_args, model_kwargs, params) return posterior_prob def log_prior(params): @@ -799,12 +802,14 @@ def log_prior(params): log_posterior_prob = vmap(log_posterior)(posterior_samples) # softmax(E_{z~Q}[l(x_i,z)]) - weights = {name: jax.nn.softmax(jnp.exp(log_posterior_prob) @ log_like / num_particles) for name, log_like in + weights = {name: jax.nn.softmax(jnp.exp(log_posterior_prob / num_particles) @ log_like / num_particles) for + name, log_like in log_likelihood_ref.items()} # ELBO = exp(log(Q(z)) @ (log(L(z)) + log(pi(z)) - log(Q(z))) elbo = { - name: jnp.exp(log_posterior_prob) @ (log_prior_prob + log_like.sum(1) - log_posterior_prob) / num_particles + name: jnp.exp(log_posterior_prob / num_particles) @ ( + log_prior_prob + log_like.sum(1) - log_posterior_prob) / num_particles for name, log_like in log_likelihood_ref.items()} def gibbs_init(rng_key, gibbs_sites): From bb783f8b503b0782ea4ce887b3de8868fd0d63ff Mon Sep 17 00:00:00 2001 From: ola Date: Wed, 3 Feb 2021 17:53:00 +0100 Subject: [PATCH 68/93] Using Likelihood as weight. --- numpyro/infer/hmc_gibbs.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 73165c84a..bb4aa70ca 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -802,14 +802,12 @@ def log_prior(params): log_posterior_prob = vmap(log_posterior)(posterior_samples) # softmax(E_{z~Q}[l(x_i,z)]) - weights = {name: jax.nn.softmax(jnp.exp(log_posterior_prob / num_particles) @ log_like / num_particles) for - name, log_like in + weights = {name: jax.nn.softmax(log_like.sum(0) / num_particles) for name, log_like in log_likelihood_ref.items()} # ELBO = exp(log(Q(z)) @ (log(L(z)) + log(pi(z)) - log(Q(z))) elbo = { - name: jnp.exp(log_posterior_prob / num_particles) @ ( - log_prior_prob + log_like.sum(1) - log_posterior_prob) / num_particles + name: jnp.exp(log_posterior_prob/num_particles) @ (log_prior_prob + log_like.sum(1) - log_posterior_prob) / num_particles for name, log_like in log_likelihood_ref.items()} def gibbs_init(rng_key, gibbs_sites): From 7644a086a71fe94062790c73bd8d02cc44895ba3 Mon Sep 17 00:00:00 2001 From: ola Date: Thu, 4 Feb 2021 13:39:25 +0100 Subject: [PATCH 69/93] factored out VECS --- examples/covtype.py | 21 ++----- examples/hmcecs/regression.py | 42 +++---------- examples/hmcecs/two_moons.py | 85 -------------------------- numpyro/infer/hmc_gibbs.py | 111 ---------------------------------- 4 files changed, 12 insertions(+), 247 deletions(-) delete mode 100644 examples/hmcecs/two_moons.py diff --git a/examples/covtype.py b/examples/covtype.py index 3b65ffc87..1e7ccf80d 100644 --- a/examples/covtype.py +++ b/examples/covtype.py @@ -4,18 +4,16 @@ import argparse import time +import jax.numpy as jnp import matplotlib.pyplot as plt - from jax import random -import jax.numpy as jnp import numpyro import numpyro.distributions as dist -from numpyro.distributions import constraints from numpyro.examples.datasets import COVTYPE, load_dataset from numpyro.infer import HMC, HMCECS, MCMC, NUTS, SVI, Trace_ELBO, init_to_value -from numpyro.infer.autoguide import AutoBNAFNormal, AutoNormal, AutoDiagonalNormal -from numpyro.infer.hmc_gibbs import taylor_proxy, variational_proxy +from numpyro.infer.autoguide import AutoBNAFNormal +from numpyro.infer.hmc_gibbs import taylor_proxy from numpyro.infer.reparam import NeuTraReparam @@ -98,17 +96,6 @@ def benchmark_hmc(args, features, labels): inner_kernel = NUTS(neutra_model, init_strategy=init_to_value(values=neutra_ref_params), adapt_mass_matrix=False) kernel = HMCECS(inner_kernel, num_blocks=100, proxy=taylor_proxy(neutra_ref_params)) - elif args.algo == "HMCVECS": - subsample_size = 1000 - guide = AutoNormal(model) - svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO()) - params, losses = svi.run(random.PRNGKey(2), 2000, features, labels) - plt.plot(losses) - plt.show() - - inner_kernel = NUTS(model, init_strategy=init_to_value(values=ref_params), - dense_mass=args.dense_mass) - kernel = HMCECS(inner_kernel, num_blocks=100, proxy=variational_proxy(guide, params, num_particles=100)) else: raise ValueError("Invalid algorithm, either 'HMC', 'NUTS', or 'HMCECS'.") mcmc = MCMC(kernel, args.num_warmup, args.num_samples) @@ -130,7 +117,7 @@ def main(args): parser.add_argument('--num-warmup', default=1000, type=int, help='number of warmup steps') parser.add_argument('--num-steps', default=10, type=int, help='number of steps (for "HMC")') parser.add_argument('--num-chains', nargs='?', default=1, type=int) - parser.add_argument('--algo', default='HMCVECS', type=str, + parser.add_argument('--algo', default='HMCECS', type=str, help='whether to run "HMCECS", "NUTS", "HMCECS", or "FlowHMCECS"') parser.add_argument('--dense-mass', action="store_true") parser.add_argument('--x64', action="store_true") diff --git a/examples/hmcecs/regression.py b/examples/hmcecs/regression.py index eaf988992..924555885 100644 --- a/examples/hmcecs/regression.py +++ b/examples/hmcecs/regression.py @@ -1,40 +1,25 @@ import argparse +import time from pathlib import Path import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np -import pandas as pd from flax import nn -from flax.nn.activation import relu, tanh -from jax import random, vmap +from flax.nn.activation import tanh +from jax import random +from jax import vmap import numpyro import numpyro.distributions as dist from numpyro import handlers from numpyro.contrib.module import random_flax_module - -from numpyro.infer import MCMC, NUTS, init_to_sample, HMC -import time - -import matplotlib.pyplot as plt - -from jax import random -import jax.numpy as jnp - -import numpyro -import numpyro.distributions as dist -from numpyro.distributions import constraints -from numpyro.examples.datasets import COVTYPE, load_dataset from numpyro.infer import HMC, HMCECS, MCMC, NUTS, SVI, Trace_ELBO, init_to_value -from numpyro.infer.autoguide import AutoBNAFNormal, AutoNormal -from numpyro.infer.hmc_gibbs import taylor_proxy, variational_proxy -from numpyro.infer.reparam import NeuTraReparam +from numpyro.infer.autoguide import AutoNormal +from numpyro.infer.hmc_gibbs import taylor_proxy uci_base_url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/' -numpyro.set_platform("cpu") - def visualize(alg, train_data, train_obs, samples, num_samples): # helper function for prediction @@ -150,19 +135,8 @@ def benchmark_hmc(args, features, labels): inner_kernel = NUTS(model, init_strategy=init_to_value(values=ref_params), dense_mass=args.dense_mass) kernel = HMCECS(inner_kernel, num_blocks=100, proxy=taylor_proxy(ref_params)) - elif args.alg == 'HMCVECS': - subsample_size = 40 - guide = AutoNormal(model) - svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO()) - params, losses = svi.run(random.PRNGKey(2), 2000, features, labels, subsample_size) - plt.plot(losses) - plt.show() - - inner_kernel = NUTS(model, init_strategy=init_to_value(values=ref_params), - dense_mass=args.dense_mass) - kernel = HMCECS(inner_kernel, num_blocks=100, proxy=variational_proxy(guide, params, num_particles=100)) else: - raise ValueError('Alg not in HMC, NUTS, HMCECS, or HMCVECS.') + raise ValueError('Alg not in HMC, NUTS, HMCECS.') mcmc = MCMC(kernel, args.num_warmup, args.num_samples) mcmc.run(rng_key, features, labels, subsample_size, extra_fields=("accept_prob",)) print("Mean accept prob:", jnp.mean(mcmc.get_extra_fields()["accept_prob"])) @@ -184,7 +158,7 @@ def main(args): parser.add_argument('--num-steps', default=10, type=int, help='number of steps (for "HMC")') parser.add_argument('--num-chains', nargs='?', default=1, type=int) parser.add_argument('--alg', default='NUTS', type=str, - help='whether to run "HMCVECS", "HMC", "NUTS", or "HMCECS"') + help='whether to run "HMC", "NUTS", or "HMCECS"') parser.add_argument('--dense-mass', action="store_true") parser.add_argument('--x64', action="store_true") parser.add_argument('--device', default='gpu', type=str, help='use "cpu" or "gpu".') diff --git a/examples/hmcecs/two_moons.py b/examples/hmcecs/two_moons.py deleted file mode 100644 index 0bca42f97..000000000 --- a/examples/hmcecs/two_moons.py +++ /dev/null @@ -1,85 +0,0 @@ -import argparse -import os - -from matplotlib.gridspec import GridSpec -import matplotlib.pyplot as plt -import seaborn as sns - -import jax -from jax import random -import jax.numpy as jnp -from jax.scipy.special import logsumexp - -import numpyro -from numpyro import optim -from numpyro.diagnostics import print_summary -import numpyro.distributions as dist -from numpyro.distributions import constraints -from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO -from numpyro.infer.autoguide import AutoBNAFNormal -from numpyro.infer.reparam import NeuTraReparam - - -class DualMoonDistribution(dist.Distribution): - support = constraints.real_vector - - def __init__(self): - super(DualMoonDistribution, self).__init__(event_shape=(2,)) - - def sample(self, key, sample_shape=()): - # it is enough to return an arbitrary sample with correct shape - return jnp.zeros(sample_shape + self.event_shape) - - def log_prob(self, x): - term1 = 0.5 * ((jnp.linalg.norm(x, axis=-1) - 2) / 0.4) ** 2 - term2 = -0.5 * ((x[..., :1] + jnp.array([-2., 2.])) / 0.6) ** 2 - pe = term1 - logsumexp(term2, axis=-1) - return -pe - - -def dual_moon_model(): - numpyro.sample('x', DualMoonDistribution()) - - -def guide(): - var = numpyro.param('var', jnp.eye(2, dtype=jnp.float32), constraints=constraints.corr_matrix) - mean = numpyro.param('mean', jnp.zeros(2, dtype=jnp.float32), constraints=constraints.real_vector) - numpyro.sample('x', dist.MultivariateNormal(mean, var)) - - -def visualize(samples): - print(samples.shape) - print(samples) - sns.kdeplot(x=samples[:, 0], y=samples[:, 1]) - plt.show() - - -def two_moons(rng_key, noise, shape): - def make_circle(data, radius, center): - return jnp.sqrt(radius ** 2 - (data - center) ** 2) - - # TODO: finish compute density - - noise_key, uni_key = random.split(rng_key) - uni_samples = jax.random.uniform(uni_key, shape) - noise = noise * jax.random.normal(noise_key, shape) - upper = uni_samples[:shape[0] // 2] - .25 - upper_noise = noise[:shape[0] // 2] - lower_noise = noise[shape[0] // 2:] - lower = uni_samples[shape[0] // 2:] + .25 - upper = jnp.vstack((upper, -make_circle(upper, .5, .25) + .1)).T - lower = jnp.vstack((lower, make_circle(lower, .5, .75) - .1)).T - - plt.scatter(upper[:, 0], upper[:, 1]) - plt.scatter(lower[:, 0], lower[:, 1]) - plt.gca().set_aspect('equal', adjustable='box') - plt.show() - - -if __name__ == '__main__': - sim_key, guide_key, mcmc_key = random.split(random.PRNGKey(0), 3) - two_moons(sim_key, noise=.05, shape=(1000,)) - dm = DualMoonDistribution() - samples = dm.sample(sim_key, (10000,)) - - # visualize(samples) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index bb4aa70ca..e56def376 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -446,10 +446,8 @@ def _block_update_proxy(num_blocks, rng_key, gibbs_sites, subsample_plate_sizes) HMCECSState = namedtuple("HMCECSState", "z, hmc_state, rng_key, gibbs_state, accept_prob") -# TODO: rename to shorter names? TaylorProxyState = namedtuple("TaylorProxyState", "ref_subsample_log_liks, " "ref_subsample_log_lik_grads, ref_subsample_log_lik_hessians") -VariationalProxyState = namedtuple('VariationalProxyState', 'subsample_weights') BlockPoissonEstState = namedtuple("BlockPoissonEstState", "block_rng_keys, sign") @@ -741,115 +739,6 @@ def _sum_all_except_at_dim(x, dim): return x.reshape(x.shape[:1] + (-1,)).sum(-1) -def variational_proxy(guide, guide_params, num_particles=10): - def construct_proxy_fn(rng_key, model, model_args, model_kwargs, num_blocks=1): - # TODO: assert that there is no auxiliary latent variable in the guide - model_kwargs = model_kwargs.copy() - prototype_trace = trace(seed(model, rng_key)).get_trace(*model_args, **model_kwargs) - subsample_plate_sizes = { - name: site["args"] - for name, site in prototype_trace.items() - if site["type"] == "plate" and site["args"][0] > site["args"][1] # i.e. size > subsample_size - } - - pos_key, guide_key, rng_key = random.split(rng_key, 3) - guide_with_params = substitute(guide, guide_params) - - # factor out? - def log_likelihood(params, subsample_indices=None): - params_flat, unravel_fn = ravel_pytree(params) - if subsample_indices is None: - subsample_indices = {k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items()} - params = unravel_fn(params_flat) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - with block(), trace() as tr, substitute(data=subsample_indices), \ - substitute(substitute_fn=partial(_unconstrain_reparam, params)): - model(*model_args, **model_kwargs) - - log_lik = defaultdict(float) - for site in tr.values(): - if site["type"] == "sample" and site["is_observed"]: - for frame in site["cond_indep_stack"]: - if frame.name in subsample_plate_sizes: - log_lik[frame.name] += _sum_all_except_at_dim( - site["fn"].log_prob(site["value"]), frame.dim) - return log_lik - - def log_posterior(params): - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=UserWarning) - dummy_subsample = {k: jnp.array([], dtype=jnp.int32) for k in subsample_plate_sizes} - with block(), substitute(data=dummy_subsample): - posterior_prob, _ = log_density(guide_with_params, model_args, model_kwargs, params) - return posterior_prob - - def log_prior(params): - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=UserWarning) - dummy_subsample = {k: jnp.array([], dtype=jnp.int32) for k in subsample_plate_sizes} - with block(), substitute(data=dummy_subsample): - prior_prob, _ = log_density(model, model_args, model_kwargs, params) - return prior_prob - - return_sites = [k for k, site in prototype_trace.items() - if site["type"] == "sample" and not site["is_observed"]] - posterior_samples = _predictive(pos_key, guide_with_params, {}, (num_particles,), return_sites=return_sites, - parallel=True, model_args=model_args, model_kwargs=model_kwargs) - log_likelihood_ref = vmap(log_likelihood)(posterior_samples) - - log_prior_prob = vmap(log_prior)(posterior_samples) - log_posterior_prob = vmap(log_posterior)(posterior_samples) - - # softmax(E_{z~Q}[l(x_i,z)]) - weights = {name: jax.nn.softmax(log_like.sum(0) / num_particles) for name, log_like in - log_likelihood_ref.items()} - - # ELBO = exp(log(Q(z)) @ (log(L(z)) + log(pi(z)) - log(Q(z))) - elbo = { - name: jnp.exp(log_posterior_prob/num_particles) @ (log_prior_prob + log_like.sum(1) - log_posterior_prob) / num_particles - for name, log_like in log_likelihood_ref.items()} - - def gibbs_init(rng_key, gibbs_sites): - return VariationalProxyState( - {name: weights[name][subsample_idx] for name, subsample_idx in gibbs_sites.items()}) - - def gibbs_update(rng_key, gibbs_sites, gibbs_state): - u_new, pads, new_idxs, starts = _block_update_proxy(num_blocks, rng_key, gibbs_sites, subsample_plate_sizes) - - new_subsample_weights = {} - for name, subsample_weights in gibbs_state.subsample_weights.items(): - size, subsample_size = subsample_plate_sizes[name] # TODO: fix duplication! - pad, new_idx, start = pads[name], new_idxs[name], starts[name] - new_value = jnp.pad(subsample_weights, - [(0, pad)] + [(0, 0)] * (jnp.ndim(subsample_weights) - 1)) - new_value = lax.dynamic_update_slice_in_dim(new_value, weights[name][new_idx], start, 0) - new_subsample_weights[name] = new_value[:subsample_size] - gibbs_state = VariationalProxyState(new_subsample_weights) - return u_new, gibbs_state - - def proxy_fn(params, subsample_lik_sites, gibbs_state): - - proxy_sum = {} - proxy_subsample = {} - # TODO: convert params to constrained space - log_prior_prob = log_prior(params) - log_posterior_prob = log_posterior(params) - - for name in subsample_lik_sites: - # Q(z) = L(z)pi(z)/p(x) => L(z) = p(x)/Q(z)pi(z) >= exp(elbo)/Q(z)pi(z) => - # log(L(z)) = elbo - Q(z) - pi(z) - proxy_sum[name] = elbo[name] - log_posterior_prob - log_prior_prob - - # w_i = exp(E_{z~Q}[l(w_i, z)]) / sum_j^n exp(E_{z~Q}[l(w_j, z)]) - proxy_subsample[name] = gibbs_state.subsample_weights[name] * proxy_sum[name] - return proxy_sum, proxy_subsample - - return proxy_fn, gibbs_init, gibbs_update - - return construct_proxy_fn - - class estimate_likelihood(numpyro.primitives.Messenger): def __init__(self, fn=None, method=None): # estimate_likelihood: accept likelihood tuple (fn, value, subsample_name, subsample_dim) From a997acbe562e9689516a78d87b5b2ce64f6a0e36 Mon Sep 17 00:00:00 2001 From: ola Date: Fri, 5 Feb 2021 13:36:08 +0100 Subject: [PATCH 70/93] Added simple test case. --- examples/hmcecs/__init__.py | 0 examples/hmcecs/cifar10.py | 152 ---------------------- examples/hmcecs/data/data.npy | Bin 6528 -> 0 bytes examples/hmcecs/logistic_regression.py | 162 ----------------------- examples/hmcecs/regression.py | 172 ------------------------- test/test_hmc_gibbs.py | 32 ++++- 6 files changed, 29 insertions(+), 489 deletions(-) delete mode 100644 examples/hmcecs/__init__.py delete mode 100644 examples/hmcecs/cifar10.py delete mode 100644 examples/hmcecs/data/data.npy delete mode 100644 examples/hmcecs/logistic_regression.py delete mode 100644 examples/hmcecs/regression.py diff --git a/examples/hmcecs/__init__.py b/examples/hmcecs/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/examples/hmcecs/cifar10.py b/examples/hmcecs/cifar10.py deleted file mode 100644 index 5bf42edae..000000000 --- a/examples/hmcecs/cifar10.py +++ /dev/null @@ -1,152 +0,0 @@ -import os -import pickle -import tarfile -from time import time -from urllib.request import urlretrieve - -import numpy as np -from flax import nn -from flax.nn.activation import selu, softmax -from jax import random, device_get - -import numpyro -import numpyro.distributions as dist -from numpyro.contrib.module import random_flax_module -from numpyro.infer import MCMC, NUTS, init_to_median - - -def cifar10(path=None): - r"""Return (train_images, train_labels, test_images, test_labels). - - Args: - path (str): Directory containing CIFAR-10. Default is - /home/USER/data/cifar10 or C:\Users\USER\data\cifar10. - Create if nonexistant. Download CIFAR-10 if missing. - - Returns: - Tuple of (train_images, train_labels, test_images, test_labels), each - a matrix. Rows are examples. Columns of images are pixel values, - with the order (red -> blue -> green). Columns of labels are a - onehot encoding of the correct class. - """ - url = 'https://www.cs.toronto.edu/~kriz/' - tar = 'cifar-10-binary.tar.gz' - files = ['cifar-10-batches-bin/data_batch_1.bin', - 'cifar-10-batches-bin/data_batch_2.bin', - 'cifar-10-batches-bin/data_batch_3.bin', - 'cifar-10-batches-bin/data_batch_4.bin', - 'cifar-10-batches-bin/data_batch_5.bin', - 'cifar-10-batches-bin/test_batch.bin'] - - if path is None: - # Set path to /home/USER/data/mnist or C:\Users\USER\data\mnist - path = os.path.join(os.path.expanduser('~'), 'data', 'cifar10') - - # Create path if it doesn't exist - os.makedirs(path, exist_ok=True) - - # Download tarfile if missing - if tar not in os.listdir(path): - urlretrieve(''.join((url, tar)), os.path.join(path, tar)) - print("Downloaded %s to %s" % (tar, path)) - - # Load data from tarfile - with tarfile.open(os.path.join(path, tar)) as tar_object: - # Each file contains 10,000 color images and 10,000 labels - fsize = 10000 * (32 * 32 * 3) + 10000 - - # There are 6 files (5 train and 1 test) - buffr = np.zeros(fsize * 6, dtype='uint8') - - # Get members of tar corresponding to data files - # -- The tar contains README's and other extraneous stuff - members = [file for file in tar_object if file.name in files] - - # Sort those members by name - # -- Ensures we load train data in the proper order - # -- Ensures that test data is the last file in the list - members.sort(key=lambda member: member.name) - - # Extract data from members - for i, member in enumerate(members): - # Get member as a file object - f = tar_object.extractfile(member) - # Read bytes from that file object into buffr - buffr[i * fsize:(i + 1) * fsize] = np.frombuffer(f.read(), 'B') - - # Parse data from buffer - # -- Examples are in chunks of 3,073 bytes - # -- First byte of each chunk is the label - # -- Next 32 * 32 * 3 = 3,072 bytes are its corresponding image - - # Labels are the first byte of every chunk - labels = buffr[::3073] - - # Pixels are everything remaining after we delete the labels - pixels = np.delete(buffr, np.arange(0, buffr.size, 3073)) - images = pixels.reshape((-1, 32, 32, 3)).astype('float32') / 255 - - # Split into train and test - train_images, test_images = images[:50000], images[50000:] - train_labels, test_labels = labels[:50000], labels[50000:] - - return train_images, train_labels, test_images, test_labels - - -def summary(dataset, name, mcmc, sample_time, svi_time=0., plates={}): - n_eff_mean = np.mean([numpyro.diagnostics.effective_sample_size(device_get(v)) - for k, v in mcmc.get_samples(True).items() if k not in plates]) - pickle.dump(mcmc.get_samples(True), open(f'{dataset}/{name}_posterior_samples.pkl', 'wb')) - step_field = 'num_steps' if name == 'hmc' else 'hmc_state.num_steps' - num_step = np.sum(mcmc.get_extra_fields()[step_field]) - accpt_prob = 1. - - with open(f'{dataset}/{name}_chain_stats.txt', 'w') as f: - print('sample_time', 'svi_time', 'n_eff_mean', 'gibbs_accpt_prob', 'tot_num_steps', 'time_per_step', - 'time_per_eff', - sep=',', file=f) - print(sample_time, svi_time, n_eff_mean, accpt_prob, num_step, sample_time / num_step, sample_time / n_eff_mean, - sep=',', file=f) - - -class Network(nn.Module): - """ Scaling Hamiltonian Monte Carlo Inference for Bayesian Neural Networks with Symmetric Splitting - Adam D. Cobb, Brian Jalaian (2020) """ - - def apply(self, x, out_channels): - c1 = selu(nn.Conv(x, features=6, kernel_size=(4, 4))) - max1 = nn.max_pool(c1, window_shape=(2, 2)) - c2 = nn.activation.selu(nn.Conv(max1, features=16, kernel_size=(4, 4))) - max2 = nn.max_pool(c2, window_shape=(2, 2)) - l1 = selu(nn.Dense(max2.reshape(x.shape[0], -1), features=400)) - l2 = selu(nn.Dense(l1, features=120)) - l3 = selu(nn.Dense(l2, features=84)) - l4 = softmax(nn.Dense(l3, features=out_channels)) - return l4 - - -def model(data, obs): - module = Network.partial(out_channels=10) - net = random_flax_module('conv_nn', module, dist.Normal(0, 1.), input_shape=data.shape) - - if obs is not None: - obs = obs[..., None] - numpyro.sample('obs', dist.Categorical(logits=net(data)), obs=obs) - - -def hmc(dataset, data, obs): - kernel = NUTS(model, init_strategy=init_to_median) - mcmc = MCMC(kernel, 100, 100) - mcmc._compile(random.PRNGKey(0), data, obs, extra_fields=("num_steps",)) - start = time() - mcmc.run(random.PRNGKey(0), data, obs, extra_fields=('num_steps',)) - summary(dataset, 'hmc', mcmc, time() - start) - - -def main(): - train_data, train_labels, test_data, test_labels = cifar10() - hmc('cifar10', train_data[:1000], train_labels[:1000]) - - -if __name__ == '__main__': - main() diff --git a/examples/hmcecs/data/data.npy b/examples/hmcecs/data/data.npy deleted file mode 100644 index 09fd6d42cb58c97323c3de77c67a6752b5731027..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6528 zcmbVQ`9n^Qk3LWH8v^h~}vM8vMAbQ0xIcr465sO~On^U_QeAlh3gGM{8ZGN$mNjvNN89-I=Cp>D}45?YDWGC1swa2lW~?m;+9-t*rF_hS zFe}!kYxAV=tTx#xbef7?PtUBPp|U+DaqV++im{0!uRfxbw-Y|!n{CY6r2kqPd0G7Z zuP6)iOD=lUgfi+v?hVwh0%lzB*Vbm>&&w&W%*694n}{Z~w(sEgZDkfDGfueZN@2}- z+s}JdhOA8olxx86=+HrqT8x*B>y$(B{6{Z0e&zTK4OBr%U0JIBQ&Wna-)UWs>>OO8 z%(vIG=h0q&l<+i5;^Jdq#EuL2xech=dY-Ruiz&(4T2mS1rMq5r6TD~V^Oy5fkSg3& zanO$2&v1Pcvb?nDNXL3E9@Cikbx9iyhz5@T0;v+-)f&CkTZH`;=$4AmwnW|b4O==8 zvmZ&h%YiCWU;AfnHugU$Q2}llW)XiU;Q7KG!4eebu1r2si}gnx-y{adZ+&9txM93R ztzkW?ecS&-Vo!{hS5_ehx#mGi>iTx<-|gYe5ZGBXYqvkfE4qXE{dLn^(w62(uCS&Qh#(3ujhO7 z8j#=mU#;hl7_#g67rr*2M`=Nem;dDWOZ_Q5&GUIf%t`;7RjM=0*fX8XW z0x4uhEtqEa1g{tFXr+eo6Xt%NkZO)2J9?d5YUa|LEBG} z!=S?d-XVAFA9>J$MyGn}m+iiT^*}rqNuhoF)M;<~7P0;iLQ^F)6J=?4KUZRYdX7|r z@n6z!w_UM+RMGcF)IU40zsM2uAqWvSp(~s95AbTMm_Iv15Nd3eU=s(7qd!So~vWM&4D7XxrSiHZ**)UxRO{qv!i5301E2DD4M z-91GAviK;{Nd~?zUn-_2;`#LD;96ACy|g1Rxs4qsVt3O(IE0T*ugCR;R*#j#8{$=Y z-hRv<9eQ1gEZp*4bCa<@^lyC$@=4o0HYDg3#yk3>0wRw!)y|!a`RDsQl!0MgR{OV1 zPHz`))gvFfWsgmZa6KXG_sAf5L_*I9Ev{dJf?zqEFupVGPb!zsC!-n>?_}Lp^{mHu z|JGGZzV*2F^1?MdpSNwC3Si$oW#mdO{!<*ZkjI~Xp!3Hkc0RvDpYi{y=Toz6j8_ok zEkcKn&NR7p)R1Jx1to`?Vg9Ea&yMRDAOH2uM&u`|sb2HloMOg_(%O17ty(P44&wA3 zO43MGEZFv|8Ly{Kc4%O6x7iWr$+%weCh5zd!|0Nade)NzC zDV&%Xb7#wr2Z*tGRzY$|I&E{OvJjt7v~E(s&8SE1&0$ziL|}jzof!DkYwJA3uIGu* z%HfyJn_0(rIKOE{u^4T9n0Q{SFJ<>5E;02r*MIqmiF4)be7>&feTTV~0ue7i%D9>*CAsl}}1(TnFj5?}^~_9iO9sJ7tIbrpPfKUg=IX@IF3^dG5A{jgOF8 ztN@caVSQZcYBBz;zEY?kZ=NnZhu71AZxo=NY<1{oFpf8`&b<+N3>rGk;w09)z-qe~ zb*gvTr-ztP?0&;1O3~H9v8pFK%vqZtIvUX9AcL+>Q#{V^X-1>!I%Vasiw*2~bebF_po|%*@VPI>@X$ND>m+-Kj7}?wzQ6r1v`h~N%<&aOT zwYchz&nL$DHp7SATVk9Ra`@^GGkQF0vdu$`<3TJpSHYzvJMY*=b9g7_sUh*~WA#Mf z#=DZ$V5VxByjb#t#YYZpQ$T5d#lEz7tPeqfgA}z-(p@@rRRcTD4?5lg88at)HO$5J zma@H}fWdAK!;UV%ds(FB|-ijKem**;P~(y_BNu1b-yO6i~* z@KY%)D}8zpgqUw$mbDZjw(fUv{8Y{26`m!Hvif`$u8r$o{E7!NN;tA^&N9m%u9alH6>tx`yTK4Pa$q!RP@b}kKi?ISP5 zoW=b*e}1tX+(I)V{(Xn_ARPa`0i`yMnYtww?zhDuP7aP^A6|Bz z!s(~fN`xG8s}I@q#`Tljy1wG_I^e9W&YR=&wHAV`5y*}OL<}btg z_?V>vji}7PJP^lA(8#AzX#7&odQa}Yx&ONnEh=nHyL{7}V)&;P9;8v~>NQ^8J1}2@ zpa=!LU(j#tj8UDee}Z||VpvZ*BqiVA^!O}_M&EyCMeLZ!)nm`4QrJ@J;oLNWtDpNF zG$3rr>1?;d`k+*0O|aOTx8h?PzR##{2`UJklC#t4D_+mHh?hdYbCxDnkMaFO$_N#> zv_-C|u*Ul2o5jn(dvufQ_k7GBWp_>m%4@oZ6R&XV3I9eAIMy%GAC|@XPZtEp;8B(6 zn6wz5FIb!_hfQsxEynG^dgE)fc*r~l*jpU>R)}a?*lu>C7UwU~=qy1GY&BWKR$_cK>{G$q)%%(%!#OUV%c6Xej&v^8*e zq#IixYhX=K0M|bfQOBjQn>V>?i5K>tT(gpf&*k|BeQ)6U!4FEIQJ?&OBlSBudrLnW z?Hqo3uk}k@|M}*XA~a}p`J3Dxm~Vp5;2qrU*FL7!kYwDWmlSKEI@Rmk!nvGYWJ~H% zwH5X4uWfjoy1!iw!ZUMAYKuOzd{CriGbA1gdtiML`_DIUXh6C7Pk!bm;Borxfo2fS z4*K(|BOfvI$-kNR+TenXO%>m8{~%0#L!%QD2M>4r^_Csy)9Ym5QSqXCb}~MX-*H3% zg%*Rf-`8P2_|bdSKz2+J_EvEBgYb$Bx~l@lwynZ^3S5}}#yVo@I~@g%53eFp1&ck! zs)cgyeY<)wlOMzFen;tIe#mXQ^=Qh^e%|LMWBvuB8U80!Job+Kjq8zM^A!mMH{NW! zEx~#f%(&VJ#mRajhTh@wYrH}UIy-uwajnJg8M0qrCD?rk8>zPu_vb>_v}VX#t~7V~ z2jdfN&!Zv8ZEc_EM(mH^U9<)Y2gIM6$;0^grP(qF4?S??ejtZ;{w*0~ALtc*Y+p0W zx8S~$67<8%%Uo61KmK!NGsx^}r|+M^<;&mOnYw?%bwRQ5CU(AX@K70i?DKlL^)c?f zkmRp`1v7{~<=2kClWV&g+y zpQV5?WxXeLrDFWlyAdrAXnW+*thYT*A5{ysqr zDUv~Y>&vddbUdFXvc%||>*c)c|Ld0nd?fHs^BKnzAF%$2zKLSA_2rx&-^;Nrh^uA3 zX9TndOgx403d=7wLtX2k+pjb@ex%1&DdZ`qwC2~8vHp`|2C3jg!o?Q9|8TwI8yuA& zBg6TTelZw7@rmg-FUqQh)vV$C|Ncn@SFFzD_%Gx7+ZDU&(BidUY=*|;dM=nbNC7X} z?gj;icd+{l&d=6B&-eXK8b#s$g!W65K}EWH?n?sKBfiBmCAhrL@KyeJhnV@qfz=Yo z_S(AeFvmAFBojW;9VCs#lecx+nzAy-dI+j z7W9S7@4Z`P$dBLW-G+JGe9weNWcBRs5}6OaUueHP1@Mlk3t#re^@6fp&-4qGl0)@X zIKN24G%b`~&+Bb5yb9xE-y1!Huhe}Vi}xcfMIuyEIos@c65fwk^+y3+Hv$Sr{o>E6KZFVM$~@x{YZ0998cjZ zT8yOQKCU}{6YF2-sUt>meS`H>1@6b_f(9krcDYj%#r3m-966H@kB;SkIF9wn8$MqF zzCYr%kKW*ZiT5E>1s&rr2a*Li--R11nD22lUHOU(j9(D z`diV$kvKlo9~-8>yU}Mt75Cm26dY?prEkX0b6JPKZwPj^&`{;O=e>6V)(hWxEW`IQ zr@=aVupa5x^EF@#ib{!oAspcF%2@-T7s(jEaXRaPyd`CmOB|lw~ zk8RrGxCGF+Awk#HV?M|N9*s!pO6Mo3xc*S1bM+%~ss!ba8ofNtn$x4(N)3qa9n9P(;^L9Se4m*b zQCVc0-^St>L@sCm!ZG*>)y>^sn|u{evCQDF?SpW>2v4QT;B)K#QQclRU#NClDJssp z`?A9c;}tGYG2fTo4Ga;-;QNh2CuPv5%liDS9PA&dzE+P)QU@+J(B5xp>Hjp zw+H6~wR52u^%pLURh4n|{pt)6(u#{74C{gQD^Mt;uxNG4t(yTj{=8Rw8hi_T8upFD z^<8Mf+($S5JTUUgbU8_C*2B+3hAoiiVvp$S0Wl&R?!O bzqXkjCs)suLb0FO**E)ff5Wq2?u-8c@H$K_ diff --git a/examples/hmcecs/logistic_regression.py b/examples/hmcecs/logistic_regression.py deleted file mode 100644 index ad754d9a3..000000000 --- a/examples/hmcecs/logistic_regression.py +++ /dev/null @@ -1,162 +0,0 @@ -import os -import pathlib -import pickle -from datetime import datetime -from time import time - -import jax.numpy as jnp -import numpy as np -from jax import random, device_get -from pandas_plink import read_plink1_bin -from sklearn.datasets import load_breast_cancer - -import numpyro -import numpyro.distributions as dist -from numpyro.distributions import constraints -from numpyro.examples.datasets import _load_higgs -from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, init_to_median, init_to_value, HMC, autoguide -from numpyro.infer.hmc_gibbs import HMCECS, variational_proxy, taylor_proxy -from numpyro.infer.util import _predictive -import matplotlib.pyplot as plt - -os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" - -platform = 'gpu' -numpyro.set_platform(platform) - - -def summary(dataset, name, mcmc, sample_time, svi_time=0., plates={}): - n_eff_mean = np.mean([numpyro.diagnostics.effective_sample_size(device_get(v)) - for k, v in mcmc.get_samples(True).items() if k not in plates]) - pickle.dump(mcmc.get_samples(True), open(f'{dataset}/{name}_posterior_samples.pkl', 'wb')) - step_field = 'num_steps' if name in ['hmc', 'nuts'] else 'hmc_state.num_steps' - num_step = np.sum(mcmc.get_extra_fields()[step_field]) - accpt_prob = np.mean(mcmc.get_extra_fields()['accept_prob']) if 'ecs' in name else 1. - - with open(f'{dataset}/{name}_chain_stats.txt', 'w') as f: - print('sample_time', 'svi_time', 'n_eff_mean', 'gibbs_accpt_prob', 'tot_num_steps', 'time_per_step', - 'time_per_eff', sep=',', file=f) - print(sample_time, svi_time, n_eff_mean, accpt_prob, num_step, sample_time / num_step, sample_time / n_eff_mean, - sep=',', file=f) - - -def higgs_data(): - obs, data = _load_higgs() - return data, obs - - -def breast_cancer_data(): - dataset = load_breast_cancer() - feats = dataset.data - feats = (feats - feats.mean(0)) / feats.std(0) - feats = jnp.hstack((feats, jnp.ones((feats.shape[0], 1)))) - return feats, dataset.target - - -def copsac_data(): - data_folder = pathlib.Path('data') - bim_file = str(data_folder / 'Sim_data_3.bim') - fam_file = str(data_folder / 'Sim_data_3.fam') - bed_file = str(data_folder / 'Sim_data_3.bed') - data = read_plink1_bin(bed_file, bim_file, fam_file) - - return jnp.array(data.values), jnp.array(data['trait'].astype(int)) - - -def model(features, obs, subsample_size): - n, m = features.shape - theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), .5 * jnp.ones(m))) - with numpyro.plate('N', n, subsample_size=subsample_size): - batch_feats = numpyro.subsample(features, event_dim=1) - batch_obs = numpyro.subsample(obs, event_dim=0) - numpyro.sample('obs', dist.Bernoulli(logits=theta @ batch_feats.T), obs=batch_obs) - - -def guide(feature, obs, subsample_size): - _, m = feature.shape - mean = numpyro.param('mean', jnp.zeros(m), constraints=constraints.real) - # var = numpyro.param('var', jnp.ones(m), constraints=constraints.positive) - numpyro.sample('theta', dist.continuous.Normal(mean, .5)) - - -def hmcecs_model(dataset, data, obs, subsample_size, proxy_name='vari'): - model_args, model_kwargs = (data, obs, subsample_size), {} - - svi_key, proxy_key, estimator_key, mcmc_key = random.split(random.PRNGKey(0), 4) - optimizer = numpyro.optim.Adam(step_size=5e-5) - guide = autoguide.AutoNormal(model) - svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) - start = time() - params, losses = svi.run(svi_key, 10000, *model_args) - svi_time = time() - start - plt.plot(losses) - plt.show() - - pickle.dump(params, open(f'{dataset}/svi_params.pkl', 'wb')) - params = params - - proxy_key, ref_key = random.split(proxy_key) - # FIXME should we substitute params to here; or even better using the optimized mean for taylor proxy? - ref_params = _predictive(ref_key, guide, {}, (1,), return_sites='', parallel=True, - model_args=model_args, model_kwargs=model_kwargs) - - ref_params = {k: v for k, v in ref_params.items() if k in ['theta']} - - if proxy_name == 'taylor': - proxy_fn = taylor_proxy(ref_params) - - else: - proxy_fn = variational_proxy(guide, params) - - # Compute HMCECS - kernel = HMCECS(NUTS(model), proxy=proxy_fn) - mcmc = MCMC(kernel, 1000, 1000) - start = time() - mcmc.run(random.PRNGKey(3), data, obs, subsample_size, extra_fields=("accept_prob", - "hmc_state.num_steps")) - mcmc.print_summary() - summary(dataset, f'ecs_{proxy_name}', mcmc, time() - start, svi_time=svi_time, plates={'N': ''}) - return ref_params - - -def plain_log_reg_model(features, obs): - n, m = features.shape - theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), .5 * jnp.ones(m))) - numpyro.sample('obs', dist.Bernoulli(logits=theta @ features.T), obs=obs) - - -def nuts(dataset, data, obs, ref_param): - kernel = NUTS(plain_log_reg_model, trajectory_length=1.2, init_strategy=init_to_value(values=ref_param)) - mcmc = MCMC(kernel, 1000, 1000) - mcmc._compile(random.PRNGKey(0), data, obs, extra_fields=("num_steps",)) - start = time() - mcmc.run(random.PRNGKey(0), data, obs, extra_fields=('num_steps',)) - summary(dataset, 'nuts', mcmc, time() - start) - - -def hmc(dataset, data, obs, ref_param): - kernel = HMC(plain_log_reg_model, trajectory_length=1.2, init_strategy=init_to_value(values=ref_param)) - mcmc = MCMC(kernel, 1000, 1000) - mcmc._compile(random.PRNGKey(0), data, obs, extra_fields=("num_steps",)) - start = time() - mcmc.run(random.PRNGKey(0), data, obs, extra_fields=('num_steps',)) - summary(dataset, 'hmc', mcmc, time() - start) - - -if __name__ == '__main__': - - load_data = {'breast': breast_cancer_data} # ,'higgs': higgs_data} , 'copsac': copsac_data} - subsample_sizes = {'higgs': 1300, 'breast': 75, } # 'copsac': 1000, - data, obs = breast_cancer_data() - - # FIXME: can we change platform in a JAX program? - for dataset in load_data.keys(): - dir = f'{platform}_{dataset}_{datetime.now().strftime("%Y_%m_%d_%H%M%S")}' - if not os.path.exists(dir): - os.mkdir(dir) - data, obs = load_data[dataset]() - ref_param = hmcecs_model(dir, data, obs, subsample_sizes[dataset], proxy_name='variational') - # ref_param = hmcecs_model(dir, data, obs, subsample_sizes[dataset], proxy_name='taylor') - # hmc(dir, data, obs, ref_param) - # nuts(dir, data, obs, ref_param) - exit() diff --git a/examples/hmcecs/regression.py b/examples/hmcecs/regression.py deleted file mode 100644 index 924555885..000000000 --- a/examples/hmcecs/regression.py +++ /dev/null @@ -1,172 +0,0 @@ -import argparse -import time -from pathlib import Path - -import jax.numpy as jnp -import matplotlib.pyplot as plt -import numpy as np -from flax import nn -from flax.nn.activation import tanh -from jax import random -from jax import vmap - -import numpyro -import numpyro.distributions as dist -from numpyro import handlers -from numpyro.contrib.module import random_flax_module -from numpyro.infer import HMC, HMCECS, MCMC, NUTS, SVI, Trace_ELBO, init_to_value -from numpyro.infer.autoguide import AutoNormal -from numpyro.infer.hmc_gibbs import taylor_proxy - -uci_base_url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/' - - -def visualize(alg, train_data, train_obs, samples, num_samples): - # helper function for prediction - def predict(model, rng_key, samples, *args, **kwargs): - model = handlers.substitute(handlers.seed(model, rng_key), samples) - # note that Y will be sampled in the model because we pass Y=None here - model_trace = handlers.trace(model).get_trace(*args, **kwargs) - return model_trace['obs']['value'] - - test_data = np.linspace(-2, 2, 500).reshape(-1, 1) - vmap_args = (samples, random.split(random.PRNGKey(1), num_samples)) - predictions = vmap(lambda samples, rng_key: predict(model, rng_key, samples, test_data))(*vmap_args) - predictions = predictions[..., 0] - fs = 14 - - m = predictions.mean(0) - percentiles = np.percentile(predictions, [2.5, 97.5], axis=0) - - f, ax = plt.subplots(1, 1, figsize=(8, 4)) - - # Get upper and lower confidence bounds - lower, upper = (percentiles[0, :]).flatten(), (percentiles[1, :]).flatten() - - # Plot training data as black stars - ax.plot(train_data, train_obs, 'x', marker='x', color='forestgreen', rasterized=True, label='Observed Data') - # Plot predictive means as blue line - ax.plot(test_data, m, 'b', rasterized=True, label="Mean Prediction") - # Shade between the lower and upper confidence bounds - ax.fill_between(test_data, lower, upper, alpha=0.5, rasterized=True, label='95% C.I.') - ax.set_ylim([-2.5, 2.5]) - ax.set_xlim([-2, 2]) - plt.grid() - ax.legend(fontsize=fs) - ax.tick_params(axis='both', which='major', labelsize=14) - ax.tick_params(axis='both', which='minor', labelsize=14) - - plt.tight_layout() - plt.savefig(f'plots/regression_{alg}.pdf', rasterized=True) - - plt.show() - - -def load_agw_1d(get_feats=False): - def features(x): - return np.hstack([x[:, None] / 2.0, (x[:, None] / 2.0) ** 2]) - - data = np.load(str(Path(__file__).parent / 'data' / 'data.npy')) - x, y = data[:, 0], data[:, 1] - y = y[:, None] - f = features(x) - - x_means, x_stds = x.mean(axis=0), x.std(axis=0) - y_means, y_stds = y.mean(axis=0), y.std(axis=0) - f_means, f_stds = f.mean(axis=0), f.std(axis=0) - - X = ((x - x_means) / x_stds).astype(np.float32) - Y = ((y - y_means) / y_stds).astype(np.float32) - F = ((f - f_means) / f_stds).astype(np.float32) - - if get_feats: - return F, Y - - return X[:, None], Y - - -class Network(nn.Module): - def apply(self, x, out_channels): - l1 = tanh(nn.Dense(x, features=100)) - l2 = tanh(nn.Dense(l1, features=100)) - means = nn.Dense(l2, features=out_channels) - return means - - -def nonlin(x): - return tanh(x) - - -def model(data, obs=None, subsample_size=None): - module = Network.partial(out_channels=1) - net = random_flax_module('fnn', module, dist.Normal(0, 1.), input_shape=data.shape[1]) - - prec_obs = numpyro.sample("prec_obs", dist.LogNormal(jnp.log(110.4), .0001)) - sigma_obs = 1.0 / jnp.sqrt(prec_obs) # prior - - with numpyro.plate('N', data.shape[0], subsample_size=subsample_size) as idx: - numpyro.sample('obs', dist.Normal(net(data[idx]), sigma_obs), obs=obs[idx]) - - -def benchmark_hmc(args, features, labels): - features = jnp.array(features) - labels = jnp.array(labels) - start = time.time() - rng_key, ref_key = random.split(random.PRNGKey(1)) - subsample_size = 40 - guide = AutoNormal(model) - svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO()) - params, losses = svi.run(random.PRNGKey(2), 2000, features, labels, subsample_size) - plt.plot(losses) - plt.show() - ref_params = svi.guide.sample_posterior(ref_key, params, (1,)) - print(ref_params) - if args.alg == "HMC": - step_size = jnp.sqrt(0.5 / features.shape[0]) - trajectory_length = step_size * args.num_steps - kernel = HMC(model, step_size=step_size, trajectory_length=trajectory_length, adapt_step_size=False, - dense_mass=args.dense_mass) - subsample_size = None - elif args.alg == "NUTS": - kernel = NUTS(model, dense_mass=args.dense_mass) - subsample_size = None - elif args.alg == "HMCECS": - subsample_size = 40 - inner_kernel = NUTS(model, init_strategy=init_to_value(values=ref_params), - dense_mass=args.dense_mass) - kernel = HMCECS(inner_kernel, num_blocks=100, proxy=taylor_proxy(ref_params)) - else: - raise ValueError('Alg not in HMC, NUTS, HMCECS.') - mcmc = MCMC(kernel, args.num_warmup, args.num_samples) - mcmc.run(rng_key, features, labels, subsample_size, extra_fields=("accept_prob",)) - print("Mean accept prob:", jnp.mean(mcmc.get_extra_fields()["accept_prob"])) - mcmc.print_summary(exclude_deterministic=False) - print('\nMCMC elapsed time:', time.time() - start) - return mcmc.get_samples() - - -def main(args): - data, obs = load_agw_1d() - samples = benchmark_hmc(args, data, obs) - visualize(args.alg, data, obs, samples, args.num_samples) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description="parse args") - parser.add_argument('-n', '--num-samples', default=1000, type=int, help='number of samples') - parser.add_argument('--num-warmup', default=200, type=int, help='number of warmup steps') - parser.add_argument('--num-steps', default=10, type=int, help='number of steps (for "HMC")') - parser.add_argument('--num-chains', nargs='?', default=1, type=int) - parser.add_argument('--alg', default='NUTS', type=str, - help='whether to run "HMC", "NUTS", or "HMCECS"') - parser.add_argument('--dense-mass', action="store_true") - parser.add_argument('--x64', action="store_true") - parser.add_argument('--device', default='gpu', type=str, help='use "cpu" or "gpu".') - args = parser.parse_args() - - numpyro.set_platform(args.device) - numpyro.set_host_device_count(args.num_chains) - if args.x64: - numpyro.enable_x64() - - main(args) diff --git a/test/test_hmc_gibbs.py b/test/test_hmc_gibbs.py index 9b0737599..22dd82ef4 100644 --- a/test/test_hmc_gibbs.py +++ b/test/test_hmc_gibbs.py @@ -3,18 +3,18 @@ from functools import partial +import jax.numpy as jnp import numpy as np -from numpy.testing import assert_allclose import pytest - from jax import random -import jax.numpy as jnp from jax.scipy.linalg import cho_factor, cho_solve, inv, solve_triangular +from numpy.testing import assert_allclose import numpyro import numpyro.distributions as dist from numpyro.handlers import plate from numpyro.infer import HMC, HMCECS, MCMC, NUTS, DiscreteHMCGibbs, HMCGibbs +from numpyro.infer.hmc_gibbs import taylor_proxy def _linear_regression_gibbs_fn(X, XX, XY, Y, rng_key, gibbs_sites, hmc_sites): @@ -240,3 +240,29 @@ def model(data): kernel = HMCECS(NUTS(model), num_blocks=10) mcmc = MCMC(kernel, 10, 10) mcmc.run(random.PRNGKey(0), data) + + +@pytest.mark.parametrize('kernel_cls', [HMC, NUTS]) +@pytest.mark.parametrize('num_block', [1, 2, 50]) +@pytest.mark.parametrize('subsample_size', [50, 150]) +def test_hmcecs_normal_normal(kernel_cls, num_block, subsample_size): + true_loc = 0.3 + num_warmup, num_samples = 200, 200 + data = true_loc + dist.Normal().sample(random.PRNGKey(1), (10000,)) + + def model(data, subsample_size): + mean = numpyro.sample('mean', dist.Normal()) + with numpyro.plate('batch', data.shape[0], subsample_size=subsample_size): + sub_data = numpyro.subsample(data, 0) + numpyro.sample("obs", dist.Normal(mean, 1), obs=sub_data) + + ref_params = {'mean': true_loc + dist.Normal(true_loc, 5e-2).sample(random.PRNGKey(0))} + proxy_fn = taylor_proxy(ref_params) + + kernel = HMCECS(kernel_cls(model), proxy=proxy_fn) + mcmc = MCMC(kernel, num_warmup, num_samples) + mcmc.run(random.PRNGKey(0), data, subsample_size) + + samples = mcmc.get_samples() + assert_allclose(np.mean(mcmc.get_samples()['mean']), true_loc, atol=0.1) + assert len(samples['mean']) == num_samples From fb9503563a64668607560adfbf88625f43538041 Mon Sep 17 00:00:00 2001 From: ola Date: Fri, 5 Feb 2021 13:43:56 +0100 Subject: [PATCH 71/93] Cleaned. --- .gitignore | 5 +- examples/Running_Tests.sh | 25 ----- numpyro/contrib/ecs.py | 187 -------------------------------- numpyro/contrib/ecs_utils.py | 156 -------------------------- numpyro/contrib/trace_struct.py | 39 ------- numpyro/distributions/util.py | 2 +- numpyro/primitives.py | 9 -- 7 files changed, 2 insertions(+), 421 deletions(-) delete mode 100644 examples/Running_Tests.sh delete mode 100644 numpyro/contrib/ecs.py delete mode 100644 numpyro/contrib/ecs_utils.py delete mode 100644 numpyro/contrib/trace_struct.py diff --git a/.gitignore b/.gitignore index 2ff5ff5fc..dbc16a61d 100644 --- a/.gitignore +++ b/.gitignore @@ -33,7 +33,4 @@ numpyro/examples/.data # docs docs/build -docs/.DS_Store - -examples/HIGGS.csv.gz -examples/PLOTS* +docs/.DS_Store \ No newline at end of file diff --git a/examples/Running_Tests.sh b/examples/Running_Tests.sh deleted file mode 100644 index 3c38ec25e..000000000 --- a/examples/Running_Tests.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/sh -#python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo NUTS -algo NUTS -map_init NUTS & -#python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo NUTS -algo NUTS -map_init HMC & -#python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo NUTS -algo NUTS -map_init SVI & #Slow, wrong number of epochs,repeat - -echo NUTS,HMC,NUTS -python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo NUTS -algo HMC -map_init NUTS & -echo NUTS,HMC,HMC -#python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo NUTS -algo HMC -map_init HMC & -echo NUTS,HMC,SVI -#python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo NUTS -algo HMC -map_init SVI & - -echo HMC,NUTS,NUTS -python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo HMC -algo NUTS -map_init NUTS & -echo HMC,NUTS,HMC -python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo HMC -algo NUTS -map_init HMC & -echo HMC,NUTS,SVI -python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo HMC -algo NUTS -map_init SVI & - -echo HMC,HMC,NUTS -python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo HMC -algo HMC -map_init NUTS & -echo HMC,HMC,HMC -python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo HMC -algo HMC -map_init HMC & -echo HMC,HMC,SVI -python logistic_hmcecs.py -num_samples 100 -num_warmup 50 -ecs_algo HMC -algo HMC -map_init SVI & diff --git a/numpyro/contrib/ecs.py b/numpyro/contrib/ecs.py deleted file mode 100644 index b1422b500..000000000 --- a/numpyro/contrib/ecs.py +++ /dev/null @@ -1,187 +0,0 @@ -""" Based on fehiepsi implementation: https://gist.github.com/fehiepsi/b4a5a80b245600b99467a0264be05fd5 """ -import copy -from collections import namedtuple - -import jax.numpy as jnp -from jax import device_put, lax, random, partial, jit, jacobian, hessian - -from numpyro.contrib.ecs_utils import ( - init_near_values, - estimator, - subsample_size, - _tangent_curve -) -from numpyro.contrib.ecs_utils import taylor_proxy, variational_proxy, DifferenceEstimator -from numpyro.handlers import substitute, trace, seed, block -from numpyro.infer import log_likelihood -from numpyro.infer.mcmc import MCMCKernel -from numpyro.infer.util import _predictive, log_density -from numpyro.util import identity - -HMC_ECS_State = namedtuple("HMC_ECS_State", "uz, hmc_state, accept_prob, rng_key") -""" - - **uz** - a dict of current subsample indices and the current latent values - - **hmc_state** - current hmc_stat log_like += j.T @ z_diff + .5 * z_diff.T @ h.reshape(k, k) @ z_diff -e - - **accept_prob** - acceptance probability of the proposal subsample indices - - **rng_key** - random key to generate new subsample indices -""" - -""" Notes: -- [x] init(...) ] -sample(...) - will use check_potential handler method! -""" - - -def _wrap_est_model(model, estimators, predecessors): - def fn(*args, **kwargs): - subsample_values = kwargs.pop("_subsample_sites", {}) - with substitute(data=subsample_values): - with estimator(model, estimators, predecessors): - model(*args, **kwargs) - - return fn - - -@partial(jit, static_argnums=(2, 3, 4)) -def _update_block(rng_key, u, n, m, g): - """Returns indexes of the new subsample. The update mechanism selects blocks of indices within the subsample to be updated. - The number of indexes to be updated depend on the block size, higher block size more correlation among elements in the subsample. - :param rng_key: - :param u: subsample indexes - :param n: total number of data - :param m: subsample size - :param g: number of subsample blocks - """ - - rng_key_block, rng_key_index = random.split(rng_key) - - chosen_block = random.randint(rng_key_block, shape=(), minval=0, maxval=g + 1) - new_idx = random.randint(rng_key_index, minval=0, maxval=n, shape=(m,)) - block_mask = (jnp.arange(m) // g == chosen_block).astype(int) - rest_mask = (block_mask - 1) ** 2 - - u_new = u * rest_mask + block_mask * new_idx - return u_new - - -class ECS(MCMCKernel): - """ Energy conserving subsampling as first described in [1]. - - ** Reference: ** - 1. *Hamiltonian Monte Carlo with Energy ConservingSubsampling* by Dang, Khue-Dang et al. - """ - sample_field = "uz" - - def __init__(self, inner_kernel, proxy, model_struct, ref=None, guide=None): - assert proxy in ('taylor', 'variational') - self.inner_kernel = copy.copy(inner_kernel) - self.inner_kernel._model = inner_kernel.model - self._guide = guide - self._proxy = proxy - self._model_struct = model_struct - self._ref = ref - self._plate_sizes = None - self._estimator = DifferenceEstimator - - @property - def model(self): - return self.inner_kernel._model - - def postprocess_fn(self, args, kwargs): - def fn(uz): - z = {k: v for k, v in uz.items() if k not in self._plate_sizes} - return self.inner_kernel.postprocess_fn(args, kwargs)(z) - - return fn - - def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): - model_kwargs = {} if model_kwargs is None else model_kwargs.copy() - rng_key, key_u, key_z = random.split(rng_key, 3) - - prototype_trace = trace(seed(self.model, key_u)).get_trace(*model_args, **model_kwargs) - u = {name: site["value"] for name, site in prototype_trace.items() - if site["type"] == "plate" and site["args"][0] > site["args"][1]} - - # TODO: estimate good block size - self._plate_sizes = {name: prototype_trace[name]["args"] + (min(prototype_trace[name]["args"][1] // 2, 100),) - for name in u} - - plate_sizes_all = {name: (prototype_trace[name]["args"][0], prototype_trace[name]["args"][0]) for name in u} - if self._proxy == 'taylor': - # Precompute Jaccobian and Hessian for Taylor Proxy - with subsample_size(self.model, plate_sizes_all): - ref_trace = trace(substitute(self.model, data=self._z_ref)).get_trace(*model_args, **model_kwargs) - jac_all = {name: _tangent_curve(site['fn'], site['value'], jacobian) for name, site in ref_trace.items() - if (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} - hess_all = {name: _tangent_curve(site['fn'], site['value'], hessian) for name, site in ref_trace.items() - if (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} - ll_ref = {name: site['fn'].log_prob(site['value']) for name, site in ref_trace.items() if - (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} - - ref_trace = trace(substitute(self.model, data={**self._z_ref, **u})).get_trace(*model_args, **model_kwargs) - proxy_fn, uproxy_fn = taylor_proxy(ref_trace, ll_ref, jac_all, hess_all) - elif self._proxy == 'variational': - pos_key, guide_key, rng_key = random.split(rng_key, 3) - num_samples = 10 # TODO: heuristic for this - guide = substitute(self._guide, self._ref) - posterior_samples = _predictive(pos_key, guide, {}, - (num_samples,), return_sites='', parallel=True, - model_args=model_args, model_kwargs=model_kwargs) - with subsample_size(self.model, plate_sizes_all): - model = subsample_size(self.model, plate_sizes_all) - ll = log_likelihood(model, posterior_samples, *model_args, **model_kwargs) - - # TODO: fix multiple likehoods - weights = {name: jnp.mean((value.T / value.sum(1).T).T, 0) for name, value in - ll.items()} # TODO: fix broadcast - prior, _ = log_density(block(model, hide_fn=lambda site: site['type'] == 'sample' and site['is_observed']), - model_args, model_kwargs, posterior_samples) - variational, _ = log_density(guide, model_args, model_kwargs, posterior_samples) - evidence = {name: variational / num_samples - prior / num_samples - ll.mean(1).sum() for name, ll in - ll.items()} # TODO: must depend on structure! - - guide_trace = trace(seed(self._guide, guide_key)).get_trace(*model_args, **model_kwargs) - proxy_fn, uproxy_fn = variational_proxy(guide_trace, evidence, weights) - else: - raise NotImplementedError - - estimators = {name: self._estimator(name=name, - proxy=proxy_fn, uproxy=uproxy_fn, - plate_name=site['cond_indep_stack'][0].name, - plate_size=self._plate_sizes[site['cond_indep_stack'][0].name]) - for name, site in prototype_trace.items() if - (site['type'] == 'sample' and site['is_observed'] and site['cond_indep_stack'])} - - predecessors = {name: self._model_struct[name] for name in estimators} - - self.inner_kernel._model = _wrap_est_model(self.model, estimators, predecessors) - - init_params = {name: init_near_values(site, self._ref) for name, site in prototype_trace.items()} - model_kwargs["_subsample_sites"] = u - hmc_state = self.inner_kernel.init(key_z, num_warmup, init_params, model_args, model_kwargs) - uz = {**u, **hmc_state.z} - return device_put(HMC_ECS_State(uz, hmc_state, 1., rng_key)) - - def sample(self, state, model_args, model_kwargs): - model_kwargs = {} if model_kwargs is None else model_kwargs.copy() - rng_key, key_u = random.split(state.rng_key) - u = {k: v for k, v in state.uz.items() if k in self._plate_sizes} - u_new = {} - for name, (size, subsample_size, num_blocks) in self._plate_sizes.items(): - key_u, subkey = random.split(key_u) - u_new[name] = _update_block(subkey, u[name], size, subsample_size, - num_blocks) # TODO: dynamically adjust block size - sample = self.postprocess_fn(model_args, model_kwargs)(state.hmc_state.z) - u_loglik = log_likelihood(self.model, sample, *model_args, batch_ndims=0, **model_kwargs, _subsample_sites=u) - u_loglik = sum(v.sum() for v in u_loglik.values()) - u_new_loglik = log_likelihood(self.model, sample, *model_args, batch_ndims=0, **model_kwargs, - _subsample_sites=u_new) - u_new_loglik = sum(v.sum() for v in u_new_loglik.values()) - accept_prob = jnp.clip(jnp.exp(u_new_loglik - u_loglik), a_max=1.0) - u = lax.cond(random.bernoulli(key_u, accept_prob), u_new, identity, u, identity) - model_kwargs["_subsample_sites"] = u - hmc_state = self.inner_kernel.sample(state.hmc_state, model_args, model_kwargs) - uz = {**u, **hmc_state.z} - return HMC_ECS_State(uz, hmc_state, accept_prob, rng_key) diff --git a/numpyro/contrib/ecs_utils.py b/numpyro/contrib/ecs_utils.py deleted file mode 100644 index 1b853bcf4..000000000 --- a/numpyro/contrib/ecs_utils.py +++ /dev/null @@ -1,156 +0,0 @@ -from collections import OrderedDict, defaultdict - -import jax -import jax.numpy as jnp - -from numpyro.primitives import Messenger, _subsample_fn - - -def _tangent_curve(dist, value, tangent_fn): - z, aux_data = dist.tree_flatten() - log_prob = lambda *params: dist.tree_unflatten(aux_data, params).log_prob(value).sum() - return tuple(tangent_fn(log_prob, argnum)(*z) for argnum in range(len(z))) - - -def init_near_values(site=None, values={}): - """Initialize the sampling to a noisy map estimate of the parameters""" - from functools import partial - - from numpyro.distributions.continuous import Normal - from numpyro.infer.initialization import init_to_uniform - - if site is None: - return partial(init_near_values(values=values)) - - if site['type'] == 'sample' and not site['is_observed']: - if site['name'] in values: - try: - rng_key = site['kwargs'].get('rng_key') - sample_shape = site['kwargs'].get('sample_shape') - return values[site['name']] + Normal(0., 1e-3).sample(rng_key, sample_shape) - except: - return init_to_uniform(site) - - -def _extract_params(distribution): - params, _ = distribution.tree_flatten() - return params - - -class estimator(Messenger): - def __init__(self, fn, estimators, predecessors): - self.estimators = estimators - self.predecessors = predecessors - self.predecessor_sites = defaultdict(OrderedDict) - self._successors = None - - super(estimator, self).__init__(fn) - - @property - def successors(self): - if getattr(self, '_successors') is None: - successors = {} - for site_name, preds in self.predecessors.items(): - successors.update({pred_name: site_name for pred_name in preds}) # TODO: handle shared priors - self._successors = successors - return self._successors - - def postprocess_message(self, msg): - if 'name' not in msg: - return - name = msg['name'] - if name in self.successors: - self.predecessor_sites[self.successors[name]][name] = msg.copy() - - if msg['type'] == 'sample' and msg['is_observed'] and msg['cond_indep_stack']: # TODO: is subsampled - msg['fn'] = self.estimators[name](msg['fn'], self.predecessor_sites[name]) - - -def taylor_proxy(ref_trace, ll_ref, jac_all, hess_all): - def proxy(name, z): - z_ref = _extract_params(ref_trace[name]['fn']) - jac, hess = jac_all[name], hess_all[name] - log_like = jnp.array(0.) - for argnum in range(len(z_ref)): - z_diff = z[argnum] - z_ref[argnum] - j, h = jac[argnum], hess[argnum] - k, = j.shape - log_like += j.T @ z_diff + .5 * z_diff.T @ h.reshape(k, k) @ z_diff - return ll_ref[name].sum() + log_like - - def uproxy(name, value, z): - ref_dist = ref_trace[name]['fn'] - z_ref, aux_data = ref_dist.tree_flatten() - - log_prob = lambda *params: ref_dist.tree_unflatten(aux_data, params).log_prob(value).sum() - log_like = jnp.array(0.) - for argnum in range(len(z_ref)): - z_diff = z[argnum] - z_ref[argnum] - jac = jax.jacobian(log_prob, argnum)(*z_ref) - k, = jac.shape - hess = jax.hessian(log_prob, argnum)(*z_ref) - log_like += jac @ z_diff + .5 * z_diff @ hess.reshape(k, k) @ z_diff.T - - return log_prob(*z_ref).sum() + log_like - - return proxy, uproxy - - -class subsample_size(Messenger): - def __init__(self, fn, plate_sizes, rng_key=None): - super(subsample_size, self).__init__(fn) - self.plate_sizes = plate_sizes - self.rng_key = rng_key - - def process_message(self, msg): - if msg['type'] == 'plate' and msg['args'] and msg["args"][0] > msg["args"][1]: - if msg['name'] in self.plate_sizes: - msg['args'] = self.plate_sizes[msg['name']] - msg['value'] = _subsample_fn(*msg['args'], self.rng_key) if msg["args"][1] < msg["args"][ - 0] else jnp.arange(msg["args"][0]) - - -class DifferenceEstimator: - def __init__(self, name, proxy, uproxy, plate_name, plate_size): - self._name = name - self.plate_name = plate_name - self.size = plate_size - self.proxy = proxy - self.uproxy = uproxy - self.subsample = None - self._dist = None - self._predecessors = None - - def __call__(self, dist, predecessors): - self.dist = dist - self.predecessors = predecessors - - def log_prob(self, value): - n, m, g = self.size - ll_sub = self.dist.log_prob(value).sum() - diff = ll_sub - self.uproxy(name=self._name, - value=value, - subsample=self.predecessors[self.plate_name], - predecessors=self.predecessors) - l_hat = self.proxy(self._name) + n / m * diff - sigma = n ** 2 / m * jnp.var(diff) - return l_hat - .5 * sigma - - -def variational_proxy(guide_trace, evidence, weights): - def _log_like(predecessors): - log_prob = jnp.array(0.) - for pred in predecessors: - if pred['type'] == 'sample': - val = pred['value'] - name = pred['name'] - log_prob += guide_trace[name]['fn'].log_prob(val) - pred['fn'].log_prob(val) - return log_prob - - def proxy(name, predecessors, *args, **kwargs): - return evidence[name] + _log_like(predecessors) - - def uproxy(name, predecessors, subsample, *args, **kwargs): - return evidence[name] + weights[name][subsample].sum() * _log_like(predecessors) - - return proxy, uproxy diff --git a/numpyro/contrib/trace_struct.py b/numpyro/contrib/trace_struct.py deleted file mode 100644 index 34737992b..000000000 --- a/numpyro/contrib/trace_struct.py +++ /dev/null @@ -1,39 +0,0 @@ -from collections import OrderedDict - - -class TraceStructure: - """ - Graph structure denoting the relationship among pyro primitives in the execution path. - """ - - def __init__(self): - self.nodes = OrderedDict() - self._successors = OrderedDict() - self._predecessors = OrderedDict() - - def __contains__(self, site): - return site in self.nodes - - def add_edge(self, from_site, to_site): - for site in (from_site, to_site): - if site not in self: - self.add_node(site) - - self._successors[from_site].add(to_site) - self._predecessors[to_site].add(to_site) - - def add_node(self, site_name, **kwargs): - if site_name in self: - # TODO: handle reused name! - pass - self.nodes[site_name] = kwargs - self._successors[site_name] = set() - self.__predecessors[site_name] = set() - - def predecessor(self, site): - return self._predecessors[site] - - def successor(self, site): - return self._successors[site] - - # TODO: remove edge diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index d2ed17849..07329e631 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -543,4 +543,4 @@ def wrapper(self, *args, **kwargs): log_prob = jnp.where(mask, log_prob, -jnp.inf) return log_prob - return wrapper \ No newline at end of file + return wrapper diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 92c7f6135..a577e7446 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -18,15 +18,6 @@ CondIndepStackFrame = namedtuple('CondIndepStackFrame', ['name', 'dim', 'size']) -@contextmanager -def inner_stack(): - global _PYRO_STACK - current_stack = _PYRO_STACK - _PYRO_STACK = [] - yield - _PYRO_STACK = current_stack - - def apply_stack(msg): pointer = 0 for pointer, handler in enumerate(reversed(_PYRO_STACK)): From e1150eacb5f9f4fde06a88008ecab64e7f6e86c4 Mon Sep 17 00:00:00 2001 From: ola Date: Fri, 5 Feb 2021 13:45:36 +0100 Subject: [PATCH 72/93] Removed old HMCECS logistic examples. --- .gitignore | 2 +- examples/logistic_hmcecs.py | 364 -------------------------------- examples/logistic_hmcecs_svi.py | 63 ------ 3 files changed, 1 insertion(+), 428 deletions(-) delete mode 100644 examples/logistic_hmcecs.py delete mode 100644 examples/logistic_hmcecs_svi.py diff --git a/.gitignore b/.gitignore index dbc16a61d..4259586b1 100644 --- a/.gitignore +++ b/.gitignore @@ -33,4 +33,4 @@ numpyro/examples/.data # docs docs/build -docs/.DS_Store \ No newline at end of file +docs/.DS_Store diff --git a/examples/logistic_hmcecs.py b/examples/logistic_hmcecs.py deleted file mode 100644 index 190483371..000000000 --- a/examples/logistic_hmcecs.py +++ /dev/null @@ -1,364 +0,0 @@ -""" Logistic regression model as implemetned in https://arxiv.org/pdf/1708.00955.pdf with Higgs Dataset """ -#!/usr/bin/env python -import jax -import jax.numpy as jnp -import numpyro -import numpyro.distributions as dist -from numpyro.infer import NUTS, MCMC, Predictive,HMC -import sys, os -from jax.config import config -import datetime,time -import argparse -import numpy as np -from numpyro.distributions.kl import kl_divergence -from matplotlib.pyplot import cm -#remember to export the path of the project -sys.path.append('/home/lys/Dropbox/PhD/numpyro/numpyro/contrib/') -sys.path.append('/home/lys/Dropbox/PhD/numpyro/numpyro/examples/') - -from hmcecs import HMCECS -from hmcecs_utils import poisson_samples_correction -#from numpyro.contrib.hmcecs import HMC - -from sklearn.datasets import load_breast_cancer -from datasets import _load_higgs -#from numpyro.examples.datasets import _load_higgs -from logistic_hmcecs_svi import svi_map -import jax.numpy as np_jax -import matplotlib.pyplot as plt -import pandas as pd -import seaborn as sns -import time -from numpyro.diagnostics import summary -from jax.tree_util import tree_flatten,tree_map - -numpyro.set_platform("cpu") - -def breast_cancer_data(): - dataset = load_breast_cancer() - feats = dataset.data - feats = (feats - feats.mean(0)) / feats.std(0) - feats = jnp.hstack((feats, jnp.ones((feats.shape[0], 1)))) - - return feats, dataset.target - - -def higgs_data(): - observations,features = _load_higgs() - return features,observations -def save_obj(obj, name): - import _pickle as cPickle - import bz2 - with bz2.BZ2File(name, "wb") as f: - cPickle.dump(obj, f) -def load_obj(name): - import _pickle as cPickle - import bz2 - data = bz2.BZ2File(name, "rb") - data = cPickle.load(data) - - return data - -def model(feats, obs): - """ Logistic regression model - - """ - n, m = feats.shape - theta = numpyro.sample('theta', dist.continuous.Normal(jnp.zeros(m), 2 * jnp.ones(m))) - numpyro.sample('obs', dist.Bernoulli(logits=jnp.matmul(feats, theta)), obs=obs) - -def infer_nuts(rng_key, feats, obs, samples, warmup ): - kernel = NUTS(model=model,target_accept_prob=0.8) - mcmc = MCMC(kernel, num_warmup=warmup, num_samples=samples) - mcmc.run(rng_key, feats, obs) - #mcmc.print_summary() - samples = mcmc.get_samples() - samples = tree_map(lambda x: x[None, ...], samples) - r_hat_average = np_jax.sum(summary(samples)["theta"]["r_hat"])/len(summary(samples)["theta"]["r_hat"]) - - return mcmc.get_samples(), r_hat_average - - - - -def infer_hmc(rng_key, feats, obs, samples, warmup ): - kernel = HMC(model=model,target_accept_prob=0.8) - mcmc = MCMC(kernel, num_warmup=warmup, num_samples=samples) - mcmc.run(rng_key, feats, obs) - #mcmc.print_summary() - samples = mcmc.get_samples() - samples = tree_map(lambda x: x[None, ...], samples) - r_hat_average = np_jax.sum(summary(samples)["theta"]["r_hat"])/len(summary(samples)["theta"]["r_hat"]) - - return mcmc.get_samples(), r_hat_average - - - - - -def infer_hmcecs(rng_key, feats, obs, m=None,g=None,n_samples=None, warmup=None,algo="NUTS",subsample_method=None,map_method=None,proxy="taylor",estimator=None,num_epochs=None,postprocess_fn=None ): - hmcecs_key, map_key = jax.random.split(rng_key) - n, _ = feats.shape - file_hyperparams = open("PLOTS_{}/Hyperparameters_{}.txt".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms"),now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")), "a") - - if subsample_method=="perturb" and proxy== "taylor": - map_samples = 10 - map_warmup = 5 - factor_NUTS = 50 - if map_method == "NUTS": - print("Running NUTS for map estimation {} + {} samples".format(map_samples,map_warmup)) - file_hyperparams.write('MAP samples : {} \n'.format(map_samples)) - file_hyperparams.write('MAP warmup : {} \n'.format(map_warmup)) - samples,r_hat_average = infer_nuts(map_key, feats[:factor_NUTS], obs[:factor_NUTS],samples=map_samples,warmup=map_warmup) - z_ref = {key: value.mean(0) for key, value in samples.items()} - if map_method == "HMC": - print("Running HMC for map estimation") - file_hyperparams.write('MAP samples : {} \n'.format(map_samples)) - file_hyperparams.write('MAP warmup : {} \n'.format(map_warmup)) - samples, r_hat_average = infer_hmc(map_key, feats[:factor_NUTS], obs[:factor_NUTS], samples=map_samples, warmup=map_warmup) - z_ref = {key: value.mean(0) for key, value in samples.items()} - if map_method == "SVI": - print("Running SVI for map estimation") - file_hyperparams.write('SVI epochs : {} \n'.format(num_epochs)) - z_ref = svi_map(model, map_key, feats=feats, obs=obs,num_epochs=num_epochs,batch_size = m) - z_ref = {k[5:]: v for k, v in z_ref.items()} #highlight: [5:] is to skip the "auto" part - svi = None - save_obj(z_ref,"{}/MAP_Dict_Samples_MAP_{}.pkl".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")), map_method)) - print("Running MCMC subsampling with Taylor proxy") - elif subsample_method =="perturb" and proxy=="svi": - factor_SVI = obs.shape[0] - batch_size = 32 #int(factor_SVI//10) - print("Running SVI for map estimation with svi proxy") - file_hyperparams.write('SVI epochs : {} \n'.format(num_epochs)) - map_key, post_key = jax.random.split(map_key) - z_ref, svi, svi_state = svi_map(model, map_key, feats=feats[:factor_SVI], obs=obs[:factor_SVI], - num_epochs=num_epochs, batch_size=batch_size) - z_ref = svi.guide.sample_posterior(post_key, svi.get_params(svi_state), (100,)) - z_ref = {name: value.mean(0) for name, value in z_ref.items()} #highlight: AutoDiagonalNormal does not have auto_ in front of the parmeters - - save_obj(z_ref,"{}/MAP_Dict_Samples_Proxy_{}.pkl".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")), - proxy)) - print("Running MCMC subsampling with SVI proxy") - - else: - z_ref = None - svi = None - - start = time.time() - extra_fields = [] - if estimator == "poisson": - postprocess_fn = None # poisson_samples_correction - extra_fields = ("sign",) - kernel = HMCECS(model=model,z_ref=z_ref,m=m,g=g,algo=algo, - subsample_method=subsample_method,proxy=proxy,svi_fn=svi, - estimator = estimator,target_accept_prob=0.8)#,postprocess_fn=postprocess_fn) - - mcmc = MCMC(kernel,num_warmup=warmup,num_samples=n_samples,num_chains=1,postprocess_fn=postprocess_fn) - mcmc.run(rng_key,feats,obs,extra_fields=extra_fields) - extra_fields = mcmc.get_extra_fields() - stop = time.time() - file_hyperparams.write('MCMC/NUTS elapsed time {}: {} \n'.format(subsample_method,time.time() - start)) - file_hyperparams.write('Effective size {}: {}\n'.format(subsample_method,n_samples)) - file_hyperparams.write('Warm up size {}: {}\n'.format(subsample_method,warmup)) - file_hyperparams.write('Subsample size (m): {}\n'.format(m)) - file_hyperparams.write('Block size (g): {}\n'.format(g)) - file_hyperparams.write('Data size (n): {}\n'.format(feats.shape[0])) - file_hyperparams.write('Estimator: {}\n'.format(estimator)) - file_hyperparams.write('...........................................\n') - file_hyperparams.close() - #print(mcmc.get_samples().keys()) - save_obj(mcmc.get_samples(),"{}/MCMC_Dict_Samples_{}_m_{}.pkl".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")),subsample_method,m)) - - return mcmc.get_samples() - - - -def Determine_best_sample_size(rng_key,feats,obs): - """Determine amount of effective sample size for z_map initialization""" - effective_sample_list=[5,10,20,30,50] - r_hat_average_list=[] - for effective_sample in effective_sample_list: - samples, r_hat_average = infer_nuts(rng_key,feats,obs,effective_sample,warmup=6) - r_hat_average_list.append(r_hat_average) - - plt.plot(effective_sample_list,r_hat_average_list) - plt.xlabel(r"Effective sample size") - plt.ylabel(r"$\hat{r}$") - plt.title("Determine best effective sample size for z_map") - plt.savefig("{}/Best_effective_size_z_map.png".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")))) -def Plot(samples_ECS,samples_NUTS,ecs_algo,algo,proxy,estimator,m,kl=None): - if estimator : - label = "ECS-{}-{} proxy-{} estimator".format(ecs_algo, proxy, estimator) - else: - label = "ECS-{}-{} proxy".format(ecs_algo, proxy) - for sample in [0,7,15,25]: - plt.figure(sample + m +3) - #samples = pd.DataFrame.from_records(samples,index="theta") - - sns.kdeplot(data=samples_ECS["theta"][sample],color="r",label=label) - sns.kdeplot(data=samples_NUTS["theta"][sample],color="b",label="{}".format(algo)) - #if kl != None: - # sns.kdeplot(data=kl, color="g", label="KL; m: {}".format(m)) - plt.xlabel(r"$\theta") - plt.ylabel("Density") - plt.legend() - plt.title(r"$\theta$ {} m: {} Density plot".format(sample,str(m))) - plt.savefig("{}/KDE_plot_theta_{}_m_{}.png".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")),sample,str(m))) - plt.clf() -def Folders(folder_name): - """ Folder for all the generated images It will updated everytime!!! Save the previous folder before running again. Creates folder in current directory""" - import os - import shutil - basepath = os.getcwd() - if not basepath: - newpath = folder_name - else: - newpath = basepath + "/%s" % folder_name - - if not os.path.exists(newpath): - try: - original_umask = os.umask(0) - os.makedirs(newpath, 0o777) - finally: - os.umask(original_umask) - else: - shutil.rmtree(newpath) # removes all the subdirectories! - os.makedirs(newpath,0o777) -def Plot_KL(map_method,ecs_algo,algo,proxy,estimator,n_samples,n_warmup,epochs): - factor_ECS= 50 #obs.shape[0] - m = [int(np_jax.sqrt(obs[:factor_ECS].shape[0])),2*int(np_jax.sqrt(obs[:factor_ECS].shape[0])),4*int(np_jax.sqrt(obs[:factor_ECS].shape[0])),8*int(np_jax.sqrt(obs[:factor_ECS].shape[0]))] - g = 5 - factor_NUTS = 50 - colors = cm.rainbow(np.linspace(0, 1, len(m))) - run_test = False - if run_test: - print("Running standard NUTS") - est_posterior_NUTS = infer_hmcecs(rng_key, feats=feats[:factor_NUTS], obs=obs[:factor_NUTS], - n_samples=n_samples, warmup=n_warmup, m="all", g=g, algo=algo) - for m_val, color in zip(m,colors): - est_posterior_ECS = infer_hmcecs(rng_key, feats=feats[:factor_ECS], obs=obs[:factor_ECS], - n_samples=n_samples, - warmup=n_warmup, - m=m_val, g=g, - algo=ecs_algo, - subsample_method="perturb", - proxy=proxy, - estimator=estimator, - map_method=map_method, - num_epochs=epochs) - - p = dist.Normal(est_posterior_ECS["theta"]) - q = dist.Normal(est_posterior_NUTS["theta"]) - kl = kl_divergence(p, q) - - Plot(samples_ECS=est_posterior_ECS, - samples_NUTS=est_posterior_NUTS, - ecs_algo= ecs_algo, - algo=algo, - proxy= proxy, - estimator = estimator, - m = m_val, - kl=kl) - exit() - -def Tests(map_method,ecs_algo,algo,estimator,n_samples,n_warmup,epochs,proxy): - m = int(np_jax.sqrt(obs.shape[0])*2) - g= 5 - est_posterior_ECS = infer_hmcecs(rng_key, feats=feats, obs=obs, - n_samples=n_samples, - warmup=n_warmup, - m =m, - g=g, - algo=ecs_algo, - subsample_method="perturb", - proxy=proxy, - estimator = estimator, - map_method = map_method, - num_epochs=epochs) - est_posterior_NUTS = infer_hmcecs(rng_key, - feats=feats, - obs=obs, - n_samples=n_samples, - warmup=n_warmup, - m =m, - g=g, - algo=algo) - - Plot(est_posterior_ECS,est_posterior_NUTS,ecs_algo,algo,proxy,estimator,m) - -if __name__ == '__main__': - - parser = argparse.ArgumentParser() - parser.add_argument('-num_samples', nargs='?', default=10,type=int) - parser.add_argument('-num_warmup', nargs='?', default=5, type=int) - parser.add_argument('-ecs_algo', nargs='?', default="NUTS", type=str) - parser.add_argument('-ecs_proxy', nargs='?', default="taylor", type=str) - parser.add_argument('-algo', nargs='?', default="HMC", type=str) - parser.add_argument('-estimator', nargs='?', default="poisson", type=str) - parser.add_argument('-map_init', nargs='?', default="NUTS", type=str) - parser.add_argument("-epochs",default=10,type=int) - args = parser.parse_args() - - - rng_key = jax.random.PRNGKey(37) - - rng_key, feat_key, obs_key = jax.random.split(rng_key, 3) - if args.ecs_proxy == "svi": - args.map_init = "SVI" - - now = datetime.datetime.now() - Folders("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms"))) - file_hyperparams = open("PLOTS_{}/Hyperparameters_{}.txt".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms"), - now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")), "a") - file_hyperparams.write('ECS algo : {} \n'.format(args.ecs_algo)) - file_hyperparams.write('algo : {} \n'.format(args.algo)) - file_hyperparams.write('ECS proxy : {} \n'.format(args.ecs_proxy)) - file_hyperparams.write('MAP init : {} \n'.format(args.map_init)) - - higgs = False - if higgs: - feats,obs = higgs_data() - file_hyperparams.write('Dataset : HIGGS \n') - - else: - feats, obs = breast_cancer_data() - file_hyperparams.write('Dataset : BREAST CANCER DATA \n') - - file_hyperparams.close() - config.update('jax_disable_jit', True) - - #Determine_best_sample_size(rng_key,feats[:100],obs[:100]) - #Tests(args.map_init,args.ecs_algo,args.algo,args.num_samples,args.num_warmup,args.epochs,args.ecs_proxy) - Plot_KL(args.map_init,args.ecs_algo,args.algo,args.ecs_proxy,args.estimator,args.num_samples,args.num_warmup,args.epochs) - - - exit() - samples_ECS_3316 = load_obj("/home/lys/Dropbox/PhD/numpyro/examples/PLOTS_2020_10_09_11h41min24s333577ms_DONOTREMOVE/MCMC_Dict_Samples_perturb_m_3316.pkl") - samples_ECS_6632 = load_obj("/home/lys/Dropbox/PhD/numpyro/examples/PLOTS_2020_10_09_11h41min24s333577ms_DONOTREMOVE/MCMC_Dict_Samples_perturb_m_6632.pkl") - samples_ECS_132264 = load_obj("/home/lys/Dropbox/PhD/numpyro/examples/PLOTS_2020_10_09_11h41min24s333577ms_DONOTREMOVE/MCMC_Dict_Samples_perturb_m_13264.pkl") - - samples_HMC = load_obj("/home/lys/Dropbox/PhD/numpyro/examples/PLOTS_2020_10_09_11h41min24s333577ms_DONOTREMOVE/MCMC_Dict_Samples_None_m_all.pkl") - - p = dist.Normal(samples_ECS_3316["theta"]) - q = dist.Normal(samples_HMC["theta"]) - kl = kl_divergence(p, q) - Plot(samples_ECS=samples_ECS_3316, - samples_NUTS=samples_HMC, - ecs_algo= args.ecs_algo, - algo=args.algo, - proxy= args.proxy, - estimator = args.estimator, - m = 3316, - kl=kl) - - # samples = pd.DataFrame.from_records(samples,index="theta") - # sns.kdeplot(data=kl, color=color, label="m : ".format(m_val)) - # plt.figure(m_val) - # plt.xlabel(r"$\theta") - # plt.ylabel("Density") - # plt.legend() - # plt.title(r"$\theta$ KL-divergence") - # plt.savefig("{}/KL_divergence_m_{}.png".format("PLOTS_{}".format(now.strftime("%Y_%m_%d_%Hh%Mmin%Ss%fms")), - # str(m_val))) - - diff --git a/examples/logistic_hmcecs_svi.py b/examples/logistic_hmcecs_svi.py deleted file mode 100644 index 3e67ab226..000000000 --- a/examples/logistic_hmcecs_svi.py +++ /dev/null @@ -1,63 +0,0 @@ -import jax.numpy as np_jax -import numpy as np -from jax import lax -def load_dataset(observations,features, batch_size=None, shuffle=True): - - arrays = (observations,features) - num_records = observations.shape[0] - idxs = np_jax.arange(num_records) - if not batch_size: - batch_size = num_records - - def init(): - return num_records // batch_size, np.random.permutation(idxs) if shuffle else idxs - - def get_batch(i=0, idxs=idxs): - ret_idx = lax.dynamic_slice_in_dim(idxs, i * batch_size, batch_size) - batch_data = np_jax.take(arrays[0], ret_idx, axis=0) - batch_matrix =np_jax.take(np_jax.take(arrays[1], ret_idx, axis=0),ret_idx,axis=1) - return (batch_data,batch_matrix) - return init, get_batch -def svi_map(model, rng_key, feats,obs,num_epochs,batch_size): - """ - MLE in numpy: https://medium.com/@rrfd/what-is-maximum-likelihood-estimation-examples-in-python-791153818030i - Cost function: -log (likelihood(parameters|data) - Calculate pdf of the parameter|data under the distribution - """ - from jax import random, jit - from numpyro import optim - from numpyro.infer.elbo import RenyiELBO, ELBO - from numpyro.infer.svi import SVI - from numpyro.util import fori_loop - import time - import numpyro - numpyro.set_platform("gpu") - - from autoguide_hmcecs import AutoDelta, AutoDiagonalNormal - n, _ = feats.shape - #guide = AutoDelta(model) - guide = AutoDiagonalNormal(model) - #loss = RenyiELBO(alpha=2, num_particles=1) - loss = ELBO() - svi = SVI(model, guide, optim.Adam(0.0003), loss=loss) - svi_state = svi.init( rng_key,feats,obs) - train_init, train_fetch = load_dataset(obs,feats, batch_size=batch_size) - num_train, train_idx = train_init() - - @jit - def epoch_train(svi_state): - def body_fn(i, val): - batch_obs = train_fetch(i, train_idx)[0] - batch_feats = train_fetch(i, train_idx)[1] - loss_sum, svi_state = val - svi_state, loss = svi.update(svi_state, feats,obs) - loss_sum += loss - return loss_sum, svi_state - - return fori_loop(0, n, body_fn, (0., svi_state)) - - for i in range(num_epochs): - t_start = time.time() - train_loss, svi_state = epoch_train(svi_state) - print("Epoch {}: loss = {} ({:.2f} s.)".format(i, train_loss, time.time() - t_start)) - return svi.get_params(svi_state), svi, svi_state \ No newline at end of file From c0a1c4c9d17552c1d7f92eca3bff11897ec8c277 Mon Sep 17 00:00:00 2001 From: ola Date: Fri, 5 Feb 2021 13:46:23 +0100 Subject: [PATCH 73/93] removed old autoguide --- numpyro/contrib/autoguide_hmcecs.py | 737 ---------------------------- 1 file changed, 737 deletions(-) delete mode 100644 numpyro/contrib/autoguide_hmcecs.py diff --git a/numpyro/contrib/autoguide_hmcecs.py b/numpyro/contrib/autoguide_hmcecs.py deleted file mode 100644 index 9de321a6a..000000000 --- a/numpyro/contrib/autoguide_hmcecs.py +++ /dev/null @@ -1,737 +0,0 @@ -# Copyright Contributors to the Pyro project. -# SPDX-License-Identifier: Apache-2.0 - -# Adapted from pyro.infer.autoguide -from abc import ABC, abstractmethod -import warnings - -from jax import hessian, lax, random, tree_map -from jax.experimental import stax -from jax.flatten_util import ravel_pytree -import jax.numpy as jnp - -import numpyro -from numpyro import handlers -from numpyro.nn.auto_reg_nn import AutoregressiveNN -from numpyro.nn.block_neural_arn import BlockNeuralAutoregressiveNN -import numpyro.distributions as dist -from numpyro.distributions import constraints -from numpyro.distributions.flows import BlockNeuralAutoregressiveTransform, InverseAutoregressiveTransform -from numpyro.distributions.transforms import ( - AffineTransform, - ComposeTransform, - LowerCholeskyAffine, - PermuteTransform, - UnpackTransform, - biject_to -) -from numpyro.distributions.util import cholesky_of_inverse, sum_rightmost -from numpyro.infer.elbo import ELBO -from numpyro.infer.util import initialize_model, init_to_uniform, find_valid_initial_params -from numpyro.util import not_jax_tracer -from contextlib import ExitStack - -__all__ = [ - 'AutoContinuous', - 'AutoGuide', - 'AutoDiagonalNormal', - 'AutoLaplaceApproximation', - 'AutoLowRankMultivariateNormal', - 'AutoMultivariateNormal', - 'AutoBNAFNormal', - 'AutoIAFNormal', - 'AutoDelta' -] - - -class ReinitGuide(ABC): - @abstractmethod - def init_params(self): - raise NotImplementedError - - @abstractmethod - def find_params(self, rng_keys, *args, **kwargs): - raise NotImplementedError - - -class AutoGuide(ABC): - """ - Base class for automatic guides. - - Derived classes must implement the :meth:`__call__` method. - - :param callable model: a pyro model - :param str prefix: a prefix that will be prefixed to all param internal sites - """ - - def __init__(self, model, prefix='auto', create_plates=None): - assert isinstance(prefix, str) - self.model = model - self.prefix = prefix - self.prototype_trace = None - self._prototype_frames = {} - self.create_plates = create_plates - - @abstractmethod - def __call__(self, *args, **kwargs): - """ - A guide with the same ``*args, **kwargs`` as the base ``model``. - - :return: A dict mapping sample site name to sampled value. - :rtype: dict - """ - raise NotImplementedError - - @abstractmethod - def sample_posterior(self, rng_key, params, *args, **kwargs): - """ - Generate samples from the approximate posterior over the latent - sites in the model. - - :param jax.random.PRNGKey rng_key: PRNG seed. - :param params: Current parameters of model and autoguide. - :param sample_shape: (keyword argument) shape of samples to be drawn. - :return: batch of samples from the approximate posterior. - """ - raise NotImplementedError - - @abstractmethod - def _sample_latent(self, *args, **kwargs): - """ - Samples an encoded latent given the same ``*args, **kwargs`` as the - base ``model``. - """ - raise NotImplementedError - - def _setup_prototype(self, *args, **kwargs): - # run the model so we can inspect its structure - rng_key = random.PRNGKey(0) - # rng_key = numpyro.rng_key("_{}_rng_key_setup".format(self.prefix)) - model = handlers.seed(self.model, rng_key) - self.prototype_trace = handlers.block(handlers.trace(model).get_trace)(*args, **kwargs) - self._args = args - self._kwargs = kwargs - for _, site in self.prototype_trace.items(): - if site['type'] != 'sample' or site['is_observed']: - continue - for frame in site['cond_indep_stack']: - if frame.vectorized: - self._prototype_frames[frame.name] = frame - else: - raise NotImplementedError("AutoGuide does not support sequential numpyro.plate") - - def _create_plates(self, *args, **kwargs): - if self.create_plates is None: - self.plates = {} - else: - plates = self.create_plates(*args, **kwargs) - if isinstance(plates, numpyro.plate): - plates = [plates] - assert all(isinstance(p, numpyro.plate) for p in plates), \ - "create_plates() returned a non-plate" - self.plates = {p.name: p for p in plates} - for name, frame in sorted(self._prototype_frames.items()): - if name not in self.plates: - self.plates[name] = numpyro.plate(name, frame.size, dim=frame.dim) - return self.plates - - -class AutoContinuous(AutoGuide): - """ - Base class for implementations of continuous-valued Automatic - Differentiation Variational Inference [1]. - - Each derived class implements its own :meth:`_get_posterior` method. - - Assumes model structure and latent dimension are fixed, and all latent - variables are continuous. - - **Reference:** - - 1. *Automatic Differentiation Variational Inference*, - Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. - Blei - - :param callable model: A NumPyro model. - :param str prefix: a prefix that will be prefixed to all param internal sites. - :param callable init_strategy: A per-site initialization function. - See :ref:`init_strategy` section for available functions. - """ - - def __init__(self, model, prefix="auto", init_strategy=init_to_uniform): - self.init_strategy = init_strategy - super(AutoContinuous, self).__init__(model, prefix=prefix) - - def _setup_prototype(self, *args, **kwargs): - rng_key = random.PRNGKey(0) - # rng_key = numpyro.rng_key("_{}_rng_key_setup".format(self.prefix)) - with handlers.block(): - init_params, _, self._postprocess_fn, self.prototype_trace = initialize_model( - rng_key, self.model, - init_strategy=self.init_strategy, - dynamic_args=False, - model_args=args, - model_kwargs=kwargs) - - self._init_latent, unpack_latent = ravel_pytree(init_params[0]) - # this is to match the behavior of Pyro, where we can apply - # unpack_latent for a batch of samples - self._unpack_latent = UnpackTransform(unpack_latent) - self.latent_dim = jnp.size(self._init_latent) - if self.latent_dim == 0: - raise RuntimeError('{} found no latent variables; Use an empty guide instead' - .format(type(self).__name__)) - - @abstractmethod - def _get_posterior(self): - raise NotImplementedError - - def _sample_latent(self, *args, **kwargs): - sample_shape = kwargs.pop('sample_shape', ()) - posterior = self._get_posterior() - return numpyro.sample("_{}_latent".format(self.prefix), posterior, sample_shape=sample_shape) - - def expectation(self, latent): - """Computes the expectation/probabilities of the parameters of the guide. The expectation over the variance over the latent space is bounded - using the reparametrization trick""" - if self.prototype_trace is None: - raise ValueError() # TODO: fix value error - - result = {} - for name, unconstrained_value in latent.items(): - site = self.prototype_trace[name] - transform = biject_to(site['fn'].support) - - value = transform(unconstrained_value) - log_density = - transform.log_abs_det_jacobian(unconstrained_value, value) - event_ndim = len(site['fn'].event_shape) - log_density = sum_rightmost(log_density, - jnp.ndim(log_density) - jnp.ndim(value) + event_ndim) - prob = jnp.exp(log_density) - result[name] = prob * value - return result - - def __call__(self, *args, **kwargs): - """ - An automatic guide with the same ``*args, **kwargs`` as the base ``model``. - - :return: A dict mapping sample site name to sampled value. - :rtype: dict - """ - if self.prototype_trace is None: - # run model to inspect the model structure - self._setup_prototype(*args, **kwargs) - - latent = self._sample_latent(*args, **kwargs) - - # unpack continuous latent samples - result = {} - - for name, unconstrained_value in self._unpack_latent(latent).items(): - site = self.prototype_trace[name] - transform = biject_to(site['fn'].support) - value = transform(unconstrained_value) - log_density = - transform.log_abs_det_jacobian(unconstrained_value, value) - event_ndim = len(site['fn'].event_shape) - log_density = sum_rightmost(log_density, - jnp.ndim(log_density) - jnp.ndim(value) + event_ndim) - delta_dist = dist.Delta(value, log_density=log_density, event_dim=event_ndim) - result[name] = numpyro.sample(name, delta_dist) - - return result - - def _unpack_and_constrain(self, latent_sample, params): - def unpack_single_latent(latent): - unpacked_samples = self._unpack_latent(latent) - # add param sites in model - unpacked_samples.update({k: v for k, v in params.items() if k in self.prototype_trace - and v['type'] == 'param'}) - return self._postprocess_fn(unpacked_samples) - - sample_shape = jnp.shape(latent_sample)[:-1] - if sample_shape: - latent_sample = jnp.reshape(latent_sample, (-1, jnp.shape(latent_sample)[-1])) - unpacked_samples = lax.map(unpack_single_latent, latent_sample) - return tree_map(lambda x: jnp.reshape(x, sample_shape + jnp.shape(x)[1:]), - unpacked_samples) - else: - return unpack_single_latent(latent_sample) - - def get_base_dist(self): - """ - Returns the base distribution of the posterior when reparameterized - as a :class:`~numpyro.distributions.distribution.TransformedDistribution`. This - should not depend on the model's `*args, **kwargs`. - """ - raise NotImplementedError - - def get_transform(self, params): - """ - Returns the transformation learned by the guide to generate samples from the unconstrained - (approximate) posterior. - - :param dict params: Current parameters of model and autoguide. - The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params` - method from :class:`~numpyro.infer.svi.SVI`. - :return: the transform of posterior distribution - :rtype: :class:`~numpyro.distributions.transforms.Transform` - """ - posterior = handlers.substitute(self._get_posterior, params)() - assert isinstance(posterior, dist.TransformedDistribution), \ - "posterior is not a transformed distribution" - if len(posterior.transforms) > 0: - return ComposeTransform(posterior.transforms) - else: - return posterior.transforms[0] - - def get_posterior(self, params): - """ - Returns the posterior distribution. - - :param dict params: Current parameters of model and autoguide. - The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params` - method from :class:`~numpyro.infer.svi.SVI`. - """ - base_dist = self.get_base_dist() - transform = self.get_transform(params) - return dist.TransformedDistribution(base_dist, transform) - - def sample_posterior(self, rng_key, params, sample_shape=()): - """ - Get samples from the learned posterior. - - :param jax.random.PRNGKey rng_key: random key to be used draw samples. - :param dict params: Current parameters of model and autoguide. - The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params` - method from :class:`~numpyro.infer.svi.SVI`. - :param tuple sample_shape: batch shape of each latent sample, defaults to (). - :return: a dict containing samples drawn the this guide. - :rtype: dict - """ - latent_sample = handlers.substitute( - handlers.seed(self._sample_latent, rng_key), params)(sample_shape=sample_shape) - return self._unpack_and_constrain(latent_sample, params) - - def median(self, params): - """ - Returns the posterior median value of each latent variable. - - :param dict params: A dict containing parameter values. - The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params` - method from :class:`~numpyro.infer.svi.SVI`. - :return: A dict mapping sample site name to median tensor. - :rtype: dict - """ - raise NotImplementedError - - def quantiles(self, params, quantiles): - """ - Returns posterior quantiles each latent variable. Example:: - - print(guide.quantiles(opt_state, [0.05, 0.5, 0.95])) - - :param dict params: A dict containing parameter values. - The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params` - method from :class:`~numpyro.infer.svi.SVI`. - :param list quantiles: A list of requested quantiles between 0 and 1. - :return: A dict mapping sample site name to a list of quantile values. - :rtype: dict - """ - raise NotImplementedError - - -class AutoDiagonalNormal(AutoContinuous): - """ - This implementation of :class:`AutoContinuous` uses a Normal distribution - with a diagonal covariance matrix to construct a guide over the entire - latent space. The guide does not depend on the model's ``*args, **kwargs``. - - Usage:: - - guide = AutoDiagonalNormal(model, ...) - svi = SVI(model, guide, ...) - """ - - def __init__(self, model, prefix="auto", init_strategy=init_to_uniform, init_scale=0.1): - if init_scale <= 0: - raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) - self._init_scale = init_scale - super().__init__(model, prefix, init_strategy) - - def _get_posterior(self): - loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent) - scale = numpyro.param('{}_scale'.format(self.prefix), - jnp.full(self.latent_dim, self._init_scale), - constraint=constraints.positive) - return dist.Normal(loc, scale) - - def get_base_dist(self): - return dist.Normal(jnp.zeros(self.latent_dim), 1).to_event(1) - - def get_transform(self, params): - loc = params['{}_loc'.format(self.prefix)] - scale = params['{}_scale'.format(self.prefix)] - return AffineTransform(loc, scale, domain=constraints.real_vector) - - def get_posterior(self, params): - """ - Returns a diagonal Normal posterior distribution. - """ - transform = self.get_transform(params) - return dist.Normal(transform.loc, transform.scale) - - def median(self, params): - loc = params['{}_loc'.format(self.prefix)] - return self._unpack_and_constrain(loc, params) - - def quantiles(self, params, quantiles): - quantiles = jnp.array(quantiles)[..., None] - latent = self.get_posterior(params).icdf(quantiles) - return self._unpack_and_constrain(latent, params) - - -class AutoMultivariateNormal(AutoContinuous): - """ - This implementation of :class:`AutoContinuous` uses a MultivariateNormal - distribution to construct a guide over the entire latent space. - The guide does not depend on the model's ``*args, **kwargs``. - - Usage:: - - guide = AutoMultivariateNormal(model, ...) - svi = SVI(model, guide, ...) - """ - - def __init__(self, model, prefix="auto", init_strategy=init_to_uniform, init_scale=0.1): - if init_scale <= 0: - raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) - self._init_scale = init_scale - super().__init__(model, prefix, init_strategy) - - def _get_posterior(self): - loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent) - scale_tril = numpyro.param('{}_scale_tril'.format(self.prefix), - jnp.identity(self.latent_dim) * self._init_scale, - constraint=constraints.lower_cholesky) - return dist.MultivariateNormal(loc, scale_tril=scale_tril) - - def get_base_dist(self): - return dist.Normal(jnp.zeros(self.latent_dim), 1).to_event(1) - - def get_transform(self, params): - loc = params['{}_loc'.format(self.prefix)] - scale_tril = params['{}_scale_tril'.format(self.prefix)] - return LowerCholeskyAffine(loc, scale_tril) # TODO: Changed MultivariateAffineTransform to LowerCholeskyAffine - - def get_posterior(self, params): - """ - Returns a multivariate Normal posterior distribution. - """ - transform = self.get_transform(params) - return dist.MultivariateNormal(transform.loc, transform.scale_tril) - - def median(self, params): - loc = params['{}_loc'.format(self.prefix)] - return self._unpack_and_constrain(loc, params) - - def quantiles(self, params, quantiles): - transform = self.get_transform(params) - quantiles = jnp.array(quantiles)[..., None] - latent = dist.Normal(transform.loc, jnp.diagonal(transform.scale_tril)).icdf(quantiles) - return self._unpack_and_constrain(latent, params) - - -class AutoLowRankMultivariateNormal(AutoContinuous): - """ - This implementation of :class:`AutoContinuous` uses a LowRankMultivariateNormal - distribution to construct a guide over the entire latent space. - The guide does not depend on the model's ``*args, **kwargs``. - - Usage:: - - guide = AutoLowRankMultivariateNormal(model, rank=2, ...) - svi = SVI(model, guide, ...) - """ - - def __init__(self, model, prefix="auto", init_strategy=init_to_uniform, init_scale=0.1, rank=None): - if init_scale <= 0: - raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) - self._init_scale = init_scale - self.rank = rank - super(AutoLowRankMultivariateNormal, self).__init__( - model, prefix=prefix, init_strategy=init_strategy) - - def _get_posterior(self, *args, **kwargs): - rank = int(round(self.latent_dim ** 0.5)) if self.rank is None else self.rank - loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent) - cov_factor = numpyro.param('{}_cov_factor'.format(self.prefix), jnp.zeros((self.latent_dim, rank))) - scale = numpyro.param('{}_scale'.format(self.prefix), - jnp.full(self.latent_dim, self._init_scale), - constraint=constraints.positive) - cov_diag = scale * scale - cov_factor = cov_factor * scale[..., None] - return dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag) - - def get_base_dist(self): - return dist.Normal(jnp.zeros(self.latent_dim), 1).to_event(1) - - def get_transform(self, params): - posterior = self.get_posterior(params) - return LowerCholeskyAffine(posterior.loc, posterior.scale_tril) - - def get_posterior(self, params): - """ - Returns a lowrank multivariate Normal posterior distribution. - """ - loc = params['{}_loc'.format(self.prefix)] - cov_factor = params['{}_cov_factor'.format(self.prefix)] - scale = params['{}_scale'.format(self.prefix)] - cov_diag = scale * scale - cov_factor = cov_factor * scale[..., None] - return dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag) - - def median(self, params): - loc = params['{}_loc'.format(self.prefix)] - return self._unpack_and_constrain(loc, params) - - def quantiles(self, params, quantiles): - transform = self.get_transform(params) - quantiles = jnp.array(quantiles)[..., None] - latent = dist.Normal(transform.loc, jnp.diagonal(transform.scale_tril)).icdf(quantiles) - return self._unpack_and_constrain(latent, params) - - -class AutoLaplaceApproximation(AutoContinuous): - r""" - Laplace approximation (quadratic approximation) approximates the posterior - :math:`\log p(z | x)` by a multivariate normal distribution in the - unconstrained space. Under the hood, it uses Delta distributions to - construct a MAP guide over the entire (unconstrained) latent space. Its - covariance is given by the inverse of the hessian of :math:`-\log p(x, z)` - at the MAP point of `z`. - - Usage:: - - guide = AutoLaplaceApproximation(model, ...) - svi = SVI(model, guide, ...) - """ - - def _setup_prototype(self, *args, **kwargs): - super(AutoLaplaceApproximation, self)._setup_prototype(*args, **kwargs) - - def loss_fn(params): - # we are doing maximum likelihood, so only require `num_particles=1` and an arbitrary rng_key. - return ELBO().loss(random.PRNGKey(0), params, self.model, self, *args, **kwargs) - - self._loss_fn = loss_fn - - def _get_posterior(self, *args, **kwargs): - # sample from Delta guide - loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent) - return dist.Delta(loc, event_dim=1) - - def get_base_dist(self): - return dist.Normal(jnp.zeros(self.latent_dim), 1).to_event(1) - - def get_transform(self, params): - def loss_fn(z): - params1 = params.copy() - params1['{}_loc'.format(self.prefix)] = z - return self._loss_fn(params1) - - loc = params['{}_loc'.format(self.prefix)] - precision = hessian(loss_fn)(loc) - scale_tril = cholesky_of_inverse(precision) - if not_jax_tracer(scale_tril): - if jnp.any(jnp.isnan(scale_tril)): - warnings.warn("Hessian of log posterior at the MAP point is singular. Posterior" - " samples from AutoLaplaceApproxmiation will be constant (equal to" - " the MAP point).") - scale_tril = jnp.where(jnp.isnan(scale_tril), 0., scale_tril) - return LowerCholeskyAffine(loc, scale_tril) - - def get_posterior(self, params): - """ - Returns a multivariate Normal posterior distribution. - """ - transform = self.get_transform(params) - return dist.MultivariateNormal(transform.loc, scale_tril=transform.scale_tril) - - def sample_posterior(self, rng_key, params, sample_shape=()): - latent_sample = self.get_posterior(params).sample(rng_key, sample_shape) - return self._unpack_and_constrain(latent_sample, params) - - def median(self, params): - loc = params['{}_loc'.format(self.prefix)] - return self._unpack_and_constrain(loc, params) - - def quantiles(self, params, quantiles): - transform = self.get_transform(params) - quantiles = jnp.array(quantiles)[..., None] - latent = dist.Normal(transform.loc, jnp.diagonal(transform.scale_tril)).icdf(quantiles) - return self._unpack_and_constrain(latent, params) - - -class AutoIAFNormal(AutoContinuous): - """ - This implementation of :class:`AutoContinuous` uses a Diagonal Normal - distribution transformed via a - :class:`~numpyro.distributions.flows.InverseAutoregressiveTransform` - to construct a guide over the entire latent space. The guide does not - depend on the model's ``*args, **kwargs``. - - Usage:: - - guide = AutoIAFNormal(model, hidden_dims=[20], skip_connections=True, ...) - svi = SVI(model, guide, ...) - - :param callable model: a generative model. - :param str prefix: a prefix that will be prefixed to all param internal sites. - :param callable init_strategy: A per-site initialization function. - :param int num_flows: the number of flows to be used, defaults to 3. - :param list hidden_dims: the dimensionality of the hidden units per layer. - Defaults to ``[latent_dim, latent_dim]``. - :param bool skip_connections: whether to add skip connections from the input to the - output of each flow. Defaults to False. - :param callable nonlinearity: the nonlinearity to use in the feedforward network. - Defaults to :func:`jax.experimental.stax.Elu`. - """ - - def __init__(self, model, prefix="auto", init_strategy=init_to_uniform, - num_flows=3, hidden_dims=None, skip_connections=False, nonlinearity=stax.Elu): - self.num_flows = num_flows - # 2-layer, stax.Elu, skip_connections=False by default following the experiments in - # IAF paper (https://arxiv.org/abs/1606.04934) - # and Neutra paper (https://arxiv.org/abs/1903.03704) - self._hidden_dims = hidden_dims - self._skip_connections = skip_connections - self._nonlinearity = nonlinearity - super(AutoIAFNormal, self).__init__(model, prefix=prefix, init_strategy=init_strategy) - - def _get_posterior(self): - if self.latent_dim == 1: - raise ValueError('latent dim = 1. Consider using AutoDiagonalNormal instead') - hidden_dims = [self.latent_dim, self.latent_dim] if self._hidden_dims is None else self._hidden_dims - flows = [] - for i in range(self.num_flows): - if i > 0: - flows.append(PermuteTransform(jnp.arange(self.latent_dim)[::-1])) - arn = AutoregressiveNN(self.latent_dim, hidden_dims, - permutation=jnp.arange(self.latent_dim), - skip_connections=self._skip_connections, - nonlinearity=self._nonlinearity) - arnn = numpyro.module('{}_arn__{}'.format(self.prefix, i), arn, (self.latent_dim,)) - flows.append(InverseAutoregressiveTransform(arnn)) - return dist.TransformedDistribution(self.get_base_dist(), flows) - - def get_base_dist(self): - return dist.Normal(jnp.zeros(self.latent_dim), 1).to_event(1) - - -class AutoBNAFNormal(AutoContinuous): - """ - This implementation of :class:`AutoContinuous` uses a Diagonal Normal - distribution transformed via a - :class:`~numpyro.distributions.flows.BlockNeuralAutoregressiveTransform` - to construct a guide over the entire latent space. The guide does not - depend on the model's ``*args, **kwargs``. - - Usage:: - - guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[50, 50], ...) - svi = SVI(model, guide, ...) - - **References** - - 1. *Block Neural Autoregressive Flow*, - Nicola De Cao, Ivan Titov, Wilker Aziz - - :param callable model: a generative model. - :param str prefix: a prefix that will be prefixed to all param internal sites. - :param callable init_strategy: A per-site initialization function. - :param int num_flows: the number of flows to be used, defaults to 3. - :param list hidden_factors: Hidden layer i has ``hidden_factors[i]`` hidden units per - input dimension. This corresponds to both :math:`a` and :math:`b` in reference [1]. - The elements of hidden_factors must be integers. - """ - - def __init__(self, model, prefix="auto", init_strategy=init_to_uniform, num_flows=1, - hidden_factors=[8, 8]): - self.num_flows = num_flows - self._hidden_factors = hidden_factors - super(AutoBNAFNormal, self).__init__(model, prefix=prefix, init_strategy=init_strategy) - - def _get_posterior(self): - if self.latent_dim == 1: - raise ValueError('latent dim = 1. Consider using AutoDiagonalNormal instead') - flows = [] - for i in range(self.num_flows): - if i > 0: - flows.append(PermuteTransform(jnp.arange(self.latent_dim)[::-1])) - residual = "gated" if i < (self.num_flows - 1) else None - arn = BlockNeuralAutoregressiveNN(self.latent_dim, self._hidden_factors, residual) - arnn = numpyro.module('{}_arn__{}'.format(self.prefix, i), arn, (self.latent_dim,)) - flows.append(BlockNeuralAutoregressiveTransform(arnn)) - return dist.TransformedDistribution(self.get_base_dist(), flows) - - def get_base_dist(self): - return dist.Normal(jnp.zeros(self.latent_dim), 1).to_event(1) - - -class AutoDelta(AutoGuide, ReinitGuide): - def __init__(self, model, *, prefix='auto', init_strategy=init_to_uniform(), create_plates=None): - self.init_strategy = init_strategy - self._param_map = None - self._init_params = None - super(AutoDelta, self).__init__(model, prefix=prefix, create_plates=create_plates) - - def init_params(self): - return self._init_params - - def __call__(self, *args, **kwargs): - if self.prototype_trace is None: - self._setup_prototype(*args, **kwargs) - plates = self._create_plates(*args, **kwargs) - result = {} - for name, site in self.prototype_trace.items(): - if site['type'] != 'sample' or site['is_observed']: - continue - with ExitStack() as stack: - for frame in site['cond_indep_stack']: - stack.enter_context(plates[frame.name]) - if site['intermediates']: - event_dim = len(site['fn'].base_dist.event_shape) - else: - event_dim = len(site['fn'].event_shape) - param_name, param_val, constraint = self._param_map[name] - val_param = numpyro.param(param_name, param_val, constraint=constraint) - result[name] = numpyro.sample(name, dist.Delta(val_param, event_dim=event_dim)) - return result - - def _sample_latent(self, *args, **kwargs): - raise NotImplementedError - - def sample_posterior(self, rng_key, *args, **kwargs): - raise NotImplementedError - - def find_params(self, rng_keys, *args, **kwargs): - params = {site['name']: site['value'] for site in self.prototype_trace.values() - if site['type'] == 'sample' and not site['is_observed']} - (init_params, _, _), _ = handlers.block(find_valid_initial_params)(rng_keys, self.model, - init_strategy=self.init_strategy, - model_args=args, - model_kwargs=kwargs, - prototype_params=params) - for name, site in self.prototype_trace.items(): - if site['type'] == 'sample' and not site['is_observed']: - param_name = "{}_{}".format(self.prefix, name) - param_val = biject_to(site['fn'].support)(init_params[name]) - params[name] = (param_name, param_val, site['fn'].support) - self._param_map = params - self._init_params = {param: (val, constr) for param, val, constr in self._param_map.values()} - - def _setup_prototype(self, *args, **kwargs): - super(AutoDelta, self)._setup_prototype(*args, **kwargs) - # rng_key = numpyro.rng_key("_{}_rng_key_init".format(self.prefix)) - rng_key = random.PRNGKey(1) - self.find_params(rng_key, *args, **kwargs) From 5583162c63dfd3e4301307bed7cf701ffa23d269 Mon Sep 17 00:00:00 2001 From: ola Date: Fri, 5 Feb 2021 15:46:53 +0100 Subject: [PATCH 74/93] Fixed linting. --- numpyro/distributions/continuous.py | 2 +- numpyro/handlers.py | 3 +-- numpyro/infer/hmc_gibbs.py | 5 ++--- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 0ba90f45f..c2a6ecdc9 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -962,7 +962,7 @@ def sample(self, key, sample_shape=()): @validate_sample def log_prob(self, value): - normalize_term = jnp.log(jnp.sqrt(2 * jnp.pi) * self.scale) #TODO:Added jnp.abs + normalize_term = jnp.log(jnp.sqrt(2 * jnp.pi) * self.scale) # TODO:Added jnp.abs value_scaled = (value - self.loc) / self.scale return -0.5 * value_scaled ** 2 - normalize_term diff --git a/numpyro/handlers.py b/numpyro/handlers.py index d69cf4f48..75f29f22b 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -77,7 +77,6 @@ import warnings from collections import OrderedDict -from functools import partial import jax.numpy as jnp import numpy as np @@ -658,8 +657,8 @@ def __init__(self, fn=None, rng_seed=None): def process_message(self, msg): if (msg['type'] == 'sample' and not msg['is_observed'] and msg['kwargs']['rng_key'] is None) or msg['type'] in ['prng_key', 'plate', 'control_flow']: - # no need to create a new key when value is available if msg['value'] is not None: + # no need to create a new key when value is available return self.rng_key, rng_key_sample = random.split(self.rng_key) msg['kwargs']['rng_key'] = rng_key_sample diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index e56def376..595f9c12c 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -5,17 +5,16 @@ import warnings from collections import defaultdict, namedtuple from functools import partial -import jax import jax.numpy as jnp -from jax import device_put, jacfwd, jacobian, grad, hessian, lax, ops, random, value_and_grad, vmap +from jax import device_put, jacfwd, jacobian, grad, hessian, lax, ops, random, value_and_grad from jax.scipy.special import expit import numpyro from numpyro.handlers import block, condition, seed, substitute, trace from numpyro.infer.hmc import HMC from numpyro.infer.mcmc import MCMCKernel -from numpyro.infer.util import _unconstrain_reparam, _predictive, log_density +from numpyro.infer.util import _unconstrain_reparam from numpyro.util import cond, fori_loop, identity, ravel_pytree HMCGibbsState = namedtuple("HMCGibbsState", "z, hmc_state, rng_key") From 48cd7efdfaeed31270764cc7ab41c05eb6ade86c Mon Sep 17 00:00:00 2001 From: Ola Date: Fri, 5 Feb 2021 18:21:27 +0100 Subject: [PATCH 75/93] fixed lint. --- numpyro/handlers.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 75f29f22b..75ebe14f8 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -655,8 +655,8 @@ def __init__(self, fn=None, rng_seed=None): super(seed, self).__init__(fn) def process_message(self, msg): - if (msg['type'] == 'sample' and not msg['is_observed'] and - msg['kwargs']['rng_key'] is None) or msg['type'] in ['prng_key', 'plate', 'control_flow']: + if (msg['type'] == 'sample' and not msg['is_observed'] and msg['kwargs']['rng_key'] is None) \ + or msg['type'] in ['prng_key', 'plate', 'control_flow']: if msg['value'] is not None: # no need to create a new key when value is available return @@ -796,5 +796,3 @@ def process_message(self, msg): msg['value'] = intervention msg['is_observed'] = True msg['stop'] = True - - From ff28eb0d9062af4e28c3435d0862698372333051 Mon Sep 17 00:00:00 2001 From: Ola Date: Sun, 7 Feb 2021 22:08:12 +0100 Subject: [PATCH 76/93] Remove Poisson, factored out pandas for loading HIGGs dataset, added SA to covtype. --- examples/covtype.py | 11 ++++++++--- numpyro/examples/datasets.py | 20 ++++++++++++-------- numpyro/infer/hmc.py | 2 +- numpyro/infer/hmc_gibbs.py | 6 ++---- numpyro/primitives.py | 2 +- 5 files changed, 24 insertions(+), 17 deletions(-) diff --git a/examples/covtype.py b/examples/covtype.py index 1e7ccf80d..d0ac8e96d 100644 --- a/examples/covtype.py +++ b/examples/covtype.py @@ -11,7 +11,7 @@ import numpyro import numpyro.distributions as dist from numpyro.examples.datasets import COVTYPE, load_dataset -from numpyro.infer import HMC, HMCECS, MCMC, NUTS, SVI, Trace_ELBO, init_to_value +from numpyro.infer import HMC, HMCECS, MCMC, NUTS, SVI, Trace_ELBO, init_to_value, SA from numpyro.infer.autoguide import AutoBNAFNormal from numpyro.infer.hmc_gibbs import taylor_proxy from numpyro.infer.reparam import NeuTraReparam @@ -34,7 +34,7 @@ def _load_dataset(): print("Data shape:", features.shape) print("Label distribution: {} has label 1, {} has label 0" .format(labels.sum(), N - labels.sum())) - return features[::5], labels[::5] + return features, labels def model(data, labels, subsample_size=None): @@ -81,6 +81,11 @@ def benchmark_hmc(args, features, labels): # note: if num_blocks=100, we'll update 10 index at each MCMC step # so it took 50000 MCMC steps to iterative the whole dataset kernel = HMCECS(inner_kernel, num_blocks=100, proxy=taylor_proxy(ref_params)) + elif args.algo == "SA": + # NB: this kernel requires large num_warmup and num_samples + # and running on GPU is much faster than on CPU + kernel = SA(model, adapt_state_size=1000, init_strategy=init_to_value(values=ref_params)) + subsample_size = None elif args.algo == "FlowHMCECS": subsample_size = 1000 guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8]) @@ -118,7 +123,7 @@ def main(args): parser.add_argument('--num-steps', default=10, type=int, help='number of steps (for "HMC")') parser.add_argument('--num-chains', nargs='?', default=1, type=int) parser.add_argument('--algo', default='HMCECS', type=str, - help='whether to run "HMCECS", "NUTS", "HMCECS", or "FlowHMCECS"') + help='whether to run "HMCECS", "NUTS", "HMCECS", "SA" or "FlowHMCECS"') parser.add_argument('--dense-mass', action="store_true") parser.add_argument('--x64', action="store_true") parser.add_argument('--device', default='gpu', type=str, help='use "cpu" or "gpu".') diff --git a/numpyro/examples/datasets.py b/numpyro/examples/datasets.py index f51a40748..4d008ada8 100644 --- a/numpyro/examples/datasets.py +++ b/numpyro/examples/datasets.py @@ -11,9 +11,9 @@ from urllib.parse import urlparse from urllib.request import urlretrieve import zipfile +import io import numpy as np -import pandas as pd from jax import device_put, lax from jax.interpreters.xla import DeviceArray @@ -34,12 +34,10 @@ 'https://d2hg8soec8ck9v.cloudfront.net/datasets/covtype.zip', ]) - DIPPER_VOLE = dset('dipper_vole', [ 'https://github.com/pyro-ppl/datasets/blob/master/dipper_vole.zip?raw=true', ]) - MNIST = dset('mnist', [ 'https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/train-images-idx3-ubyte.gz', 'https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/train-labels-idx1-ubyte.gz', @@ -84,7 +82,7 @@ def _load_baseball(): def train_test_split(file): train, test, player_names = [], [], [] with open(file, 'r') as f: - csv_reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE) + csv_reader = csv.reader(f, delimiter='\t', quoting=csv.QUOTE_NONE) for row in csv_reader: player_names.append(row['FirstName'] + ' ' + row['LastName']) at_bats, hits = row['At-Bats'], row['Hits'] @@ -239,12 +237,18 @@ def _load_jsb_chorales(): def _load_higgs(): - warnings.warn("Downloading 2.6 GB dataset") + warnings.warn("Higgs is a 2.6 GB dataset") _download(HIGGS) + file_path = os.path.join(DATA_DIR, 'HIGGS.csv.gz') - df = pd.read_csv(file_path, header=None) - obs, feats = df.iloc[:, 0], df.iloc[:, 1:] - return obs.to_numpy().astype(int), feats.to_numpy() + with io.TextIOWrapper(gzip.open(file_path, 'rb')) as f: + csv_reader = csv.reader(f, delimiter=',', quoting=csv.QUOTE_NONE) + obs = [] + data = [] + for row in csv_reader: + obs.append(row[0]) + data.append(row[1:]) + return np.stack(obs), np.stack(data) def _load(dset): diff --git a/numpyro/infer/hmc.py b/numpyro/infer/hmc.py index 3f0541532..8e9d9302b 100644 --- a/numpyro/infer/hmc.py +++ b/numpyro/infer/hmc.py @@ -120,7 +120,7 @@ def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='NUTS'): >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist - >>> from numpyro.infer.benchmark_hmc import hmc + >>> from numpyro.infer.hmc import hmc >>> from numpyro.infer.util import initialize_model >>> from numpyro.util import fori_collect diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 595f9c12c..7494806aa 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -447,7 +447,6 @@ def _block_update_proxy(num_blocks, rng_key, gibbs_sites, subsample_plate_sizes) HMCECSState = namedtuple("HMCECSState", "z, hmc_state, rng_key, gibbs_state, accept_prob") TaylorProxyState = namedtuple("TaylorProxyState", "ref_subsample_log_liks, " "ref_subsample_log_lik_grads, ref_subsample_log_lik_hessians") -BlockPoissonEstState = namedtuple("BlockPoissonEstState", "block_rng_keys, sign") def _wrap_gibbs_state(model): @@ -513,14 +512,13 @@ class HMCECS(HMCGibbs): """ - def __init__(self, inner_kernel, *, num_blocks=1, proxy=None, method='perturbed'): + def __init__(self, inner_kernel, *, num_blocks=1, proxy=None): super().__init__(inner_kernel, lambda *args: None, None) - assert method in {'perturbed'} self.inner_kernel._model = _wrap_gibbs_state(self.inner_kernel._model) self._num_blocks = num_blocks self._proxy = proxy - self._method = method + self._method = 'perturbed' def postprocess_fn(self, args, kwargs): def fn(z): diff --git a/numpyro/primitives.py b/numpyro/primitives.py index a577e7446..01c174888 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -374,7 +374,7 @@ def postprocess_message(self, msg): raise ValueError( "Inside numpyro.plate({}, {}, dim={}) invalid shape of {}: {}" .format(self.name, self.size, self.dim, statement, shape)) - elif self.subsample_size < self.size: + if self.subsample_size < self.size: value = msg["value"] new_value = jnp.take(value, self._indices, dim) msg["value"] = new_value From 5febf7e0ff4f30144052fb161316da82297a5c6c Mon Sep 17 00:00:00 2001 From: Ola Date: Sun, 7 Feb 2021 22:23:33 +0100 Subject: [PATCH 77/93] Fixed _block_update refactor. Missing new test cases, 2 more TODOs. --- numpyro/infer/hmc_gibbs.py | 52 ++++++++++++++------------------------ 1 file changed, 19 insertions(+), 33 deletions(-) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 7494806aa..f6f90810f 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -399,48 +399,36 @@ def potential_fn(z_gibbs, z_hmc): return HMCGibbsState(z, hmc_state, rng_key) +def _update_block(rng_key, num_blocks, subsample_idx, plate_size): + size, subsample_size = plate_size + rng_key, subkey, block_key = random.split(rng_key, 3) + block_size = (subsample_size - 1) // num_blocks + 1 + pad = block_size - (subsample_size - 1) % block_size - 1 + + chosen_block = random.randint(block_key, shape=(), minval=0, maxval=num_blocks) + new_idx = random.randint(subkey, minval=0, maxval=size, shape=(block_size,)) + subsample_idx_padded = jnp.pad(subsample_idx, (0, pad)) + start = chosen_block * block_size + subsample_idx_padded = lax.dynamic_update_slice_in_dim( + subsample_idx_padded, new_idx, start, 0) + return rng_key, subsample_idx_padded[:subsample_size], pad, new_idx, start + + def _block_update(plate_sizes, num_blocks, rng_key, gibbs_sites, gibbs_state): u_new = {} for name, subsample_idx in gibbs_sites.items(): - size, subsample_size = plate_sizes[name] - rng_key, subkey, block_key = random.split(rng_key, 3) - block_size = (subsample_size - 1) // num_blocks + 1 - pad = block_size - (subsample_size - 1) % block_size - 1 - - chosen_block = random.randint(block_key, shape=(), minval=0, maxval=num_blocks) - new_idx = random.randint(subkey, minval=0, maxval=size, shape=(block_size,)) - subsample_idx_padded = jnp.pad(subsample_idx, (0, pad)) - start = chosen_block * block_size - subsample_idx_padded = lax.dynamic_update_slice_in_dim( - subsample_idx_padded, new_idx, start, 0) - - u_new[name] = subsample_idx_padded[:subsample_size] + rng_key, u_new[name], *_ = _update_block(rng_key, num_blocks, subsample_idx, plate_sizes[name]) return u_new, gibbs_state -def _block_update_proxy(num_blocks, rng_key, gibbs_sites, subsample_plate_sizes): +def _block_update_proxy(num_blocks, rng_key, gibbs_sites, plate_sizes): u_new = {} pads = {} new_idxs = {} starts = {} for name, subsample_idx in gibbs_sites.items(): - # TODO: merge with _block_update - size, subsample_size = subsample_plate_sizes[name] - rng_key, subkey, block_key = random.split(rng_key, 3) - block_size = (subsample_size - 1) // num_blocks + 1 - pad = block_size - (subsample_size - 1) % block_size - 1 - - chosen_block = random.randint(block_key, shape=(), minval=0, maxval=num_blocks) - new_idx = random.randint(subkey, minval=0, maxval=size, shape=(block_size,)) - subsample_idx_padded = jnp.pad(subsample_idx, (0, pad)) - start = chosen_block * block_size - subsample_idx_padded = lax.dynamic_update_slice_in_dim( - subsample_idx_padded, new_idx, start, 0) - - u_new[name] = subsample_idx_padded[:subsample_size] - pads[name] = pad - new_idxs[name] = new_idx - starts[name] = start + rng_key, u_new[name], pads[name], new_idxs[name], starts[name] = _update_block(rng_key, num_blocks, + subsample_idx, plate_sizes[name]) return u_new, pads, new_idxs, starts @@ -593,8 +581,6 @@ def potential_fn(z_gibbs, gibbs_state, z_hmc): hmc_state = self.inner_kernel.sample(hmc_state, model_args, model_kwargs) z = {**z_gibbs, **hmc_state.z} - # TODO: post update gibbs_state to update sign in Block Poisson estimator - # extra_fields=('gibbs_state.sign',) return HMCECSState(z, hmc_state, rng_key, gibbs_state, accept_prob) From 2c261739513eae2205e7ae17ccfa8ffeca991371 Mon Sep 17 00:00:00 2001 From: Ola Date: Sun, 7 Feb 2021 22:27:18 +0100 Subject: [PATCH 78/93] fixed isort --- docs/source/conf.py | 2 +- examples/covtype.py | 5 +++-- numpyro/contrib/funsor/enum_messenger.py | 1 + numpyro/examples/datasets.py | 7 ++++--- numpyro/handlers.py | 5 +++-- numpyro/infer/hmc_gibbs.py | 6 +++--- numpyro/infer/util.py | 4 ++-- test/test_hmc_gibbs.py | 5 +++-- 8 files changed, 20 insertions(+), 15 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 4f0aa83a2..adf47cdd4 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -9,7 +9,6 @@ import nbsphinx import sphinx_rtd_theme - # import pkg_resources # -*- coding: utf-8 -*- @@ -33,6 +32,7 @@ # HACK: This is to ensure that local functions are documented by sphinx. from numpyro.infer.hmc import hmc # noqa: E402 + hmc(None, None) # -- Project information ----------------------------------------------------- diff --git a/examples/covtype.py b/examples/covtype.py index d0ac8e96d..f96498161 100644 --- a/examples/covtype.py +++ b/examples/covtype.py @@ -4,14 +4,15 @@ import argparse import time -import jax.numpy as jnp import matplotlib.pyplot as plt + from jax import random +import jax.numpy as jnp import numpyro import numpyro.distributions as dist from numpyro.examples.datasets import COVTYPE, load_dataset -from numpyro.infer import HMC, HMCECS, MCMC, NUTS, SVI, Trace_ELBO, init_to_value, SA +from numpyro.infer import HMC, HMCECS, MCMC, NUTS, SA, SVI, Trace_ELBO, init_to_value from numpyro.infer.autoguide import AutoBNAFNormal from numpyro.infer.hmc_gibbs import taylor_proxy from numpyro.infer.reparam import NeuTraReparam diff --git a/numpyro/contrib/funsor/enum_messenger.py b/numpyro/contrib/funsor/enum_messenger.py index 461809cf1..aa9eb900c 100644 --- a/numpyro/contrib/funsor/enum_messenger.py +++ b/numpyro/contrib/funsor/enum_messenger.py @@ -7,6 +7,7 @@ from jax import lax import jax.numpy as jnp + import funsor from numpyro.handlers import infer_config from numpyro.handlers import trace as OrigTraceMessenger diff --git a/numpyro/examples/datasets.py b/numpyro/examples/datasets.py index 4d008ada8..0e7b8d68e 100644 --- a/numpyro/examples/datasets.py +++ b/numpyro/examples/datasets.py @@ -1,19 +1,20 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +from collections import namedtuple import csv import gzip +import io import os import pickle import struct -import warnings -from collections import namedtuple from urllib.parse import urlparse from urllib.request import urlretrieve +import warnings import zipfile -import io import numpy as np + from jax import device_put, lax from jax.interpreters.xla import DeviceArray diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 75ebe14f8..a40572504 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -75,12 +75,13 @@ -874.89813 """ -import warnings from collections import OrderedDict +import warnings -import jax.numpy as jnp import numpy as np + from jax import lax, random +import jax.numpy as jnp import numpyro from numpyro.distributions.distribution import COERCIONS diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index f6f90810f..8a9d36476 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -1,13 +1,13 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -import copy -import warnings from collections import defaultdict, namedtuple +import copy from functools import partial +import warnings +from jax import device_put, grad, hessian, jacfwd, jacobian, lax, ops, random, value_and_grad import jax.numpy as jnp -from jax import device_put, jacfwd, jacobian, grad, hessian, lax, ops, random, value_and_grad from jax.scipy.special import expit import numpyro diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 5da8883a7..8b4227819 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -1,15 +1,15 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -import warnings from collections import namedtuple from functools import partial +import warnings -import jax.numpy as jnp import numpy as np from jax import device_get, jacfwd, lax, random, value_and_grad from jax.flatten_util import ravel_pytree +import jax.numpy as jnp import numpyro from numpyro.distributions import constraints diff --git a/test/test_hmc_gibbs.py b/test/test_hmc_gibbs.py index 22dd82ef4..9b8bbe3bf 100644 --- a/test/test_hmc_gibbs.py +++ b/test/test_hmc_gibbs.py @@ -3,12 +3,13 @@ from functools import partial -import jax.numpy as jnp import numpy as np +from numpy.testing import assert_allclose import pytest + from jax import random +import jax.numpy as jnp from jax.scipy.linalg import cho_factor, cho_solve, inv, solve_triangular -from numpy.testing import assert_allclose import numpyro import numpyro.distributions as dist From 63592929d6e9e8f25620b3662df5115c4b332f44 Mon Sep 17 00:00:00 2001 From: Ola Date: Mon, 8 Feb 2021 14:30:53 +0100 Subject: [PATCH 79/93] Fixed comments, some 3 TODOs left. --- numpyro/infer/hmc_gibbs.py | 30 ++++++++---------------------- numpyro/primitives.py | 2 +- 2 files changed, 9 insertions(+), 23 deletions(-) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 8a9d36476..f51a85168 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -474,6 +474,7 @@ class HMCECS(HMCGibbs): :param inner_kernel: One of :class:`~numpyro.infer.hmc.HMC` or :class:`~numpyro.infer.hmc.NUTS`. :param int num_blocks: Number of blocks to partition subsample into. + :param callable proxy: TODO: add description. **Example** @@ -531,12 +532,12 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): self._gibbs_sites = list(self._subsample_plate_sizes.keys()) if self._proxy is not None: rng_key, proxy_key, method_key = random.split(rng_key, 3) - proxy_fn, gibbs_init, self._gibbs_update = self._proxy(rng_key, + proxy_fn, gibbs_init, self._gibbs_update = self._proxy(self._subsample_plate_sizes, self.model, model_args, - model_kwargs, + model_kwargs.copy(), num_blocks=self._num_blocks) - method = perturbed_method(method_key, self.model, model_args, model_kwargs, proxy_fn) + method = perturbed_method(self._subsample_plate_sizes, proxy_fn) self.inner_kernel._model = estimate_likelihood(self.inner_kernel._model, method) z_gibbs = {name: site["value"] for name, site in self._prototype_trace.items() if name in self._gibbs_sites} @@ -584,15 +585,7 @@ def potential_fn(z_gibbs, gibbs_state, z_hmc): return HMCECSState(z, hmc_state, rng_key, gibbs_state, accept_prob) -def perturbed_method(rng_key, model, model_args, model_kwargs, proxy_fn): - # subsample_plate_sizes: name -> (size, subsample_size) - prototype_trace = trace(seed(model, rng_key)).get_trace(*model_args, **model_kwargs) - subsample_plate_sizes = { - name: site["args"] - for name, site in prototype_trace.items() - if site["type"] == "plate" and site["args"][0] > site["args"][1] - } - +def perturbed_method(subsample_plate_sizes, proxy_fn): def estimator(likelihoods, params, gibbs_state): subsample_log_liks = defaultdict(float) for (fn, value, name, subsample_dim) in likelihoods.values(): @@ -616,17 +609,9 @@ def estimator(likelihoods, params, gibbs_state): def taylor_proxy(reference_params): - def construct_proxy_fn(rng_key, model, model_args, model_kwargs, num_blocks=1): - prototype_trace = trace(seed(model, rng_key)).get_trace(*model_args, **model_kwargs) - subsample_plate_sizes = { - name: site["args"] - for name, site in prototype_trace.items() - if site["type"] == "plate" and site["args"][0] > site["args"][1] # i.e. size > subsample_size - } - + def construct_proxy_fn(subsample_plate_sizes, model, model_args, model_kwargs, num_blocks=1): # TODO: map reference params to unconstraint_params - # subsample_plate_sizes: name -> (size, subsample_size) ref_params_flat, unravel_fn = ravel_pytree(reference_params) def log_likelihood(params_flat, subsample_indices=None): @@ -736,8 +721,9 @@ def __init__(self, fn=None, method=None): self.gibbs_state = None def __enter__(self): - # trace(substitute(substitute(control_variate(model), unconstrained_reparam))) for handler in numpyro.primitives._PYRO_STACK[::-1]: + # the potential_fn in HMC makes the PYRO_STACK nested like trace(...); so we can extract the + # unconstrained_params from the _unconstrain_reparam substitute_fn if isinstance(handler, substitute) and isinstance(handler.substitute_fn, partial) \ and handler.substitute_fn.func is _unconstrain_reparam: self.params = handler.substitute_fn.args[0] diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 01c174888..3353f4966 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -304,7 +304,7 @@ def _subsample(name, size, subsample_size, dim): } apply_stack(msg) subsample = msg['value'] - subsample_size = msg['args'][1] # TODO: rewrite plate + subsample_size = msg['args'][1] if subsample_size is not None and subsample_size != subsample.shape[0]: warnings.warn("subsample_size does not match len(subsample), {} vs {}.".format( subsample_size, len(subsample)) + From 0b886d7d2227d9e255199d536297178319c53538 Mon Sep 17 00:00:00 2001 From: Ola Date: Mon, 8 Feb 2021 17:13:42 +0100 Subject: [PATCH 80/93] Conditioned gradient computation and moved to unconstraint sapce for ref. params. --- numpyro/infer/hmc_gibbs.py | 39 ++++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index f51a85168..4d2abb762 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -471,10 +471,13 @@ class HMCECS(HMCGibbs): Quiroz, M., Kohn, R., Villani, M., & Tran, M. N. (2018) 3. *The Block Pseudo-Margional Sampler*, Tran, M.-N., Kohn, R., Quiroz, M. Villani, M. (2017) + 4. *The Fundamental Incompatibility ofScalable Hamiltonian Monte Carlo and Naive Data Subsampling* + Betancourt, M. (2015) :param inner_kernel: One of :class:`~numpyro.infer.hmc.HMC` or :class:`~numpyro.infer.hmc.NUTS`. :param int num_blocks: Number of blocks to partition subsample into. - :param callable proxy: TODO: add description. + :param proxy: Either :function `~numpyro.infer.hmc_gibbs.taylor_proxy` for likelihood estimation, + or, None for naive (in-between trajectory) subsampling as outlined in [4]. **Example** @@ -566,16 +569,19 @@ def potential_fn(z_gibbs, gibbs_state, z_hmc): pe = state.hmc_state.potential_energy pe_new = potential_fn(z_gibbs_new, gibbs_state_new, state.hmc_state.z) accept_prob = jnp.clip(jnp.exp(pe - pe_new), a_max=1.0) - z_gibbs, gibbs_state, pe = cond(random.bernoulli(rng_key, accept_prob), + transition = random.bernoulli(rng_key, accept_prob) + z_gibbs, gibbs_state, pe = cond(transition, (z_gibbs_new, gibbs_state_new, pe_new), identity, (z_gibbs, state.gibbs_state, pe), identity) - # TODO (very low priority): move this to the above cond, only compute grad when accepting - if self.inner_kernel._forward_mode_differentiation: - z_grad = jacfwd(partial(potential_fn, z_gibbs, gibbs_state))(state.hmc_state.z) - else: - z_grad = grad(partial(potential_fn, z_gibbs, gibbs_state))(state.hmc_state.z) - hmc_state = state.hmc_state._replace(z_grad=z_grad, potential_energy=pe) + hmc_state = cond(transition, # only update gradient an d + (potential_fn, z_gibbs, gibbs_state), lambda fn, zgibbs, gstate: + state.hmc_state._replace(z_grad=self.grad_mode(partial(potential_fn, + zgibbs, + gstate))(state.hmc_state.z), + potential_energy=pe), + (potential_fn, z_gibbs, gibbs_state), lambda fn, zgibbs, gstate: state.hmc_state + ) model_kwargs["_gibbs_sites"] = z_gibbs model_kwargs["_gibbs_state"] = gibbs_state @@ -584,6 +590,16 @@ def potential_fn(z_gibbs, gibbs_state, z_hmc): z = {**z_gibbs, **hmc_state.z} return HMCECSState(z, hmc_state, rng_key, gibbs_state, accept_prob) + @property + def grad_mode(self): + if not hasattr(self, '_grad_mode'): + if self.inner_kernel._forward_mode_differentiation: + grad_fn = jacfwd + else: + grad_fn = grad + self._grad_mode = grad_fn + return self._grad_mode + def perturbed_method(subsample_plate_sizes, proxy_fn): def estimator(likelihoods, params, gibbs_state): @@ -610,9 +626,12 @@ def estimator(likelihoods, params, gibbs_state): def taylor_proxy(reference_params): def construct_proxy_fn(subsample_plate_sizes, model, model_args, model_kwargs, num_blocks=1): - # TODO: map reference params to unconstraint_params + with block(), trace as tr, substitute(substitute_fn=partial(_unconstrain_reparam, reference_params)): + model(*model_args, **model_kwargs) - ref_params_flat, unravel_fn = ravel_pytree(reference_params) + # map to unconstraint_space # TODO: check this + ref_params = {name: site['value'] for name, site in tr if name in reference_params} + ref_params_flat, unravel_fn = ravel_pytree(ref_params) def log_likelihood(params_flat, subsample_indices=None): if subsample_indices is None: From 469a1f268b823fcef9f9a4b3b10363d9d7e128a2 Mon Sep 17 00:00:00 2001 From: Ola Date: Mon, 8 Feb 2021 21:30:05 +0100 Subject: [PATCH 81/93] Fixed test for HMCECS and bumped jaxlib version. --- numpyro/infer/hmc_gibbs.py | 38 +++++++++++++++++++------------------- setup.py | 2 +- test/test_hmc_gibbs.py | 2 +- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 4d2abb762..d10796fa3 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -511,6 +511,7 @@ def __init__(self, inner_kernel, *, num_blocks=1, proxy=None): self._num_blocks = num_blocks self._proxy = proxy self._method = 'perturbed' + self._grad_fn = None def postprocess_fn(self, args, kwargs): def fn(z): @@ -535,7 +536,8 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): self._gibbs_sites = list(self._subsample_plate_sizes.keys()) if self._proxy is not None: rng_key, proxy_key, method_key = random.split(rng_key, 3) - proxy_fn, gibbs_init, self._gibbs_update = self._proxy(self._subsample_plate_sizes, + proxy_fn, gibbs_init, self._gibbs_update = self._proxy(self._prototype_trace, + self._subsample_plate_sizes, self.model, model_args, model_kwargs.copy(), @@ -574,13 +576,15 @@ def potential_fn(z_gibbs, gibbs_state, z_hmc): (z_gibbs_new, gibbs_state_new, pe_new), identity, (z_gibbs, state.gibbs_state, pe), identity) - hmc_state = cond(transition, # only update gradient an d - (potential_fn, z_gibbs, gibbs_state), lambda fn, zgibbs, gstate: - state.hmc_state._replace(z_grad=self.grad_mode(partial(potential_fn, - zgibbs, - gstate))(state.hmc_state.z), - potential_energy=pe), - (potential_fn, z_gibbs, gibbs_state), lambda fn, zgibbs, gstate: state.hmc_state + state.hmc_state._replace(z_grad=self.grad_fn(partial(potential_fn, + z_gibbs, + gibbs_state))(state.hmc_state.z), + potential_energy=pe), + hmc_state = cond(transition, # only update gradient and potential energy when block is updated + state.hmc_state._replace( + z_grad=self.grad_fn(partial(potential_fn, z_gibbs, gibbs_state))(state.hmc_state.z), + potential_energy=pe), identity, + state.hmc_state, identity ) model_kwargs["_gibbs_sites"] = z_gibbs @@ -591,14 +595,14 @@ def potential_fn(z_gibbs, gibbs_state, z_hmc): return HMCECSState(z, hmc_state, rng_key, gibbs_state, accept_prob) @property - def grad_mode(self): + def grad_fn(self): if not hasattr(self, '_grad_mode'): if self.inner_kernel._forward_mode_differentiation: grad_fn = jacfwd else: grad_fn = grad - self._grad_mode = grad_fn - return self._grad_mode + self._grad_fn = grad_fn + return self._grad_fn def perturbed_method(subsample_plate_sizes, proxy_fn): @@ -625,12 +629,9 @@ def estimator(likelihoods, params, gibbs_state): def taylor_proxy(reference_params): - def construct_proxy_fn(subsample_plate_sizes, model, model_args, model_kwargs, num_blocks=1): - with block(), trace as tr, substitute(substitute_fn=partial(_unconstrain_reparam, reference_params)): - model(*model_args, **model_kwargs) - - # map to unconstraint_space # TODO: check this - ref_params = {name: site['value'] for name, site in tr if name in reference_params} + def construct_proxy_fn(prototype_trace, subsample_plate_sizes, model, model_args, model_kwargs, num_blocks=1): + ref_params = {name: _unconstrain_reparam(reference_params, site) for name, site in prototype_trace.items() + if name in reference_params} ref_params_flat, unravel_fn = ravel_pytree(ref_params) def log_likelihood(params_flat, subsample_indices=None): @@ -639,8 +640,7 @@ def log_likelihood(params_flat, subsample_indices=None): params = unravel_fn(params_flat) with warnings.catch_warnings(): warnings.simplefilter("ignore") - with block(), trace() as tr, substitute(data=subsample_indices), \ - substitute(substitute_fn=partial(_unconstrain_reparam, params)): + with block(), trace() as tr, substitute(data=subsample_indices), substitute(data=params): model(*model_args, **model_kwargs) log_lik = {} diff --git a/setup.py b/setup.py index bd20fd1b0..b1642cedb 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ # TODO: pin to a specific version for the release (until JAX's API becomes stable) 'jax==0.2.8', # check min version here: https://github.com/google/jax/blob/master/jax/lib/__init__.py#L26 - 'jaxlib==0.1.59', + 'jaxlib==0.1.60', 'tqdm', ], extras_require={ diff --git a/test/test_hmc_gibbs.py b/test/test_hmc_gibbs.py index 9b8bbe3bf..b4a6b2abb 100644 --- a/test/test_hmc_gibbs.py +++ b/test/test_hmc_gibbs.py @@ -207,9 +207,9 @@ def model(probs, locs): assert_allclose(jnp.var(samples["c"]), 1.03, atol=0.1) -@pytest.mark.parametrize('kernel_cls', [HMC, NUTS]) @pytest.mark.parametrize('num_blocks', [1, 2, 50, 100]) def test_subsample_gibbs_partitioning(kernel_cls, num_blocks): + # TODO: fix test to new API def model(obs): with plate('N', obs.shape[0], subsample_size=100) as idx: numpyro.sample('x', dist.Normal(0, 1), obs=obs[idx]) From 0fb8d01517c3fecb5239baa2b58c1a6489be14a8 Mon Sep 17 00:00:00 2001 From: Ola Date: Mon, 8 Feb 2021 21:50:30 +0100 Subject: [PATCH 82/93] Fixed test. --- test/test_hmc_gibbs.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/test/test_hmc_gibbs.py b/test/test_hmc_gibbs.py index b4a6b2abb..0fb143349 100644 --- a/test/test_hmc_gibbs.py +++ b/test/test_hmc_gibbs.py @@ -208,27 +208,24 @@ def model(probs, locs): @pytest.mark.parametrize('num_blocks', [1, 2, 50, 100]) -def test_subsample_gibbs_partitioning(kernel_cls, num_blocks): - # TODO: fix test to new API - def model(obs): - with plate('N', obs.shape[0], subsample_size=100) as idx: - numpyro.sample('x', dist.Normal(0, 1), obs=obs[idx]) - - obs = random.normal(random.PRNGKey(0), (10000,)) / 100 - kernel = HMCECS(kernel_cls(model), num_blocks=num_blocks) - state = kernel.init(random.PRNGKey(1), 10, None, model_args=(obs,), model_kwargs=None) - gibbs_sites = {'N': jnp.arange(100)} - - def potential_fn(z_gibbs, z_hmc): - return kernel.inner_kernel._potential_fn_gen(obs, _gibbs_sites=z_gibbs)(z_hmc) - - gibbs_fn = numpyro.infer.hmc_gibbs._subsample_gibbs_fn(potential_fn, kernel._plate_sizes, num_blocks) - new_gibbs_sites, _ = gibbs_fn(random.PRNGKey(2), gibbs_sites, state.hmc_state.z, - state.hmc_state.potential_energy) # accept_prob > .999 +def test_block_update_partitioning(num_blocks): + plate_size = 10000, 100 + + plate_sizes = {'N': plate_size} + gibbs_sites = {'N': jnp.arange(plate_size[1])} + gibbs_state = {} + + new_gibbs_sites, new_gibbs_state = numpyro.infer.hmc_gibbs._block_update(plate_sizes, + num_blocks, + random.PRNGKey(2), + gibbs_sites, + gibbs_state) block_size = 100 // num_blocks for name in gibbs_sites: assert block_size == jnp.not_equal(gibbs_sites[name], new_gibbs_sites[name]).sum() + assert gibbs_state == new_gibbs_state + def test_enum_subsample_smoke(): def model(data): From 5a6f629265d04df5d9c4cbec87b44e2c8a0b4da6 Mon Sep 17 00:00:00 2001 From: Ola Date: Mon, 8 Feb 2021 22:36:39 +0100 Subject: [PATCH 83/93] Fixed lint. --- test/test_hmc_gibbs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_hmc_gibbs.py b/test/test_hmc_gibbs.py index 0fb143349..d3718f879 100644 --- a/test/test_hmc_gibbs.py +++ b/test/test_hmc_gibbs.py @@ -13,7 +13,6 @@ import numpyro import numpyro.distributions as dist -from numpyro.handlers import plate from numpyro.infer import HMC, HMCECS, MCMC, NUTS, DiscreteHMCGibbs, HMCGibbs from numpyro.infer.hmc_gibbs import taylor_proxy From 29f37089263d08e265bf0630dfba27749cf9b783 Mon Sep 17 00:00:00 2001 From: Ola Date: Tue, 9 Feb 2021 21:37:58 +0100 Subject: [PATCH 84/93] Corrected taylor_proxy works in unconstraint space. Added docstring and changed Norm-Norm to 3-dim latent param. --- numpyro/examples/datasets.py | 2 +- numpyro/infer/hmc_gibbs.py | 58 ++++++++++++++++++++---------------- setup.py | 4 +-- test/test_hmc_gibbs.py | 10 +++---- 4 files changed, 41 insertions(+), 33 deletions(-) diff --git a/numpyro/examples/datasets.py b/numpyro/examples/datasets.py index 0e7b8d68e..18ace29e3 100644 --- a/numpyro/examples/datasets.py +++ b/numpyro/examples/datasets.py @@ -83,7 +83,7 @@ def _load_baseball(): def train_test_split(file): train, test, player_names = [], [], [] with open(file, 'r') as f: - csv_reader = csv.reader(f, delimiter='\t', quoting=csv.QUOTE_NONE) + csv_reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE) for row in csv_reader: player_names.append(row['FirstName'] + ' ' + row['LastName']) at_bats, hits = row['At-Bats'], row['Hits'] diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index d10796fa3..4939a0b34 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -11,6 +11,7 @@ from jax.scipy.special import expit import numpyro +from numpyro.distributions.transforms import biject_to from numpyro.handlers import block, condition, seed, substitute, trace from numpyro.infer.hmc import HMC from numpyro.infer.mcmc import MCMCKernel @@ -510,8 +511,7 @@ def __init__(self, inner_kernel, *, num_blocks=1, proxy=None): self.inner_kernel._model = _wrap_gibbs_state(self.inner_kernel._model) self._num_blocks = num_blocks self._proxy = proxy - self._method = 'perturbed' - self._grad_fn = None + self._grad_ = None def postprocess_fn(self, args, kwargs): def fn(z): @@ -535,7 +535,6 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): } self._gibbs_sites = list(self._subsample_plate_sizes.keys()) if self._proxy is not None: - rng_key, proxy_key, method_key = random.split(rng_key, 3) proxy_fn, gibbs_init, self._gibbs_update = self._proxy(self._prototype_trace, self._subsample_plate_sizes, self.model, @@ -572,20 +571,15 @@ def potential_fn(z_gibbs, gibbs_state, z_hmc): pe_new = potential_fn(z_gibbs_new, gibbs_state_new, state.hmc_state.z) accept_prob = jnp.clip(jnp.exp(pe - pe_new), a_max=1.0) transition = random.bernoulli(rng_key, accept_prob) - z_gibbs, gibbs_state, pe = cond(transition, - (z_gibbs_new, gibbs_state_new, pe_new), identity, - (z_gibbs, state.gibbs_state, pe), identity) - - state.hmc_state._replace(z_grad=self.grad_fn(partial(potential_fn, - z_gibbs, - gibbs_state))(state.hmc_state.z), - potential_energy=pe), - hmc_state = cond(transition, # only update gradient and potential energy when block is updated - state.hmc_state._replace( - z_grad=self.grad_fn(partial(potential_fn, z_gibbs, gibbs_state))(state.hmc_state.z), - potential_energy=pe), identity, - state.hmc_state, identity - ) + grad_ = jacfwd if self.inner_kernel._forward_mode_differentiation else grad + z_gibbs, gibbs_state, pe, z_grad = cond(transition, + (z_gibbs_new, gibbs_state_new, pe_new), + lambda vals: vals + (grad_(partial(potential_fn, + vals[0], + vals[1]))(state.hmc_state.z),), + (z_gibbs, state.gibbs_state, pe, state.hmc_state.z_grad), identity) + + hmc_state = state.hmc_state._replace(z_grad=z_grad, potential_energy=pe) model_kwargs["_gibbs_sites"] = z_gibbs model_kwargs["_gibbs_state"] = gibbs_state @@ -595,14 +589,14 @@ def potential_fn(z_gibbs, gibbs_state, z_hmc): return HMCECSState(z, hmc_state, rng_key, gibbs_state, accept_prob) @property - def grad_fn(self): + def grad_(self): if not hasattr(self, '_grad_mode'): if self.inner_kernel._forward_mode_differentiation: - grad_fn = jacfwd + grad_ = jacfwd else: - grad_fn = grad - self._grad_fn = grad_fn - return self._grad_fn + grad_ = grad + self._grad_ = grad_ + return self._grad_ def perturbed_method(subsample_plate_sizes, proxy_fn): @@ -629,9 +623,21 @@ def estimator(likelihoods, params, gibbs_state): def taylor_proxy(reference_params): + """ Control variate for unbiased log likelihood estimation using a Taylor expansion around a reference + parameter. Suggest for subsampling in [1]. + + :param reference_params: Model parameterization at MLE or MAP-estimate. + + ** References: ** + + [1] Towards scaling up Markov chainMonte Carlo: an adaptive subsampling approach + Bardenet., R., Doucet, A., Holmes, C. (2014) + """ + def construct_proxy_fn(prototype_trace, subsample_plate_sizes, model, model_args, model_kwargs, num_blocks=1): - ref_params = {name: _unconstrain_reparam(reference_params, site) for name, site in prototype_trace.items() - if name in reference_params} + ref_params = {name: biject_to(prototype_trace[name]["fn"].support).inv(value) + for name, value in reference_params.items()} + ref_params_flat, unravel_fn = ravel_pytree(ref_params) def log_likelihood(params_flat, subsample_indices=None): @@ -640,7 +646,9 @@ def log_likelihood(params_flat, subsample_indices=None): params = unravel_fn(params_flat) with warnings.catch_warnings(): warnings.simplefilter("ignore") - with block(), trace() as tr, substitute(data=subsample_indices), substitute(data=params): + + with block(), trace() as tr, substitute(data=subsample_indices), substitute(data=params), \ + substitute(substitute_fn=partial(_unconstrain_reparam, params)): model(*model_args, **model_kwargs) log_lik = {} diff --git a/setup.py b/setup.py index b1642cedb..79c086e28 100644 --- a/setup.py +++ b/setup.py @@ -33,9 +33,9 @@ author='Uber AI Labs', install_requires=[ # TODO: pin to a specific version for the release (until JAX's API becomes stable) - 'jax==0.2.8', + 'jax=>0.2.8', # check min version here: https://github.com/google/jax/blob/master/jax/lib/__init__.py#L26 - 'jaxlib==0.1.60', + 'jaxlib=>0.1.59', 'tqdm', ], extras_require={ diff --git a/test/test_hmc_gibbs.py b/test/test_hmc_gibbs.py index d3718f879..204ab2650 100644 --- a/test/test_hmc_gibbs.py +++ b/test/test_hmc_gibbs.py @@ -243,13 +243,13 @@ def model(data): @pytest.mark.parametrize('num_block', [1, 2, 50]) @pytest.mark.parametrize('subsample_size', [50, 150]) def test_hmcecs_normal_normal(kernel_cls, num_block, subsample_size): - true_loc = 0.3 + true_loc = jnp.array([0.3, 0.1, 0.9]) num_warmup, num_samples = 200, 200 - data = true_loc + dist.Normal().sample(random.PRNGKey(1), (10000,)) + data = true_loc + dist.Normal(jnp.zeros(3,), jnp.ones(3,)).sample(random.PRNGKey(1), (10000,)) def model(data, subsample_size): - mean = numpyro.sample('mean', dist.Normal()) - with numpyro.plate('batch', data.shape[0], subsample_size=subsample_size): + mean = numpyro.sample('mean', dist.Normal().expand((3,)).to_event(1)) + with numpyro.plate('batch', data.shape[0], dim=-2, subsample_size=subsample_size): sub_data = numpyro.subsample(data, 0) numpyro.sample("obs", dist.Normal(mean, 1), obs=sub_data) @@ -261,5 +261,5 @@ def model(data, subsample_size): mcmc.run(random.PRNGKey(0), data, subsample_size) samples = mcmc.get_samples() - assert_allclose(np.mean(mcmc.get_samples()['mean']), true_loc, atol=0.1) + assert_allclose(np.mean(mcmc.get_samples()['mean'], axis=0), true_loc, atol=0.1) assert len(samples['mean']) == num_samples From 2aef8565946f32f6fc643e3bba65908599457551 Mon Sep 17 00:00:00 2001 From: Ola Date: Tue, 9 Feb 2021 21:41:09 +0100 Subject: [PATCH 85/93] Flipped syntax for geq in setup.py --- setup.py | 4 ++-- test/test_hmc_gibbs.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 79c086e28..dcad69e77 100644 --- a/setup.py +++ b/setup.py @@ -33,9 +33,9 @@ author='Uber AI Labs', install_requires=[ # TODO: pin to a specific version for the release (until JAX's API becomes stable) - 'jax=>0.2.8', + 'jax>=0.2.8', # check min version here: https://github.com/google/jax/blob/master/jax/lib/__init__.py#L26 - 'jaxlib=>0.1.59', + 'jaxlib>=0.1.59', 'tqdm', ], extras_require={ diff --git a/test/test_hmc_gibbs.py b/test/test_hmc_gibbs.py index 204ab2650..409ad7f12 100644 --- a/test/test_hmc_gibbs.py +++ b/test/test_hmc_gibbs.py @@ -245,7 +245,7 @@ def model(data): def test_hmcecs_normal_normal(kernel_cls, num_block, subsample_size): true_loc = jnp.array([0.3, 0.1, 0.9]) num_warmup, num_samples = 200, 200 - data = true_loc + dist.Normal(jnp.zeros(3,), jnp.ones(3,)).sample(random.PRNGKey(1), (10000,)) + data = true_loc + dist.Normal(jnp.zeros(3, ), jnp.ones(3, )).sample(random.PRNGKey(1), (10000,)) def model(data, subsample_size): mean = numpyro.sample('mean', dist.Normal().expand((3,)).to_event(1)) From cc2669e0a97700cb65ec67280509c44058f55dd9 Mon Sep 17 00:00:00 2001 From: Ola Date: Tue, 9 Feb 2021 22:02:08 +0100 Subject: [PATCH 86/93] Made default device for covtype example cpu. --- examples/covtype.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/covtype.py b/examples/covtype.py index f96498161..702b0d151 100644 --- a/examples/covtype.py +++ b/examples/covtype.py @@ -127,7 +127,7 @@ def main(args): help='whether to run "HMCECS", "NUTS", "HMCECS", "SA" or "FlowHMCECS"') parser.add_argument('--dense-mass', action="store_true") parser.add_argument('--x64', action="store_true") - parser.add_argument('--device', default='gpu', type=str, help='use "cpu" or "gpu".') + parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".') args = parser.parse_args() numpyro.set_platform(args.device) From 63323ec28702241102dc6a72e18cc495692bbfc5 Mon Sep 17 00:00:00 2001 From: Ola Date: Wed, 10 Feb 2021 14:18:00 +0100 Subject: [PATCH 87/93] Added taylor proxy test. --- numpyro/distributions/continuous.py | 2 +- numpyro/infer/hmc_gibbs.py | 16 ++------- test/test_hmc_gibbs.py | 50 ++++++++++++++++++++++++++++- 3 files changed, 52 insertions(+), 16 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index c2a6ecdc9..3f6dbc5dc 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -962,7 +962,7 @@ def sample(self, key, sample_shape=()): @validate_sample def log_prob(self, value): - normalize_term = jnp.log(jnp.sqrt(2 * jnp.pi) * self.scale) # TODO:Added jnp.abs + normalize_term = jnp.log(jnp.sqrt(2 * jnp.pi) * self.scale) value_scaled = (value - self.loc) / self.scale return -0.5 * value_scaled ** 2 - normalize_term diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 4939a0b34..408af1e1b 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -511,7 +511,6 @@ def __init__(self, inner_kernel, *, num_blocks=1, proxy=None): self.inner_kernel._model = _wrap_gibbs_state(self.inner_kernel._model) self._num_blocks = num_blocks self._proxy = proxy - self._grad_ = None def postprocess_fn(self, args, kwargs): def fn(z): @@ -588,16 +587,6 @@ def potential_fn(z_gibbs, gibbs_state, z_hmc): z = {**z_gibbs, **hmc_state.z} return HMCECSState(z, hmc_state, rng_key, gibbs_state, accept_prob) - @property - def grad_(self): - if not hasattr(self, '_grad_mode'): - if self.inner_kernel._forward_mode_differentiation: - grad_ = jacfwd - else: - grad_ = grad - self._grad_ = grad_ - return self._grad_ - def perturbed_method(subsample_plate_sizes, proxy_fn): def estimator(likelihoods, params, gibbs_state): @@ -646,9 +635,8 @@ def log_likelihood(params_flat, subsample_indices=None): params = unravel_fn(params_flat) with warnings.catch_warnings(): warnings.simplefilter("ignore") - - with block(), trace() as tr, substitute(data=subsample_indices), substitute(data=params), \ - substitute(substitute_fn=partial(_unconstrain_reparam, params)): + params = {name: biject_to(prototype_trace[name]["fn"].support)(value) for name, value in params.items()} + with block(), trace() as tr, substitute(data=subsample_indices), substitute(data=params): model(*model_args, **model_kwargs) log_lik = {} diff --git a/test/test_hmc_gibbs.py b/test/test_hmc_gibbs.py index 409ad7f12..c89f24197 100644 --- a/test/test_hmc_gibbs.py +++ b/test/test_hmc_gibbs.py @@ -2,12 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 from functools import partial +import math import numpy as np from numpy.testing import assert_allclose import pytest -from jax import random +from jax import random, vmap, jacrev, hessian import jax.numpy as jnp from jax.scipy.linalg import cho_factor, cho_solve, inv, solve_triangular @@ -263,3 +264,50 @@ def model(data, subsample_size): samples = mcmc.get_samples() assert_allclose(np.mean(mcmc.get_samples()['mean'], axis=0), true_loc, atol=0.1) assert len(samples['mean']) == num_samples + + +@pytest.mark.parametrize('subsample_size', [5, 10, 15]) +def test_taylor_proxy_norm(subsample_size): + data_key, tr_key, rng_key = random.split(random.PRNGKey(0), 3) + ref_params = jnp.array([0.1, 0.5, -0.2]) + sigma = .1 + + data = ref_params + dist.Normal(jnp.zeros(3), jnp.ones(3)).sample(data_key, (100,)) + n, _ = data.shape + + def model(data, subsample_size): + mean = numpyro.sample('mean', dist.Normal(ref_params, jnp.ones_like(ref_params))) + with numpyro.plate('data', data.shape[0], subsample_size=subsample_size, dim=-2) as idx: + numpyro.sample('obs', dist.Normal(mean, sigma), obs=data[idx]) + + log_prob_fn = lambda params: vmap(dist.Normal(params, sigma).log_prob)(data).sum(-1) + log_prob = log_prob_fn(ref_params) + log_norm_jac = jacrev(log_prob_fn)(ref_params) + log_norm_hessian = hessian(log_prob_fn)(ref_params) + + tr = numpyro.handlers.trace(numpyro.handlers.seed(model, tr_key)).get_trace(data, subsample_size) + plate_sizes = {'data': (n, subsample_size)} + + proxy_constructer = numpyro.infer.hmc_gibbs.taylor_proxy({'mean': ref_params}) + proxy_fn, gibbs_init, gibbs_update = proxy_constructer(tr, plate_sizes, model, (data, subsample_size), {}) + + def taylor_expand_2nd_order(idx, pos): + return log_prob[idx] + (log_norm_jac[idx] @ pos) + .5 * (pos @ log_norm_hessian[idx]) @ pos + + def taylor_expand_2nd_order_sum(pos): + return log_prob.sum() + log_norm_jac.sum(0) @ pos + .5 * pos @ log_norm_hessian.sum(0) @ pos + + for _ in range(5): + split_key, perturbe_key, rng_key = random.split(rng_key, 3) + perturbe_params = ref_params + dist.Normal(.1, 0.01).sample(perturbe_key, ref_params.shape) + subsample_idx = random.randint(rng_key, (subsample_size,), 0, n) + gibbs_site = {'data': subsample_idx} + proxy_state = gibbs_init(None, gibbs_site) + actual_proxy_sum, actual_proxy_sub = proxy_fn({'data': perturbe_params}, ['data'], proxy_state) + assert_allclose(actual_proxy_sub['data'], + taylor_expand_2nd_order(subsample_idx, perturbe_params - ref_params), rtol=1e-5) + assert_allclose(actual_proxy_sum['data'], taylor_expand_2nd_order_sum(perturbe_params - ref_params), rtol=1e-5) + + +def test_estimate_likelihood(): + pass From 89a99a2b576d84e86fadfb5603ac751450fdfc50 Mon Sep 17 00:00:00 2001 From: Ola Date: Wed, 10 Feb 2021 23:45:43 +0100 Subject: [PATCH 88/93] Added test for variance. --- numpyro/infer/hmc_gibbs.py | 2 +- test/test_hmc_gibbs.py | 36 ++++++++++++++++++++++++++++-------- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 408af1e1b..eec683a09 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -615,7 +615,7 @@ def taylor_proxy(reference_params): """ Control variate for unbiased log likelihood estimation using a Taylor expansion around a reference parameter. Suggest for subsampling in [1]. - :param reference_params: Model parameterization at MLE or MAP-estimate. + :param dict reference_params: Model parameterization at MLE or MAP-estimate. ** References: ** diff --git a/test/test_hmc_gibbs.py b/test/test_hmc_gibbs.py index c89f24197..247df89a9 100644 --- a/test/test_hmc_gibbs.py +++ b/test/test_hmc_gibbs.py @@ -1,14 +1,12 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 - from functools import partial -import math import numpy as np from numpy.testing import assert_allclose import pytest -from jax import random, vmap, jacrev, hessian +from jax import hessian, jacrev, random, vmap import jax.numpy as jnp from jax.scipy.linalg import cho_factor, cho_solve, inv, solve_triangular @@ -288,8 +286,8 @@ def model(data, subsample_size): tr = numpyro.handlers.trace(numpyro.handlers.seed(model, tr_key)).get_trace(data, subsample_size) plate_sizes = {'data': (n, subsample_size)} - proxy_constructer = numpyro.infer.hmc_gibbs.taylor_proxy({'mean': ref_params}) - proxy_fn, gibbs_init, gibbs_update = proxy_constructer(tr, plate_sizes, model, (data, subsample_size), {}) + proxy_constructor = taylor_proxy({'mean': ref_params}) + proxy_fn, gibbs_init, gibbs_update = proxy_constructor(tr, plate_sizes, model, (data, subsample_size), {}) def taylor_expand_2nd_order(idx, pos): return log_prob[idx] + (log_norm_jac[idx] @ pos) + .5 * (pos @ log_norm_hessian[idx]) @ pos @@ -299,7 +297,7 @@ def taylor_expand_2nd_order_sum(pos): for _ in range(5): split_key, perturbe_key, rng_key = random.split(rng_key, 3) - perturbe_params = ref_params + dist.Normal(.1, 0.01).sample(perturbe_key, ref_params.shape) + perturbe_params = ref_params + dist.Normal(.1, 0.1).sample(perturbe_key, ref_params.shape) subsample_idx = random.randint(rng_key, (subsample_size,), 0, n) gibbs_site = {'data': subsample_idx} proxy_state = gibbs_init(None, gibbs_site) @@ -309,5 +307,27 @@ def taylor_expand_2nd_order_sum(pos): assert_allclose(actual_proxy_sum['data'], taylor_expand_2nd_order_sum(perturbe_params - ref_params), rtol=1e-5) -def test_estimate_likelihood(): - pass +@pytest.mark.parametrize('kernel_cls', [HMC, NUTS]) +def test_estimate_likelihood(kernel_cls): + data_key, tr_key, sub_key, rng_key = random.split(random.PRNGKey(0), 4) + ref_params = jnp.array([0.1, 0.5, -0.2]) + sigma = .1 + data = ref_params + dist.Normal(jnp.zeros(3), jnp.ones(3)).sample(data_key, (10_000,)) + n, _ = data.shape + num_warmup = 200 + num_samples = 1000 + num_blocks = 20 + + def model(data): + mean = numpyro.sample('mean', dist.Normal(ref_params, jnp.ones_like(ref_params))) + with numpyro.plate('N', data.shape[0], subsample_size=10, dim=-2) as idx: + numpyro.sample('obs', dist.Normal(mean, sigma), obs=data[idx]) + + proxy_fn = taylor_proxy({'mean': ref_params}) + kernel = HMCECS(kernel_cls(model), proxy=proxy_fn, num_blocks=num_blocks) + mcmc = MCMC(kernel, num_warmup, num_samples) + + mcmc.run(random.PRNGKey(0), data, extra_fields=['hmc_state.potential_energy']) + + pes = mcmc.get_extra_fields()['hmc_state.potential_energy'] + assert jnp.var(pes) < 2. From 0131f3e984b316454b27eab7ff7b181870f237d1 Mon Sep 17 00:00:00 2001 From: Ola Date: Wed, 10 Feb 2021 23:57:17 +0100 Subject: [PATCH 89/93] Fixed lint. --- test/test_hmc_gibbs.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/test/test_hmc_gibbs.py b/test/test_hmc_gibbs.py index 247df89a9..dfc684b2b 100644 --- a/test/test_hmc_gibbs.py +++ b/test/test_hmc_gibbs.py @@ -2,13 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 from functools import partial +import jax.numpy as jnp import numpy as np -from numpy.testing import assert_allclose import pytest - from jax import hessian, jacrev, random, vmap -import jax.numpy as jnp from jax.scipy.linalg import cho_factor, cho_solve, inv, solve_triangular +from numpy.testing import assert_allclose import numpyro import numpyro.distributions as dist @@ -278,10 +277,12 @@ def model(data, subsample_size): with numpyro.plate('data', data.shape[0], subsample_size=subsample_size, dim=-2) as idx: numpyro.sample('obs', dist.Normal(mean, sigma), obs=data[idx]) - log_prob_fn = lambda params: vmap(dist.Normal(params, sigma).log_prob)(data).sum(-1) - log_prob = log_prob_fn(ref_params) - log_norm_jac = jacrev(log_prob_fn)(ref_params) - log_norm_hessian = hessian(log_prob_fn)(ref_params) + def log_prob(params): + return vmap(dist.Normal(params, sigma).log_prob)(data).sum(-1) + + log_prob = log_prob(ref_params) + log_norm_jac = jacrev(log_prob)(ref_params) + log_norm_hessian = hessian(log_prob)(ref_params) tr = numpyro.handlers.trace(numpyro.handlers.seed(model, tr_key)).get_trace(data, subsample_size) plate_sizes = {'data': (n, subsample_size)} From c697e913726f29b7ff0f1d0616bae88cbcc40a0d Mon Sep 17 00:00:00 2001 From: Ola Date: Thu, 11 Feb 2021 00:28:37 +0100 Subject: [PATCH 90/93] Added all log_density computation to test_estimate_likelihood and assert variance on difference between sub and all. --- test/test_hmc_gibbs.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/test/test_hmc_gibbs.py b/test/test_hmc_gibbs.py index dfc684b2b..1ee783af2 100644 --- a/test/test_hmc_gibbs.py +++ b/test/test_hmc_gibbs.py @@ -13,6 +13,7 @@ import numpyro.distributions as dist from numpyro.infer import HMC, HMCECS, MCMC, NUTS, DiscreteHMCGibbs, HMCGibbs from numpyro.infer.hmc_gibbs import taylor_proxy +from numpyro.infer.util import log_density def _linear_regression_gibbs_fn(X, XX, XY, Y, rng_key, gibbs_sites, hmc_sites): @@ -308,6 +309,7 @@ def taylor_expand_2nd_order_sum(pos): assert_allclose(actual_proxy_sum['data'], taylor_expand_2nd_order_sum(perturbe_params - ref_params), rtol=1e-5) +@pytest.mark.filterwarnings('ignore::UserWarning') @pytest.mark.parametrize('kernel_cls', [HMC, NUTS]) def test_estimate_likelihood(kernel_cls): data_key, tr_key, sub_key, rng_key = random.split(random.PRNGKey(0), 4) @@ -316,12 +318,12 @@ def test_estimate_likelihood(kernel_cls): data = ref_params + dist.Normal(jnp.zeros(3), jnp.ones(3)).sample(data_key, (10_000,)) n, _ = data.shape num_warmup = 200 - num_samples = 1000 + num_samples = 200 num_blocks = 20 def model(data): mean = numpyro.sample('mean', dist.Normal(ref_params, jnp.ones_like(ref_params))) - with numpyro.plate('N', data.shape[0], subsample_size=10, dim=-2) as idx: + with numpyro.plate('N', data.shape[0], subsample_size=100, dim=-2) as idx: numpyro.sample('obs', dist.Normal(mean, sigma), obs=data[idx]) proxy_fn = taylor_proxy({'mean': ref_params}) @@ -331,4 +333,7 @@ def model(data): mcmc.run(random.PRNGKey(0), data, extra_fields=['hmc_state.potential_energy']) pes = mcmc.get_extra_fields()['hmc_state.potential_energy'] - assert jnp.var(pes) < 2. + samples = mcmc.get_samples() + pes_full = vmap(lambda sample: log_density(model, (data,), {}, {**sample, **{'N': jnp.arange(n)}})[0])(samples) + + assert jnp.var(jnp.exp(-pes - pes_full)) < 1. From 400389ffc5354779643adeb5a4b5b1fcd68ea9fd Mon Sep 17 00:00:00 2001 From: Ola Date: Thu, 11 Feb 2021 00:30:43 +0100 Subject: [PATCH 91/93] Fixed typo and isort. --- numpyro/infer/hmc_gibbs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index eec683a09..ef6881123 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -472,7 +472,7 @@ class HMCECS(HMCGibbs): Quiroz, M., Kohn, R., Villani, M., & Tran, M. N. (2018) 3. *The Block Pseudo-Margional Sampler*, Tran, M.-N., Kohn, R., Quiroz, M. Villani, M. (2017) - 4. *The Fundamental Incompatibility ofScalable Hamiltonian Monte Carlo and Naive Data Subsampling* + 4. *The Fundamental Incompatibility of Scalable Hamiltonian Monte Carlo and Naive Data Subsampling* Betancourt, M. (2015) :param inner_kernel: One of :class:`~numpyro.infer.hmc.HMC` or :class:`~numpyro.infer.hmc.NUTS`. From 2e1dddb94d4b60aeb95653c5653230bc8761c773 Mon Sep 17 00:00:00 2001 From: Ola Date: Thu, 11 Feb 2021 00:36:44 +0100 Subject: [PATCH 92/93] isort not included in previous commit. --- test/test_hmc_gibbs.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_hmc_gibbs.py b/test/test_hmc_gibbs.py index 1ee783af2..e8d02d301 100644 --- a/test/test_hmc_gibbs.py +++ b/test/test_hmc_gibbs.py @@ -2,12 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 from functools import partial -import jax.numpy as jnp import numpy as np +from numpy.testing import assert_allclose import pytest + from jax import hessian, jacrev, random, vmap +import jax.numpy as jnp from jax.scipy.linalg import cho_factor, cho_solve, inv, solve_triangular -from numpy.testing import assert_allclose import numpyro import numpyro.distributions as dist From 6474df3ff5d254f737d1245cf6301429d159440c Mon Sep 17 00:00:00 2001 From: Ola Date: Thu, 11 Feb 2021 07:54:39 +0100 Subject: [PATCH 93/93] Fixed shadowing log_prob. --- test/test_hmc_gibbs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_hmc_gibbs.py b/test/test_hmc_gibbs.py index e8d02d301..e15eb47e0 100644 --- a/test/test_hmc_gibbs.py +++ b/test/test_hmc_gibbs.py @@ -279,12 +279,12 @@ def model(data, subsample_size): with numpyro.plate('data', data.shape[0], subsample_size=subsample_size, dim=-2) as idx: numpyro.sample('obs', dist.Normal(mean, sigma), obs=data[idx]) - def log_prob(params): + def log_prob_fn(params): return vmap(dist.Normal(params, sigma).log_prob)(data).sum(-1) - log_prob = log_prob(ref_params) - log_norm_jac = jacrev(log_prob)(ref_params) - log_norm_hessian = hessian(log_prob)(ref_params) + log_prob = log_prob_fn(ref_params) + log_norm_jac = jacrev(log_prob_fn)(ref_params) + log_norm_hessian = hessian(log_prob_fn)(ref_params) tr = numpyro.handlers.trace(numpyro.handlers.seed(model, tr_key)).get_trace(data, subsample_size) plate_sizes = {'data': (n, subsample_size)}