Skip to content

Extend optimize_for_ort to cover passes #2274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions onnxscript/rewriter/ort_fusions/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import onnxscript.ir as ir
from onnxscript.ir.passes.common import shape_inference
import onnxscript.ir.passes.common as common_passes
from onnxscript.optimizer import optimize
from onnxscript.rewriter import rewrite
from onnxscript.rewriter.ort_fusions import (
Expand Down Expand Up @@ -47,7 +47,7 @@
# TODO: Do we need this dependence on ONNX's partial-data-propagation? There are some
# extra shape-propagation and partial-data-propagation rules in ONNX that are not yet
# incorporated in our optimizer.
shape_inference.infer_shapes(model)
common_passes.ShapeInferencePass()(model)
optimize(model)
return model

Expand Down Expand Up @@ -130,4 +130,18 @@
)
# Apply the ORT pattern rewrite rules.
rewrite(model, ORT_PATTERN_REWRITE_RULES)
return model, fusion_count

passes = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
passes = [
passes = ir.passes.Sequential(

# TODO(exporter team): Fold transpose into initializers
# Apply the ORT optimization passes.
# https://github.com/microsoft/onnxruntime/blob/74dcf7e296639095dfa55d31336998b6f719ed76/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py#L172

Check warning on line 137 in onnxscript/rewriter/ort_fusions/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/_core.py#L137

Added line #L137 was not covered by tests
common_passes.ClearMetadataAndDocStringPass(),
# https://github.com/microsoft/onnxruntime/blob/74dcf7e296639095dfa55d31336998b6f719ed76/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py#L139

Check warning on line 139 in onnxscript/rewriter/ort_fusions/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/_core.py#L139

Added line #L139 was not covered by tests
common_passes.LiftConstantsToInitializersPass(lift_all_constants=False, size_limit=1),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have another pass called LiftSubgraphInitializersToMainGraphPass. Do we know if it's needed in genAI? @kunal-vaishnavi

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the pass logic is in DynamoOnnxHelper, then it is used for ONNX Runtime GenAI.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really produce graphs with subgraph initializers. I think we are ok either way

common_passes.RemoveInitializersFromInputsPass(),
common_passes.ShapeInferencePass(),
common_passes.CheckerPass(),
]
optimize_for_ort_passes = ir.passes.Sequential(*passes)
Comment on lines +149 to +150
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
]
optimize_for_ort_passes = ir.passes.Sequential(*passes)
)

result = optimize_for_ort_passes(model)
return result.model, fusion_count
Loading