diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index d706a0e4422..181ad72fc7a 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -61,6 +61,7 @@ from pymc.printing import str_for_dist from pymc.pytensorf import ( collect_default_updates, + collect_default_updates_inner_fgraph, constant_fold, convert_observed_data, floatX, @@ -300,14 +301,14 @@ def __init__( kwargs.setdefault("inline", True) super().__init__(*args, **kwargs) - def update(self, node: Node): + def update(self, node: Node) -> dict[Variable, Variable]: """Symbolic update expression for input random state variables Returns a dictionary with the symbolic expressions required for correct updating of random state input variables repeated function evaluations. This is used by `pytensorf.compile_pymc`. """ - return {} + return collect_default_updates_inner_fgraph(node) def batch_ndim(self, node: Node) -> int: """Number of dimensions of the distribution's batch shape.""" @@ -705,20 +706,6 @@ class CustomSymbolicDistRV(SymbolicRandomVariable): _print_name = ("CustomSymbolicDist", "\\operatorname{CustomSymbolicDist}") - def update(self, node: Node): - op = node.op - inner_updates = collect_default_updates( - inputs=op.inner_inputs, outputs=op.inner_outputs, must_be_shared=False - ) - - # Map inner updates to outer inputs/outputs - updates = {} - for rng, update in inner_updates.items(): - inp_idx = op.inner_inputs.index(rng) - out_idx = op.inner_outputs.index(update) - updates[node.inputs[inp_idx]] = node.outputs[out_idx] - return updates - @_support_point.register(CustomSymbolicDistRV) def dist_support_point(op, rv, *args): diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index a68cd64d8a3..4f48f53534b 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -23,11 +23,13 @@ from pytensor import scalar from pytensor.compile import Function, Mode, get_mode +from pytensor.compile.builders import OpFromGraph from pytensor.gradient import grad from pytensor.graph import Type, rewrite_graph from pytensor.graph.basic import ( Apply, Constant, + Node, Variable, clone_get_equiv, graph_inputs, @@ -781,6 +783,23 @@ def reseed_rngs( rng.set_value(new_rng, borrow=True) +def collect_default_updates_inner_fgraph(node: Node) -> dict[Variable, Variable]: + """Collect default updates from node with inner fgraph.""" + op = node.op + inner_updates = collect_default_updates( + inputs=op.inner_inputs, outputs=op.inner_outputs, must_be_shared=False + ) + + # Map inner updates to outer inputs/outputs + updates = {} + for rng, update in inner_updates.items(): + inp_idx = op.inner_inputs.index(rng) + out_idx = op.inner_outputs.index(update) + updates[node.inputs[inp_idx]] = node.outputs[out_idx] + + return updates + + def collect_default_updates( outputs: Sequence[Variable], *, @@ -874,9 +893,16 @@ def find_default_update(clients, rng: Variable) -> None | Variable: f"No update found for at least one RNG used in Scan Op {client.op}.\n" "You can use `pytensorf.collect_default_updates` inside the Scan function to return updates automatically." ) + elif isinstance(client.op, OpFromGraph): + try: + next_rng = collect_default_updates_inner_fgraph(client)[rng] + except (ValueError, KeyError): + raise ValueError( + f"No update found for at least one RNG used in OpFromGraph Op {client.op}.\n" + "You can use `pytensorf.collect_default_updates` and include those updates as outputs." + ) else: - # We don't know how this RNG should be updated (e.g., OpFromGraph). - # The user should provide an update manually + # We don't know how this RNG should be updated. The user should provide an update manually return None # Recurse until we find final update for RNG diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index bb43063be92..aa5808c00e9 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -52,7 +52,7 @@ from pymc.exceptions import BlockModelAccessError from pymc.logprob.basic import conditional_logp, logcdf, logp from pymc.model import Deterministic, Model -from pymc.pytensorf import collect_default_updates +from pymc.pytensorf import collect_default_updates, compile_pymc from pymc.sampling import draw, sample from pymc.testing import ( BaseTestDistributionRandom, @@ -791,6 +791,41 @@ class TestInlinedSymbolicRV(SymbolicRandomVariable): x_inline = TestInlinedSymbolicRV([], [Flat.dist()], ndim_supp=0)() assert np.isclose(logp(x_inline, 0).eval(), 0) + def test_default_update(self): + """Test SymbolicRandomVariable Op default to updates from inner graph.""" + + class SymbolicRVDefaultUpdates(SymbolicRandomVariable): + pass + + class SymbolicRVCustomUpdates(SymbolicRandomVariable): + def update(self, node): + return {} + + rng = pytensor.shared(np.random.default_rng()) + dummy_rng = rng.type() + dummy_next_rng, dummy_x = pt.random.normal(rng=dummy_rng).owner.outputs + + # Check that default updates work + next_rng, x = SymbolicRVDefaultUpdates( + inputs=[dummy_rng], + outputs=[dummy_next_rng, dummy_x], + ndim_supp=0, + )(rng) + fn = compile_pymc(inputs=[], outputs=x, random_seed=431) + assert fn() != fn() + + # Check that custom updates are respected, by using one that's broken + next_rng, x = SymbolicRVCustomUpdates( + inputs=[dummy_rng], + outputs=[dummy_next_rng, dummy_x], + ndim_supp=0, + )(rng) + with pytest.raises( + ValueError, + match="No update found for at least one RNG used in SymbolicRandomVariable Op SymbolicRVCustomUpdates", + ): + compile_pymc(inputs=[], outputs=x, random_seed=431) + def test_tag_future_warning_dist(): # Test no unexpected warnings diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index 9dcaaf94c38..bb294668eb5 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -408,28 +408,6 @@ def test_compile_pymc_updates_inputs(self): # Each RV adds a shared output for its rng assert len(fn_fgraph.outputs) == 1 + rvs_in_graph - def test_compile_pymc_symbolic_rv_update(self): - """Test that SymbolicRandomVariable Op update methods are used by compile_pymc""" - - class NonSymbolicRV(OpFromGraph): - def update(self, node): - return {node.inputs[0]: node.outputs[0]} - - rng = pytensor.shared(np.random.default_rng()) - dummy_rng = rng.type() - dummy_next_rng, dummy_x = NonSymbolicRV( - [dummy_rng], pt.random.normal(rng=dummy_rng).owner.outputs - )(rng) - - # Check that there are no updates at first - fn = compile_pymc(inputs=[], outputs=dummy_x) - assert fn() == fn() - - # And they are enabled once the Op is registered as a SymbolicRV - SymbolicRandomVariable.register(NonSymbolicRV) - fn = compile_pymc(inputs=[], outputs=dummy_x, random_seed=431) - assert fn() != fn() - def test_compile_pymc_symbolic_rv_missing_update(self): """Test that error is raised if SymbolicRandomVariable Op does not provide rule for updating RNG""" @@ -588,6 +566,22 @@ def step_wo_update(x, rng): fn = compile_pymc([], ys, random_seed=1) assert not (set(fn()) & set(fn())) + def test_op_from_graph_updates(self): + rng = pytensor.shared(np.random.default_rng()) + next_rng_, x_ = pt.random.normal(size=(10,), rng=rng).owner.outputs + + x = OpFromGraph([], [x_])() + with pytest.raises( + ValueError, + match="No update found for at least one RNG used in OpFromGraph Op", + ): + collect_default_updates([x]) + + next_rng, x = OpFromGraph([], [next_rng_, x_])() + assert collect_default_updates([x]) == {rng: next_rng} + fn = compile_pymc([], x, random_seed=1) + assert not (set(fn()) & set(fn())) + def test_replace_rng_nodes(): rng = pytensor.shared(np.random.default_rng())