Skip to content

Commit

Permalink
Sample remaining variables with the NUTS sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Oct 6, 2022
1 parent 47df7f5 commit 87d1aa8
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 92 deletions.
26 changes: 20 additions & 6 deletions aemcmc/basic.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Dict, Tuple

import aesara.tensor as at
from aesara.graph.basic import Variable
from aesara.graph.fg import FunctionGraph
from aesara.tensor.random.utils import RandomStream
from aesara.tensor.var import TensorVariable

from aemcmc.nuts import construct_nuts_sampler
from aemcmc.rewriting import (
SamplerTracker,
construct_ir_fgraph,
Expand Down Expand Up @@ -103,14 +105,26 @@ def construct_sampler(
updates = dict(zip(update_keys, update_values))
posterior_updates.update(updates)

# We use the NUTS sampler for the remaining variables
# TODO: NUTS cannot handle RV with discrete supports
# TODO: Track the transformations made by NUTS? It would make more sense to
# apply the transforms on the probabilistic graph, in which case we would
# only need to return the transformed graph.
if rvs_without_samplers:
# TODO: Assign NUTS to these
raise NotImplementedError(
f"Could not find a posterior samplers for {rvs_without_samplers}"
)

# TODO: Track/handle "auxiliary/augmentation" variables introduced by sample
# steps?
# TODO: Parameters should be contained in a `SamplingStep` data structure.
inverse_mass_matrix = at.vector("inverse_mass_matrix")
step_size = at.scalar("step_size")

# We use the updated values of the other rvs to compute the logprob
rvs_to_values = {rv: rvs_to_init_vals[rv] for rv in rvs_without_samplers}
rvs_to_values.update(posterior_sample_steps)

nuts_sample_steps, updates = construct_nuts_sampler(
srng, rvs_without_samplers, rvs_to_values, inverse_mass_matrix, step_size
)
posterior_sample_steps.update(nuts_sample_steps)
posterior_updates.update(updates)

return (
{
Expand Down
81 changes: 41 additions & 40 deletions aemcmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
_default_transformed_rv,
)
from aesara.tensor.random import RandomStream
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.var import TensorVariable

from aemcmc.utils import ModelInfo

NUTSStateType = Tuple[TensorVariable, TensorVariable, TensorVariable]
NUTSKernelType = Callable[
[NUTSStateType],
Expand All @@ -28,97 +27,99 @@
]


def nuts(
# As avaible inputs (basic.py) we have
# - rvs_to_init_vals, a dictionary that maps RVs to corresponding
# value variables;
# We also need:
# - The other variables' values to update the logprob's value
def construct_nuts_sampler(
srng: RandomStream,
model: ModelInfo,
to_sample_rvs, # RVs to sample
rvs_to_values, # All RVs to values
inverse_mass_matrix: TensorVariable,
step_size: TensorVariable,
) -> Tuple[NUTSStateType, NUTSKernelType]:
) -> Tuple[Dict[RandomVariable, TensorVariable], Dict]:
"""Build a NUTS kernel and the initial state.
This function currently assumes that we will update the value of all of the
model's variables with the NUTS sampler.
Parameters
----------
model
The Aesara model whose posterior distribution we wish to sample from
passed as a `ModelInfo` instance.
rvs_to_samples
A sequence that contains the random variables whose posterior
distribution we wish to sample from.
rvs_to_values
A dictionary that maps all random variables in the model (including
those not sampled with NUTS) to their value variable.
step_size
The step size used in the symplectic integrator.
inverse_mass_matrix
One or two-dimensional array used as the inverse mass matrix that
defines the euclidean metric.
Returns
-------
A NUTS sampling step for each variable.
"""
unobserved_rvs = tuple(
rv for rv in model.rvs_to_values.keys() if rv not in model.observed_rvs
)
unobserved_rvs_to_values = {rv: model.rvs_to_values[rv] for rv in unobserved_rvs}
observed_vvs = tuple(model.rvs_to_values[rv] for rv in model.observed_rvs)

# Algorithms in the HMC family can more easily explore the posterior distribution
# when the support of each random variable's distribution is unconstrained.
# First we build the logprob graph in the transformed space.
transforms = {
vv: get_transform(rv) if rv in unobserved_rvs else None
for rv, vv in model.rvs_to_values.items()
vv: get_transform(rv) for rv, vv in rvs_to_values.items() if rv in to_sample_rvs
}

logprob_sum = joint_logprob(
model.rvs_to_values, extra_rewrites=TransformValuesRewrite(transforms)
rvs_to_values, extra_rewrites=TransformValuesRewrite(transforms)
)

# Then we transform the value variables.
transformed_vvs = {
vv: transform_forward(rv, vv, transforms[vv])
for rv, vv in unobserved_rvs_to_values.items()
for rv, vv in rvs_to_values.items()
if rv in to_sample_rvs
}

# Algorithms in `aehmc` work with flat arrays and we need to ravel parameter
# values to use them as an input to the NUTS kernel.
rp_map = RaveledParamsMap(tuple(transformed_vvs.values()))
rp_map.ref_params = tuple(transformed_vvs.keys())

# Make shared variables for all the non-NUTS sampled terms
non_nuts_vals = {
vv: vv for rv, vv in rvs_to_values.items() if rv not in to_sample_rvs
}

# We can now write the logprob function associated with the model and build
# the NUTS kernel.
def logprob_fn(q):
unraveled_q = rp_map.unravel_params(q)
unraveled_q.update({vv: vv for vv in observed_vvs})
unraveled_q.update(non_nuts_vals)

memo = aesara.graph.basic.clone_get_equiv(
[], [logprob_sum], copy_inputs=False, copy_orphans=False, memo=unraveled_q
)

return memo[logprob_sum]

# Finally we build the NUTS sampling step
nuts_kernel = aehmc_nuts.new_kernel(srng, logprob_fn)

def step_fn(state):
"""Take one step with the NUTS kernel.
The NUTS kernel works with a state that contains the current value of
the variables, but also the current value of the potential and its
gradient, and we need to carry this state forward. We also return the
unraveled parameter values in both the original and transformed space.
"""
(new_q, new_pe, new_peg, *_), updates = nuts_kernel(
*state, step_size, inverse_mass_matrix
)
new_state = (new_q, new_pe, new_peg)
transformed_params = rp_map.unravel_params(new_q)
params = {
vv: transform_backward(rv, transformed_params[vv], transforms[vv])
for rv, vv in unobserved_rvs_to_values.items()
}
return (new_state, params, transformed_params), updates

# Finally we build the initial state
initial_q = rp_map.ravel_params(tuple(transformed_vvs.values()))
initial_state = aehmc_nuts.new_state(initial_q, logprob_fn)

return initial_state, step_fn
# TODO: Does that lead to wasteful computation? Or is it handled by Aesara?
(new_q, *_), updates = nuts_kernel(*initial_state, step_size, inverse_mass_matrix)
transformed_params = rp_map.unravel_params(new_q)
params = {
rv: transform_backward(rv, transformed_params[vv], transforms[vv])
for rv, vv in rvs_to_values.items()
if rv in to_sample_rvs
}

return params, updates


def get_transform(rv: TensorVariable):
Expand Down
9 changes: 5 additions & 4 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import aesara
import aesara.tensor as at
import numpy as np
import pytest
from aesara.graph.basic import graph_inputs, io_toposort
from aesara.ifelse import IfElse
from aesara.tensor.random import RandomStream
Expand Down Expand Up @@ -49,7 +48,7 @@ def test_closed_form_posterior_gamma_poisson():
assert isinstance(p_posterior_step.owner.op, GammaRV)


def test_no_samplers():
def test_nuts_sampler_single_variable():
srng = RandomStream(0)

size = at.lscalar("size")
Expand All @@ -59,8 +58,10 @@ def test_no_samplers():
y_vv = Y_rv.clone()
y_vv.name = "y"

with pytest.raises(NotImplementedError):
construct_sampler({Y_rv: y_vv}, srng)
sample_steps, updates, initial_values = construct_sampler({Y_rv: y_vv}, srng)
# TODO: Test something here
# TODO: Add test (or replace this one) for a multi-variable NUTS sampler
# TODO: Add test for mixed samplers


def test_create_gibbs():
Expand Down
55 changes: 13 additions & 42 deletions tests/test_nuts.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import aesara
import aesara.tensor as at
import numpy as np
import pytest
from aeppl import joint_logprob
from aesara.tensor.random import RandomStream

from aemcmc.nuts import nuts
from aemcmc.utils import ModelInfo
from aemcmc.nuts import construct_nuts_sampler


def test_nuts():
Expand All @@ -19,47 +15,22 @@ def test_nuts():
mu_vv.name = "mu_vv"
sigma_vv = sigma_rv.clone()
sigma_vv.name = "sigma_vv"
Y_at = at.scalar(name="Y_at")
y_vv = Y_rv.clone()

rvs_to_values = {mu_rv: mu_vv, sigma_rv: sigma_vv, Y_rv: Y_at}
model = ModelInfo((Y_rv,), rvs_to_values, (), ())
to_sample_rvs = [mu_rv, sigma_rv]
rvs_to_values = {mu_rv: mu_vv, sigma_rv: sigma_vv, Y_rv: y_vv}

inverse_mass_matrix = at.as_tensor([1.0, 1.0])
step_size = at.as_tensor(0.1)
state_at, step_fn = nuts(srng, model, inverse_mass_matrix, step_size)

# Make sure that the state is properly initialized
state_fn = aesara.function((mu_vv, sigma_vv, Y_at), state_at)
state = state_fn(1.0, 1.0, 1.0)

position = state[0]
assert position[0] == 1.0
assert position[1] != 1.0 # The state is in the transformed space
logprob = joint_logprob(rvs_to_values)
assert state[1] == -1 * logprob.eval({Y_at: 1.0, mu_vv: 1.0, sigma_vv: 1.0})

# Make sure that the step function updates the state
(new_state, params, transformed_params), updates = step_fn(state_at)
update_fn = aesara.function(
(mu_vv, sigma_vv, Y_at),
(
params[mu_vv],
params[sigma_vv],
transformed_params[mu_vv],
transformed_params[sigma_vv],
),
updates=updates,
state_at, step_fn = construct_nuts_sampler(
srng, to_sample_rvs, rvs_to_values, inverse_mass_matrix, step_size
)

new_position = update_fn(1.0, 1.0, 1.0)
untransformed_position = np.array([new_position[0], new_position[1]])
transformed_position = np.array([new_position[2], new_position[3]])

# Did the chain advance?
np.testing.assert_raises(
AssertionError, np.testing.assert_equal, untransformed_position, [1.0, 1.0]
)
# Make sure that the state is properly initialized
sample_steps = [state_at[rv] for rv in to_sample_rvs]
state_fn = aesara.function((mu_vv, sigma_vv, y_vv), sample_steps)
new_state = state_fn(1.0, 1.0, 1.0)

# Are the transformations applied correctly?
assert untransformed_position[0] == pytest.approx(transformed_position[0]) # mu
assert untransformed_position[1] != transformed_position[1] # sigma
# Make sure that the state has advanced
assert new_state[0] != 1.0
assert new_state[1] != 1.0

0 comments on commit 87d1aa8

Please sign in to comment.