Skip to content

Commit 75c1a4d

Browse files
Copilotjustinchubygramalingam
authored
[docs] Document rewriter pattern options (#2406)
This PR adds comprehensive documentation for the rewriter pattern options that were previously undocumented. The rewriter pattern system supports four key options for controlling pattern matching and replacement behavior: ## New Documentation Added ### `_allow_other_inputs` option - **File**: `docs/tutorial/rewriter/allow_other_inputs.md` - **Purpose**: Controls whether patterns can match nodes with additional inputs beyond those specified - **Default**: `False` (exact input matching) - **Example**: Matching `Conv` operations that may have optional bias inputs ```python def conv_pattern(op, input, weight): # Matches Conv with 2 or 3 inputs (weight + optional bias) return op.Conv(input, weight, _allow_other_inputs=True) ``` ### `_domain` option - **File**: `docs/tutorial/rewriter/domain_option.md` - **Purpose**: Specifies operator domains for pattern matching and replacement - **Use cases**: Domain-specific rewrites, migrating between operator domains - **Example**: Targeting operations from specific domains like "com.microsoft" ```python def custom_relu_pattern(op, input): # Only matches Relu from custom domain return op.Relu(input, _domain="custom.domain") ``` ### `_outputs` option - **File**: `docs/tutorial/rewriter/outputs_option.md` - **Purpose**: Specifies number and names of operation outputs - **Formats**: Integer count (`_outputs=2`) or named list (`_outputs=["first", "second"]`) - **Example**: Handling multi-output operations like `Split` ```python def split_pattern(op, input): # Matches Split operations with exactly 2 outputs return op.Split(input, num_outputs=2, axis=0, _outputs=2) ``` ### Enhanced `_allow_other_attributes` documentation - **File**: `docs/tutorial/rewriter/attributes.md` (improved formatting) - **Already documented**: Controls whether patterns match nodes with additional attributes - **Default**: `True` (allows extra attributes) ## Documentation Structure Improvements - Added "Pattern Options" section to main rewriter documentation - Integrated all option docs into the tutorial flow - Created working code examples for each option - Followed existing documentation patterns and style - All examples compile and run successfully - Documentation builds correctly with Sphinx The documentation now provides complete coverage of all rewriter pattern options with practical examples showing real-world usage patterns. Fixes #2405. > [!WARNING] > > <details> > <summary>Firewall rules blocked me from connecting to one or more addresses</summary> > > #### I tried to connect to the following addresses, but was blocked by firewall rules: > > - `docs.python.org` > - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns block) > - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 ` (dns block) > - `docs.scipy.org` > - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns block) > - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 ` (dns block) > - `matplotlib.org` > - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns block) > - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 ` (dns block) > - `numpy.org` > - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns block) > - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 ` (dns block) > - `onnx.ai` > - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns block) > - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 ` (dns block) > - `onnxruntime.ai` > - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns block) > - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 ` (dns block) > - `pytorch.org` > - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns block) > - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 ` (dns block) > > If you need me to access, download, or install something from one of these locations, you can either: > > - Configure [Actions setup steps](https://gh.io/copilot/actions-setup-steps) to set up my environment, which run before the firewall is enabled > - Add the appropriate URLs or hosts to my [firewall allow list](https://gh.io/copilot/firewall-config) > > </details> <!-- START COPILOT CODING AGENT TIPS --> --- 💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click [here](https://survey.alchemer.com/s3/8343779/Copilot-Coding-agent) to start the survey. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> Co-authored-by: gramalingam <10075881+gramalingam@users.noreply.github.com>
1 parent 38c4468 commit 75c1a4d

File tree

8 files changed

+362
-0
lines changed

8 files changed

+362
-0
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Specifying variable inputs in the pattern
2+
3+
This section demonstrates the use of the `_allow_other_inputs` option in pattern-based rewriting.
4+
The `_allow_other_inputs` option allows the pattern to match nodes that have additional inputs
5+
beyond those specified in the pattern. If it is set to `False` (the default), then the node must
6+
have exactly the specified inputs for a successful match. If set to `True`, the pattern will
7+
match nodes that have the specified inputs plus any number of additional inputs.
8+
9+
This is particularly useful when matching operations like `Conv` that can have optional inputs
10+
(such as bias), or when creating generic patterns that should work with various input configurations.
11+
12+
```{literalinclude} examples/allow_other_inputs.py
13+
:pyobject: conv_pattern
14+
```
15+
16+
```{literalinclude} examples/allow_other_inputs.py
17+
:pyobject: conv_replacement
18+
```
19+
20+
```{literalinclude} examples/allow_other_inputs.py
21+
:pyobject: apply_rewrite
22+
```
23+
24+
In this example, the pattern matches `Conv` operations with any number of inputs. A `Conv` operation
25+
might have 2 inputs (input and weight) or 3 inputs (input, weight, and bias). By setting
26+
`_allow_other_inputs=True`, our pattern will match both cases even though we only specify 2 inputs
27+
in the pattern definition.

docs/tutorial/rewriter/attributes.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ This section demonstrates the use of attribute values in pattern-based rewriting
44
First, write a target pattern and replacement pattern in a similar way to the previous examples.
55
The example pattern below will match successfully only against Dropout nodes with the
66
attribute value `training_mode` set to `False`.
7+
78
The `_allow_other_attributes` option allows the pattern to match nodes that have additional attributes
89
not specified in the pattern. If it is set to `False`, then the node must have only the specified
910
attribute values, and no other attributes, for a successful match. The default value for this
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Specifying domains in the pattern
2+
3+
This section demonstrates the use of the `_domain` option in pattern-based rewriting.
4+
The `_domain` option allows you to specify which operator domain the pattern should match against,
5+
and also allows you to create replacement operations in specific domains.
6+
7+
ONNX operators can belong to different domains:
8+
- The default ONNX domain (empty string or "ai.onnx")
9+
- Custom domains like "com.microsoft" for Microsoft-specific operations
10+
- User-defined domains for custom operations
11+
12+
## Matching operations from a specific domain
13+
14+
```{literalinclude} examples/domain_option.py
15+
:pyobject: custom_relu_pattern
16+
```
17+
18+
In this pattern, `_domain="custom.domain"` ensures that only `Relu` operations from the
19+
"custom.domain" domain will be matched, not standard ONNX `Relu` operations.
20+
21+
## Creating replacement operations in a specific domain
22+
23+
```{literalinclude} examples/domain_option.py
24+
:pyobject: microsoft_relu_replacement
25+
```
26+
27+
Here, the replacement operation is created in the "com.microsoft" domain, which might
28+
provide optimized implementations of standard operations.
29+
30+
## Complete rewrite example
31+
32+
```{literalinclude} examples/domain_option.py
33+
:pyobject: apply_rewrite
34+
```
35+
36+
This example shows how domain-specific pattern matching can be used to migrate operations
37+
between different operator domains, such as replacing custom domain operations with
38+
standard ONNX operations or vice versa.
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""ONNX Pattern Rewriting with variable number of inputs
4+
5+
This script shows how to define a rewriting rule based on patterns that
6+
can match nodes with additional inputs beyond those specified in the pattern.
7+
"""
8+
9+
import onnx
10+
11+
import onnxscript
12+
from onnxscript import FLOAT, opset18, script
13+
from onnxscript.rewriter import pattern
14+
15+
16+
@script()
17+
def original_model(A: FLOAT[2, 2], B: FLOAT[2, 2], C: FLOAT[2, 2]) -> FLOAT[2, 2]:
18+
# Conv with bias - has 3 inputs: input, weight, bias
19+
result = opset18.Conv(A, B, C)
20+
return result
21+
22+
23+
_model = original_model.to_model_proto()
24+
onnx.checker.check_model(_model)
25+
26+
27+
####################################
28+
# The target pattern
29+
# =====================
30+
31+
32+
def conv_pattern(op, input, weight):
33+
# Pattern to match Conv operations, allowing additional inputs like bias
34+
# _allow_other_inputs=True allows the pattern to match Conv with bias (3 inputs)
35+
# even though we only specify 2 inputs in the pattern
36+
return op.Conv(input, weight, _allow_other_inputs=True)
37+
38+
39+
####################################
40+
# The replacement pattern
41+
# =====================
42+
43+
44+
def conv_replacement(op, input, weight, **_):
45+
# Replace with a custom operation in a different domain
46+
return op.OptimizedConv(input, weight, _domain="custom.domain")
47+
48+
49+
####################################
50+
# Create Rewrite Rule and Apply to Model
51+
# =====================
52+
53+
54+
def apply_rewrite(model):
55+
# Create rewrite rules
56+
conv_rule = pattern.RewriteRule(
57+
conv_pattern, # target pattern
58+
conv_replacement, # replacement pattern
59+
)
60+
# Create a Rewrite Rule Set
61+
rewrite_rule_set = pattern.RewriteRuleSet([conv_rule])
62+
# Apply rewrite
63+
model_with_rewrite = onnxscript.rewriter.rewrite(
64+
model,
65+
pattern_rewrite_rules=rewrite_rule_set,
66+
)
67+
return model_with_rewrite
68+
69+
70+
_model_with_rewrite = apply_rewrite(_model)
71+
onnx.checker.check_model(_model_with_rewrite)
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""ONNX Pattern Rewriting with domain specification
4+
5+
This script shows how to define a rewriting rule that targets operations
6+
from specific domains and replaces them with operations in other domains.
7+
"""
8+
9+
import onnx
10+
11+
import onnxscript
12+
from onnxscript import script
13+
from onnxscript.rewriter import pattern
14+
from onnxscript.values import Opset
15+
16+
# Create an opset for the custom domain
17+
opset = Opset("custom.domain", 1)
18+
19+
20+
@script(opset)
21+
def create_model_with_custom_domain(input: onnxscript.FLOAT[2, 2]) -> onnxscript.FLOAT[2, 2]:
22+
"""Create a model with a Relu operation in a custom domain."""
23+
return opset.Relu(input)
24+
25+
26+
_model = create_model_with_custom_domain.to_model_proto()
27+
_model = onnx.shape_inference.infer_shapes(_model)
28+
onnx.checker.check_model(_model)
29+
30+
31+
####################################
32+
# The target pattern
33+
# =====================
34+
35+
36+
def custom_relu_pattern(op, input):
37+
# Pattern to match Relu operations from a specific domain
38+
# _domain="custom.domain" specifies we only want to match operations from this domain
39+
return op.Relu(input, _domain="custom.domain")
40+
41+
42+
####################################
43+
# The replacement pattern
44+
# =====================
45+
46+
47+
def standard_relu_replacement(op, input, **_):
48+
# Replace with standard ONNX Relu (default domain)
49+
return op.Relu(input)
50+
51+
52+
####################################
53+
# Alternative: Replace with operation in different domain
54+
# =====================
55+
56+
57+
def microsoft_relu_replacement(op, input, **_):
58+
# Replace with operation in Microsoft's domain
59+
return op.OptimizedRelu(input, _domain="com.microsoft")
60+
61+
62+
####################################
63+
# Create Rewrite Rule and Apply to Model
64+
# =====================
65+
66+
67+
def apply_rewrite(model):
68+
# Create rewrite rules
69+
relu_rule = pattern.RewriteRule(
70+
custom_relu_pattern, # target pattern - matches custom domain operations
71+
standard_relu_replacement, # replacement pattern - uses standard domain
72+
)
73+
# Create a Rewrite Rule Set
74+
rewrite_rule_set = pattern.RewriteRuleSet([relu_rule])
75+
# Apply rewrite
76+
model_with_rewrite = onnxscript.rewriter.rewrite(
77+
model,
78+
pattern_rewrite_rules=rewrite_rule_set,
79+
)
80+
return model_with_rewrite
81+
82+
83+
# The rewrite rule will now match the Relu operation in the custom domain
84+
# and replace it with a standard ONNX Relu operation
85+
_model_with_rewrite = apply_rewrite(_model)
86+
onnx.checker.check_model(_model_with_rewrite)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""ONNX Pattern Rewriting with output specification
4+
5+
This script shows how to define a rewriting rule that specifies
6+
the number and names of outputs from operations.
7+
"""
8+
9+
import onnx
10+
11+
import onnxscript
12+
from onnxscript import FLOAT, opset18, script
13+
from onnxscript.rewriter import pattern
14+
15+
16+
@script()
17+
def original_model(A: FLOAT[4, 4]) -> FLOAT[2, 4]:
18+
# Split operation that produces 2 outputs
19+
result1, _result2 = opset18.Split(A, num_outputs=2, axis=0)
20+
# We only return the first output for simplicity
21+
return result1
22+
23+
24+
_model = original_model.to_model_proto()
25+
onnx.checker.check_model(_model)
26+
27+
28+
####################################
29+
# The target pattern with multiple outputs
30+
# =====================
31+
32+
33+
def split_pattern(op, input):
34+
# Pattern to match Split operations with 2 outputs
35+
# num_outputs=2 corresponds to the attribute of the ONNX Split op
36+
# _outputs=2 is an option controlling the pattern constructor
37+
return op.Split(input, num_outputs=2, axis=0, _outputs=2)
38+
39+
40+
####################################
41+
# The replacement pattern with named outputs
42+
# =====================
43+
44+
45+
def custom_split_replacement(op, input, **_):
46+
# Replace with a custom split operation using named outputs
47+
# _outputs=["first_half", "second_half"] assigns names to the outputs
48+
# IMPORTANT: The number of outputs must match the pattern (2 outputs)
49+
return op.CustomSplit(
50+
input, _domain="custom.domain", _outputs=["first_half", "second_half"]
51+
)
52+
53+
54+
####################################
55+
# Create Rewrite Rule and Apply to Model
56+
# =====================
57+
58+
59+
def apply_rewrite(model):
60+
# Create rewrite rules
61+
split_rule = pattern.RewriteRule(
62+
split_pattern, # target pattern - matches Split with 2 outputs
63+
custom_split_replacement, # replacement pattern - uses named outputs
64+
)
65+
# Create a Rewrite Rule Set
66+
rewrite_rule_set = pattern.RewriteRuleSet([split_rule])
67+
# Apply rewrite
68+
model_with_rewrite = onnxscript.rewriter.rewrite(
69+
model,
70+
pattern_rewrite_rules=rewrite_rule_set,
71+
)
72+
return model_with_rewrite
73+
74+
75+
_model_with_rewrite = apply_rewrite(_model)
76+
onnx.checker.check_model(_model_with_rewrite)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Specifying outputs in the pattern
2+
3+
This section demonstrates the use of the `_outputs` option in pattern-based rewriting.
4+
The `_outputs` option allows you to specify the number of outputs an operation produces
5+
and optionally assign names to those outputs for easier reference in replacement patterns.
6+
7+
The `_outputs` option can be specified in two ways:
8+
- As an integer: `_outputs=2` specifies that the operation produces 2 unnamed outputs
9+
- As a list of strings/None: `_outputs=["first", "second"]` specifies 2 named outputs
10+
11+
## Matching operations with multiple outputs
12+
13+
```{literalinclude} examples/outputs_option.py
14+
:pyobject: split_pattern
15+
```
16+
17+
This pattern matches `Split` operations that produce exactly 2 outputs. The `_outputs=2`
18+
specification ensures the pattern only matches operations with this specific output count.
19+
20+
## Creating replacement operations with named outputs
21+
22+
```{literalinclude} examples/outputs_option.py
23+
:pyobject: custom_split_replacement
24+
```
25+
26+
In the replacement, `_outputs=["first_half", "second_half"]` creates two outputs with
27+
descriptive names. This can make the replacement pattern more readable and maintainable.
28+
29+
**Important**: The number of outputs in the replacement pattern must match the number of
30+
outputs in the target pattern. Since the pattern specifies `_outputs=2`, the replacement
31+
must also produce exactly 2 outputs.
32+
33+
## Complete rewrite example
34+
35+
```{literalinclude} examples/outputs_option.py
36+
:pyobject: apply_rewrite
37+
```
38+
39+
The `_outputs` option is particularly important when:
40+
- Working with operations that have variable numbers of outputs (like `Split`)
41+
- Creating custom operations that need specific output configurations
42+
- Ensuring pattern matching precision by specifying exact output counts
43+
- Improving code readability by naming outputs in replacement patterns

docs/tutorial/rewriter/rewrite_patterns.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,32 @@ There are three main components needed when rewriting patterns in the graph:
1010
2. `replacement_pattern` : Pattern to replace the original pattern with. This pattern is also written as a function using ONNXScript-like operators.
1111
3. `match_condition` (optional) : Pattern rewrite will occur only if the match condition is satisfied.
1212

13+
## Pattern Options
14+
15+
When defining patterns, you can use several special options to control how patterns match and what they produce:
16+
17+
- `_allow_other_attributes`: Controls whether the pattern allows additional attributes not specified in the pattern (default: True)
18+
- `_allow_other_inputs`: Controls whether the pattern allows additional inputs beyond those specified (default: False)
19+
- `_domain`: Specifies the operator domain for matching or creating operations
20+
- `_outputs`: Specifies the number and optionally names of outputs from an operation
21+
22+
These options are documented in detail in the following sections.
23+
1324
```{include} simple_example.md
1425
```
1526

1627
```{include} attributes.md
1728
```
1829

30+
```{include} allow_other_inputs.md
31+
```
32+
33+
```{include} domain_option.md
34+
```
35+
36+
```{include} outputs_option.md
37+
```
38+
1939
```{include} conditional_rewrite.md
2040
```
2141

0 commit comments

Comments
 (0)