diff --git a/aemcmc/basic.py b/aemcmc/basic.py index 7207e7d..a40da72 100644 --- a/aemcmc/basic.py +++ b/aemcmc/basic.py @@ -1,10 +1,12 @@ from typing import Dict, Tuple +import aesara.tensor as at from aesara.graph.basic import Variable from aesara.graph.fg import FunctionGraph from aesara.tensor.random.utils import RandomStream from aesara.tensor.var import TensorVariable +from aemcmc.nuts import construct_nuts_sampler from aemcmc.rewriting import ( SamplerTracker, construct_ir_fgraph, @@ -103,14 +105,26 @@ def construct_sampler( updates = dict(zip(update_keys, update_values)) posterior_updates.update(updates) + # We use the NUTS sampler for the remaining variables + # TODO: NUTS cannot handle RV with discrete supports + # TODO: Track the transformations made by NUTS? It would make more sense to + # apply the transforms on the probabilistic graph, in which case we would + # only need to return the transformed graph. if rvs_without_samplers: - # TODO: Assign NUTS to these - raise NotImplementedError( - f"Could not find a posterior samplers for {rvs_without_samplers}" - ) - # TODO: Track/handle "auxiliary/augmentation" variables introduced by sample - # steps? + # TODO: Parameters should be contained in a `SamplingStep` data structure. + inverse_mass_matrix = at.vector("inverse_mass_matrix") + step_size = at.scalar("step_size") + + # We use the updated values of the other rvs to compute the logprob + rvs_to_values = {rv: rvs_to_init_vals[rv] for rv in rvs_without_samplers} + rvs_to_values.update(posterior_sample_steps) + + nuts_sample_steps, updates = construct_nuts_sampler( + srng, rvs_without_samplers, rvs_to_values, inverse_mass_matrix, step_size + ) + posterior_sample_steps.update(nuts_sample_steps) + posterior_updates.update(updates) return ( { diff --git a/aemcmc/nuts.py b/aemcmc/nuts.py index 6ca7d82..bf611d2 100644 --- a/aemcmc/nuts.py +++ b/aemcmc/nuts.py @@ -10,10 +10,9 @@ _default_transformed_rv, ) from aesara.tensor.random import RandomStream +from aesara.tensor.random.op import RandomVariable from aesara.tensor.var import TensorVariable -from aemcmc.utils import ModelInfo - NUTSStateType = Tuple[TensorVariable, TensorVariable, TensorVariable] NUTSKernelType = Callable[ [NUTSStateType], @@ -28,12 +27,18 @@ ] -def nuts( +# As avaible inputs (basic.py) we have +# - rvs_to_init_vals, a dictionary that maps RVs to corresponding +# value variables; +# We also need: +# - The other variables' values to update the logprob's value +def construct_nuts_sampler( srng: RandomStream, - model: ModelInfo, + to_sample_rvs, # RVs to sample + rvs_to_values, # All RVs to values inverse_mass_matrix: TensorVariable, step_size: TensorVariable, -) -> Tuple[NUTSStateType, NUTSKernelType]: +) -> Tuple[Dict[RandomVariable, TensorVariable], Dict]: """Build a NUTS kernel and the initial state. This function currently assumes that we will update the value of all of the @@ -41,38 +46,40 @@ def nuts( Parameters ---------- - model - The Aesara model whose posterior distribution we wish to sample from - passed as a `ModelInfo` instance. + rvs_to_samples + A sequence that contains the random variables whose posterior + distribution we wish to sample from. + rvs_to_values + A dictionary that maps all random variables in the model (including + those not sampled with NUTS) to their value variable. step_size The step size used in the symplectic integrator. inverse_mass_matrix One or two-dimensional array used as the inverse mass matrix that defines the euclidean metric. + Returns + ------- + A NUTS sampling step for each variable. + """ - unobserved_rvs = tuple( - rv for rv in model.rvs_to_values.keys() if rv not in model.observed_rvs - ) - unobserved_rvs_to_values = {rv: model.rvs_to_values[rv] for rv in unobserved_rvs} - observed_vvs = tuple(model.rvs_to_values[rv] for rv in model.observed_rvs) # Algorithms in the HMC family can more easily explore the posterior distribution # when the support of each random variable's distribution is unconstrained. # First we build the logprob graph in the transformed space. transforms = { - vv: get_transform(rv) if rv in unobserved_rvs else None - for rv, vv in model.rvs_to_values.items() + vv: get_transform(rv) for rv, vv in rvs_to_values.items() if rv in to_sample_rvs } logprob_sum = joint_logprob( - model.rvs_to_values, extra_rewrites=TransformValuesRewrite(transforms) + rvs_to_values, extra_rewrites=TransformValuesRewrite(transforms) ) # Then we transform the value variables. transformed_vvs = { vv: transform_forward(rv, vv, transforms[vv]) - for rv, vv in unobserved_rvs_to_values.items() + for rv, vv in rvs_to_values.items() + if rv in to_sample_rvs } # Algorithms in `aehmc` work with flat arrays and we need to ravel parameter @@ -80,11 +87,16 @@ def nuts( rp_map = RaveledParamsMap(tuple(transformed_vvs.values())) rp_map.ref_params = tuple(transformed_vvs.keys()) + # Make shared variables for all the non-NUTS sampled terms + non_nuts_vals = { + vv: vv for rv, vv in rvs_to_values.items() if rv not in to_sample_rvs + } + # We can now write the logprob function associated with the model and build # the NUTS kernel. def logprob_fn(q): unraveled_q = rp_map.unravel_params(q) - unraveled_q.update({vv: vv for vv in observed_vvs}) + unraveled_q.update(non_nuts_vals) memo = aesara.graph.basic.clone_get_equiv( [], [logprob_sum], copy_inputs=False, copy_orphans=False, memo=unraveled_q @@ -92,33 +104,22 @@ def logprob_fn(q): return memo[logprob_sum] + # Finally we build the NUTS sampling step nuts_kernel = aehmc_nuts.new_kernel(srng, logprob_fn) - def step_fn(state): - """Take one step with the NUTS kernel. - - The NUTS kernel works with a state that contains the current value of - the variables, but also the current value of the potential and its - gradient, and we need to carry this state forward. We also return the - unraveled parameter values in both the original and transformed space. - - """ - (new_q, new_pe, new_peg, *_), updates = nuts_kernel( - *state, step_size, inverse_mass_matrix - ) - new_state = (new_q, new_pe, new_peg) - transformed_params = rp_map.unravel_params(new_q) - params = { - vv: transform_backward(rv, transformed_params[vv], transforms[vv]) - for rv, vv in unobserved_rvs_to_values.items() - } - return (new_state, params, transformed_params), updates - - # Finally we build the initial state initial_q = rp_map.ravel_params(tuple(transformed_vvs.values())) initial_state = aehmc_nuts.new_state(initial_q, logprob_fn) - return initial_state, step_fn + # TODO: Does that lead to wasteful computation? Or is it handled by Aesara? + (new_q, *_), updates = nuts_kernel(*initial_state, step_size, inverse_mass_matrix) + transformed_params = rp_map.unravel_params(new_q) + params = { + rv: transform_backward(rv, transformed_params[vv], transforms[vv]) + for rv, vv in rvs_to_values.items() + if rv in to_sample_rvs + } + + return params, updates def get_transform(rv: TensorVariable): diff --git a/tests/test_basic.py b/tests/test_basic.py index b536bca..979cafb 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,7 +1,6 @@ import aesara import aesara.tensor as at import numpy as np -import pytest from aesara.graph.basic import graph_inputs, io_toposort from aesara.ifelse import IfElse from aesara.tensor.random import RandomStream @@ -49,7 +48,7 @@ def test_closed_form_posterior_gamma_poisson(): assert isinstance(p_posterior_step.owner.op, GammaRV) -def test_no_samplers(): +def test_nuts_sampler_single_variable(): srng = RandomStream(0) size = at.lscalar("size") @@ -59,8 +58,10 @@ def test_no_samplers(): y_vv = Y_rv.clone() y_vv.name = "y" - with pytest.raises(NotImplementedError): - construct_sampler({Y_rv: y_vv}, srng) + sample_steps, updates, initial_values = construct_sampler({Y_rv: y_vv}, srng) + # TODO: Test something here + # TODO: Add test (or replace this one) for a multi-variable NUTS sampler + # TODO: Add test for mixed samplers def test_create_gibbs(): diff --git a/tests/test_nuts.py b/tests/test_nuts.py index 75d0921..175187f 100644 --- a/tests/test_nuts.py +++ b/tests/test_nuts.py @@ -1,12 +1,8 @@ import aesara import aesara.tensor as at -import numpy as np -import pytest -from aeppl import joint_logprob from aesara.tensor.random import RandomStream -from aemcmc.nuts import nuts -from aemcmc.utils import ModelInfo +from aemcmc.nuts import construct_nuts_sampler def test_nuts(): @@ -19,47 +15,22 @@ def test_nuts(): mu_vv.name = "mu_vv" sigma_vv = sigma_rv.clone() sigma_vv.name = "sigma_vv" - Y_at = at.scalar(name="Y_at") + y_vv = Y_rv.clone() - rvs_to_values = {mu_rv: mu_vv, sigma_rv: sigma_vv, Y_rv: Y_at} - model = ModelInfo((Y_rv,), rvs_to_values, (), ()) + to_sample_rvs = [mu_rv, sigma_rv] + rvs_to_values = {mu_rv: mu_vv, sigma_rv: sigma_vv, Y_rv: y_vv} inverse_mass_matrix = at.as_tensor([1.0, 1.0]) step_size = at.as_tensor(0.1) - state_at, step_fn = nuts(srng, model, inverse_mass_matrix, step_size) - - # Make sure that the state is properly initialized - state_fn = aesara.function((mu_vv, sigma_vv, Y_at), state_at) - state = state_fn(1.0, 1.0, 1.0) - - position = state[0] - assert position[0] == 1.0 - assert position[1] != 1.0 # The state is in the transformed space - logprob = joint_logprob(rvs_to_values) - assert state[1] == -1 * logprob.eval({Y_at: 1.0, mu_vv: 1.0, sigma_vv: 1.0}) - - # Make sure that the step function updates the state - (new_state, params, transformed_params), updates = step_fn(state_at) - update_fn = aesara.function( - (mu_vv, sigma_vv, Y_at), - ( - params[mu_vv], - params[sigma_vv], - transformed_params[mu_vv], - transformed_params[sigma_vv], - ), - updates=updates, + state_at, step_fn = construct_nuts_sampler( + srng, to_sample_rvs, rvs_to_values, inverse_mass_matrix, step_size ) - new_position = update_fn(1.0, 1.0, 1.0) - untransformed_position = np.array([new_position[0], new_position[1]]) - transformed_position = np.array([new_position[2], new_position[3]]) - - # Did the chain advance? - np.testing.assert_raises( - AssertionError, np.testing.assert_equal, untransformed_position, [1.0, 1.0] - ) + # Make sure that the state is properly initialized + sample_steps = [state_at[rv] for rv in to_sample_rvs] + state_fn = aesara.function((mu_vv, sigma_vv, y_vv), sample_steps) + new_state = state_fn(1.0, 1.0, 1.0) - # Are the transformations applied correctly? - assert untransformed_position[0] == pytest.approx(transformed_position[0]) # mu - assert untransformed_position[1] != transformed_position[1] # sigma + # Make sure that the state has advanced + assert new_state[0] != 1.0 + assert new_state[1] != 1.0