diff --git a/README.md b/README.md index adfc3238d0..ec3ce7bcc8 100644 --- a/README.md +++ b/README.md @@ -162,7 +162,7 @@ import onnxscript onnxscript.optimizer.optimize(onnx_model) ``` -For a detailed summary of all the optimizations applied by the optimizer call, refer to the tutorial [Optimizing a Model using the Optimizer](https://onnxscript.ai/tutorial/optimizer/optimize.html) +For a detailed summary of all the optimizations applied by the optimizer call, refer to the tutorial [Optimizing a Model using the Optimizer](https://microsoft.github.io/onnxscript/tutorial/optimizer/optimize.html) ### ONNX Rewriter @@ -205,11 +205,7 @@ model_with_rewrite_applied = onnxscript.rewriter.rewrite( return model_with_rewrite_applied ``` -For a detailed tutorial on how to create target_pattern, replacement_pattern and match_condition blocks in order to utilize the pattern-based rewriter, refer to the tutorial [Pattern-based Rewrite Using Rules](https://onnxscript.ai/tutorial/rewriter/rewrite_patterns.html) - -### Function-based rewriting - -This style of rewriting matches a `FUNCTION_KEYWORD` and `PACKAGE_NAME` provided by the user to an existing function within the graph and replaces it with a new function provided by the user. +For a detailed tutorial on how to create target_pattern, replacement_pattern and match_condition blocks in order to utilize the pattern-based rewriter, refer to the tutorial [Pattern-based Rewrite Using Rules](https://microsoft.github.io/onnxscript/tutorial/rewriter/rewrite_patterns.html) ## Development Guidelines diff --git a/docs/tutorial/rewriter/simple_example.md b/docs/tutorial/rewriter/simple_example.md index 942f0ad48f..2da32f958d 100644 --- a/docs/tutorial/rewriter/simple_example.md +++ b/docs/tutorial/rewriter/simple_example.md @@ -49,8 +49,8 @@ rule = pattern.RewriteRule( Now that the rewrite rule has been created, the next step is to apply these pattern-based rewrite rules. The `rewriter.rewrite` call consists of three main components: 1. `model` : The original model on which the pattern rewrite rules are to be applied. This is of type `onnx.ModelProto`. -2. `function_rewrite_rules` : `(Optional)` This parameter is used to pass rewrite rules based on function names. Steps on how to use this parameter will be covered in a different tutorial. This parameter is of type `Sequence[type[FunctionRewriteRule]]` -3. `pattern_rewrite_rules` : `(Optional)` This parameter is used to pass rewrite rules based on a provided replacement pattern. For the purpose of this tutorial, we will be using only this parameter in conjunction with `model`. This parameter is of either one of these types: + +2. `pattern_rewrite_rules` : `(Optional)` This parameter is used to pass rewrite rules based on a provided replacement pattern. This parameter is of either one of these types: - `Sequence[PatternRewriteRule]` - `RewriteRuleSet` diff --git a/onnxscript/rewriter/onnxruntime/__init__.py b/onnxscript/rewriter/onnxruntime/__init__.py index d6510f8a93..5069b65457 100644 --- a/onnxscript/rewriter/onnxruntime/__init__.py +++ b/onnxscript/rewriter/onnxruntime/__init__.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import Any +from typing import Any, Sequence import onnx @@ -25,15 +25,12 @@ def rewrite( model_proto: onnx.ModelProto, /, - function_rules=None, - pattern_rules: list[pattern.RewriteRule] | None = None, + pattern_rules: Sequence[pattern.RewriteRule] | None = None, ) -> onnx.ModelProto: """Rewrite the model using the given rules. Args: model_proto: The model to rewrite. - function_rules: The function rewrite rules to apply. If None, the default rules - for onnxruntime are used. pattern_rules: The pattern rewrite rules to apply. If None, the default rules for onnxruntime are used. diff --git a/tools/ort_rewriter_profiling/README.md b/tools/ort_rewriter_profiling/README.md index 66f3af36bd..eefeef644e 100644 --- a/tools/ort_rewriter_profiling/README.md +++ b/tools/ort_rewriter_profiling/README.md @@ -127,14 +127,13 @@ 5. Develop optimization code. - `onnx-script/onnxscript/optimizer`: Optimizations such as constant folding, inlining, dead code elimination etc. - `onnx-script/onnxscript/rewriter`: Pattern based fusions. - - `onnx-script/onnxscript/rewriter/onnxruntime`: Onnxruntime specific pattern based fusions. - - `onnx-script/onnxscript/rewriter/onnxruntime/transformers`: Onnxruntime specific function based fusions. + - `onnx-script/onnxscript/rewriter/ort_fusions`: Onnxruntime specific pattern based fusions. - Use function unittest producer tool to create function fusion unittest. Example command to distill 4 unittests for function `LlamaSdpaAttention` from `llama_v2_7b` `dynamo` model. The unittest models are named with prefix `sdpa_llama2`: ``` - # Under onnx-script/onnxscript/rewriter/transformers - CUDA_VISIBLE_DEVICES="3" python tools/function_unittest_producer.py --model-path ../../../tools/onnx_models/llama_v2_7b_16h/dynamo_ort_rewritten/llama_v2_7b_16h_dynamo_ort_rewritten.onnx --function LlamaSdpaAttention --output-dir ../../testing/rewriter/transformers/unittest_models/ --max-outputs 4 --name sdpa_llama2 + # Under onnx-script/onnxscript/rewriter + CUDA_VISIBLE_DEVICES="3" python tools/function_unittest_producer.py --model-path ../../../tools/onnx_models/llama_v2_7b_16h/dynamo_ort_rewritten/llama_v2_7b_16h_dynamo_ort_rewritten.onnx --function LlamaSdpaAttention --output-dir ../../testing/rewriter/unittest_models/ --max-outputs 4 --name sdpa_llama2 ``` - - Create new testcase under `onnx-script/onnxscript/rewriter/transformers` with the generated unittest models. + - Create new testcase under `onnx-script/onnxscript/rewriter/ort_fusions` with the generated unittest models. ```python def test_sdpa_llama2(self): common.test_function_rewrite("sdpa_llama2", 4)