Skip to content

Commit

Permalink
Allow Scan value transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 28, 2022
1 parent afe57f8 commit d0c676c
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 5 deletions.
2 changes: 0 additions & 2 deletions pymc/logprob/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,6 @@ def find_measurable_scans(fgraph, node):

# We're going to set those values on our `new_val_var` so that it can
# serve as a complete replacement for the old input `outer_input_var`.
# from aesara.graph import clone_replace
#
new_val_var = outer_input_var.owner.clone_with_new_inputs(
[new_val_var] + outer_input_var.owner.inputs[1:]
).default_output()
Expand Down
92 changes: 89 additions & 3 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@
import aesara.tensor as at

from aesara.gradient import DisconnectedType, jacobian
from aesara.graph.basic import Apply, Node, Variable
from aesara.graph.basic import Apply, Node, Variable, clone_replace
from aesara.graph.features import AlreadyThere, Feature
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
from aesara.scalar import Add, Exp, Log, Mul
from aesara.scan.op import Scan
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.rewriting.basic import (
register_specialize,
Expand Down Expand Up @@ -219,11 +220,94 @@ def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
return trans_node.outputs


@node_rewriter(tracks=[Scan])
def transform_scan_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
"""Apply transforms to Scan value variables.
This specialized rewrite is needed because Scan replaces the original value variables
by a more complex graph. We want to apply the transform to the original value variable
in this subgraph, leaving the rest intact
"""

rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
values_to_transforms: Optional[TransformValuesMapping] = getattr(
fgraph, "values_to_transforms", None
)

if rv_map_feature is None or values_to_transforms is None:
return None # pragma: no cover

rv_vars = []
value_vars = []

for out in node.outputs:
value = rv_map_feature.rv_values.get(out, None)
if value is None:
continue
rv_vars.append(out)
value_vars.append(value)

if not value_vars:
return None

transforms = [
values_to_transforms.get(rv_map_feature.original_values[value], None)
for value_var in value_vars
]

if all(transform is None for transform in transforms):
return None

new_op = _create_transformed_rv_op(node.op, transforms)
trans_node = node.clone()
trans_node.op = new_op

# We now assume that the old value variable represents the *transformed space*.
# This means that we need to replace all instance of the old value variable
# with "inversely/un-" transformed versions of itself.
for rv_var, value_var, transform in zip(rv_vars, value_vars, transforms):
rv_var_out_idx = node.outputs.index(rv_var)
trans_node.outputs[rv_var_out_idx].name = rv_var.name

if transform is None:
continue

# We access the original value variable and apply the transform to that
original_value_var = rv_map_feature.original_values[value_var]
trans_original_value_var = transform.backward(original_value_var, *trans_node.inputs)

# We then replace the reference to the original value variable in the scan value
# variable by the back-transform projection computed above

# The first input corresponds to the original value variable. We are careful to
# only clone_replace that part of the graph, as we don't want to break the
# mappings between other rvs that are likely to be present in the rest of the
# scan value variable graph
# TODO: Is it true that the original value only appears in the first input
# and that no other RV can appear there?
(trans_original_value_var,) = clone_replace(
(value_var.owner.inputs[0],),
replace={original_value_var: trans_original_value_var},
)
trans_value_var = value_var.owner.clone_with_new_inputs(
inputs=[trans_original_value_var] + value_var.owner.inputs[1:]
).default_output()

new_value_var = transformed_variable(trans_value_var, original_value_var)

if value_var.name and getattr(transform, "name", None):
new_value_var.name = f"{value_var.name}_{transform.name}"

rv_map_feature.update_rv_maps(rv_var, new_value_var, trans_node.outputs[rv_var_out_idx])

return trans_node.outputs


class TransformValuesMapping(Feature):
r"""A `Feature` that maintains a map between value variables and their transforms."""

def __init__(self, values_to_transforms):
self.values_to_transforms = values_to_transforms
self.values_to_transforms = values_to_transforms.copy()

def on_attach(self, fgraph):
if hasattr(fgraph, "values_to_transforms"):
Expand All @@ -236,6 +320,7 @@ class TransformValuesRewrite(GraphRewriter):
r"""Transforms value variables according to a map and/or per-`RandomVariable` defaults."""

default_transform_rewrite = in2out(transform_values, ignore_newtrees=True)
scan_transform_rewrite = in2out(transform_scan_values, ignore_newtrees=True)

def __init__(
self,
Expand All @@ -261,7 +346,8 @@ def add_requirements(self, fgraph):
fgraph.attach_feature(values_transforms_feature)

def apply(self, fgraph: FunctionGraph):
return self.default_transform_rewrite.rewrite(fgraph)
self.default_transform_rewrite.rewrite(fgraph)
self.scan_transform_rewrite.rewrite(fgraph)


class MeasurableTransform(MeasurableElemwise):
Expand Down
53 changes: 53 additions & 0 deletions pymc/tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,3 +801,56 @@ def test_invalid_broadcasted_transform_rv_fails():
logp = joint_logprob({y_rv: y_vv})
logp.eval({y_vv: [0, 0, 0, 0], loc: [0, 0, 0, 0]})
assert False, "Should have failed before"


def test_scan_transform():
"""Test that Scan valued variables can be transformed"""

init = at.random.beta(1, 1, name="init")
init_vv = init.clone()

innov, _ = aesara.scan(
fn=lambda prev_innov: at.random.beta(prev_innov * 10, (1 - prev_innov) * 10),
outputs_info=[init],
n_steps=4,
)
innov.name = "innov"
innov_vv = innov.clone()

tr = TransformValuesRewrite(
{
init_vv: LogOddsTransform(),
innov_vv: LogOddsTransform(),
}
)
logp = factorized_joint_logprob(
{init: init_vv, innov: innov_vv}, extra_rewrites=tr, use_jacobian=True
)[innov_vv]
logp_fn = aesara.function([init_vv, innov_vv], logp, on_unused_input="ignore")

# Create an unrolled scan graph as reference
innov = []
prev_innov = init
for i in range(4):
next_innov = at.random.beta(prev_innov * 10, (1 - prev_innov) * 10, name=f"innov[i]")
innov.append(next_innov)
prev_innov = next_innov
innov = at.stack(innov)
innov.name = "innov"

tr = TransformValuesRewrite(
{
init_vv: LogOddsTransform(),
innov_vv: LogOddsTransform(),
}
)
ref_logp = factorized_joint_logprob(
{init: init_vv, innov: innov_vv}, extra_rewrites=tr, use_jacobian=True
)[innov_vv]
ref_logp_fn = aesara.function([init_vv, innov_vv], ref_logp, on_unused_input="ignore")

test_point = {
"init": np.array(-0.5),
"innov": np.full((4,), -0.5),
}
np.testing.assert_allclose(logp_fn(**test_point), ref_logp_fn(**test_point))

0 comments on commit d0c676c

Please sign in to comment.