Skip to content

Commit

Permalink
Encapsulate NUTS updates in an OpFromGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Nov 22, 2022
1 parent 0be02fb commit e33be86
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 6 deletions.
50 changes: 48 additions & 2 deletions aemcmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@
_default_transformed_rv,
)
from aesara import config
from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.basic import graph_inputs
from aesara.graph.type import Constant
from aesara.tensor.random import RandomStream
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.var import TensorVariable

from aemcmc.types import SamplingStep

NUTSStateType = Tuple[TensorVariable, TensorVariable, TensorVariable]
NUTSKernelType = Callable[
[NUTSStateType],
Expand All @@ -29,7 +34,18 @@
]


def construct_sampler(
class NUTSKernel(SamplingStep):
"""An `Op` that represents the update of one or many random variables
with the NUTS sampling algorithm.
"""

def __init__(self, inputs, outputs, sampled_rvs, parameters):
name = "NUTS sampler"
super().__init__(inputs, outputs, sampled_rvs, name, parameters)


def step(
srng: RandomStream,
to_sample_rvs: Dict[RandomVariable, TensorVariable],
realized_rvs_to_values: Dict[RandomVariable, TensorVariable],
Expand Down Expand Up @@ -63,7 +79,7 @@ def construct_sampler(
"""

# Create the initial values for the random variables that are assigned this
# Get the initial values for the random variables that are assigned this
# sampling step.
initial_values = to_sample_rvs.values()

Expand Down Expand Up @@ -131,6 +147,36 @@ def logprob_fn(q):
)


def construct_sampler(
srng: RandomStream,
to_sample_rvs: Dict[RandomVariable, TensorVariable],
realized_rvs_to_values: Dict[RandomVariable, TensorVariable],
) -> Tuple[Dict[RandomVariable, TensorVariable], Dict, Dict[str, TensorVariable],]:

results, updates, parameters = step(srng, to_sample_rvs, realized_rvs_to_values)

# Build an `Op` that represents the NUTS sampling step
update_outputs = list(updates.values())
outputs = list(results.values()) + update_outputs
inputs = [
var_in
for var_in in graph_inputs(outputs)
if not isinstance(var_in, Constant) and not isinstance(var_in, SharedVariable)
]
nuts_op = NUTSKernel(inputs, outputs, list(to_sample_rvs.keys()), parameters)

posterior = nuts_op(*inputs)
results = {rv: posterior[i] for i, rv in enumerate(to_sample_rvs)}

updates_input = posterior[0].owner.inputs[len(inputs) :]
updates_output = posterior[len(results) :]
updates = {
updates_input[i]: update_out for i, update_out in enumerate(updates_output)
}

return results, updates, parameters


def get_transform(rv: TensorVariable):
"""Get the default transform associated with the random variable."""
transform = _default_transformed_rv(rv.owner.op, rv.owner)
Expand Down
6 changes: 2 additions & 4 deletions tests/test_nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@ def test_nuts():
to_sample_rvs = {mu_rv: mu_vv, sigma_rv: sigma_vv}
observed = {Y_rv: y_vv}

state_at, updates, parameters = nuts.construct_sampler(
srng, to_sample_rvs, observed
)
sampling_step, updates, parameters = nuts.step(srng, to_sample_rvs, observed)

# Make sure that the state is properly initialized
sample_steps = [state_at[rv] for rv in to_sample_rvs]
sample_steps = [sampling_step[rv] for rv in to_sample_rvs]
state_fn = aesara.function(
(
mu_vv,
Expand Down

0 comments on commit e33be86

Please sign in to comment.