diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index ed33807db9..faca1f9ba8 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -140,4 +140,18 @@ def optimize_for_ort( ) # Apply the ORT pattern rewrite rules. rewrite(model, ORT_PATTERN_REWRITE_RULES) + + passes = ir.passes.Sequential( + # Apply the ORT optimization passes. + # https://github.com/microsoft/onnxruntime/blob/74dcf7e296639095dfa55d31336998b6f719ed76/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py#L172 + common_passes.ClearMetadataAndDocStringPass(), + # https://github.com/microsoft/onnxruntime/blob/74dcf7e296639095dfa55d31336998b6f719ed76/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py#L139 + common_passes.LiftConstantsToInitializersPass(lift_all_constants=False, size_limit=1), + common_passes.RemoveInitializersFromInputsPass(), + common_passes.ShapeInferencePass(), + common_passes.CheckerPass(), + ) + assert passes.in_place + result = passes(model) + assert result.model is model return model, fusion_count