diff --git a/onnxscript/rewriter/_matcher.py b/onnxscript/rewriter/_matcher.py index 4993fe8232..a007926c37 100644 --- a/onnxscript/rewriter/_matcher.py +++ b/onnxscript/rewriter/_matcher.py @@ -183,6 +183,16 @@ def _match_value( self, pattern_value: _pattern_ir.ValuePattern, value: ir.Value | None ) -> bool: """Match an IR value against a ValuePattern instance.""" + if value is not None and value.graph is not self._graph_or_function: + if not isinstance( + pattern_value, (_pattern_ir.Var, _pattern_ir.Constant, _pattern_ir.AnyValue) + ): + # If the pattern value is a Var, Constant, or AnyValue, we allow it to match + # values from other graphs. Otherwise, we fail the match. + return self.fail( + f"Value {value.name} is not in the graph {self._graph_or_function.name}. " + f"Pattern matches crossing graph boundaries are not supported." + ) if isinstance(pattern_value, _pattern_ir.AnyValue): return True @@ -352,6 +362,7 @@ def match( complications which require careful consideration. """ self._tracer = tracer + self._graph_or_function = graph_or_function[0].graph if self.pattern.has_single_output_node: self._init_match(verbose) return self._match_single_output_node( diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index a2ec410e5b..9481ca5077 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -736,6 +736,18 @@ def _apply_to_graph_or_function( count += 1 break + # Apply rewrite rules to subgraphs of the node. + for attr in node.attributes.values(): + if attr.type == ir.AttributeType.GRAPH: + count += self._apply_to_graph_or_function( + model, attr.value, verbose=verbose, tracer=tracer + ) + elif attr.type == ir.AttributeType.GRAPHS: + for graph in attr.value: + count += self._apply_to_graph_or_function( + model, graph, verbose=verbose, tracer=tracer + ) + for rule in self.rules: if rule.graph_post_visitor: rule.graph_post_visitor()