Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 3 additions & 20 deletions docs/tutorial/optimizer/optimize.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ onnxscript.optimizer.optimize(model)
```

### optimize API

The `onnxscript.optimizer.optimize` call takes in several optional parameters that allows the caller to further fine-tune the process of optimization.

```{eval-rst}
Expand All @@ -24,30 +25,12 @@ The `onnxscript.optimizer.optimize` call takes in several optional parameters th

## Description of optimizations applied by `onnxscript.optimizer.optimize`

:::{table}
:widths: auto
:align: center

| Optimization 'onnxscript.optimizer.` + .. | Description |
| - | - |
| Optimization | Description |
|-------------|-------------|
| **Constant folding** <br>`constant_folding.fold_constants` | Applies constant folding optimization to the model. |
| **Constant propagation** <br>`constant_folding.fold_constants` | Applies constant propagation optimization to the model. Applied as part of the constant folding optimization. |
| **Sequence simplification** <br>`constant_folding.fold_constants` | Simplifies Sequence based ops (SequenceConstruct, ConcatFromSequence) present in the model. Applied as part of the constant folding optimization. |
| **Remove unused nodes** <br>`remove_unused.remove_unused_nodes` | Removes unused nodes from the model. |
| **Remove unused functions** <br>`remove_unused_function.remove_unused_functions` | Removes unused function protos from the model. |
| **Inline functions with unused outputs** <br>`simple_function_folding.inline_functions_with_unused_outputs` | Inlines function nodes that have unused outputs. |
| **Inline simple functions** <br>`simple_function_folding.inline_simple_functions` | Inlines simple functions based on a node count threshold. |
:::

## List of pattern rewrite rules applied by `onnxscript.optimizer.optimize`

```{eval-rst}
.. autosummary::
:nosignatures:

onnxscript.rewriter.broadcast_to_matmul
onnxscript.rewriter.cast_constant_of_shape
onnxscript.rewriter.gemm_to_matmul_add
onnxscript.rewriter.no_op

```
2 changes: 0 additions & 2 deletions onnxscript/rewriter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
broadcast_to_matmul,
cast_constant_of_shape,
collapse_slices,
gemm_to_matmul_add,
no_op,
pattern,
)
Expand All @@ -28,7 +27,6 @@
_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = (
*no_op.rules.rules, # TODO: merge this rule into constant folding?
*broadcast_to_matmul.rules.rules,
gemm_to_matmul_add.rule, # type: ignore[has-type]
*cast_constant_of_shape.rules.rules,
*collapse_slices.rules.rules,
*basic_rules.basic_optimization_rules().rules,
Expand Down
7 changes: 3 additions & 4 deletions onnxscript/rewriter/ort_fusions/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
import onnxscript.rewriter.ort_fusions.fused_matmul_rule_sets as fused_matmul_rule_sets
import onnxscript.rewriter.ort_fusions.shape_optimization as shape_optimization
from onnxscript.optimizer import optimize
from onnxscript.rewriter import rewrite
from onnxscript.rewriter import gemm_to_matmul_add, rewrite
from onnxscript.rewriter.ort_fusions import (
# group_normalization_merge_silu,
instance_to_group_normalization,
softmax,
)
Expand Down Expand Up @@ -38,7 +37,7 @@
*instance_to_group_normalization.rules.rules,
# NOTE: group normalization merge silu should be applied after instance to group normalization
# *group_normalization_merge_silu.rules.rules,
*fused_matmul_rule_sets.fused_matmul_rule_sets().rules,
*fused_matmul_rule_sets.fused_matmul_rule_sets(),
]


Expand Down Expand Up @@ -130,7 +129,7 @@ def optimize_for_ort(
- The optimized `ir.Model` after applying transformer-specific fusions.
- A dictionary with a count of each of the fusions applied.
"""

rewrite(model, [gemm_to_matmul_add.rule])
model, fusion_count = fuse_xformers(
model,
debug=debug,
Expand Down
Loading