You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Implement MatchContext class for rewriter pattern matching (#2455)
This PR introduces the `PatternMatchContext` class to provide context
information during pattern matching in the ONNX rewriter system.
## Changes Made
### Core Implementation
- **Added `PatternMatchContext` class** in
`onnxscript/rewriter/_basics.py` with read-only properties:
- `model`: The model being matched
- `graph_or_function`: The graph or function being matched
- `main_root_node`: The main root node of the matching subgraph
- `output_values`: The output values of the matching subgraph
- `nodes`: All nodes of the matching subgraph
- **Updated pattern matching logic** in
`onnxscript/rewriter/_rewrite_rule.py` at line 134 to create and pass
`PatternMatchContext` instances to condition functions
- **Exported the new class** in the rewriter module's `__all__` list for
external use
### Usage Example
```python
def condition_with_context(context, x, y):
# Access match context information
model = context.model
main_node = context.main_root_node
matched_nodes = context.nodes
outputs = context.output_values
# Use context for advanced pattern validation
if main_node.op_type == "Mul" and len(matched_nodes) > 1:
return True
return False
rule = pattern.RewriteRule(
target_pattern,
replacement_pattern,
condition_function=condition_with_context
)
```
### Testing
- **Comprehensive test suite** in
`onnxscript/rewriter/pattern_match_context_test.py` covering:
- Property access and type validation
- Read-only behavior enforcement
- Backward compatibility with existing condition functions
- Practical usage scenarios in real pattern matching
### Backward Compatibility
- All existing condition functions continue to work unchanged
- The `context` parameter is passed as the first argument, maintaining
the existing `**match.bindings` pattern
- No breaking changes to the existing API
## Validation
- All existing rewriter tests pass (39/39 tests in pattern-related
modules)
- New functionality validated with 4 comprehensive test cases
- Integration testing confirms proper context creation and usage
Fixes#2454.
<!-- START COPILOT CODING AGENT TIPS -->
---
💡 You can make Copilot smarter by setting up custom instructions,
customizing its development environment and configuring Model Context
Protocol (MCP) servers. Learn more [Copilot coding agent
tips](https://gh.io/copilot-coding-agent-tips) in the docs.
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: gramalingam <10075881+gramalingam@users.noreply.github.com>
# Using MatchContext for Advanced Condition Checking
55
+
56
+
The `context` parameter passed to condition functions is an instance of {py:class}`onnxscript.rewriter.MatchContext`, which provides access to additional information about the pattern match that can be useful for sophisticated condition checking.
57
+
58
+
## MatchContext Properties
59
+
60
+
The MatchContext provides the following read-only properties:
61
+
62
+
-`model`: The entire ONNX model being matched
63
+
-`graph_or_function`: The specific graph or function being matched
64
+
-`root`: The root node of the matching subgraph
65
+
-`output_values`: The output values of the matching subgraph
66
+
-`nodes`: All nodes that are part of the matching subgraph
67
+
68
+
## Example Usage
69
+
70
+
Here's an example showing how to use the MatchContext to implement more sophisticated condition checking:
71
+
72
+
```python
73
+
defadvanced_condition_check(context, x, y, **_):
74
+
"""Example condition function using MatchContext."""
75
+
76
+
# Access the main node of the pattern match
77
+
main_node = context.root
78
+
79
+
# Check that the main_node does not have an attribute called "alpha"
80
+
if"alpha"in main_node.attributes:
81
+
returnFalse
82
+
83
+
# Access the broader graph context and check that x occurs as a graph-input
84
+
model = context.model
85
+
if x notin model.graph.inputs:
86
+
returnFalse
87
+
88
+
# You can inspect the matched nodes for advanced validation
89
+
for node in context.nodes:
90
+
if node.op_type =="Constant":
91
+
# Check properties of constant nodes in the match
92
+
pass
93
+
94
+
# Access output values for shape/type validation
95
+
outputs = context.output_values
96
+
iflen(outputs) >0and outputs[0].shape isnotNone:
97
+
# Validate output shapes
98
+
pass
99
+
100
+
returnTrue
101
+
```
102
+
103
+
This context information enables condition functions to make decisions based on the broader graph structure, the specific nodes involved in the match, and relationships between matched patterns and the rest of the model.
0 commit comments