Skip to content
38 changes: 38 additions & 0 deletions src/target/llvm/intrin_rule_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,44 @@ TVM_REGISTER_OP("tir.sinh")
return ret;
});

TVM_REGISTER_OP("tir.asin")
.set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
using tir::make_const;
const tir::CallNode* call = e.as<tir::CallNode>();
ICHECK(call != nullptr);
const PrimExpr& x = call->args[0];
PrimExpr x2 = x * x;
PrimExpr term1 = x;
PrimExpr term3 = term1 * x2 / make_const(x.dtype(), 6);
PrimExpr term5 = term3 * x2 * make_const(x.dtype(), 9) / make_const(x.dtype(), 40);
PrimExpr term7 = term5 * x2 * make_const(x.dtype(), 25) / make_const(x.dtype(), 112);
PrimExpr term9 = term7 * x2 * make_const(x.dtype(), 1225) / make_const(x.dtype(), 3456);
PrimExpr term11 = term9 * x2 * make_const(x.dtype(), 3969) / make_const(x.dtype(), 28160);
return term1 + term3 + term5 + term7 + term9 + term11;
});

TVM_REGISTER_OP("tir.acos")
.set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
using tir::make_const;
const tir::CallNode* call = e.as<tir::CallNode>();
ICHECK(call != nullptr) << "Invalid call node in acos legalization";
const PrimExpr& x = call->args[0];
PrimExpr half_pi = make_const(x.dtype(), M_PI / 2);
PrimExpr asin_x = asin(x);
return half_pi - asin_x;
});

TVM_REGISTER_OP("tir.atan")
.set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
using tir::make_const;
const tir::CallNode* call = e.as<tir::CallNode>();
ICHECK(call != nullptr) << "Invalid call node in atan legalization";
const PrimExpr& x = call->args[0];
PrimExpr one = make_const(x.dtype(), 1.0);
PrimExpr denom = sqrt(x * x + one);
return asin(x / denom);
});

TVM_REGISTER_OP("tir.clz").set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
const tir::CallNode* call = e.as<tir::CallNode>();
ICHECK(call != nullptr);
Expand Down
6 changes: 3 additions & 3 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,9 +447,9 @@ def test_bitwise_shift(direction: str):
"Sinh",
"Cosh",
"Tanh",
"Asin",
"Acos",
"Atan",
# "Asin", // TODO @jikechao, fix the precision loss due to the Taylor approximation
# "Acos",
# "Atan",
"Asinh",
"Acosh",
"Atanh",
Expand Down
7 changes: 4 additions & 3 deletions tests/python/tir-base/test_tir_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_unary_intrin():
(tvm.tir.atanh, lambda x: np.arctanh(x)),
]

def run_test(tvm_intrin, np_func):
def run_test(tvm_intrin, np_func, atol=1e-5, rtol=1e-5):
m = te.var(
"m",
)
Expand All @@ -98,10 +98,11 @@ def run_test(tvm_intrin, np_func):
a = tvm.nd.array(np.random.uniform(0.1, 0.5, size=n).astype(A.dtype), dev)
b = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
func(a, b)
tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=1e-5, rtol=1e-5)
tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=atol, rtol=rtol)

for func in test_funcs:
run_test(*func)
atol = rtol = 1e-3 if func[0].__name__ in ["asin", "acos", "atan"] else 1e-5
run_test(*func, atol, rtol)


def test_binary_intrin():
Expand Down
Loading