From f46a2c9b9bbd6e0c97f653570b4c7b863a900c53 Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Fri, 22 Aug 2025 13:10:07 +0200 Subject: [PATCH 1/2] [Rewriter] Prevent out of range when matching node outputs Trying to bind more outputs (from the pattern) than there are actual outputs of the candidate node now simply rejects the node before even trying to index into the list of node outputs. Signed-off-by: Christoph Berganski --- onnxscript/rewriter/_matcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/_matcher.py b/onnxscript/rewriter/_matcher.py index a007926c37..fe32ed1a1e 100644 --- a/onnxscript/rewriter/_matcher.py +++ b/onnxscript/rewriter/_matcher.py @@ -174,7 +174,7 @@ def _match_node(self, pattern_node: _pattern_ir.NodePattern, node: ir.Node) -> b return False for i, output_value_pattern in enumerate(pattern_node.outputs): - if not self._match.bind_value(output_value_pattern, node.outputs[i]): + if i >= len(node.outputs) or not self._match.bind_value(output_value_pattern, node.outputs[i]): return False return True From a79f117e762b43b519170f252e9ecbb1d3db94c3 Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Tue, 26 Aug 2025 15:40:39 +0200 Subject: [PATCH 2/2] [Rewriter] Clarify rejecting nodes due to index out of range Signed-off-by: Christoph Berganski --- onnxscript/rewriter/_matcher.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/onnxscript/rewriter/_matcher.py b/onnxscript/rewriter/_matcher.py index fe32ed1a1e..e8fc6a8a3d 100644 --- a/onnxscript/rewriter/_matcher.py +++ b/onnxscript/rewriter/_matcher.py @@ -174,7 +174,13 @@ def _match_node(self, pattern_node: _pattern_ir.NodePattern, node: ir.Node) -> b return False for i, output_value_pattern in enumerate(pattern_node.outputs): - if i >= len(node.outputs) or not self._match.bind_value(output_value_pattern, node.outputs[i]): + # When trying to bind more outputs (from the pattern) than there are + # actual outputs of the candidate node, reject the node before even + # trying to index into the list of node outputs. + if i >= len(node.outputs): + return False + + if not self._match.bind_value(output_value_pattern, node.outputs[i]): return False return True