diff --git a/pymc/logprob/scan.py b/pymc/logprob/scan.py index 8626c20c68f..8f6942458ed 100644 --- a/pymc/logprob/scan.py +++ b/pymc/logprob/scan.py @@ -463,7 +463,9 @@ def find_measurable_scans(fgraph, node): # We must also replace any lingering references to the old RVs by the new measurable RVS # For example if we had measurable out1 = exp(normal()) and out2 = out1 - x # We need to replace references of original out1 by the new MeasurableExp(normal()) - inner_outs = node.op.inner_outputs.copy() + clone_fgraph = node.op.fgraph.clone() + inner_inps = clone_fgraph.inputs + inner_outs = clone_fgraph.outputs inner_rvs_replacements = [] for idx, new_inner_rv in zip(valued_output_idxs, inner_rvs, strict=True): old_inner_rv = inner_outs[idx] @@ -474,8 +476,7 @@ def find_measurable_scans(fgraph, node): clone=False, ) toposort_replace(temp_fgraph, inner_rvs_replacements) - inner_outs = temp_fgraph.outputs[: len(inner_outs)] - op = MeasurableScan(node.op.inner_inputs, inner_outs, node.op.info, mode=copy(node.op.mode)) + op = MeasurableScan(inner_inps, inner_outs, node.op.info, mode=copy(node.op.mode)) new_outs = op.make_node(*node.inputs).outputs old_outs = node.outputs diff --git a/tests/logprob/test_scan.py b/tests/logprob/test_scan.py index 381eed221d1..17fb198ca21 100644 --- a/tests/logprob/test_scan.py +++ b/tests/logprob/test_scan.py @@ -550,3 +550,25 @@ def test_scan_multiple_output_types(): test_value, [a + b for a, b in itertools.pairwise([1, 1, *test_value[:-1]])] ), ) + + +def test_generative_graph_unchanged(): + # Regression test where creating the IR would overwrite the original Scan inner fgraph + + def step(eps_tm1): + x = pt.random.normal(0, eps_tm1) + eps_t = x - 0 + return (x, eps_t), {x.owner.inputs[0]: x.owner.outputs[0]} + + [xs, _], update = pytensor.scan(step, outputs_info=[None, pt.ones(())], n_steps=5) + + before = xs.dprint(file="str") + + xs_value = np.ones(5) + expected_logp = stats.norm.logpdf(xs_value, 0, 1) + for i in range(2): + xs_logp = logp(xs, xs_value) + np.testing.assert_allclose(xs_logp.eval(), expected_logp) + + after = xs.dprint(file="str") + assert before == after