Skip to content

Commit

Permalink
Do not mutate Scan inner graph when deriving logprob
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 14, 2024
1 parent e75cd73 commit 136dc1c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
7 changes: 4 additions & 3 deletions pymc/logprob/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,9 @@ def find_measurable_scans(fgraph, node):
# We must also replace any lingering references to the old RVs by the new measurable RVS
# For example if we had measurable out1 = exp(normal()) and out2 = out1 - x
# We need to replace references of original out1 by the new MeasurableExp(normal())
inner_outs = node.op.inner_outputs.copy()
clone_fgraph = node.op.fgraph.clone()
inner_inps = clone_fgraph.inputs
inner_outs = clone_fgraph.outputs
inner_rvs_replacements = []
for idx, new_inner_rv in zip(valued_output_idxs, inner_rvs, strict=True):
old_inner_rv = inner_outs[idx]
Expand All @@ -474,8 +476,7 @@ def find_measurable_scans(fgraph, node):
clone=False,
)
toposort_replace(temp_fgraph, inner_rvs_replacements)
inner_outs = temp_fgraph.outputs[: len(inner_outs)]
op = MeasurableScan(node.op.inner_inputs, inner_outs, node.op.info, mode=copy(node.op.mode))
op = MeasurableScan(inner_inps, inner_outs, node.op.info, mode=copy(node.op.mode))
new_outs = op.make_node(*node.inputs).outputs

old_outs = node.outputs
Expand Down
19 changes: 19 additions & 0 deletions tests/logprob/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,3 +550,22 @@ def test_scan_multiple_output_types():
test_value, [a + b for a, b in itertools.pairwise([1, 1, *test_value[:-1]])]
),
)


def test_generative_graph_unchanged():
# Regression test where creating the IR would overwrite the original Scan inner fgraph

def step(eps_tm1):
x = pt.random.normal(0, eps_tm1)
eps_t = x - 0
return (x, eps_t), {x.owner.inputs[0]: x.owner.outputs[0]}

[xs, _], update = pytensor.scan(step, outputs_info=[None, pt.ones(())], n_steps=5)
expected_logp = stats.norm.logpdf(1, 0, 1) * 5

before = xs.dprint(file="str")
for i in range(2):
xs_logp = logp(xs, pt.ones(5)).sum()
np.testing.assert_allclose(xs_logp.eval(), expected_logp)
after = xs.dprint(file="str")
assert before == after

0 comments on commit 136dc1c

Please sign in to comment.