Skip to content

Commit ae4c668

Browse files
authored
Extend rewriter to handle subgraphs (#2494)
This extends the rewriter to apply fusions in nested subgraphs as well. It is currently limited to patterns that completely lie within a single graph (with special exceptions for constants/variables). --------- Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent 3af04e9 commit ae4c668

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

onnxscript/rewriter/_matcher.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,16 @@ def _match_value(
183183
self, pattern_value: _pattern_ir.ValuePattern, value: ir.Value | None
184184
) -> bool:
185185
"""Match an IR value against a ValuePattern instance."""
186+
if value is not None and value.graph is not self._graph_or_function:
187+
if not isinstance(
188+
pattern_value, (_pattern_ir.Var, _pattern_ir.Constant, _pattern_ir.AnyValue)
189+
):
190+
# If the pattern value is a Var, Constant, or AnyValue, we allow it to match
191+
# values from other graphs. Otherwise, we fail the match.
192+
return self.fail(
193+
f"Value {value.name} is not in the graph {self._graph_or_function.name}. "
194+
f"Pattern matches crossing graph boundaries are not supported."
195+
)
186196
if isinstance(pattern_value, _pattern_ir.AnyValue):
187197
return True
188198

@@ -352,6 +362,7 @@ def match(
352362
complications which require careful consideration.
353363
"""
354364
self._tracer = tracer
365+
self._graph_or_function = graph_or_function[0].graph
355366
if self.pattern.has_single_output_node:
356367
self._init_match(verbose)
357368
return self._match_single_output_node(

onnxscript/rewriter/_rewrite_rule.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,18 @@ def _apply_to_graph_or_function(
736736
count += 1
737737
break
738738

739+
# Apply rewrite rules to subgraphs of the node.
740+
for attr in node.attributes.values():
741+
if attr.type == ir.AttributeType.GRAPH:
742+
count += self._apply_to_graph_or_function(
743+
model, attr.value, verbose=verbose, tracer=tracer
744+
)
745+
elif attr.type == ir.AttributeType.GRAPHS:
746+
for graph in attr.value:
747+
count += self._apply_to_graph_or_function(
748+
model, graph, verbose=verbose, tracer=tracer
749+
)
750+
739751
for rule in self.rules:
740752
if rule.graph_post_visitor:
741753
rule.graph_post_visitor()

0 commit comments

Comments
 (0)