diff --git a/docs/tutorial/optimizer/optimize.md b/docs/tutorial/optimizer/optimize.md
index 5ceb7dfb80..8ff36f4c67 100644
--- a/docs/tutorial/optimizer/optimize.md
+++ b/docs/tutorial/optimizer/optimize.md
@@ -15,6 +15,7 @@ onnxscript.optimizer.optimize(model)
```
### optimize API
+
The `onnxscript.optimizer.optimize` call takes in several optional parameters that allows the caller to further fine-tune the process of optimization.
```{eval-rst}
@@ -24,12 +25,8 @@ The `onnxscript.optimizer.optimize` call takes in several optional parameters th
## Description of optimizations applied by `onnxscript.optimizer.optimize`
-:::{table}
-:widths: auto
-:align: center
-
-| Optimization 'onnxscript.optimizer.` + .. | Description |
-| - | - |
+| Optimization | Description |
+|-------------|-------------|
| **Constant folding**
`constant_folding.fold_constants` | Applies constant folding optimization to the model. |
| **Constant propagation**
`constant_folding.fold_constants` | Applies constant propagation optimization to the model. Applied as part of the constant folding optimization. |
| **Sequence simplification**
`constant_folding.fold_constants` | Simplifies Sequence based ops (SequenceConstruct, ConcatFromSequence) present in the model. Applied as part of the constant folding optimization. |
@@ -37,17 +34,3 @@ The `onnxscript.optimizer.optimize` call takes in several optional parameters th
| **Remove unused functions**
`remove_unused_function.remove_unused_functions` | Removes unused function protos from the model. |
| **Inline functions with unused outputs**
`simple_function_folding.inline_functions_with_unused_outputs` | Inlines function nodes that have unused outputs. |
| **Inline simple functions**
`simple_function_folding.inline_simple_functions` | Inlines simple functions based on a node count threshold. |
-:::
-
-## List of pattern rewrite rules applied by `onnxscript.optimizer.optimize`
-
-```{eval-rst}
-.. autosummary::
- :nosignatures:
-
- onnxscript.rewriter.broadcast_to_matmul
- onnxscript.rewriter.cast_constant_of_shape
- onnxscript.rewriter.gemm_to_matmul_add
- onnxscript.rewriter.no_op
-
-```
diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py
index cb0c1c70d6..fb7815bd1c 100644
--- a/onnxscript/rewriter/__init__.py
+++ b/onnxscript/rewriter/__init__.py
@@ -19,7 +19,6 @@
broadcast_to_matmul,
cast_constant_of_shape,
collapse_slices,
- gemm_to_matmul_add,
no_op,
pattern,
)
@@ -28,7 +27,6 @@
_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = (
*no_op.rules.rules, # TODO: merge this rule into constant folding?
*broadcast_to_matmul.rules.rules,
- gemm_to_matmul_add.rule, # type: ignore[has-type]
*cast_constant_of_shape.rules.rules,
*collapse_slices.rules.rules,
*basic_rules.basic_optimization_rules().rules,
diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py
index e0d9331065..5fbdba42f9 100644
--- a/onnxscript/rewriter/ort_fusions/_core.py
+++ b/onnxscript/rewriter/ort_fusions/_core.py
@@ -7,9 +7,8 @@
import onnxscript.rewriter.ort_fusions.fused_matmul_rule_sets as fused_matmul_rule_sets
import onnxscript.rewriter.ort_fusions.shape_optimization as shape_optimization
from onnxscript.optimizer import optimize
-from onnxscript.rewriter import rewrite
+from onnxscript.rewriter import gemm_to_matmul_add, rewrite
from onnxscript.rewriter.ort_fusions import (
- # group_normalization_merge_silu,
instance_to_group_normalization,
softmax,
)
@@ -38,7 +37,7 @@
*instance_to_group_normalization.rules.rules,
# NOTE: group normalization merge silu should be applied after instance to group normalization
# *group_normalization_merge_silu.rules.rules,
- *fused_matmul_rule_sets.fused_matmul_rule_sets().rules,
+ *fused_matmul_rule_sets.fused_matmul_rule_sets(),
]
@@ -130,7 +129,7 @@ def optimize_for_ort(
- The optimized `ir.Model` after applying transformer-specific fusions.
- A dictionary with a count of each of the fusions applied.
"""
-
+ rewrite(model, [gemm_to_matmul_add.rule])
model, fusion_count = fuse_xformers(
model,
debug=debug,