diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index e519c9eef397..15cc4450901e 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -242,6 +242,24 @@ TVM_REGISTER_OP("tir.atanh") return (log(one + x) - log(one - x)) * make_const(x.dtype(), 0.5); }); +TVM_REGISTER_OP("tir.erf").set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { + using tir::make_const; + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr) << "Invalid call node in erf legalization"; + const PrimExpr& x = call->args[0]; + PrimExpr abs_x = tvm::abs(x); + PrimExpr t = make_const(x.dtype(), 1.0) / + (make_const(x.dtype(), 1.0) + make_const(x.dtype(), 0.3275911) * abs_x); + PrimExpr a1 = make_const(x.dtype(), 0.254829592); + PrimExpr a2 = make_const(x.dtype(), -0.284496736); + PrimExpr a3 = make_const(x.dtype(), 1.421413741); + PrimExpr a4 = make_const(x.dtype(), -1.453152027); + PrimExpr a5 = make_const(x.dtype(), 1.061405429); + PrimExpr poly = (((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t); + PrimExpr approx = make_const(x.dtype(), 1.0) - poly * exp(-abs_x * abs_x); + return tvm::tir::Select(x < 0, -approx, approx); +}); + TVM_REGISTER_OP("tir.clz").set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { const tir::CallNode* call = e.as(); ICHECK(call != nullptr); diff --git a/tests/python/tir-base/test_tir_intrin.py b/tests/python/tir-base/test_tir_intrin.py index 3e731f55fb0e..55f8dbed6c3c 100644 --- a/tests/python/tir-base/test_tir_intrin.py +++ b/tests/python/tir-base/test_tir_intrin.py @@ -23,6 +23,7 @@ import numpy as np import ctypes import math +import scipy def test_nearbyint(): @@ -77,6 +78,7 @@ def test_unary_intrin(): (tvm.tir.asinh, lambda x: np.arcsinh(x)), (tvm.tir.acosh, lambda x: np.arccosh(x)), (tvm.tir.atanh, lambda x: np.arctanh(x)), + (tvm.tir.erf, lambda x: scipy.special.erf(x)), ] def run_test(tvm_intrin, np_func, atol=1e-5, rtol=1e-5):