Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sample remaining variables with the NUTS sampler #68

Merged
merged 2 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions aemcmc/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from aesara.tensor.random.utils import RandomStream
from aesara.tensor.var import TensorVariable

from aemcmc.nuts import construct_nuts_sampler
from aemcmc.rewriting import (
SamplerTracker,
construct_ir_fgraph,
Expand All @@ -19,6 +20,7 @@ def construct_sampler(
Dict[TensorVariable, TensorVariable],
Dict[Variable, Variable],
Dict[TensorVariable, TensorVariable],
Dict[str, TensorVariable],
]:
r"""Eagerly construct a sampler for a given set of observed variables and their observations.

Expand Down Expand Up @@ -55,6 +57,7 @@ def construct_sampler(
# TODO FIXME: Get/extract `Scan`-generated updates
posterior_updates: Dict[Variable, Variable] = {}

parameters: Dict[str, TensorVariable] = {}
rvs_without_samplers = set()

for rv in fgraph.outputs:
Expand Down Expand Up @@ -103,14 +106,22 @@ def construct_sampler(
updates = dict(zip(update_keys, update_values))
posterior_updates.update(updates)

# We use the NUTS sampler for the remaining variables
# TODO: NUTS cannot handle RV with discrete supports
# TODO: Track the transformations made by NUTS? It would make more sense to
# apply the transforms on the probabilistic graph, in which case we would
# only need to return the transformed graph.
if rvs_without_samplers:
# TODO: Assign NUTS to these
raise NotImplementedError(
f"Could not find a posterior samplers for {rvs_without_samplers}"
)
# We condition on the updated values of the other rvs
rvs_to_values = {rv: rvs_to_init_vals[rv] for rv in rvs_without_samplers}
rvs_to_values.update(posterior_sample_steps)

# TODO: Track/handle "auxiliary/augmentation" variables introduced by sample
# steps?
nuts_sample_steps, updates, nuts_parameters = construct_nuts_sampler(
srng, rvs_without_samplers, rvs_to_values
)
posterior_sample_steps.update(nuts_sample_steps)
posterior_updates.update(updates)
parameters.update(nuts_parameters)

return (
{
Expand All @@ -120,4 +131,5 @@ def construct_sampler(
},
posterior_updates,
{new_to_old_rvs[rv]: init_var for rv, init_var in rvs_to_init_vals.items()},
parameters,
)
96 changes: 48 additions & 48 deletions aemcmc/nuts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Callable, Dict, Tuple

import aesara
import aesara.tensor as at
from aehmc import nuts as aehmc_nuts
from aehmc.utils import RaveledParamsMap
from aeppl import joint_logprob
Expand All @@ -9,11 +10,11 @@
TransformValuesRewrite,
_default_transformed_rv,
)
from aesara import config
from aesara.tensor.random import RandomStream
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.var import TensorVariable

from aemcmc.utils import ModelInfo

NUTSStateType = Tuple[TensorVariable, TensorVariable, TensorVariable]
NUTSKernelType = Callable[
[NUTSStateType],
Expand All @@ -28,97 +29,96 @@
]


def nuts(
def construct_nuts_sampler(
srng: RandomStream,
model: ModelInfo,
inverse_mass_matrix: TensorVariable,
step_size: TensorVariable,
) -> Tuple[NUTSStateType, NUTSKernelType]:
to_sample_rvs, # RVs to sample
rvs_to_values, # All RVs to values
) -> Tuple[Dict[RandomVariable, TensorVariable], Dict, Dict[str, TensorVariable]]:
"""Build a NUTS kernel and the initial state.

This function currently assumes that we will update the value of all of the
model's variables with the NUTS sampler.

Parameters
----------
model
The Aesara model whose posterior distribution we wish to sample from
passed as a `ModelInfo` instance.
step_size
The step size used in the symplectic integrator.
inverse_mass_matrix
One or two-dimensional array used as the inverse mass matrix that
defines the euclidean metric.
rvs_to_samples
A sequence that contains the random variables whose posterior
distribution we wish to sample from.
rvs_to_values
A dictionary that maps all random variables in the model (including
those not sampled with NUTS) to their value variable.

Returns
-------
A NUTS sampling step for each variable.

"""
unobserved_rvs = tuple(
rv for rv in model.rvs_to_values.keys() if rv not in model.observed_rvs
)
unobserved_rvs_to_values = {rv: model.rvs_to_values[rv] for rv in unobserved_rvs}
observed_vvs = tuple(model.rvs_to_values[rv] for rv in model.observed_rvs)

# Algorithms in the HMC family can more easily explore the posterior distribution
# when the support of each random variable's distribution is unconstrained.
# First we build the logprob graph in the transformed space.
transforms = {
vv: get_transform(rv) if rv in unobserved_rvs else None
for rv, vv in model.rvs_to_values.items()
vv: get_transform(rv) for rv, vv in rvs_to_values.items() if rv in to_sample_rvs
}

logprob_sum = joint_logprob(
model.rvs_to_values, extra_rewrites=TransformValuesRewrite(transforms)
rvs_to_values, extra_rewrites=TransformValuesRewrite(transforms)
)

# Then we transform the value variables.
transformed_vvs = {
vv: transform_forward(rv, vv, transforms[vv])
for rv, vv in unobserved_rvs_to_values.items()
for rv, vv in rvs_to_values.items()
if rv in to_sample_rvs
}

# Algorithms in `aehmc` work with flat arrays and we need to ravel parameter
# values to use them as an input to the NUTS kernel.
rp_map = RaveledParamsMap(tuple(transformed_vvs.values()))
rp_map.ref_params = tuple(transformed_vvs.keys())

# Make shared variables for all the non-NUTS sampled terms
non_nuts_vals = {
vv: vv for rv, vv in rvs_to_values.items() if rv not in to_sample_rvs
}

# We can now write the logprob function associated with the model and build
# the NUTS kernel.
def logprob_fn(q):
unraveled_q = rp_map.unravel_params(q)
unraveled_q.update({vv: vv for vv in observed_vvs})

unraveled_q.update(non_nuts_vals)
memo = aesara.graph.basic.clone_get_equiv(
[], [logprob_sum], copy_inputs=False, copy_orphans=False, memo=unraveled_q
)

return memo[logprob_sum]

# Finally we build the NUTS sampling step
nuts_kernel = aehmc_nuts.new_kernel(srng, logprob_fn)

def step_fn(state):
"""Take one step with the NUTS kernel.

The NUTS kernel works with a state that contains the current value of
the variables, but also the current value of the potential and its
gradient, and we need to carry this state forward. We also return the
unraveled parameter values in both the original and transformed space.

"""
(new_q, new_pe, new_peg, *_), updates = nuts_kernel(
*state, step_size, inverse_mass_matrix
)
new_state = (new_q, new_pe, new_peg)
transformed_params = rp_map.unravel_params(new_q)
params = {
vv: transform_backward(rv, transformed_params[vv], transforms[vv])
for rv, vv in unobserved_rvs_to_values.items()
}
return (new_state, params, transformed_params), updates

# Finally we build the initial state
initial_q = rp_map.ravel_params(tuple(transformed_vvs.values()))
initial_state = aehmc_nuts.new_state(initial_q, logprob_fn)

return initial_state, step_fn
# Initialize the parameter values
step_size = at.scalar("step_size", dtype=config.floatX)
inverse_mass_matrix = at.tensor(
name="inverse_mass_matrix", shape=initial_q.type.shape, dtype=config.floatX
)

# TODO: Does that lead to wasteful computation? Or is it handled by Aesara?
(new_q, *_), updates = nuts_kernel(*initial_state, step_size, inverse_mass_matrix)
transformed_params = rp_map.unravel_params(new_q)
params = {
rv: transform_backward(rv, transformed_params[vv], transforms[vv])
for rv, vv in rvs_to_values.items()
if rv in to_sample_rvs
}

return (
params,
updates,
{"step_size": step_size, "inverse_mass_matrix": inverse_mass_matrix},
)


def get_transform(rv: TensorVariable):
Expand Down
79 changes: 70 additions & 9 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@ def test_closed_form_posterior_beta_binomial():
y_vv = Y_rv.clone()
y_vv.name = "y"

sample_steps, updates, initial_values = construct_sampler({Y_rv: y_vv}, srng)
sample_steps, updates, initial_values, parameters = construct_sampler(
{Y_rv: y_vv}, srng
)

p_posterior_step = sample_steps[p_rv]
assert len(parameters) == 0
assert isinstance(p_posterior_step.owner.op, BetaRV)


Expand All @@ -43,24 +46,80 @@ def test_closed_form_posterior_gamma_poisson():
y_vv = Y_rv.clone()
y_vv.name = "y"

sample_steps, updates, initial_values = construct_sampler({Y_rv: y_vv}, srng)
sample_steps, updates, initial_values, parameters = construct_sampler(
{Y_rv: y_vv}, srng
)

p_posterior_step = sample_steps[l_rv]
assert len(parameters) == 0
assert isinstance(p_posterior_step.owner.op, GammaRV)


def test_no_samplers():
@pytest.mark.parametrize("size", [1, (1,), (2, 3)])
def test_nuts_sampler_single_variable(size):
"""We make sure that the NUTS sampler compiles and updates the chains for
different sizes of the random variable.

"""
srng = RandomStream(0)

size = at.lscalar("size")
tau_rv = srng.halfcauchy(0, 1, name="tau")
Y_rv = srng.halfcauchy(0, tau_rv, size=size, name="Y")
tau_rv = srng.halfcauchy(0, 1, size=size, name="tau")
Y_rv = srng.halfcauchy(0, tau_rv, name="Y")

y_vv = Y_rv.clone()
y_vv.name = "y"

sample_steps, updates, initial_values, parameters = construct_sampler(
{Y_rv: y_vv}, srng
)

assert len(parameters) == 2
assert len(sample_steps) == 1

tau_post_step = sample_steps[tau_rv]
assert y_vv in graph_inputs([tau_post_step])

inputs = [
initial_values[tau_rv],
y_vv,
parameters["step_size"],
parameters["inverse_mass_matrix"],
]
output = tau_post_step
sample_step = aesara.function(inputs, output)

tau_val = np.ones(shape=size)
y_val = np.ones(shape=size)
step_size = 1e-1
inverse_mass_matrix = np.ones(shape=size).flatten()
res = sample_step(tau_val, y_val, step_size, inverse_mass_matrix)
with pytest.raises(AssertionError):
np.testing.assert_equal(tau_val, res)


def test_nuts_with_closed_form():
"""Make sure that the NUTS sampler works in combination with closed-form posteriors."""
srng = RandomStream(0)

alpha_tt = at.scalar("alpha")
beta_rv = srng.halfnormal(1.0, name="beta")
l_rv = srng.gamma(alpha_tt, beta_rv, name="p")

Y_rv = srng.poisson(l_rv, name="Y")

y_vv = Y_rv.clone()
y_vv.name = "y"

with pytest.raises(NotImplementedError):
construct_sampler({Y_rv: y_vv}, srng)
sample_steps, updates, initial_values, parameters = construct_sampler(
{Y_rv: y_vv}, srng
)

p_posterior_step = sample_steps[l_rv]
assert len(parameters) == 2
assert len(initial_values) == 2
assert isinstance(p_posterior_step.owner.op, GammaRV)

assert beta_rv in sample_steps


def test_create_gibbs():
Expand All @@ -87,7 +146,9 @@ def test_create_gibbs():

sample_vars = [tau_rv, lmbda_rv, beta_rv, h_rv]

sample_steps, updates, initial_values = construct_sampler({Y_rv: y_vv}, srng)
sample_steps, updates, initial_values, parameters = construct_sampler(
{Y_rv: y_vv}, srng
)

assert len(sample_steps) == 4
assert updates
Expand Down
Loading