|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | +"""ONNX Pattern Rewriting with domain specification |
| 4 | +
|
| 5 | +This script shows how to define a rewriting rule that targets operations |
| 6 | +from specific domains and replaces them with operations in other domains. |
| 7 | +""" |
| 8 | + |
| 9 | +import onnx |
| 10 | + |
| 11 | +import onnxscript |
| 12 | +from onnxscript import script |
| 13 | +from onnxscript.rewriter import pattern |
| 14 | +from onnxscript.values import Opset |
| 15 | + |
| 16 | +# Create an opset for the custom domain |
| 17 | +opset = Opset("custom.domain", 1) |
| 18 | + |
| 19 | + |
| 20 | +@script(opset) |
| 21 | +def create_model_with_custom_domain(input: onnxscript.FLOAT[2, 2]) -> onnxscript.FLOAT[2, 2]: |
| 22 | + """Create a model with a Relu operation in a custom domain.""" |
| 23 | + return opset.Relu(input) |
| 24 | + |
| 25 | + |
| 26 | +_model = create_model_with_custom_domain.to_model_proto() |
| 27 | +_model = onnx.shape_inference.infer_shapes(_model) |
| 28 | +onnx.checker.check_model(_model) |
| 29 | + |
| 30 | + |
| 31 | +#################################### |
| 32 | +# The target pattern |
| 33 | +# ===================== |
| 34 | + |
| 35 | + |
| 36 | +def custom_relu_pattern(op, input): |
| 37 | + # Pattern to match Relu operations from a specific domain |
| 38 | + # _domain="custom.domain" specifies we only want to match operations from this domain |
| 39 | + return op.Relu(input, _domain="custom.domain") |
| 40 | + |
| 41 | + |
| 42 | +#################################### |
| 43 | +# The replacement pattern |
| 44 | +# ===================== |
| 45 | + |
| 46 | + |
| 47 | +def standard_relu_replacement(op, input, **_): |
| 48 | + # Replace with standard ONNX Relu (default domain) |
| 49 | + return op.Relu(input) |
| 50 | + |
| 51 | + |
| 52 | +#################################### |
| 53 | +# Alternative: Replace with operation in different domain |
| 54 | +# ===================== |
| 55 | + |
| 56 | + |
| 57 | +def microsoft_relu_replacement(op, input, **_): |
| 58 | + # Replace with operation in Microsoft's domain |
| 59 | + return op.OptimizedRelu(input, _domain="com.microsoft") |
| 60 | + |
| 61 | + |
| 62 | +#################################### |
| 63 | +# Create Rewrite Rule and Apply to Model |
| 64 | +# ===================== |
| 65 | + |
| 66 | + |
| 67 | +def apply_rewrite(model): |
| 68 | + # Create rewrite rules |
| 69 | + relu_rule = pattern.RewriteRule( |
| 70 | + custom_relu_pattern, # target pattern - matches custom domain operations |
| 71 | + standard_relu_replacement, # replacement pattern - uses standard domain |
| 72 | + ) |
| 73 | + # Create a Rewrite Rule Set |
| 74 | + rewrite_rule_set = pattern.RewriteRuleSet([relu_rule]) |
| 75 | + # Apply rewrite |
| 76 | + model_with_rewrite = onnxscript.rewriter.rewrite( |
| 77 | + model, |
| 78 | + pattern_rewrite_rules=rewrite_rule_set, |
| 79 | + ) |
| 80 | + return model_with_rewrite |
| 81 | + |
| 82 | + |
| 83 | +# The rewrite rule will now match the Relu operation in the custom domain |
| 84 | +# and replace it with a standard ONNX Relu operation |
| 85 | +_model_with_rewrite = apply_rewrite(_model) |
| 86 | +onnx.checker.check_model(_model_with_rewrite) |
0 commit comments