Skip to content

Commit

Permalink
Let NUTS initialize its parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Oct 18, 2022
1 parent 16484dd commit 3a783e4
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 24 deletions.
11 changes: 3 additions & 8 deletions aemcmc/basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Dict, Tuple

import aesara.tensor as at
from aesara.graph.basic import Variable
from aesara.graph.fg import FunctionGraph
from aesara.tensor.random.utils import RandomStream
Expand Down Expand Up @@ -113,20 +112,16 @@ def construct_sampler(
# apply the transforms on the probabilistic graph, in which case we would
# only need to return the transformed graph.
if rvs_without_samplers:
inverse_mass_matrix = at.vector("inverse_mass_matrix")
step_size = at.scalar("step_size")
parameters["step_size"] = step_size
parameters["inverse_mass_matrix"] = inverse_mass_matrix

# We condition on the updated values of the other rvs
rvs_to_values = {rv: rvs_to_init_vals[rv] for rv in rvs_without_samplers}
rvs_to_values.update(posterior_sample_steps)

nuts_sample_steps, updates = construct_nuts_sampler(
srng, rvs_without_samplers, rvs_to_values, inverse_mass_matrix, step_size
nuts_sample_steps, updates, nuts_parameters = construct_nuts_sampler(
srng, rvs_without_samplers, rvs_to_values
)
posterior_sample_steps.update(nuts_sample_steps)
posterior_updates.update(updates)
parameters.update(nuts_parameters)

return (
{
Expand Down
23 changes: 14 additions & 9 deletions aemcmc/nuts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Callable, Dict, Tuple

import aesara
import aesara.tensor as at
from aehmc import nuts as aehmc_nuts
from aehmc.utils import RaveledParamsMap
from aeppl import joint_logprob
Expand All @@ -9,6 +10,7 @@
TransformValuesRewrite,
_default_transformed_rv,
)
from aesara import config
from aesara.tensor.random import RandomStream
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.var import TensorVariable
Expand All @@ -31,9 +33,7 @@ def construct_nuts_sampler(
srng: RandomStream,
to_sample_rvs, # RVs to sample
rvs_to_values, # All RVs to values
inverse_mass_matrix: TensorVariable,
step_size: TensorVariable,
) -> Tuple[Dict[RandomVariable, TensorVariable], Dict]:
) -> Tuple[Dict[RandomVariable, TensorVariable], Dict, Dict[str, TensorVariable]]:
"""Build a NUTS kernel and the initial state.
This function currently assumes that we will update the value of all of the
Expand All @@ -47,11 +47,6 @@ def construct_nuts_sampler(
rvs_to_values
A dictionary that maps all random variables in the model (including
those not sampled with NUTS) to their value variable.
step_size
The step size used in the symplectic integrator.
inverse_mass_matrix
One or two-dimensional array used as the inverse mass matrix that
defines the euclidean metric.
Returns
-------
Expand Down Expand Up @@ -104,6 +99,12 @@ def logprob_fn(q):
initial_q = rp_map.ravel_params(tuple(transformed_vvs.values()))
initial_state = aehmc_nuts.new_state(initial_q, logprob_fn)

# Initialize the parameter values
step_size = at.scalar("step_size", dtype=config.floatX)
inverse_mass_matrix = at.tensor(
name="inverse_mass_matrix", shape=initial_q.type.shape, dtype=config.floatX
)

# TODO: Does that lead to wasteful computation? Or is it handled by Aesara?
(new_q, *_), updates = nuts_kernel(*initial_state, step_size, inverse_mass_matrix)
transformed_params = rp_map.unravel_params(new_q)
Expand All @@ -113,7 +114,11 @@ def logprob_fn(q):
if rv in to_sample_rvs
}

return params, updates
return (
params,
updates,
{"step_size": step_size, "inverse_mass_matrix": inverse_mass_matrix},
)


def get_transform(rv: TensorVariable):
Expand Down
20 changes: 13 additions & 7 deletions tests/test_nuts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import aesara
import aesara.tensor as at
from aesara.tensor.random import RandomStream

from aemcmc.nuts import construct_nuts_sampler
Expand All @@ -20,16 +19,23 @@ def test_nuts():
to_sample_rvs = [mu_rv, sigma_rv]
rvs_to_values = {mu_rv: mu_vv, sigma_rv: sigma_vv, Y_rv: y_vv}

inverse_mass_matrix = at.as_tensor([1.0, 1.0])
step_size = at.as_tensor(0.1)
state_at, step_fn = construct_nuts_sampler(
srng, to_sample_rvs, rvs_to_values, inverse_mass_matrix, step_size
state_at, step_fn, parameters = construct_nuts_sampler(
srng, to_sample_rvs, rvs_to_values
)

# Make sure that the state is properly initialized
sample_steps = [state_at[rv] for rv in to_sample_rvs]
state_fn = aesara.function((mu_vv, sigma_vv, y_vv), sample_steps)
new_state = state_fn(1.0, 1.0, 1.0)
state_fn = aesara.function(
(
mu_vv,
sigma_vv,
y_vv,
parameters["step_size"],
parameters["inverse_mass_matrix"],
),
sample_steps,
)
new_state = state_fn(1.0, 1.0, 1.0, 0.01, [1.0, 1.0])

# Make sure that the state has advanced
assert new_state[0] != 1.0
Expand Down

0 comments on commit 3a783e4

Please sign in to comment.