From fab2ab0a04dbbcf317f879d407be79e1db3ed171 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 18 Jun 2025 12:42:03 -0700 Subject: [PATCH 1/2] Move gemm_to_matmul_add rule to ort fusion rules Signed-off-by: Justin Chu --- onnxscript/rewriter/__init__.py | 2 -- onnxscript/rewriter/ort_fusions/_core.py | 7 +++---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index cb0c1c70d6..fb7815bd1c 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -19,7 +19,6 @@ broadcast_to_matmul, cast_constant_of_shape, collapse_slices, - gemm_to_matmul_add, no_op, pattern, ) @@ -28,7 +27,6 @@ _DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = ( *no_op.rules.rules, # TODO: merge this rule into constant folding? *broadcast_to_matmul.rules.rules, - gemm_to_matmul_add.rule, # type: ignore[has-type] *cast_constant_of_shape.rules.rules, *collapse_slices.rules.rules, *basic_rules.basic_optimization_rules().rules, diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index e0d9331065..5fbdba42f9 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -7,9 +7,8 @@ import onnxscript.rewriter.ort_fusions.fused_matmul_rule_sets as fused_matmul_rule_sets import onnxscript.rewriter.ort_fusions.shape_optimization as shape_optimization from onnxscript.optimizer import optimize -from onnxscript.rewriter import rewrite +from onnxscript.rewriter import gemm_to_matmul_add, rewrite from onnxscript.rewriter.ort_fusions import ( - # group_normalization_merge_silu, instance_to_group_normalization, softmax, ) @@ -38,7 +37,7 @@ *instance_to_group_normalization.rules.rules, # NOTE: group normalization merge silu should be applied after instance to group normalization # *group_normalization_merge_silu.rules.rules, - *fused_matmul_rule_sets.fused_matmul_rule_sets().rules, + *fused_matmul_rule_sets.fused_matmul_rule_sets(), ] @@ -130,7 +129,7 @@ def optimize_for_ort( - The optimized `ir.Model` after applying transformer-specific fusions. - A dictionary with a count of each of the fusions applied. """ - + rewrite(model, [gemm_to_matmul_add.rule]) model, fusion_count = fuse_xformers( model, debug=debug, From 3eb6229d73baa960e3fceaaf23756aac62df6720 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 18 Jun 2025 12:46:15 -0700 Subject: [PATCH 2/2] update docs Signed-off-by: Justin Chu --- docs/tutorial/optimizer/optimize.md | 23 +++-------------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/docs/tutorial/optimizer/optimize.md b/docs/tutorial/optimizer/optimize.md index 5ceb7dfb80..8ff36f4c67 100644 --- a/docs/tutorial/optimizer/optimize.md +++ b/docs/tutorial/optimizer/optimize.md @@ -15,6 +15,7 @@ onnxscript.optimizer.optimize(model) ``` ### optimize API + The `onnxscript.optimizer.optimize` call takes in several optional parameters that allows the caller to further fine-tune the process of optimization. ```{eval-rst} @@ -24,12 +25,8 @@ The `onnxscript.optimizer.optimize` call takes in several optional parameters th ## Description of optimizations applied by `onnxscript.optimizer.optimize` -:::{table} -:widths: auto -:align: center - -| Optimization 'onnxscript.optimizer.` + .. | Description | -| - | - | +| Optimization | Description | +|-------------|-------------| | **Constant folding**
`constant_folding.fold_constants` | Applies constant folding optimization to the model. | | **Constant propagation**
`constant_folding.fold_constants` | Applies constant propagation optimization to the model. Applied as part of the constant folding optimization. | | **Sequence simplification**
`constant_folding.fold_constants` | Simplifies Sequence based ops (SequenceConstruct, ConcatFromSequence) present in the model. Applied as part of the constant folding optimization. | @@ -37,17 +34,3 @@ The `onnxscript.optimizer.optimize` call takes in several optional parameters th | **Remove unused functions**
`remove_unused_function.remove_unused_functions` | Removes unused function protos from the model. | | **Inline functions with unused outputs**
`simple_function_folding.inline_functions_with_unused_outputs` | Inlines function nodes that have unused outputs. | | **Inline simple functions**
`simple_function_folding.inline_simple_functions` | Inlines simple functions based on a node count threshold. | -::: - -## List of pattern rewrite rules applied by `onnxscript.optimizer.optimize` - -```{eval-rst} -.. autosummary:: - :nosignatures: - - onnxscript.rewriter.broadcast_to_matmul - onnxscript.rewriter.cast_constant_of_shape - onnxscript.rewriter.gemm_to_matmul_add - onnxscript.rewriter.no_op - -```