Skip to content

Commit

Permalink
Automatically retrieve updates from OpFromGraph nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Apr 8, 2024
1 parent 937e5fd commit f0390d4
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 41 deletions.
19 changes: 3 additions & 16 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from pymc.printing import str_for_dist
from pymc.pytensorf import (
collect_default_updates,
collect_default_updates_inner_fgraph,
constant_fold,
convert_observed_data,
floatX,
Expand Down Expand Up @@ -300,14 +301,14 @@ def __init__(
kwargs.setdefault("inline", True)
super().__init__(*args, **kwargs)

def update(self, node: Node):
def update(self, node: Node) -> dict[Variable, Variable]:
"""Symbolic update expression for input random state variables
Returns a dictionary with the symbolic expressions required for correct updating
of random state input variables repeated function evaluations. This is used by
`pytensorf.compile_pymc`.
"""
return {}
return collect_default_updates_inner_fgraph(node)

def batch_ndim(self, node: Node) -> int:
"""Number of dimensions of the distribution's batch shape."""
Expand Down Expand Up @@ -705,20 +706,6 @@ class CustomSymbolicDistRV(SymbolicRandomVariable):

_print_name = ("CustomSymbolicDist", "\\operatorname{CustomSymbolicDist}")

def update(self, node: Node):
op = node.op
inner_updates = collect_default_updates(
inputs=op.inner_inputs, outputs=op.inner_outputs, must_be_shared=False
)

# Map inner updates to outer inputs/outputs
updates = {}
for rng, update in inner_updates.items():
inp_idx = op.inner_inputs.index(rng)
out_idx = op.inner_outputs.index(update)
updates[node.inputs[inp_idx]] = node.outputs[out_idx]
return updates


@_support_point.register(CustomSymbolicDistRV)
def dist_support_point(op, rv, *args):
Expand Down
30 changes: 28 additions & 2 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@

from pytensor import scalar
from pytensor.compile import Function, Mode, get_mode
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import grad
from pytensor.graph import Type, rewrite_graph
from pytensor.graph.basic import (
Apply,
Constant,
Node,
Variable,
clone_get_equiv,
graph_inputs,
Expand Down Expand Up @@ -781,6 +783,23 @@ def reseed_rngs(
rng.set_value(new_rng, borrow=True)


def collect_default_updates_inner_fgraph(node: Node) -> dict[Variable, Variable]:
"""Collect default updates from node with inner fgraph."""
op = node.op
inner_updates = collect_default_updates(
inputs=op.inner_inputs, outputs=op.inner_outputs, must_be_shared=False
)

# Map inner updates to outer inputs/outputs
updates = {}
for rng, update in inner_updates.items():
inp_idx = op.inner_inputs.index(rng)
out_idx = op.inner_outputs.index(update)
updates[node.inputs[inp_idx]] = node.outputs[out_idx]

return updates


def collect_default_updates(
outputs: Sequence[Variable],
*,
Expand Down Expand Up @@ -874,9 +893,16 @@ def find_default_update(clients, rng: Variable) -> None | Variable:
f"No update found for at least one RNG used in Scan Op {client.op}.\n"
"You can use `pytensorf.collect_default_updates` inside the Scan function to return updates automatically."
)
elif isinstance(client.op, OpFromGraph):
try:
next_rng = collect_default_updates_inner_fgraph(client)[rng]
except (ValueError, KeyError):
raise ValueError(
f"No update found for at least one RNG used in OpFromGraph Op {client.op}.\n"
"You can use `pytensorf.collect_default_updates` and include those updates as outputs."
)
else:
# We don't know how this RNG should be updated (e.g., OpFromGraph).
# The user should provide an update manually
# We don't know how this RNG should be updated. The user should provide an update manually
return None

# Recurse until we find final update for RNG
Expand Down
37 changes: 36 additions & 1 deletion tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from pymc.exceptions import BlockModelAccessError
from pymc.logprob.basic import conditional_logp, logcdf, logp
from pymc.model import Deterministic, Model
from pymc.pytensorf import collect_default_updates
from pymc.pytensorf import collect_default_updates, compile_pymc
from pymc.sampling import draw, sample
from pymc.testing import (
BaseTestDistributionRandom,
Expand Down Expand Up @@ -791,6 +791,41 @@ class TestInlinedSymbolicRV(SymbolicRandomVariable):
x_inline = TestInlinedSymbolicRV([], [Flat.dist()], ndim_supp=0)()
assert np.isclose(logp(x_inline, 0).eval(), 0)

def test_default_update(self):
"""Test SymbolicRandomVariable Op default to updates from inner graph."""

class SymbolicRVDefaultUpdates(SymbolicRandomVariable):
pass

class SymbolicRVCustomUpdates(SymbolicRandomVariable):
def update(self, node):
return {}

rng = pytensor.shared(np.random.default_rng())
dummy_rng = rng.type()
dummy_next_rng, dummy_x = pt.random.normal(rng=dummy_rng).owner.outputs

# Check that default updates work
next_rng, x = SymbolicRVDefaultUpdates(
inputs=[dummy_rng],
outputs=[dummy_next_rng, dummy_x],
ndim_supp=0,
)(rng)
fn = compile_pymc(inputs=[], outputs=x, random_seed=431)
assert fn() != fn()

# Check that custom updates are respected, by using one that's broken
next_rng, x = SymbolicRVCustomUpdates(
inputs=[dummy_rng],
outputs=[dummy_next_rng, dummy_x],
ndim_supp=0,
)(rng)
with pytest.raises(
ValueError,
match="No update found for at least one RNG used in SymbolicRandomVariable Op SymbolicRVCustomUpdates",
):
compile_pymc(inputs=[], outputs=x, random_seed=431)


def test_tag_future_warning_dist():
# Test no unexpected warnings
Expand Down
38 changes: 16 additions & 22 deletions tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,28 +408,6 @@ def test_compile_pymc_updates_inputs(self):
# Each RV adds a shared output for its rng
assert len(fn_fgraph.outputs) == 1 + rvs_in_graph

def test_compile_pymc_symbolic_rv_update(self):
"""Test that SymbolicRandomVariable Op update methods are used by compile_pymc"""

class NonSymbolicRV(OpFromGraph):
def update(self, node):
return {node.inputs[0]: node.outputs[0]}

rng = pytensor.shared(np.random.default_rng())
dummy_rng = rng.type()
dummy_next_rng, dummy_x = NonSymbolicRV(
[dummy_rng], pt.random.normal(rng=dummy_rng).owner.outputs
)(rng)

# Check that there are no updates at first
fn = compile_pymc(inputs=[], outputs=dummy_x)
assert fn() == fn()

# And they are enabled once the Op is registered as a SymbolicRV
SymbolicRandomVariable.register(NonSymbolicRV)
fn = compile_pymc(inputs=[], outputs=dummy_x, random_seed=431)
assert fn() != fn()

def test_compile_pymc_symbolic_rv_missing_update(self):
"""Test that error is raised if SymbolicRandomVariable Op does not
provide rule for updating RNG"""
Expand Down Expand Up @@ -588,6 +566,22 @@ def step_wo_update(x, rng):
fn = compile_pymc([], ys, random_seed=1)
assert not (set(fn()) & set(fn()))

def test_op_from_graph_updates(self):
rng = pytensor.shared(np.random.default_rng())
next_rng_, x_ = pt.random.normal(size=(10,), rng=rng).owner.outputs

x = OpFromGraph([], [x_])()
with pytest.raises(
ValueError,
match="No update found for at least one RNG used in OpFromGraph Op",
):
collect_default_updates([x])

next_rng, x = OpFromGraph([], [next_rng_, x_])()
assert collect_default_updates([x]) == {rng: next_rng}
fn = compile_pymc([], x, random_seed=1)
assert not (set(fn()) & set(fn()))


def test_replace_rng_nodes():
rng = pytensor.shared(np.random.default_rng())
Expand Down

0 comments on commit f0390d4

Please sign in to comment.