From 63b1560342e2b4ac759ce9e00554978641745bd5 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 15 Aug 2025 22:15:09 -0700 Subject: [PATCH 1/2] Add Erf-based Gelu fusion rule Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/gelu.py | 19 +++++++++-- onnxscript/rewriter/ort_fusions/gelu_test.py | 33 ++++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gelu.py b/onnxscript/rewriter/ort_fusions/gelu.py index d31f4ef749..f9aa62727e 100644 --- a/onnxscript/rewriter/ort_fusions/gelu.py +++ b/onnxscript/rewriter/ort_fusions/gelu.py @@ -7,6 +7,7 @@ from onnxscript.rewriter import _fusion_utils, pattern _sqrt_two_over_pi = math.sqrt(2.0 / math.pi) +_sqrt_two = math.sqrt(2.0) class GeluTanhFusion(pattern.RewriteRuleClassBase): @@ -27,9 +28,23 @@ def rewrite(self, op, x): return op.FastGelu(x, _domain="com.microsoft") -_rule = GeluTanhFusion.rule() +class GeluErfFusion(pattern.RewriteRuleClassBase): + def pattern(self, op, x): + # GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2))) + t1 = op.Div(x, _sqrt_two) + t2 = op.Erf(t1) + t3 = op.Add(t2, 1.0) + t4 = op.Mul(x, t3) + result = op.Mul(t4, 0.5) + return result + + def rewrite(self, op, x): + return op.Gelu(x, _domain="com.microsoft") + -gelu_rules = pattern.RewriteRuleSet([_rule]) +_tanh_rule = GeluTanhFusion.rule() +_erf_rule = GeluErfFusion.rule() +gelu_rules = pattern.RewriteRuleSet([_tanh_rule, _erf_rule]) fuse_gelu = _fusion_utils.apply_fusion_rules(gelu_rules) diff --git a/onnxscript/rewriter/ort_fusions/gelu_test.py b/onnxscript/rewriter/ort_fusions/gelu_test.py index 1ab6486c87..9726e39756 100644 --- a/onnxscript/rewriter/ort_fusions/gelu_test.py +++ b/onnxscript/rewriter/ort_fusions/gelu_test.py @@ -52,6 +52,39 @@ def gelu_model(x): optimized_output = test_utils.ort_run("Optimized", model, input) test_utils.assert_allclose(original_output, optimized_output) + def test_gelu_erf_fusion(self): + _sqrt_two = math.sqrt(2.0) + + @script() + def gelu_erf_model(x): + # GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2))) + t1 = op.Div(x, _sqrt_two) + t2 = op.Erf(t1) + t3 = op.Add(t2, 1.0) + t4 = op.Mul(x, t3) + result = op.Mul(t4, 0.5) + return result + + model_proto = gelu_erf_model.to_model_proto( + input_types=[FLOAT[10]], output_types=[FLOAT[10]] + ) + model = ir.serde.deserialize_model(model_proto) + + # Eliminate redundant CastLike ops: + optimize(model) + + input = {"x": np.random.randn(10).astype(np.float32)} + original_output = test_utils.ort_run("Original", model, input) + + fuse_gelu(model) + remove_unused_nodes(model) + + self.assertEqual(len(model.graph), 1) + self.assertEqual(model.graph.node(0).op_type, "Gelu") + + optimized_output = test_utils.ort_run("Optimized", model, input) + test_utils.assert_allclose(original_output, optimized_output) + if __name__ == "__main__": unittest.main() From f6e8009d9f152deff7790ecd4a2c9e88a15a9c57 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 19 Aug 2025 13:14:13 -0700 Subject: [PATCH 2/2] Use upper case constant names Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/gelu.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gelu.py b/onnxscript/rewriter/ort_fusions/gelu.py index f9aa62727e..f4f27a03b5 100644 --- a/onnxscript/rewriter/ort_fusions/gelu.py +++ b/onnxscript/rewriter/ort_fusions/gelu.py @@ -6,8 +6,8 @@ from onnxscript.rewriter import _fusion_utils, pattern -_sqrt_two_over_pi = math.sqrt(2.0 / math.pi) -_sqrt_two = math.sqrt(2.0) +_SQRT_TWO_OVER_PI = math.sqrt(2.0 / math.pi) +_SQRT_TWO = math.sqrt(2.0) class GeluTanhFusion(pattern.RewriteRuleClassBase): @@ -17,7 +17,7 @@ def pattern(self, op, x): t2 = op.Mul(0.044715, t1) t3 = op.Add(x, t2) - t4 = op.Mul(_sqrt_two_over_pi, t3) + t4 = op.Mul(_SQRT_TWO_OVER_PI, t3) t5 = op.Tanh(t4) t6 = op.Add(t5, 1) t7 = op.Mul(0.5, t6) @@ -31,7 +31,7 @@ def rewrite(self, op, x): class GeluErfFusion(pattern.RewriteRuleClassBase): def pattern(self, op, x): # GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2))) - t1 = op.Div(x, _sqrt_two) + t1 = op.Div(x, _SQRT_TWO) t2 = op.Erf(t1) t3 = op.Add(t2, 1.0) t4 = op.Mul(x, t3)