Skip to content

Commit

Permalink
Add compile_forward_sampling_function
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed May 16, 2022
1 parent ab593b1 commit 10e06bc
Show file tree
Hide file tree
Showing 4 changed files with 412 additions and 40 deletions.
41 changes: 40 additions & 1 deletion pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,46 @@ def dist(
return rv_out


class SymbolicDistribution:
class RvOpTagger:
def __init__(self, descriptor):
"""A class that wraps the rv_op classmethod to add a special tag value to the outputs.
Refer to the descriptor protocol docs for info on why we need to wrap the __get__ method
as we had to do here (https://docs.python.org/3/howto/descriptor.html).
"""
self.wrapped = descriptor
if hasattr(self.wrapped, "__set__"):
self.__set__ = self.wrapped.__set__
if hasattr(self.wrapped, "__del__"):
self.__del__ = self.wrapped.__del__
if hasattr(self.wrapped, "__doc__"):
self.__doc__ = self.wrapped.__doc__

def __get__(self, obj, objtype=None):
f = self.wrapped.__get__(obj, objtype)

@functools.wraps(f)
def wrapped(*args, **kwargs):
outputs = f(*args, **kwargs)
if isinstance(outputs, tuple):
for output in outputs:
if isinstance(output, Variable):
output.tag.symbolic_distribution = True
elif isinstance(outputs, Variable):
outputs.tag.symbolic_distribution = True
return outputs

return wrapped


class SymbolicDistributionMeta(type):
def __new__(meta, classname, bases, class_dict):
if "rv_op" in class_dict:
class_dict["rv_op"] = RvOpTagger(class_dict["rv_op"])
return super().__new__(meta, classname, bases, class_dict)


class SymbolicDistribution(metaclass=SymbolicDistributionMeta):
"""Symbolic statistical distribution
While traditional PyMC distributions are represented by a single RandomVariable
Expand Down
217 changes: 178 additions & 39 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,22 @@
import numpy as np
import xarray

from aesara.graph.basic import Constant, Variable
from aesara.tensor import TensorVariable
from aesara import tensor as at
from aesara.graph.basic import Apply, Constant, Variable, general_toposort, walk
from aesara.graph.fg import FunctionGraph
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.var import (
RandomGeneratorSharedVariable,
RandomStateSharedVariable,
)
from aesara.tensor.sharedvar import SharedVariable
from arviz import InferenceData
from fastprogress.fastprogress import progress_bar
from typing_extensions import TypeAlias

import pymc as pm

from pymc.aesaraf import change_rv_size, compile_pymc, inputvars, walk_model
from pymc.aesaraf import change_rv_size, compile_pymc
from pymc.backends.arviz import _DefaultTrace
from pymc.backends.base import BaseTrace, MultiTrace
from pymc.backends.ndarray import NDArray
Expand All @@ -75,6 +81,7 @@
get_default_varnames,
get_untransformed_name,
is_transformed_name,
point_wrapper,
)
from pymc.vartypes import discrete_types

Expand All @@ -83,6 +90,7 @@
__all__ = [
"sample",
"iter_sample",
"compile_forward_sampling_function",
"sample_posterior_predictive",
"sample_posterior_predictive_w",
"init_nuts",
Expand Down Expand Up @@ -1534,6 +1542,157 @@ def stop_tuning(step):
return step


def get_vars_in_point_list(trace, model):
"""Get the list of Variable instances in the model that have values stored in the trace."""
if not isinstance(trace, MultiTrace):
names_in_trace = list(trace[0])
else:
names_in_trace = trace.varnames
vars_in_trace = [model[v] for v in names_in_trace]
return vars_in_trace


def compile_forward_sampling_function(
outputs: List[Variable],
vars_in_trace: List[Variable],
givens_dict: Optional[Dict[Variable, Any]] = None,
**kwargs,
) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]:
"""Compile a function to draw samples, conditioned on the values of some variables.
The goal of this function is to walk the aesara computational graph from the list
of output nodes down to the root nodes, and then compile a function that will produce
values for these output nodes. The compiled function will take as inputs the subset of
variables in the ``vars_in_trace`` that are deemed to not be **volatile**.
Volatile variables are variables whose values could change between runs of the
compiled function. These variables are: variables in the outputs list, SharedVariable
instances, RandomVariable instances that have volatile parameters, RandomVariables that
are not in the ``vars_in_trace`` list, and any other type of variable that depends
on volatile variables.
Concretely, this function can be used to compile a function to sample from the
posterior predictive distribution of a model that has variables that are conditioned
on ``MutableData`` instances. The variables that depend on the mutable data will be
considered volatile, and as such, they wont be included as inputs into the compiled function.
This means that if they have values stored in the posterior, these values will be ignored
and new values will be computed (in the case of deterministics and potentials) or sampled
(in the case of random variables).
This function also enables a way to impute values for any variable in the computational
graph that produces the desired outputs: the ``givens_dict``. This dictionary can be used
to set the ``givens`` argument of the aesara function compilation. This will essentially
replace a node in the computational graph with any other expression that has the same
type as the desired node. Passing variables in the givens_dict is considered an intervention
that might lead to different variable values from those that could have been seen during
inference, as such, **any variable that is passed in the ``givens_dict`` will be considered
volatile**.
Parameters
----------
outputs : List[aesara.graph.basic.Variable]
The list of variables that will be returned by the compiled function
vars_in_trace : List[aesara.graph.basic.Variable]
The list of variables that are assumed to have values stored in the trace
givens_dict : Optional[Dict[aesara.graph.basic.Variable, Any]]
A dictionary that maps tensor variables to the values that should be used to replace them
in the compiled function. The types of the key and value should match or an error will be
raised during compilation.
"""
if givens_dict is None:
givens_dict = {}
# We need a function graph to walk the clients and propagate the volatile property
fg = FunctionGraph(outputs=outputs, clone=False)

# Walk the graph from inputs to outputs and tag the volatile variables
nodes: List[Union[Variable, Apply]] = general_toposort(
fg.outputs, deps=lambda x: x.get_parents()
)
volatile_nodes: Set[Any] = set()
for node in nodes:
if isinstance(node, SharedVariable) or node in fg.outputs or node in givens_dict:
volatile_nodes.add(node)
is_volatile = node in volatile_nodes

# Propagate the volatile property to the node's clients
if is_volatile:
try:
clients = [c for c, _ in fg.clients[node]] # type: ignore
except KeyError:
# Some Apply nodes don't show up in the clients of a FunctionGraph
# so we take the Apply node's outputs instead
clients = node.outputs # type: ignore
for client in clients:
if (
isinstance(client, Variable)
and client.owner is not None
and isinstance(client.owner.op, RandomVariable)
and any(
rv_param in volatile_nodes
for rv_param in client.owner.inputs
if not isinstance(
rv_param, (RandomStateSharedVariable, RandomGeneratorSharedVariable)
)
)
):
# Random variables that have volatile parameters are also volatile
volatile_nodes.add(client)
elif not (
client in vars_in_trace
and isinstance(client, Variable)
and client.owner is not None
and (
isinstance(client.owner.op, RandomVariable)
or getattr(client.tag, "symbolic_distribution", False)
)
):
# Other variables that aren't random variables in the trace are volatile
volatile_nodes.add(client)

# Collect the function inputs by walking the graph from the outputs. Inputs will be:
# 1. Random variables that are not volatile
# 2. Variables that have no owner and are not constant or shared
inputs = []

def expand(node):
if isinstance(node, Apply):
# Apply nodes are never inputs
return node.get_parents()
assert isinstance(node, Variable)
if (
(
node.owner is None and not isinstance(node, (Constant, SharedVariable))
) # Variables without owners that are not constant or shared
or (
node.owner is not None and isinstance(node.owner.op, RandomVariable)
) # Random variables
or (node in vars_in_trace) # Variables in the trace
):
if node not in volatile_nodes:
# This test will include variables without owners, and that are not constant
# or shared, because these nodes will never be considered volatile
inputs.append(node)
return node.get_parents()

# walk produces a generator, so we have to actually exhaust the generator in a list to walk
# the entire graph
list(walk(fg.outputs, expand))
givens = []

# Populate the givens list
givens = [
(
node,
value
if isinstance(value, (Variable, Apply))
else at.constant(value, dtype=getattr(node, "dtype", None), name=node.name),
)
for node, value in givens_dict.items()
]

return compile_pymc(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs)


def sample_posterior_predictive(
trace,
samples: Optional[int] = None,
Expand Down Expand Up @@ -1718,38 +1877,22 @@ def sample_posterior_predictive(
return trace
return {}

inputs: Sequence[TensorVariable]
input_names: Sequence[str]
if not isinstance(_trace, MultiTrace):
names_in_trace = list(_trace[0])
else:
names_in_trace = _trace.varnames
inputs_and_names = [
(rv, rv.name)
for rv in walk_model(vars_to_sample, walk_past_rvs=True)
if rv not in vars_to_sample
and rv in model.named_vars.values()
and not isinstance(rv, (Constant, SharedVariable))
and rv.name in names_in_trace
]
if inputs_and_names:
inputs, input_names = zip(*inputs_and_names)
else:
inputs, input_names = [], []

if size is not None:
vars_to_sample = [change_rv_size(v, size, expand=True) for v in vars_to_sample]
vars_in_trace = get_vars_in_point_list(_trace, model)

if compile_kwargs is None:
compile_kwargs = {}

sampler_fn = compile_pymc(
inputs,
vars_to_sample,
allow_input_downcast=True,
accept_inplace=True,
on_unused_input="ignore",
**compile_kwargs,
compile_kwargs.setdefault("allow_input_downcast", True)
compile_kwargs.setdefault("accept_inplace", True)

sampler_fn = point_wrapper(
compile_forward_sampling_function(
outputs=vars_to_sample,
vars_in_trace=vars_in_trace,
givens_dict=None,
**compile_kwargs,
)
)

ppc_trace_t = _DefaultTrace(samples)
Expand All @@ -1775,7 +1918,7 @@ def sample_posterior_predictive(
else:
param = _trace[idx % len_trace]

values = sampler_fn(*(param[n] for n in input_names))
values = sampler_fn(**param)

for k, v in zip(vars_, values):
ppc_trace_t.insert(k.name, v, idx)
Expand Down Expand Up @@ -2063,17 +2206,13 @@ def sample_prior_predictive(
names.append(rv_var.name)
vars_to_sample.append(rv_var)

inputs = [i for i in inputvars(vars_to_sample) if not isinstance(i, (Constant, SharedVariable))]

if compile_kwargs is None:
compile_kwargs = {}
compile_kwargs.setdefault("allow_input_downcast", True)
compile_kwargs.setdefault("accept_inplace", True)

sampler_fn = compile_pymc(
inputs,
vars_to_sample,
allow_input_downcast=True,
accept_inplace=True,
**compile_kwargs,
sampler_fn = compile_forward_sampling_function(
vars_to_sample, vars_in_trace=[], givens_dict=None, **compile_kwargs
)

values = zip(*(sampler_fn() for i in range(samples)))
Expand Down
Loading

0 comments on commit 10e06bc

Please sign in to comment.