From 16484dd2422c0fde7cfd6b8f7211daf90efc6205 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 5 Sep 2022 20:58:59 +0200 Subject: [PATCH 1/2] Sample remaining variables with the NUTS sampler --- aemcmc/basic.py | 29 +++++++++++++---- aemcmc/nuts.py | 77 +++++++++++++++++++++---------------------- tests/test_basic.py | 79 +++++++++++++++++++++++++++++++++++++++------ tests/test_nuts.py | 55 ++++++++----------------------- 4 files changed, 142 insertions(+), 98 deletions(-) diff --git a/aemcmc/basic.py b/aemcmc/basic.py index 7207e7d..2e22447 100644 --- a/aemcmc/basic.py +++ b/aemcmc/basic.py @@ -1,10 +1,12 @@ from typing import Dict, Tuple +import aesara.tensor as at from aesara.graph.basic import Variable from aesara.graph.fg import FunctionGraph from aesara.tensor.random.utils import RandomStream from aesara.tensor.var import TensorVariable +from aemcmc.nuts import construct_nuts_sampler from aemcmc.rewriting import ( SamplerTracker, construct_ir_fgraph, @@ -19,6 +21,7 @@ def construct_sampler( Dict[TensorVariable, TensorVariable], Dict[Variable, Variable], Dict[TensorVariable, TensorVariable], + Dict[str, TensorVariable], ]: r"""Eagerly construct a sampler for a given set of observed variables and their observations. @@ -55,6 +58,7 @@ def construct_sampler( # TODO FIXME: Get/extract `Scan`-generated updates posterior_updates: Dict[Variable, Variable] = {} + parameters: Dict[str, TensorVariable] = {} rvs_without_samplers = set() for rv in fgraph.outputs: @@ -103,14 +107,26 @@ def construct_sampler( updates = dict(zip(update_keys, update_values)) posterior_updates.update(updates) + # We use the NUTS sampler for the remaining variables + # TODO: NUTS cannot handle RV with discrete supports + # TODO: Track the transformations made by NUTS? It would make more sense to + # apply the transforms on the probabilistic graph, in which case we would + # only need to return the transformed graph. if rvs_without_samplers: - # TODO: Assign NUTS to these - raise NotImplementedError( - f"Could not find a posterior samplers for {rvs_without_samplers}" - ) + inverse_mass_matrix = at.vector("inverse_mass_matrix") + step_size = at.scalar("step_size") + parameters["step_size"] = step_size + parameters["inverse_mass_matrix"] = inverse_mass_matrix + + # We condition on the updated values of the other rvs + rvs_to_values = {rv: rvs_to_init_vals[rv] for rv in rvs_without_samplers} + rvs_to_values.update(posterior_sample_steps) - # TODO: Track/handle "auxiliary/augmentation" variables introduced by sample - # steps? + nuts_sample_steps, updates = construct_nuts_sampler( + srng, rvs_without_samplers, rvs_to_values, inverse_mass_matrix, step_size + ) + posterior_sample_steps.update(nuts_sample_steps) + posterior_updates.update(updates) return ( { @@ -120,4 +136,5 @@ def construct_sampler( }, posterior_updates, {new_to_old_rvs[rv]: init_var for rv, init_var in rvs_to_init_vals.items()}, + parameters, ) diff --git a/aemcmc/nuts.py b/aemcmc/nuts.py index 6ca7d82..4ea3a3f 100644 --- a/aemcmc/nuts.py +++ b/aemcmc/nuts.py @@ -10,10 +10,9 @@ _default_transformed_rv, ) from aesara.tensor.random import RandomStream +from aesara.tensor.random.op import RandomVariable from aesara.tensor.var import TensorVariable -from aemcmc.utils import ModelInfo - NUTSStateType = Tuple[TensorVariable, TensorVariable, TensorVariable] NUTSKernelType = Callable[ [NUTSStateType], @@ -28,12 +27,13 @@ ] -def nuts( +def construct_nuts_sampler( srng: RandomStream, - model: ModelInfo, + to_sample_rvs, # RVs to sample + rvs_to_values, # All RVs to values inverse_mass_matrix: TensorVariable, step_size: TensorVariable, -) -> Tuple[NUTSStateType, NUTSKernelType]: +) -> Tuple[Dict[RandomVariable, TensorVariable], Dict]: """Build a NUTS kernel and the initial state. This function currently assumes that we will update the value of all of the @@ -41,38 +41,40 @@ def nuts( Parameters ---------- - model - The Aesara model whose posterior distribution we wish to sample from - passed as a `ModelInfo` instance. + rvs_to_samples + A sequence that contains the random variables whose posterior + distribution we wish to sample from. + rvs_to_values + A dictionary that maps all random variables in the model (including + those not sampled with NUTS) to their value variable. step_size The step size used in the symplectic integrator. inverse_mass_matrix One or two-dimensional array used as the inverse mass matrix that defines the euclidean metric. + Returns + ------- + A NUTS sampling step for each variable. + """ - unobserved_rvs = tuple( - rv for rv in model.rvs_to_values.keys() if rv not in model.observed_rvs - ) - unobserved_rvs_to_values = {rv: model.rvs_to_values[rv] for rv in unobserved_rvs} - observed_vvs = tuple(model.rvs_to_values[rv] for rv in model.observed_rvs) # Algorithms in the HMC family can more easily explore the posterior distribution # when the support of each random variable's distribution is unconstrained. # First we build the logprob graph in the transformed space. transforms = { - vv: get_transform(rv) if rv in unobserved_rvs else None - for rv, vv in model.rvs_to_values.items() + vv: get_transform(rv) for rv, vv in rvs_to_values.items() if rv in to_sample_rvs } logprob_sum = joint_logprob( - model.rvs_to_values, extra_rewrites=TransformValuesRewrite(transforms) + rvs_to_values, extra_rewrites=TransformValuesRewrite(transforms) ) # Then we transform the value variables. transformed_vvs = { vv: transform_forward(rv, vv, transforms[vv]) - for rv, vv in unobserved_rvs_to_values.items() + for rv, vv in rvs_to_values.items() + if rv in to_sample_rvs } # Algorithms in `aehmc` work with flat arrays and we need to ravel parameter @@ -80,45 +82,38 @@ def nuts( rp_map = RaveledParamsMap(tuple(transformed_vvs.values())) rp_map.ref_params = tuple(transformed_vvs.keys()) + # Make shared variables for all the non-NUTS sampled terms + non_nuts_vals = { + vv: vv for rv, vv in rvs_to_values.items() if rv not in to_sample_rvs + } + # We can now write the logprob function associated with the model and build # the NUTS kernel. def logprob_fn(q): unraveled_q = rp_map.unravel_params(q) - unraveled_q.update({vv: vv for vv in observed_vvs}) - + unraveled_q.update(non_nuts_vals) memo = aesara.graph.basic.clone_get_equiv( [], [logprob_sum], copy_inputs=False, copy_orphans=False, memo=unraveled_q ) return memo[logprob_sum] + # Finally we build the NUTS sampling step nuts_kernel = aehmc_nuts.new_kernel(srng, logprob_fn) - def step_fn(state): - """Take one step with the NUTS kernel. - - The NUTS kernel works with a state that contains the current value of - the variables, but also the current value of the potential and its - gradient, and we need to carry this state forward. We also return the - unraveled parameter values in both the original and transformed space. - - """ - (new_q, new_pe, new_peg, *_), updates = nuts_kernel( - *state, step_size, inverse_mass_matrix - ) - new_state = (new_q, new_pe, new_peg) - transformed_params = rp_map.unravel_params(new_q) - params = { - vv: transform_backward(rv, transformed_params[vv], transforms[vv]) - for rv, vv in unobserved_rvs_to_values.items() - } - return (new_state, params, transformed_params), updates - - # Finally we build the initial state initial_q = rp_map.ravel_params(tuple(transformed_vvs.values())) initial_state = aehmc_nuts.new_state(initial_q, logprob_fn) - return initial_state, step_fn + # TODO: Does that lead to wasteful computation? Or is it handled by Aesara? + (new_q, *_), updates = nuts_kernel(*initial_state, step_size, inverse_mass_matrix) + transformed_params = rp_map.unravel_params(new_q) + params = { + rv: transform_backward(rv, transformed_params[vv], transforms[vv]) + for rv, vv in rvs_to_values.items() + if rv in to_sample_rvs + } + + return params, updates def get_transform(rv: TensorVariable): diff --git a/tests/test_basic.py b/tests/test_basic.py index b536bca..2fe4470 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -25,9 +25,12 @@ def test_closed_form_posterior_beta_binomial(): y_vv = Y_rv.clone() y_vv.name = "y" - sample_steps, updates, initial_values = construct_sampler({Y_rv: y_vv}, srng) + sample_steps, updates, initial_values, parameters = construct_sampler( + {Y_rv: y_vv}, srng + ) p_posterior_step = sample_steps[p_rv] + assert len(parameters) == 0 assert isinstance(p_posterior_step.owner.op, BetaRV) @@ -43,24 +46,80 @@ def test_closed_form_posterior_gamma_poisson(): y_vv = Y_rv.clone() y_vv.name = "y" - sample_steps, updates, initial_values = construct_sampler({Y_rv: y_vv}, srng) + sample_steps, updates, initial_values, parameters = construct_sampler( + {Y_rv: y_vv}, srng + ) p_posterior_step = sample_steps[l_rv] + assert len(parameters) == 0 assert isinstance(p_posterior_step.owner.op, GammaRV) -def test_no_samplers(): +@pytest.mark.parametrize("size", [1, (1,), (2, 3)]) +def test_nuts_sampler_single_variable(size): + """We make sure that the NUTS sampler compiles and updates the chains for + different sizes of the random variable. + + """ srng = RandomStream(0) - size = at.lscalar("size") - tau_rv = srng.halfcauchy(0, 1, name="tau") - Y_rv = srng.halfcauchy(0, tau_rv, size=size, name="Y") + tau_rv = srng.halfcauchy(0, 1, size=size, name="tau") + Y_rv = srng.halfcauchy(0, tau_rv, name="Y") + + y_vv = Y_rv.clone() + y_vv.name = "y" + + sample_steps, updates, initial_values, parameters = construct_sampler( + {Y_rv: y_vv}, srng + ) + + assert len(parameters) == 2 + assert len(sample_steps) == 1 + + tau_post_step = sample_steps[tau_rv] + assert y_vv in graph_inputs([tau_post_step]) + + inputs = [ + initial_values[tau_rv], + y_vv, + parameters["step_size"], + parameters["inverse_mass_matrix"], + ] + output = tau_post_step + sample_step = aesara.function(inputs, output) + + tau_val = np.ones(shape=size) + y_val = np.ones(shape=size) + step_size = 1e-1 + inverse_mass_matrix = np.ones(shape=size).flatten() + res = sample_step(tau_val, y_val, step_size, inverse_mass_matrix) + with pytest.raises(AssertionError): + np.testing.assert_equal(tau_val, res) + + +def test_nuts_with_closed_form(): + """Make sure that the NUTS sampler works in combination with closed-form posteriors.""" + srng = RandomStream(0) + + alpha_tt = at.scalar("alpha") + beta_rv = srng.halfnormal(1.0, name="beta") + l_rv = srng.gamma(alpha_tt, beta_rv, name="p") + + Y_rv = srng.poisson(l_rv, name="Y") y_vv = Y_rv.clone() y_vv.name = "y" - with pytest.raises(NotImplementedError): - construct_sampler({Y_rv: y_vv}, srng) + sample_steps, updates, initial_values, parameters = construct_sampler( + {Y_rv: y_vv}, srng + ) + + p_posterior_step = sample_steps[l_rv] + assert len(parameters) == 2 + assert len(initial_values) == 2 + assert isinstance(p_posterior_step.owner.op, GammaRV) + + assert beta_rv in sample_steps def test_create_gibbs(): @@ -87,7 +146,9 @@ def test_create_gibbs(): sample_vars = [tau_rv, lmbda_rv, beta_rv, h_rv] - sample_steps, updates, initial_values = construct_sampler({Y_rv: y_vv}, srng) + sample_steps, updates, initial_values, parameters = construct_sampler( + {Y_rv: y_vv}, srng + ) assert len(sample_steps) == 4 assert updates diff --git a/tests/test_nuts.py b/tests/test_nuts.py index 75d0921..175187f 100644 --- a/tests/test_nuts.py +++ b/tests/test_nuts.py @@ -1,12 +1,8 @@ import aesara import aesara.tensor as at -import numpy as np -import pytest -from aeppl import joint_logprob from aesara.tensor.random import RandomStream -from aemcmc.nuts import nuts -from aemcmc.utils import ModelInfo +from aemcmc.nuts import construct_nuts_sampler def test_nuts(): @@ -19,47 +15,22 @@ def test_nuts(): mu_vv.name = "mu_vv" sigma_vv = sigma_rv.clone() sigma_vv.name = "sigma_vv" - Y_at = at.scalar(name="Y_at") + y_vv = Y_rv.clone() - rvs_to_values = {mu_rv: mu_vv, sigma_rv: sigma_vv, Y_rv: Y_at} - model = ModelInfo((Y_rv,), rvs_to_values, (), ()) + to_sample_rvs = [mu_rv, sigma_rv] + rvs_to_values = {mu_rv: mu_vv, sigma_rv: sigma_vv, Y_rv: y_vv} inverse_mass_matrix = at.as_tensor([1.0, 1.0]) step_size = at.as_tensor(0.1) - state_at, step_fn = nuts(srng, model, inverse_mass_matrix, step_size) - - # Make sure that the state is properly initialized - state_fn = aesara.function((mu_vv, sigma_vv, Y_at), state_at) - state = state_fn(1.0, 1.0, 1.0) - - position = state[0] - assert position[0] == 1.0 - assert position[1] != 1.0 # The state is in the transformed space - logprob = joint_logprob(rvs_to_values) - assert state[1] == -1 * logprob.eval({Y_at: 1.0, mu_vv: 1.0, sigma_vv: 1.0}) - - # Make sure that the step function updates the state - (new_state, params, transformed_params), updates = step_fn(state_at) - update_fn = aesara.function( - (mu_vv, sigma_vv, Y_at), - ( - params[mu_vv], - params[sigma_vv], - transformed_params[mu_vv], - transformed_params[sigma_vv], - ), - updates=updates, + state_at, step_fn = construct_nuts_sampler( + srng, to_sample_rvs, rvs_to_values, inverse_mass_matrix, step_size ) - new_position = update_fn(1.0, 1.0, 1.0) - untransformed_position = np.array([new_position[0], new_position[1]]) - transformed_position = np.array([new_position[2], new_position[3]]) - - # Did the chain advance? - np.testing.assert_raises( - AssertionError, np.testing.assert_equal, untransformed_position, [1.0, 1.0] - ) + # Make sure that the state is properly initialized + sample_steps = [state_at[rv] for rv in to_sample_rvs] + state_fn = aesara.function((mu_vv, sigma_vv, y_vv), sample_steps) + new_state = state_fn(1.0, 1.0, 1.0) - # Are the transformations applied correctly? - assert untransformed_position[0] == pytest.approx(transformed_position[0]) # mu - assert untransformed_position[1] != transformed_position[1] # sigma + # Make sure that the state has advanced + assert new_state[0] != 1.0 + assert new_state[1] != 1.0 From 3a783e41ba87419f53fc55de303167304064963d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 13 Oct 2022 15:03:41 +0200 Subject: [PATCH 2/2] Let NUTS initialize its parameters --- aemcmc/basic.py | 11 +++-------- aemcmc/nuts.py | 23 ++++++++++++++--------- tests/test_nuts.py | 20 +++++++++++++------- 3 files changed, 30 insertions(+), 24 deletions(-) diff --git a/aemcmc/basic.py b/aemcmc/basic.py index 2e22447..3c6a966 100644 --- a/aemcmc/basic.py +++ b/aemcmc/basic.py @@ -1,6 +1,5 @@ from typing import Dict, Tuple -import aesara.tensor as at from aesara.graph.basic import Variable from aesara.graph.fg import FunctionGraph from aesara.tensor.random.utils import RandomStream @@ -113,20 +112,16 @@ def construct_sampler( # apply the transforms on the probabilistic graph, in which case we would # only need to return the transformed graph. if rvs_without_samplers: - inverse_mass_matrix = at.vector("inverse_mass_matrix") - step_size = at.scalar("step_size") - parameters["step_size"] = step_size - parameters["inverse_mass_matrix"] = inverse_mass_matrix - # We condition on the updated values of the other rvs rvs_to_values = {rv: rvs_to_init_vals[rv] for rv in rvs_without_samplers} rvs_to_values.update(posterior_sample_steps) - nuts_sample_steps, updates = construct_nuts_sampler( - srng, rvs_without_samplers, rvs_to_values, inverse_mass_matrix, step_size + nuts_sample_steps, updates, nuts_parameters = construct_nuts_sampler( + srng, rvs_without_samplers, rvs_to_values ) posterior_sample_steps.update(nuts_sample_steps) posterior_updates.update(updates) + parameters.update(nuts_parameters) return ( { diff --git a/aemcmc/nuts.py b/aemcmc/nuts.py index 4ea3a3f..22a608c 100644 --- a/aemcmc/nuts.py +++ b/aemcmc/nuts.py @@ -1,6 +1,7 @@ from typing import Callable, Dict, Tuple import aesara +import aesara.tensor as at from aehmc import nuts as aehmc_nuts from aehmc.utils import RaveledParamsMap from aeppl import joint_logprob @@ -9,6 +10,7 @@ TransformValuesRewrite, _default_transformed_rv, ) +from aesara import config from aesara.tensor.random import RandomStream from aesara.tensor.random.op import RandomVariable from aesara.tensor.var import TensorVariable @@ -31,9 +33,7 @@ def construct_nuts_sampler( srng: RandomStream, to_sample_rvs, # RVs to sample rvs_to_values, # All RVs to values - inverse_mass_matrix: TensorVariable, - step_size: TensorVariable, -) -> Tuple[Dict[RandomVariable, TensorVariable], Dict]: +) -> Tuple[Dict[RandomVariable, TensorVariable], Dict, Dict[str, TensorVariable]]: """Build a NUTS kernel and the initial state. This function currently assumes that we will update the value of all of the @@ -47,11 +47,6 @@ def construct_nuts_sampler( rvs_to_values A dictionary that maps all random variables in the model (including those not sampled with NUTS) to their value variable. - step_size - The step size used in the symplectic integrator. - inverse_mass_matrix - One or two-dimensional array used as the inverse mass matrix that - defines the euclidean metric. Returns ------- @@ -104,6 +99,12 @@ def logprob_fn(q): initial_q = rp_map.ravel_params(tuple(transformed_vvs.values())) initial_state = aehmc_nuts.new_state(initial_q, logprob_fn) + # Initialize the parameter values + step_size = at.scalar("step_size", dtype=config.floatX) + inverse_mass_matrix = at.tensor( + name="inverse_mass_matrix", shape=initial_q.type.shape, dtype=config.floatX + ) + # TODO: Does that lead to wasteful computation? Or is it handled by Aesara? (new_q, *_), updates = nuts_kernel(*initial_state, step_size, inverse_mass_matrix) transformed_params = rp_map.unravel_params(new_q) @@ -113,7 +114,11 @@ def logprob_fn(q): if rv in to_sample_rvs } - return params, updates + return ( + params, + updates, + {"step_size": step_size, "inverse_mass_matrix": inverse_mass_matrix}, + ) def get_transform(rv: TensorVariable): diff --git a/tests/test_nuts.py b/tests/test_nuts.py index 175187f..d10e313 100644 --- a/tests/test_nuts.py +++ b/tests/test_nuts.py @@ -1,5 +1,4 @@ import aesara -import aesara.tensor as at from aesara.tensor.random import RandomStream from aemcmc.nuts import construct_nuts_sampler @@ -20,16 +19,23 @@ def test_nuts(): to_sample_rvs = [mu_rv, sigma_rv] rvs_to_values = {mu_rv: mu_vv, sigma_rv: sigma_vv, Y_rv: y_vv} - inverse_mass_matrix = at.as_tensor([1.0, 1.0]) - step_size = at.as_tensor(0.1) - state_at, step_fn = construct_nuts_sampler( - srng, to_sample_rvs, rvs_to_values, inverse_mass_matrix, step_size + state_at, step_fn, parameters = construct_nuts_sampler( + srng, to_sample_rvs, rvs_to_values ) # Make sure that the state is properly initialized sample_steps = [state_at[rv] for rv in to_sample_rvs] - state_fn = aesara.function((mu_vv, sigma_vv, y_vv), sample_steps) - new_state = state_fn(1.0, 1.0, 1.0) + state_fn = aesara.function( + ( + mu_vv, + sigma_vv, + y_vv, + parameters["step_size"], + parameters["inverse_mass_matrix"], + ), + sample_steps, + ) + new_state = state_fn(1.0, 1.0, 1.0, 0.01, [1.0, 1.0]) # Make sure that the state has advanced assert new_state[0] != 1.0