diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 512f0ef3ab..baf6b4e381 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -2032,7 +2032,7 @@ def compare_nodes(nd_x, nd_y, common, different): def get_var_by_name( - graphs: Iterable[Variable], target_var_id: str, ids: str = "CHAR" + graphs: Iterable[Variable], target_var_id: str, include_inner_graphs: bool = False ) -> tuple[Variable, ...]: r"""Get variables in a graph using their names. @@ -2057,7 +2057,7 @@ def expand(r) -> list[Variable] | None: res = list(r.owner.inputs) - if isinstance(r.owner.op, HasInnerGraph): + if include_inner_graphs and isinstance(r.owner.op, HasInnerGraph): res.extend(r.owner.op.inner_outputs) return res diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index 84ffb365b5..4eb9ba735a 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -628,6 +628,31 @@ def test_get_var_by_name(): assert res == exp_res +def test_get_var_by_name_include_inner_graphs_flag(): + r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) + o1 = MyOp(r1, r2) + o1.name = "o1" + + # Inner graph + igo_in_1 = MyVariable(4) + igo_in_2 = MyVariable(5) + igo_out_1 = MyOp(igo_in_1, igo_in_2) + igo_out_1.name = "igo1" + + igo = MyInnerGraphOp([igo_in_1, igo_in_2], [igo_out_1]) + o2 = igo(r3, o1) + + res = get_var_by_name([o1, o2], "igo1", include_inner_graphs=False) + assert ( + res == () + ), "Should not return inner graph variable when include_inner_graphs is False" + + res = get_var_by_name([o1, o2], "igo1", include_inner_graphs=True) + assert any( + v.name == "igo1" for v in res + ), "Should return inner graph variable when include_inner_graphs is True" + + def test_clone_new_inputs(): """Make sure that `Apply.clone_with_new_inputs` properly handles `Type` changes."""