diff --git a/aemcmc/basic.py b/aemcmc/basic.py index 2e22447..3c6a966 100644 --- a/aemcmc/basic.py +++ b/aemcmc/basic.py @@ -1,6 +1,5 @@ from typing import Dict, Tuple -import aesara.tensor as at from aesara.graph.basic import Variable from aesara.graph.fg import FunctionGraph from aesara.tensor.random.utils import RandomStream @@ -113,20 +112,16 @@ def construct_sampler( # apply the transforms on the probabilistic graph, in which case we would # only need to return the transformed graph. if rvs_without_samplers: - inverse_mass_matrix = at.vector("inverse_mass_matrix") - step_size = at.scalar("step_size") - parameters["step_size"] = step_size - parameters["inverse_mass_matrix"] = inverse_mass_matrix - # We condition on the updated values of the other rvs rvs_to_values = {rv: rvs_to_init_vals[rv] for rv in rvs_without_samplers} rvs_to_values.update(posterior_sample_steps) - nuts_sample_steps, updates = construct_nuts_sampler( - srng, rvs_without_samplers, rvs_to_values, inverse_mass_matrix, step_size + nuts_sample_steps, updates, nuts_parameters = construct_nuts_sampler( + srng, rvs_without_samplers, rvs_to_values ) posterior_sample_steps.update(nuts_sample_steps) posterior_updates.update(updates) + parameters.update(nuts_parameters) return ( { diff --git a/aemcmc/nuts.py b/aemcmc/nuts.py index 4ea3a3f..22a608c 100644 --- a/aemcmc/nuts.py +++ b/aemcmc/nuts.py @@ -1,6 +1,7 @@ from typing import Callable, Dict, Tuple import aesara +import aesara.tensor as at from aehmc import nuts as aehmc_nuts from aehmc.utils import RaveledParamsMap from aeppl import joint_logprob @@ -9,6 +10,7 @@ TransformValuesRewrite, _default_transformed_rv, ) +from aesara import config from aesara.tensor.random import RandomStream from aesara.tensor.random.op import RandomVariable from aesara.tensor.var import TensorVariable @@ -31,9 +33,7 @@ def construct_nuts_sampler( srng: RandomStream, to_sample_rvs, # RVs to sample rvs_to_values, # All RVs to values - inverse_mass_matrix: TensorVariable, - step_size: TensorVariable, -) -> Tuple[Dict[RandomVariable, TensorVariable], Dict]: +) -> Tuple[Dict[RandomVariable, TensorVariable], Dict, Dict[str, TensorVariable]]: """Build a NUTS kernel and the initial state. This function currently assumes that we will update the value of all of the @@ -47,11 +47,6 @@ def construct_nuts_sampler( rvs_to_values A dictionary that maps all random variables in the model (including those not sampled with NUTS) to their value variable. - step_size - The step size used in the symplectic integrator. - inverse_mass_matrix - One or two-dimensional array used as the inverse mass matrix that - defines the euclidean metric. Returns ------- @@ -104,6 +99,12 @@ def logprob_fn(q): initial_q = rp_map.ravel_params(tuple(transformed_vvs.values())) initial_state = aehmc_nuts.new_state(initial_q, logprob_fn) + # Initialize the parameter values + step_size = at.scalar("step_size", dtype=config.floatX) + inverse_mass_matrix = at.tensor( + name="inverse_mass_matrix", shape=initial_q.type.shape, dtype=config.floatX + ) + # TODO: Does that lead to wasteful computation? Or is it handled by Aesara? (new_q, *_), updates = nuts_kernel(*initial_state, step_size, inverse_mass_matrix) transformed_params = rp_map.unravel_params(new_q) @@ -113,7 +114,11 @@ def logprob_fn(q): if rv in to_sample_rvs } - return params, updates + return ( + params, + updates, + {"step_size": step_size, "inverse_mass_matrix": inverse_mass_matrix}, + ) def get_transform(rv: TensorVariable): diff --git a/tests/test_nuts.py b/tests/test_nuts.py index 175187f..d10e313 100644 --- a/tests/test_nuts.py +++ b/tests/test_nuts.py @@ -1,5 +1,4 @@ import aesara -import aesara.tensor as at from aesara.tensor.random import RandomStream from aemcmc.nuts import construct_nuts_sampler @@ -20,16 +19,23 @@ def test_nuts(): to_sample_rvs = [mu_rv, sigma_rv] rvs_to_values = {mu_rv: mu_vv, sigma_rv: sigma_vv, Y_rv: y_vv} - inverse_mass_matrix = at.as_tensor([1.0, 1.0]) - step_size = at.as_tensor(0.1) - state_at, step_fn = construct_nuts_sampler( - srng, to_sample_rvs, rvs_to_values, inverse_mass_matrix, step_size + state_at, step_fn, parameters = construct_nuts_sampler( + srng, to_sample_rvs, rvs_to_values ) # Make sure that the state is properly initialized sample_steps = [state_at[rv] for rv in to_sample_rvs] - state_fn = aesara.function((mu_vv, sigma_vv, y_vv), sample_steps) - new_state = state_fn(1.0, 1.0, 1.0) + state_fn = aesara.function( + ( + mu_vv, + sigma_vv, + y_vv, + parameters["step_size"], + parameters["inverse_mass_matrix"], + ), + sample_steps, + ) + new_state = state_fn(1.0, 1.0, 1.0, 0.01, [1.0, 1.0]) # Make sure that the state has advanced assert new_state[0] != 1.0