Skip to content

Commit

Permalink
Add compile_forward_sampling_function
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed May 13, 2022
1 parent ab593b1 commit 67c2a13
Show file tree
Hide file tree
Showing 4 changed files with 323 additions and 31 deletions.
145 changes: 145 additions & 0 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import warnings

from typing import (
Any,
Callable,
Dict,
Generator,
Expand All @@ -33,6 +34,7 @@
from aeppl.abstract import MeasurableVariable
from aeppl.logprob import CheckParameterValue
from aesara import config, scalar
from aesara.compile.function.types import Function
from aesara.compile.mode import Mode, get_mode
from aesara.gradient import grad
from aesara.graph import local_optimizer
Expand All @@ -41,6 +43,7 @@
Constant,
Variable,
clone_get_equiv,
general_toposort,
graph_inputs,
vars_between,
walk,
Expand Down Expand Up @@ -82,6 +85,7 @@
"at_rng",
"take_along_axis",
"convert_observed_data",
"compile_forward_sampling_function",
]


Expand Down Expand Up @@ -1015,3 +1019,144 @@ def compile_pymc(
**kwargs,
)
return aesara_function


def compile_forward_sampling_function(
outputs: List[Variable],
vars_in_trace: List[Variable],
givens_dict: Optional[Dict[Variable, Any]] = None,
**kwargs,
) -> Function:
"""Compile a function to draw samples, conditioned on the values of some variables.
The goal of this function is to walk the aesara computational graph from the list
of output nodes down to the root nodes, and then compile a function that will produce
values for these output nodes. The compiled function will take as inputs the subset of
variables in the ``vars_in_trace`` that are deemed to not be **volatile**.
Volatile variables are variables whose values could change between runs of the
compiled function. These variables are: variables in the outputs list, SharedVariable
instances, RandomVariable instances that have volatile parameters, RandomVariables that
are not in the ``vars_in_trace`` list, and that any other type of variable that depends
on volatile variables.
Concretely, this function can be used to compile a function to sample from the
posterior predictive distribution of a model that has variables that are conditioned
on ``MutableData`` instances. The variables that depend on the mutable data will be
considered volatile, and as such, they wont be included as inputs into the compiled function.
This means that if they have values stored in the posterior, these values will be ignored
and new values will be computed (in the case of deterministics and potentials) or sampled
(in the case of random variables).
This function also enables a way to impute values for any variable in the computational
graph that produces the desired outputs: the ``givens_dict``. This dictionary can be used
to set the ``givens`` argument of the aesara function compilation. This will essentially
replace a node in the computational graph with any other expression that has the same
type as the desired node. Passing variables in the givens_dict is considered an intervention
that might lead to different variable values from those that could have been seen during
inference, as such, **any variable that is passed in the ``givens_dict`` will be considered
volatile**.
Parameters
----------
outputs : List[aesara.graph.basic.Variable]
The list of variables that will be returned by the compiled function
vars_in_trace : List[aesara.graph.basic.Variable]
The list of variables that are assumed to have values stored in the trace
givens_dict : Optional[Dict[aesara.graph.basic.Variable, Any]]
A dictionary that maps tensor variables to the values that should be used to replace them
in the compiled function. The types of the key and value should match or an error will be
raised during compilation.
"""
if givens_dict is None:
givens_dict = {}
# We need a function graph to walk the clients and propagate the volatile property
fg = aesara.graph.FunctionGraph(outputs=outputs, clone=False)

# Walk the graph from inputs to outputs and tag the volatile variables
nodes = general_toposort(fg.outputs, deps=lambda x: x.get_parents())
givens = []
registered_givens = set()
volatile_nodes = set()
for node in nodes:
if (
isinstance(node, aesara.compile.SharedVariable)
or node in fg.outputs
or node in givens_dict
):
volatile_nodes.add(node)
is_volatile = node in volatile_nodes

# Populate the givens list if the node is in the givens_dict
if node in givens_dict and node not in registered_givens:
value = givens_dict[node]
if not isinstance(value, (Apply, Variable)):
# If the value is not a tensor, it is wrapped as a Constant with the node's name
value = at.constant(
givens_dict[node],
dtype=node.dtype,
name=node.name,
)
givens.append((node, value))
registered_givens.add(node)

# Propagate the volatile property to the node's clients
if is_volatile:
try:
clients = [c for c, _ in fg.clients[node]]
except KeyError:
# Some Apply nodes don't show up in the clients of a FunctionGraph
# so we take the Apply node's outputs instead
if isinstance(node, aesara.graph.basic.Apply):
clients = node.outputs
else:
raise
for client in clients:
if (
isinstance(client, aesara.graph.basic.Variable)
and client.owner is not None
and isinstance(client.owner.op, at.random.basic.RandomVariable)
and any(rv_param in volatile_nodes for rv_param in client.owner.inputs[3:])
):
# Random variables that have volatile parameters are also volatile
volatile_nodes.add(client)
elif not (
client in vars_in_trace
and isinstance(client, aesara.graph.basic.Variable)
and client.owner is not None
and isinstance(client.owner.op, at.random.basic.RandomVariable)
):
# Other variables that aren't random variables in the trace are volatile
volatile_nodes.add(client)

# Collect the function inputs by walking the graph from the outputs. Inputs will be:
# 1. Random variables that are not volatile
# 2. Variables that have no owner and are not constant or shared
inputs = []

def expand(node):
if isinstance(node, Apply):
# Apply nodes are never inputs
return node.get_parents()
assert isinstance(node, Variable)
if (
(
node.owner is None
and not isinstance(node, (at.basic.Constant, aesara.compile.SharedVariable))
) # Variables without owners that are not constant or shared
or (
node.owner is not None and isinstance(node.owner.op, at.random.basic.RandomVariable)
) # Random variables
or (node in vars_in_trace) # Variables in the trace
):
if node not in volatile_nodes:
# This test will include variables without owners, and that are not constant
# or shared, because these nodes will never be considered volatile
inputs.append(node)
return node.get_parents()

# walk produces a generator, so we have to actually exhaust the generator in a list to walk
# the entire graph
list(walk(fg.outputs, expand))

return compile_pymc(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs)
59 changes: 29 additions & 30 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,19 @@
import xarray

from aesara.graph.basic import Constant, Variable
from aesara.tensor import TensorVariable
from aesara.tensor.sharedvar import SharedVariable
from arviz import InferenceData
from fastprogress.fastprogress import progress_bar
from typing_extensions import TypeAlias

import pymc as pm

from pymc.aesaraf import change_rv_size, compile_pymc, inputvars, walk_model
from pymc.aesaraf import (
change_rv_size,
compile_forward_sampling_function,
compile_pymc,
inputvars,
)
from pymc.backends.arviz import _DefaultTrace
from pymc.backends.base import BaseTrace, MultiTrace
from pymc.backends.ndarray import NDArray
Expand All @@ -75,6 +79,7 @@
get_default_varnames,
get_untransformed_name,
is_transformed_name,
point_wrapper,
)
from pymc.vartypes import discrete_types

Expand Down Expand Up @@ -1534,6 +1539,16 @@ def stop_tuning(step):
return step


def get_vars_in_point_list(trace, model):
"""Get the list of Variable instances in the model that have values stored in the trace."""
if not isinstance(trace, MultiTrace):
names_in_trace = list(trace[0])
else:
names_in_trace = trace.varnames
vars_in_trace = [model[v] for v in names_in_trace]
return vars_in_trace


def sample_posterior_predictive(
trace,
samples: Optional[int] = None,
Expand Down Expand Up @@ -1718,38 +1733,22 @@ def sample_posterior_predictive(
return trace
return {}

inputs: Sequence[TensorVariable]
input_names: Sequence[str]
if not isinstance(_trace, MultiTrace):
names_in_trace = list(_trace[0])
else:
names_in_trace = _trace.varnames
inputs_and_names = [
(rv, rv.name)
for rv in walk_model(vars_to_sample, walk_past_rvs=True)
if rv not in vars_to_sample
and rv in model.named_vars.values()
and not isinstance(rv, (Constant, SharedVariable))
and rv.name in names_in_trace
]
if inputs_and_names:
inputs, input_names = zip(*inputs_and_names)
else:
inputs, input_names = [], []

if size is not None:
vars_to_sample = [change_rv_size(v, size, expand=True) for v in vars_to_sample]
vars_in_trace = get_vars_in_point_list(_trace, model)

if compile_kwargs is None:
compile_kwargs = {}

sampler_fn = compile_pymc(
inputs,
vars_to_sample,
allow_input_downcast=True,
accept_inplace=True,
on_unused_input="ignore",
**compile_kwargs,
compile_kwargs.setdefault("allow_input_downcast", True)
compile_kwargs.setdefault("accept_inplace", True)

sampler_fn = point_wrapper(
compile_forward_sampling_function(
outputs=vars_to_sample,
vars_in_trace=vars_in_trace,
givens_dict=None,
**compile_kwargs,
)
)

ppc_trace_t = _DefaultTrace(samples)
Expand All @@ -1775,7 +1774,7 @@ def sample_posterior_predictive(
else:
param = _trace[idx % len_trace]

values = sampler_fn(*(param[n] for n in input_names))
values = sampler_fn(**param)

for k, v in zip(vars_, values):
ppc_trace_t.insert(k.name, v, idx)
Expand Down
Loading

0 comments on commit 67c2a13

Please sign in to comment.