diff --git a/onnxscript/rewriter/_matcher.py b/onnxscript/rewriter/_matcher.py index a007926c37..e8fc6a8a3d 100644 --- a/onnxscript/rewriter/_matcher.py +++ b/onnxscript/rewriter/_matcher.py @@ -174,6 +174,12 @@ def _match_node(self, pattern_node: _pattern_ir.NodePattern, node: ir.Node) -> b return False for i, output_value_pattern in enumerate(pattern_node.outputs): + # When trying to bind more outputs (from the pattern) than there are + # actual outputs of the candidate node, reject the node before even + # trying to index into the list of node outputs. + if i >= len(node.outputs): + return False + if not self._match.bind_value(output_value_pattern, node.outputs[i]): return False