diff --git a/onnxscript/rewriter/ort_fusions/gelu.py b/onnxscript/rewriter/ort_fusions/gelu.py index d31f4ef749..f4f27a03b5 100644 --- a/onnxscript/rewriter/ort_fusions/gelu.py +++ b/onnxscript/rewriter/ort_fusions/gelu.py @@ -6,7 +6,8 @@ from onnxscript.rewriter import _fusion_utils, pattern -_sqrt_two_over_pi = math.sqrt(2.0 / math.pi) +_SQRT_TWO_OVER_PI = math.sqrt(2.0 / math.pi) +_SQRT_TWO = math.sqrt(2.0) class GeluTanhFusion(pattern.RewriteRuleClassBase): @@ -16,7 +17,7 @@ def pattern(self, op, x): t2 = op.Mul(0.044715, t1) t3 = op.Add(x, t2) - t4 = op.Mul(_sqrt_two_over_pi, t3) + t4 = op.Mul(_SQRT_TWO_OVER_PI, t3) t5 = op.Tanh(t4) t6 = op.Add(t5, 1) t7 = op.Mul(0.5, t6) @@ -27,9 +28,23 @@ def rewrite(self, op, x): return op.FastGelu(x, _domain="com.microsoft") -_rule = GeluTanhFusion.rule() +class GeluErfFusion(pattern.RewriteRuleClassBase): + def pattern(self, op, x): + # GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2))) + t1 = op.Div(x, _SQRT_TWO) + t2 = op.Erf(t1) + t3 = op.Add(t2, 1.0) + t4 = op.Mul(x, t3) + result = op.Mul(t4, 0.5) + return result + + def rewrite(self, op, x): + return op.Gelu(x, _domain="com.microsoft") + -gelu_rules = pattern.RewriteRuleSet([_rule]) +_tanh_rule = GeluTanhFusion.rule() +_erf_rule = GeluErfFusion.rule() +gelu_rules = pattern.RewriteRuleSet([_tanh_rule, _erf_rule]) fuse_gelu = _fusion_utils.apply_fusion_rules(gelu_rules) diff --git a/onnxscript/rewriter/ort_fusions/gelu_test.py b/onnxscript/rewriter/ort_fusions/gelu_test.py index 1ab6486c87..9726e39756 100644 --- a/onnxscript/rewriter/ort_fusions/gelu_test.py +++ b/onnxscript/rewriter/ort_fusions/gelu_test.py @@ -52,6 +52,39 @@ def gelu_model(x): optimized_output = test_utils.ort_run("Optimized", model, input) test_utils.assert_allclose(original_output, optimized_output) + def test_gelu_erf_fusion(self): + _sqrt_two = math.sqrt(2.0) + + @script() + def gelu_erf_model(x): + # GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2))) + t1 = op.Div(x, _sqrt_two) + t2 = op.Erf(t1) + t3 = op.Add(t2, 1.0) + t4 = op.Mul(x, t3) + result = op.Mul(t4, 0.5) + return result + + model_proto = gelu_erf_model.to_model_proto( + input_types=[FLOAT[10]], output_types=[FLOAT[10]] + ) + model = ir.serde.deserialize_model(model_proto) + + # Eliminate redundant CastLike ops: + optimize(model) + + input = {"x": np.random.randn(10).astype(np.float32)} + original_output = test_utils.ort_run("Original", model, input) + + fuse_gelu(model) + remove_unused_nodes(model) + + self.assertEqual(len(model.graph), 1) + self.assertEqual(model.graph.node(0).op_type, "Gelu") + + optimized_output = test_utils.ort_run("Optimized", model, input) + test_utils.assert_allclose(original_output, optimized_output) + if __name__ == "__main__": unittest.main()