-
Notifications
You must be signed in to change notification settings - Fork 83
Expose match functionality of rewrite-rule by extracting base classes #2447
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
ba1535a
Initial plan
Copilot c0a7ab2
Implement PatternImpl and PatternBase base classes for standalone pat…
Copilot e280541
Add example demonstrating new pattern matching functionality
Copilot b8a63e2
Return match object instead of None when pattern doesn't match initially
Copilot b771457
Address review feedback: rename parameter and add utility match method
Copilot 4bef262
Simplify test assertion to use bool(match_result) for pattern matchin…
Copilot 1238b8e
Update onnxscript/rewriter/pattern_base_test.py
gramalingam 9c1bd34
Fix lint errors in pattern_base_test.py
Copilot acb6eff
Merge branch 'main' into copilot/fix-2446
titaiwangms 5f99316
Rename PatternImpl to CompiledPattern for better API naming
Copilot aa36d46
Remove create_compiled_pattern() method and store CompiledPattern int…
Copilot fadd8c6
Rename CompiledPattern to Pattern for better public API naming
Copilot f02c7d2
Run lintrunner
gramalingam e06f31a
Replace onnx.parser.parse_model with ir.from_onnx_text in tests
Copilot bd6eb6a
Fix PatternBase initialization order issue with lazy Pattern creation
Copilot File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
"""Example demonstrating the new pattern matching functionality.""" | ||
|
||
import onnx.parser | ||
|
||
|
||
from onnxscript import ir | ||
from onnxscript.rewriter import pattern | ||
|
||
|
||
def example_standalone_pattern_matching(): | ||
"""Example showing how to use Pattern for standalone pattern matching.""" | ||
|
||
print("=== Standalone Pattern Matching Example ===") | ||
|
||
# Define a pattern that matches Identity nodes | ||
def identity_pattern(op, x): | ||
return op.Identity(x) | ||
|
||
# Create a Pattern for standalone pattern matching (no replacement) | ||
pattern_matcher = pattern.Pattern(identity_pattern, name="IdentityMatcher") | ||
|
||
# Create a model with an Identity node | ||
model_proto = onnx.parser.parse_model( | ||
""" | ||
<ir_version: 7, opset_import: [ "" : 17]> | ||
agraph (float[N] x) => (float[N] z) | ||
{ | ||
z = Identity(x) | ||
} | ||
""" | ||
) | ||
model = ir.serde.deserialize_model(model_proto) | ||
|
||
# Find nodes to test pattern matching against | ||
for node in model.graph: | ||
print(f"Testing pattern against {node.op_type} node...") | ||
match_result = pattern_matcher.match(model, model.graph, node) | ||
|
||
if match_result is not None: | ||
print(f" ✓ Pattern matched! Found {len(match_result.nodes)} nodes in match.") | ||
print(f" Matched node: {match_result.nodes[0].op_type}") | ||
else: | ||
print(f" ✗ Pattern did not match {node.op_type} node.") | ||
|
||
|
||
def example_class_based_pattern(): | ||
"""Example showing how to use PatternBase for class-based pattern definition.""" | ||
|
||
print("\n=== Class-Based Pattern Example ===") | ||
|
||
class IdentityPatternClass(pattern.PatternBase): | ||
"""A class-based pattern that matches Identity nodes.""" | ||
|
||
def pattern(self, op, x): | ||
return op.Identity(x) | ||
|
||
def check(self, context, x): | ||
"""Custom condition - always succeeds for this example.""" | ||
print(f" Checking condition for input: {x}") | ||
return pattern.MatchResult() # Always succeeds | ||
|
||
# Create an instance of the pattern class | ||
identity_pattern_class = IdentityPatternClass(name="ClassBasedIdentity") | ||
|
||
# The Pattern is created internally, we can use the pattern directly | ||
print(f"Created pattern matcher: {identity_pattern_class.name}") | ||
|
||
# Use it directly with the match method | ||
model_proto = onnx.parser.parse_model( | ||
""" | ||
<ir_version: 7, opset_import: [ "" : 17]> | ||
agraph (float[N] x) => (float[N] z) | ||
{ | ||
z = Identity(x) | ||
} | ||
""" | ||
) | ||
model = ir.serde.deserialize_model(model_proto) | ||
|
||
for node in model.graph: | ||
if node.op_type == "Identity": | ||
print(f"Testing class-based pattern against {node.op_type} node...") | ||
match_result = identity_pattern_class.match(model, model.graph, node) | ||
|
||
if match_result is not None: | ||
print(" ✓ Class-based pattern matched!") | ||
else: | ||
print(" ✗ Class-based pattern did not match.") | ||
|
||
|
||
def example_rewrite_rule_still_works(): | ||
"""Example showing that existing RewriteRule functionality is preserved.""" | ||
|
||
print("\n=== Existing RewriteRule Still Works ===") | ||
|
||
def identity_pattern(op, x): | ||
return op.Identity(x) | ||
|
||
def identity_replacement(op, x): | ||
return op.Identity(x) # No-op replacement | ||
|
||
# Create a RewriteRule (which now inherits from Pattern) | ||
rule = pattern.RewriteRule(identity_pattern, identity_replacement, name="IdentityRule") | ||
|
||
print(f"Created rewrite rule: {rule.name}") | ||
print(f"Rule is also a Pattern: {isinstance(rule, pattern.Pattern)}") | ||
|
||
# The rule can be used both for pattern matching and rewriting | ||
model_proto = onnx.parser.parse_model( | ||
""" | ||
<ir_version: 7, opset_import: [ "" : 17]> | ||
agraph (float[N] x) => (float[N] z) | ||
{ | ||
z = Identity(x) | ||
} | ||
""" | ||
) | ||
model = ir.serde.deserialize_model(model_proto) | ||
|
||
# Use it for just pattern matching (inherited from Pattern) | ||
for node in model.graph: | ||
if node.op_type == "Identity": | ||
print(f"Using RewriteRule for pattern matching on {node.op_type}...") | ||
match_result = rule.match(model, model.graph, node) | ||
|
||
if match_result is not None: | ||
print(" ✓ RewriteRule matched as a pattern matcher!") | ||
|
||
# Use it for rewriting (original functionality) | ||
print("Using RewriteRule for rewriting...") | ||
count = rule.apply_to_model(model) | ||
print(f" Applied rule {count} times") | ||
|
||
|
||
if __name__ == "__main__": | ||
example_standalone_pattern_matching() | ||
example_class_based_pattern() | ||
example_rewrite_rule_still_works() | ||
print("\n=== All Examples Completed ===") |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.