Skip to content

Commit

Permalink
Add compile_forward_sampling_function
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed May 17, 2022
1 parent ab593b1 commit c346391
Show file tree
Hide file tree
Showing 3 changed files with 419 additions and 37 deletions.
227 changes: 190 additions & 37 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,22 @@
import numpy as np
import xarray

from aesara.graph.basic import Constant, Variable
from aesara.tensor import TensorVariable
from aesara import tensor as at
from aesara.graph.basic import Apply, Constant, Variable, general_toposort, walk
from aesara.graph.fg import FunctionGraph
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.var import (
RandomGeneratorSharedVariable,
RandomStateSharedVariable,
)
from aesara.tensor.sharedvar import SharedVariable
from arviz import InferenceData
from fastprogress.fastprogress import progress_bar
from typing_extensions import TypeAlias

import pymc as pm

from pymc.aesaraf import change_rv_size, compile_pymc, inputvars, walk_model
from pymc.aesaraf import change_rv_size, compile_pymc
from pymc.backends.arviz import _DefaultTrace
from pymc.backends.base import BaseTrace, MultiTrace
from pymc.backends.ndarray import NDArray
Expand All @@ -75,6 +81,7 @@
get_default_varnames,
get_untransformed_name,
is_transformed_name,
point_wrapper,
)
from pymc.vartypes import discrete_types

Expand All @@ -83,6 +90,7 @@
__all__ = [
"sample",
"iter_sample",
"compile_forward_sampling_function",
"sample_posterior_predictive",
"sample_posterior_predictive_w",
"init_nuts",
Expand Down Expand Up @@ -1534,6 +1542,166 @@ def stop_tuning(step):
return step


def get_vars_in_point_list(trace, model):
"""Get the list of Variable instances in the model that have values stored in the trace."""
if not isinstance(trace, MultiTrace):
names_in_trace = list(trace[0])
else:
names_in_trace = trace.varnames
vars_in_trace = [model[v] for v in names_in_trace]
return vars_in_trace


def compile_forward_sampling_function(
outputs: List[Variable],
vars_in_trace: List[Variable],
basic_rvs: Optional[List[Variable]] = None,
givens_dict: Optional[Dict[Variable, Any]] = None,
**kwargs,
) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]:
"""Compile a function to draw samples, conditioned on the values of some variables.
The goal of this function is to walk the aesara computational graph from the list
of output nodes down to the root nodes, and then compile a function that will produce
values for these output nodes. The compiled function will take as inputs the subset of
variables in the ``vars_in_trace`` that are deemed to not be **volatile**.
Volatile variables are variables whose values could change between runs of the
compiled function. These variables are: variables in the outputs list, SharedVariable
instances, RandomVariable instances that have volatile parameters, RandomVariables that
are not in the ``vars_in_trace`` list, and any other type of variable that depends
on volatile variables.
Concretely, this function can be used to compile a function to sample from the
posterior predictive distribution of a model that has variables that are conditioned
on ``MutableData`` instances. The variables that depend on the mutable data will be
considered volatile, and as such, they wont be included as inputs into the compiled function.
This means that if they have values stored in the posterior, these values will be ignored
and new values will be computed (in the case of deterministics and potentials) or sampled
(in the case of random variables).
This function also enables a way to impute values for any variable in the computational
graph that produces the desired outputs: the ``givens_dict``. This dictionary can be used
to set the ``givens`` argument of the aesara function compilation. This will essentially
replace a node in the computational graph with any other expression that has the same
type as the desired node. Passing variables in the givens_dict is considered an intervention
that might lead to different variable values from those that could have been seen during
inference, as such, **any variable that is passed in the ``givens_dict`` will be considered
volatile**.
Parameters
----------
outputs : List[aesara.graph.basic.Variable]
The list of variables that will be returned by the compiled function
vars_in_trace : List[aesara.graph.basic.Variable]
The list of variables that are assumed to have values stored in the trace
basic_rvs : Optional[List[aesara.graph.basic.Variable]]
A list of random variables that are defined in the model. This list (which could be the
output of ``model.basic_RVs``) should have a reference to the variables that should
be considered as random variable instances. This includes variables that have
a ``RandomVariable`` owner op, but also unpure random variables like Mixtures, or
Censored distributions. If ``None``, only pure random variables will be considered
as potential random variables.
givens_dict : Optional[Dict[aesara.graph.basic.Variable, Any]]
A dictionary that maps tensor variables to the values that should be used to replace them
in the compiled function. The types of the key and value should match or an error will be
raised during compilation.
"""
if givens_dict is None:
givens_dict = {}

if basic_rvs is None:
basic_rvs = []

def is_a_random_variable(node):
return node in basic_rvs or (
isinstance(node, Variable)
and node.owner is not None
and isinstance(node.owner.op, RandomVariable)
)

def distribution_not_dist_parameters(node):
return [
rv_param
for rv_param in node.owner.inputs
if (
not (
isinstance( # Random state and generator shared variables are always volatile but they don't count as distribution parameters
rv_param, (RandomStateSharedVariable, RandomGeneratorSharedVariable)
)
or (
isinstance(
getattr(getattr(rv_param, "owner", None), "op", None), RandomVariable
)
and rv_param not in basic_rvs
) # Mixture distributions have RandomVariable parameters that are not measurable (not in model.basic_RVs). These shouldn't propagate their volatility
)
)
]

# We need a function graph to walk the clients and propagate the volatile property
fg = FunctionGraph(outputs=outputs, clone=False)

# Walk the graph from inputs to outputs and tag the volatile variables
nodes: List[Variable] = general_toposort(
fg.outputs, deps=lambda x: x.owner.inputs if x.owner else []
)
volatile_nodes: Set[Any] = set()
for node in nodes:
if isinstance(node, SharedVariable) or node in fg.outputs or node in givens_dict:
volatile_nodes.add(node)
is_volatile = node in volatile_nodes

# Propagate the volatile property to the node's clients
if is_volatile:
clients = [out for c, _ in fg.clients[node] if c != "output" for out in c.outputs] # type: ignore
for client in clients:
if is_a_random_variable(client) and any(
rv_param in volatile_nodes
for rv_param in distribution_not_dist_parameters(client)
):
# Random variables that have volatile parameters are also volatile
volatile_nodes.add(client)
elif not (client in vars_in_trace and is_a_random_variable(client)):
# Other variables that aren't random variables in the trace are volatile
volatile_nodes.add(client)

# Collect the function inputs by walking the graph from the outputs. Inputs will be:
# 1. Random variables that are not volatile
# 2. Variables that have no owner and are not constant or shared
inputs = []

def expand(node):
if (
(
node.owner is None and not isinstance(node, (Constant, SharedVariable))
) # Variables without owners that are not constant or shared
or node in vars_in_trace # Variables in the trace
) and node not in volatile_nodes:
# This test will include variables without owners, and that are not constant
# or shared, because these nodes will never be considered volatile
inputs.append(node)
if node.owner:
return node.owner.inputs

# walk produces a generator, so we have to actually exhaust the generator in a list to walk
# the entire graph
list(walk(fg.outputs, expand))

# Populate the givens list
givens = [
(
node,
value
if isinstance(value, (Variable, Apply))
else at.constant(value, dtype=getattr(node, "dtype", None), name=node.name),
)
for node, value in givens_dict.items()
]

return compile_pymc(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs)


def sample_posterior_predictive(
trace,
samples: Optional[int] = None,
Expand Down Expand Up @@ -1718,38 +1886,23 @@ def sample_posterior_predictive(
return trace
return {}

inputs: Sequence[TensorVariable]
input_names: Sequence[str]
if not isinstance(_trace, MultiTrace):
names_in_trace = list(_trace[0])
else:
names_in_trace = _trace.varnames
inputs_and_names = [
(rv, rv.name)
for rv in walk_model(vars_to_sample, walk_past_rvs=True)
if rv not in vars_to_sample
and rv in model.named_vars.values()
and not isinstance(rv, (Constant, SharedVariable))
and rv.name in names_in_trace
]
if inputs_and_names:
inputs, input_names = zip(*inputs_and_names)
else:
inputs, input_names = [], []

if size is not None:
vars_to_sample = [change_rv_size(v, size, expand=True) for v in vars_to_sample]
vars_in_trace = get_vars_in_point_list(_trace, model)

if compile_kwargs is None:
compile_kwargs = {}

sampler_fn = compile_pymc(
inputs,
vars_to_sample,
allow_input_downcast=True,
accept_inplace=True,
on_unused_input="ignore",
**compile_kwargs,
compile_kwargs.setdefault("allow_input_downcast", True)
compile_kwargs.setdefault("accept_inplace", True)

sampler_fn = point_wrapper(
compile_forward_sampling_function(
outputs=vars_to_sample,
vars_in_trace=vars_in_trace,
basic_rvs=model.basic_RVs,
givens_dict=None,
**compile_kwargs,
)
)

ppc_trace_t = _DefaultTrace(samples)
Expand All @@ -1775,7 +1928,7 @@ def sample_posterior_predictive(
else:
param = _trace[idx % len_trace]

values = sampler_fn(*(param[n] for n in input_names))
values = sampler_fn(**param)

for k, v in zip(vars_, values):
ppc_trace_t.insert(k.name, v, idx)
Expand Down Expand Up @@ -2063,16 +2216,16 @@ def sample_prior_predictive(
names.append(rv_var.name)
vars_to_sample.append(rv_var)

inputs = [i for i in inputvars(vars_to_sample) if not isinstance(i, (Constant, SharedVariable))]

if compile_kwargs is None:
compile_kwargs = {}
compile_kwargs.setdefault("allow_input_downcast", True)
compile_kwargs.setdefault("accept_inplace", True)

sampler_fn = compile_pymc(
inputs,
sampler_fn = compile_forward_sampling_function(
vars_to_sample,
allow_input_downcast=True,
accept_inplace=True,
vars_in_trace=[],
basic_rvs=model.basic_RVs,
givens_dict=None,
**compile_kwargs,
)

Expand Down
Loading

0 comments on commit c346391

Please sign in to comment.