diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index dd1c79b1fc..1f4c0c39d8 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -4,6 +4,7 @@ import onnxscript.ir as ir import onnxscript.ir.passes.common as common_passes +import onnxscript.rewriter.ort_fusions.fused_matmul_rule_sets as fused_matmul_rule_sets import onnxscript.rewriter.ort_fusions.shape_optimization as shape_optimization from onnxscript.optimizer import optimize from onnxscript.rewriter import rewrite @@ -37,9 +38,7 @@ *instance_to_group_normalization.rules.rules, # NOTE: group normalization merge silu should be applied after instance to group normalization # *group_normalization_merge_silu.rules.rules, - # NOTE: The rules below are broken: - # https://github.com/microsoft/onnxscript/pull/2317#issuecomment-2896058483 - # *fused_matmul_rule_sets.fused_matmul_rule_sets(), + *fused_matmul_rule_sets.fused_matmul_rule_sets().rules, ] diff --git a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py index c9c2480428..5082c20464 100644 --- a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py +++ b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py @@ -327,9 +327,8 @@ def pattern(self, op, x, y): def fused_matmul_rule_sets() -> orp.RewriteRuleSet: - """Returns a set of rules introducing onnxruntime contrib obs. - This requires onnxruntime to run the model after - it is rewritten. + """Returns a set of rules introducing onnxruntime contrib ops. + This requires onnxruntime to run the model after it is rewritten. Returns: RewriteRuleSet