Skip to content

Commit

Permalink
Return sampling steps in a Sampler dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Nov 28, 2022
1 parent e33be86 commit 613836d
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 61 deletions.
55 changes: 36 additions & 19 deletions aemcmc/basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, Tuple
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple

from aesara.graph.basic import Variable
from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph
from aesara.tensor.random.utils import RandomStream
from aesara.tensor.var import TensorVariable
Expand All @@ -14,14 +15,31 @@
)


@dataclass(frozen=True)
class Sampler:
"""A class that tracks sampling steps and their parameters."""

sample_steps: Dict[TensorVariable, TensorVariable]
"""A map between measures and their updated value under the current sampling scheme."""
updates: Optional[Dict[Variable, TensorVariable]] = field(default_factory=dict)
"""Updates to be passed to `aesara.function`"""
parameters: Dict[Apply, Tuple[TensorVariable]] = field(default_factory=dict)
"""Parameters needed by the sampling steps."""

samplers: List[Apply] = field(init=False)
"""A list of sampling steps used to build this sampler."""

def __post_init__(self):

samplers = set()
for updated_rv in self.sample_steps.values():
samplers.add(updated_rv.owner)
super().__setattr__("samplers", list(samplers))


def construct_sampler(
obs_rvs_to_values: Dict[TensorVariable, TensorVariable], srng: RandomStream
) -> Tuple[
Dict[TensorVariable, TensorVariable],
Dict[Variable, Variable],
Dict[TensorVariable, TensorVariable],
Dict[str, TensorVariable],
]:
) -> Tuple[Sampler, Dict[TensorVariable, TensorVariable]]:
r"""Eagerly construct a sampler for a given set of observed variables and their observations.
Parameters
Expand Down Expand Up @@ -57,7 +75,7 @@ def construct_sampler(
# TODO FIXME: Get/extract `Scan`-generated updates
posterior_updates: Dict[Variable, Variable] = {}

parameters: Dict[str, TensorVariable] = {}
parameters: Dict[Apply, Tuple[TensorVariable]] = {}
rvs_without_samplers = set()

for rv in fgraph.outputs:
Expand Down Expand Up @@ -125,13 +143,12 @@ def construct_sampler(
posterior_updates.update(updates)
parameters.update(nuts_parameters)

return (
{
new_to_old_rvs[rv]: step
for rv, step in posterior_sample_steps.items()
if rv not in obs_rvs_to_values
},
posterior_updates,
{new_to_old_rvs[rv]: init_var for rv, init_var in rvs_to_init_vals.items()},
parameters,
)
sampling_steps = {
new_to_old_rvs[rv]: step
for rv, step in posterior_sample_steps.items()
if rv not in obs_rvs_to_values
}

return Sampler(sampling_steps, updates, parameters), {
new_to_old_rvs[rv]: init_var for rv, init_var in rvs_to_init_vals.items()
}
13 changes: 8 additions & 5 deletions aemcmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from aesara import config
from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.basic import graph_inputs
from aesara.graph.basic import Apply, graph_inputs
from aesara.graph.type import Constant
from aesara.tensor.random import RandomStream
from aesara.tensor.random.op import RandomVariable
Expand Down Expand Up @@ -49,7 +49,9 @@ def step(
srng: RandomStream,
to_sample_rvs: Dict[RandomVariable, TensorVariable],
realized_rvs_to_values: Dict[RandomVariable, TensorVariable],
) -> Tuple[Dict[RandomVariable, TensorVariable], Dict, Dict[str, TensorVariable],]:
) -> Tuple[
Dict[RandomVariable, TensorVariable], Dict, Tuple[TensorVariable, TensorVariable]
]:
"""Build a NUTS sampling step and its initial state.
This sampling step works with variables in their original space, to create
Expand Down Expand Up @@ -143,15 +145,15 @@ def logprob_fn(q):
return (
results,
updates,
{"step_size": step_size, "inverse_mass_matrix": inverse_mass_matrix},
(step_size, inverse_mass_matrix),
)


def construct_sampler(
srng: RandomStream,
to_sample_rvs: Dict[RandomVariable, TensorVariable],
realized_rvs_to_values: Dict[RandomVariable, TensorVariable],
) -> Tuple[Dict[RandomVariable, TensorVariable], Dict, Dict[str, TensorVariable],]:
) -> Tuple[Dict[RandomVariable, TensorVariable], Dict, Dict[Apply, TensorVariable]]:

results, updates, parameters = step(srng, to_sample_rvs, realized_rvs_to_values)

Expand All @@ -166,6 +168,7 @@ def construct_sampler(
nuts_op = NUTSKernel(inputs, outputs, list(to_sample_rvs.keys()), parameters)

posterior = nuts_op(*inputs)
node = posterior[0].owner # TODO: Add lists in the Aesara graph
results = {rv: posterior[i] for i, rv in enumerate(to_sample_rvs)}

updates_input = posterior[0].owner.inputs[len(inputs) :]
Expand All @@ -174,7 +177,7 @@ def construct_sampler(
updates_input[i]: update_out for i, update_out in enumerate(updates_output)
}

return results, updates, parameters
return results, updates, {node: parameters}


def get_transform(rv: TensorVariable):
Expand Down
63 changes: 28 additions & 35 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,10 @@ def test_closed_form_posterior_beta_binomial():
y_vv = Y_rv.clone()
y_vv.name = "y"

sample_steps, updates, initial_values, parameters = construct_sampler(
{Y_rv: y_vv}, srng
)
sampler, initial_values = construct_sampler({Y_rv: y_vv}, srng)

p_posterior_step = sample_steps[p_rv]
assert len(parameters) == 0
p_posterior_step = sampler.sample_steps[p_rv]
assert len(sampler.parameters) == 0
assert isinstance(p_posterior_step.owner.op, BetaRV)


Expand All @@ -50,12 +48,10 @@ def test_closed_form_posterior_gamma_poisson():
y_vv = Y_rv.clone()
y_vv.name = "y"

sample_steps, updates, initial_values, parameters = construct_sampler(
{Y_rv: y_vv}, srng
)
sampler, initial_values = construct_sampler({Y_rv: y_vv}, srng)

p_posterior_step = sample_steps[l_rv]
assert len(parameters) == 0
p_posterior_step = sampler.sample_steps[l_rv]
assert len(sampler.parameters) == 0
assert isinstance(p_posterior_step.owner.op, GammaRV)


Expand All @@ -73,21 +69,20 @@ def test_nuts_sampler_single_variable(size):
y_vv = Y_rv.clone()
y_vv.name = "y"

sample_steps, updates, initial_values, parameters = construct_sampler(
{Y_rv: y_vv}, srng
)
sampler, initial_values = construct_sampler({Y_rv: y_vv}, srng)

assert len(parameters) == 2
assert len(sample_steps) == 1
assert len(sampler.sample_steps) == 1

tau_post_step = sample_steps[tau_rv]
tau_post_step = sampler.sample_steps[tau_rv]
nuts = tau_post_step.owner
assert y_vv in graph_inputs([tau_post_step])
assert len(sampler.parameters[nuts]) == 2

inputs = [
initial_values[tau_rv],
y_vv,
parameters["step_size"],
parameters["inverse_mass_matrix"],
sampler.parameters[nuts][0],
sampler.parameters[nuts][1],
]
output = tau_post_step
sample_step = aesara.function(inputs, output)
Expand All @@ -114,17 +109,17 @@ def test_nuts_with_closed_form():
y_vv = Y_rv.clone()
y_vv.name = "y"

sample_steps, updates, initial_values, parameters = construct_sampler(
{Y_rv: y_vv}, srng
)
sampler, initial_values = construct_sampler({Y_rv: y_vv}, srng)

nuts = sampler.sample_steps[beta_rv].owner
assert len(sampler.parameters[nuts]) == 2

p_posterior_step = sample_steps[l_rv]
p_posterior_step = sampler.sample_steps[l_rv]
assert y_vv in graph_inputs([p_posterior_step])
assert len(parameters) == 2
assert len(initial_values) == 2
assert isinstance(p_posterior_step.owner.op, GammaRV)

assert beta_rv in sample_steps
assert beta_rv in sampler.sample_steps


def test_create_gibbs():
Expand All @@ -151,28 +146,26 @@ def test_create_gibbs():

sample_vars = [tau_rv, lmbda_rv, beta_rv, h_rv]

sample_steps, updates, initial_values, parameters = construct_sampler(
{Y_rv: y_vv}, srng
)
sampler, initial_values = construct_sampler({Y_rv: y_vv}, srng)

assert len(sample_steps) == 4
assert updates
assert len(sampler.sample_steps) == 4
assert sampler.updates

tau_post_step = sample_steps[tau_rv]
tau_post_step = sampler.sample_steps[tau_rv]
# These are *very* rough checks of the resulting graphs
assert isinstance(tau_post_step.owner.op, HorseshoeGibbsKernel)

lmbda_post_step = sample_steps[lmbda_rv]
lmbda_post_step = sampler.sample_steps[lmbda_rv]
assert isinstance(lmbda_post_step.owner.op, HorseshoeGibbsKernel)

beta_post_step = sample_steps[beta_rv]
beta_post_step = sampler.sample_steps[beta_rv]
assert isinstance(beta_post_step.owner.op, NBRegressionGibbsKernel)

h_post_step = sample_steps[h_rv]
h_post_step = sampler.sample_steps[h_rv]
assert isinstance(h_post_step.owner.op, DispersionGibbsKernel)

inputs = [X, a, b, y_vv] + [initial_values[rv] for rv in sample_vars]
outputs = [sample_steps[rv] for rv in sample_vars]
outputs = [sampler.sample_steps[rv] for rv in sample_vars]

subsuming_elemwises = [
n for n in io_toposort([], outputs) if isinstance(n.op, SubsumingElemwise)
Expand All @@ -182,7 +175,7 @@ def test_create_gibbs():
sample_step = aesara.function(
inputs,
outputs,
updates=updates,
updates=sampler.updates,
on_unused_input="ignore",
)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def test_nuts():
mu_vv,
sigma_vv,
y_vv,
parameters["step_size"],
parameters["inverse_mass_matrix"],
parameters[0],
parameters[1],
),
sample_steps,
updates=updates,
Expand Down

0 comments on commit 613836d

Please sign in to comment.