Skip to content

Commit

Permalink
Track valued/bound variables using an in-graph Op
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Dec 4, 2022
1 parent d0c009d commit 03dc97e
Show file tree
Hide file tree
Showing 16 changed files with 382 additions and 384 deletions.
55 changes: 55 additions & 0 deletions aeppl/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
from functools import singledispatch
from typing import Callable, List, Tuple

import aesara.tensor as at
from aesara.gradient import grad_undefined
from aesara.graph.basic import Apply, Variable
from aesara.graph.op import Op
from aesara.graph.utils import MetaType
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.type import TensorType


class MeasurableVariable(abc.ABC):
Expand Down Expand Up @@ -134,3 +137,55 @@ def __init__(self, scalar_op, *args, **kwargs):


MeasurableVariable.register(MeasurableElemwise)


class ValuedVariable(Op):
r"""Represents the association of a measurable variable and its value.
A `ValuedVariable` node represents the pair :math:`(Y, y)`, where
:math:`Y` is a random variable and :math:`y \sim Y`.
Log-probability (densities) are functions over these pairs, which makes
these nodes in a graph an intermediate form that serves to construct a
log-probability from a model graph.
This intermediate form can be used as the target for rewrites that
otherwise wouldn't make sense to apply to--say--a random variable node
directly. An example is `BroadcastTo` lifting through `RandomVariable`\s.
"""

default_output = 0
view_map = {0: [0]}

def make_node(self, rv, value):

assert isinstance(rv.type, TensorType)
out_rv = rv.type()

vv = at.as_tensor_variable(value)
assert isinstance(vv.type, TensorType)

# TODO: We should probably check the `Type`s of `out_rv` and `vv`
if vv.type.dtype != rv.type.dtype:
raise TypeError(
f"Value type {vv.type} does not match random variable type {out_rv.type}"
)

return Apply(self, [rv, vv], [out_rv])

def perform(self, node, inputs, out):
out[0][0] = inputs[0]

def grad(self, inputs, outputs):
return [
grad_undefined(self, k, inp, "No gradient defined for `ValuedVariable`")
for k, inp in enumerate(inputs)
]

def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]]


MeasurableVariable.register(ValuedVariable)

valued_variable = ValuedVariable()
13 changes: 3 additions & 10 deletions aeppl/censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from aeppl.abstract import (
MeasurableElemwise,
MeasurableVariable,
ValuedVariable,
assign_custom_measurable_outputs,
)
from aeppl.logprob import CheckParameterValue, _logcdf, _logprob, logdiffexp
Expand All @@ -34,10 +35,6 @@ def find_measurable_clips(
) -> Optional[List[MeasurableClip]]:
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)

rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover

if isinstance(node.op, MeasurableClip):
return None # pragma: no cover

Expand All @@ -50,7 +47,7 @@ def find_measurable_clips(
if not (
base_var.owner
and isinstance(base_var.owner.op, MeasurableVariable)
and base_var not in rv_map_feature.rv_values
and not isinstance(base_var, ValuedVariable)
):
return None

Expand Down Expand Up @@ -155,10 +152,6 @@ def find_measurable_roundings(
fgraph: FunctionGraph, node: Node
) -> Optional[List[MeasurableRound]]:

rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover

if isinstance(node.op, MeasurableRound):
return None # pragma: no cover

Expand All @@ -174,7 +167,7 @@ def find_measurable_roundings(
if not (
base_var.owner
and isinstance(base_var.owner.op, MeasurableVariable)
and base_var not in rv_map_feature.rv_values
and not isinstance(base_var, ValuedVariable)
# Rounding only makes sense for continuous variables
and base_var.dtype.startswith("float")
):
Expand Down
17 changes: 7 additions & 10 deletions aeppl/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
from aesara.graph.rewriting.basic import node_rewriter
from aesara.tensor.extra_ops import CumOp

from aeppl.abstract import MeasurableVariable, assign_custom_measurable_outputs
from aeppl.abstract import (
MeasurableVariable,
ValuedVariable,
assign_custom_measurable_outputs,
)
from aeppl.logprob import _logprob, logprob
from aeppl.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
from aeppl.rewriting import measurable_ir_rewrites_db


class MeasurableCumsum(CumOp):
Expand Down Expand Up @@ -50,20 +54,13 @@ def find_measurable_cumsums(fgraph, node) -> Optional[List[MeasurableCumsum]]:
if isinstance(node.op, MeasurableCumsum):
return None # pragma: no cover

rv_map_feature: Optional[PreserveRVMappings] = getattr(
fgraph, "preserve_rv_mappings", None
)

if rv_map_feature is None:
return None # pragma: no cover

rv = node.outputs[0]

base_rv = node.inputs[0]
if not (
base_rv.owner
and isinstance(base_rv.owner.op, MeasurableVariable)
and base_rv not in rv_map_feature.rv_values
and not isinstance(base_rv, ValuedVariable)
):
return None # pragma: no cover

Expand Down
156 changes: 44 additions & 112 deletions aeppl/joint_logprob.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
import warnings
from collections import deque
from typing import Dict, List, Optional, Tuple, Union

import aesara.tensor as at
from aesara import config
from aesara.graph.basic import graph_inputs, io_toposort
from aesara.graph.op import compute_test_value
from aesara.graph.rewriting.basic import GraphRewriter, NodeRewriter
from aesara.tensor.var import TensorVariable

from aeppl.abstract import get_measurable_outputs
from aeppl.abstract import ValuedVariable, get_measurable_outputs
from aeppl.logprob import _logprob
from aeppl.rewriting import construct_ir_fgraph
from aeppl.utils import rvs_to_value_vars


def conditional_logprob(
Expand All @@ -22,7 +16,7 @@ def conditional_logprob(
ir_rewriter: Optional[GraphRewriter] = None,
extra_rewrites: Optional[Union[GraphRewriter, NodeRewriter]] = None,
**kwargs,
) -> Tuple[Dict[TensorVariable, TensorVariable], List[TensorVariable]]:
) -> Tuple[Dict[TensorVariable, TensorVariable], Tuple[TensorVariable, ...]]:
r"""Create a map between random variables and their conditional log-probabilities.
The list of measurable variables implicitly defines a joint probability that
Expand Down Expand Up @@ -106,133 +100,71 @@ def conditional_logprob(
# graphs. We can thus use them to recover the original random variables to index the
# maps to the logprob graphs and value variables before returning them.
rv_values = {**original_rv_values, **realized}
vv_to_original_rvs = {vv: rv for rv, vv in rv_values.items()}
# vv_to_original_rvs = {vv: rv for rv, vv in rv_values.items()}

fgraph, rv_values, _ = construct_ir_fgraph(rv_values, ir_rewriter=ir_rewriter)
fgraph, _, memo = construct_ir_fgraph(rv_values, ir_rewriter=ir_rewriter)

# The interface for transformations assumes that the value variables are in
# the transformed space. To get the correct `shape` and `dtype` for the
# value variables we return we need to apply the forward transformation to
# our RV copies, and return the type of the resulting variable as a value
# variable.
vv_remapper = {}
if extra_rewrites is not None:
extra_rewrites.add_requirements(fgraph, {**original_rv_values, **realized})
extra_rewrites.add_requirements(fgraph, rv_values)
extra_rewrites.apply(fgraph)
vv_remapper = fgraph.values_to_untransformed

rv_remapper = fgraph.preserve_rv_mappings

# This is the updated random-to-value-vars map with the lifted/rewritten
# variables. The rewrites are supposed to produce new
# `MeasurableVariable`s that are amenable to `_logprob`.
updated_rv_values = rv_remapper.rv_values

# Some rewrites also transform the original value variables. This is the
# updated map from the new value variables to the original ones, which
# we want to use as the keys in the final dictionary output
original_values = rv_remapper.original_values

# When a `_logprob` has been produced for a `MeasurableVariable` node, all
# other references to it need to be replaced with its value-variable all
# throughout the `_logprob`-produced graphs. The following `dict`
# cumulatively maintains remappings for all the variables/nodes that needed
# to be recreated after replacing `MeasurableVariable`s with their
# value-variables. Since these replacements work in topological order, all
# the necessary value-variable replacements should be present for each
# node.
replacements = updated_rv_values.copy()

# To avoid cloning the value variables, we map them to themselves in the
# `replacements` `dict` (i.e. entries already existing in `replacements`
# aren't cloned)
replacements.update({v: v for v in rv_values.values()})

# Walk the graph from its inputs to its outputs and construct the
# log-probability
q = deque(fgraph.toposort())

logprob_vars = {}
value_variables = {}

while q:
node = q.popleft()
for out, old_out in zip(fgraph.outputs, rv_values.keys()):
node = out.owner

outputs = get_measurable_outputs(node.op, node)
if not outputs:
continue

if any(o not in updated_rv_values for o in outputs):
if warn_missing_rvs:
warnings.warn(
"Found a random variable that is not assigned a value variable: "
f"{node.outputs}"
)
continue

q_value_vars = [replacements[q_rv_var] for q_rv_var in outputs]

if not q_value_vars:
continue

# Replace `RandomVariable`s in the inputs with value variables.
# Also, store the results in the `replacements` map for the nodes
# that follow.
remapped_vars, _ = rvs_to_value_vars(
q_value_vars + list(node.inputs),
initial_replacements=replacements,
)
q_value_vars = remapped_vars[: len(q_value_vars)]
q_rv_inputs = remapped_vars[len(q_value_vars) :]

q_logprob_vars = _logprob(
node.op,
q_value_vars,
*q_rv_inputs,
**kwargs,
)
assert isinstance(node.op, ValuedVariable)

if not isinstance(q_logprob_vars, (list, tuple)):
q_logprob_vars = [q_logprob_vars]
rv_var, val_var = node.inputs

for q_value_var, q_logprob_var in zip(q_value_vars, q_logprob_vars):
rv_node = rv_var.owner
outputs = get_measurable_outputs(rv_node.op, rv_node)

q_value_var = original_values[q_value_var]
q_rv = vv_to_original_rvs[q_value_var]

if q_rv.name:
q_logprob_var.name = f"{q_rv.name}_logprob"
if not outputs:
raise ValueError(f"Couldn't derive a log-probability for {out}")

# TODO: This probably needs to be done outside of this loop.
# if warn_missing_rvs:
# warnings.warn(
# "Found a random variable that is not assigned a value variable: "
# f"{node.outputs}"
# )
rv_logprob = _logprob(
rv_node.op,
[val_var],
*rv_node.inputs,
**kwargs,
)

if q_rv in logprob_vars:
raise ValueError(
f"More than one logprob factor was assigned to the random variable {q_rv}"
)
if isinstance(rv_logprob, (tuple, list)):
(rv_logprob,) = rv_logprob

logprob_vars[q_rv] = q_logprob_var
if old_out.name:
rv_logprob.name = f"{old_out.name}_logprob"

q_value_var = vv_remapper.get(q_value_var, q_value_var)
value_variables[q_rv] = q_value_var
logprob_vars[old_out] = rv_logprob

# Recompute test values for the changes introduced by the
# replacements above.
if config.compute_test_value != "off":
for node in io_toposort(graph_inputs(q_logprob_vars), q_logprob_vars):
compute_test_value(node)
# # Recompute test values for the changes introduced by the
# # replacements above.
# if config.compute_test_value != "off":
# for node in io_toposort(graph_inputs([rv_logprob]), q_logprob_vars):
# compute_test_value(node)

missing_value_terms = set(vv_to_original_rvs.values()) - set(logprob_vars.keys())
if missing_value_terms:
raise RuntimeError(
f"The logprob terms of the following random variables could not be derived: {missing_value_terms}"
)
# missing_value_terms = set(vv_to_original_rvs.values()) - set(logprob_vars.keys())
# if missing_value_terms:
# raise RuntimeError(
# f"The logprob terms of the following random variables could not be derived: {missing_value_terms}"
# )

return logprob_vars, [value_variables[rv] for rv in original_rv_values.keys()]
value_vars = tuple(memo[vv] for rv, vv in rv_values.items() if rv not in realized)
return logprob_vars, value_vars


def joint_logprob(
*random_variables: List[TensorVariable],
realized: Dict[TensorVariable, TensorVariable] = {},
**kwargs,
) -> Optional[Tuple[TensorVariable, List[TensorVariable]]]:
) -> Optional[Tuple[TensorVariable, Tuple[TensorVariable, ...]]]:
"""Create a graph representing the joint log-probability/measure of a graph.
This function calls `factorized_joint_logprob` and returns the combined
Expand Down
Loading

0 comments on commit 03dc97e

Please sign in to comment.