From 290d57fca192ed481996040b34f356d5a71a8690 Mon Sep 17 00:00:00 2001 From: Qingchao Shen Date: Tue, 13 May 2025 18:25:39 +0800 Subject: [PATCH 1/2] Update intrin_rule_llvm.cc --- src/target/llvm/intrin_rule_llvm.cc | 32 +++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 0f05ff8fda4d..3b45a93f4501 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -200,6 +200,38 @@ TVM_REGISTER_OP("tir.atan") return asin(x / denom); }); +TVM_REGISTER_OP("tir.asinh") + .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { + using tir::make_const; + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr) << "Invalid call node in asinh legalization"; + const PrimExpr& x = call->args[0]; + PrimExpr one = make_const(x.dtype(), 1.0); + PrimExpr sqrt_val = sqrt(x * x + one); + return log(x + sqrt_val); + }); + +TVM_REGISTER_OP("tir.acosh") + .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { + using tir::make_const; + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr) << "Invalid call node in acosh legalization"; + const PrimExpr& x = call->args[0]; + PrimExpr one = make_const(x.dtype(), 1.0); + PrimExpr sqrt_val = sqrt(x * x - one); + return log(x + sqrt_val); +}); + +TVM_REGISTER_OP("tir.atanh") + .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { + using tir::make_const; + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr) << "Invalid call node in atanh legalization"; + const PrimExpr& x = call->args[0]; + PrimExpr one = make_const(x.dtype(), 1.0); + return (log(one + x) - log(one - x)) * make_const(x.dtype(), 0.5); +}); + TVM_REGISTER_OP("tir.clz").set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { const tir::CallNode* call = e.as(); ICHECK(call != nullptr); From 94f88a5734a64f37026f41c439e4e61859a753ca Mon Sep 17 00:00:00 2001 From: Qingchao Shen Date: Tue, 13 May 2025 18:46:22 +0800 Subject: [PATCH 2/2] Update intrin_rule_llvm.cc --- src/target/llvm/intrin_rule_llvm.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 3b45a93f4501..eefa547efdee 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -202,13 +202,13 @@ TVM_REGISTER_OP("tir.atan") TVM_REGISTER_OP("tir.asinh") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tir::make_const; - const tir::CallNode* call = e.as(); - ICHECK(call != nullptr) << "Invalid call node in asinh legalization"; - const PrimExpr& x = call->args[0]; - PrimExpr one = make_const(x.dtype(), 1.0); - PrimExpr sqrt_val = sqrt(x * x + one); - return log(x + sqrt_val); + using tir::make_const; + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr) << "Invalid call node in asinh legalization"; + const PrimExpr& x = call->args[0]; + PrimExpr one = make_const(x.dtype(), 1.0); + PrimExpr sqrt_val = sqrt(x * x + one); + return log(x + sqrt_val); }); TVM_REGISTER_OP("tir.acosh") @@ -220,7 +220,7 @@ TVM_REGISTER_OP("tir.acosh") PrimExpr one = make_const(x.dtype(), 1.0); PrimExpr sqrt_val = sqrt(x * x - one); return log(x + sqrt_val); -}); + }); TVM_REGISTER_OP("tir.atanh") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { @@ -230,7 +230,7 @@ TVM_REGISTER_OP("tir.atanh") const PrimExpr& x = call->args[0]; PrimExpr one = make_const(x.dtype(), 1.0); return (log(one + x) - log(one - x)) * make_const(x.dtype(), 0.5); -}); + }); TVM_REGISTER_OP("tir.clz").set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { const tir::CallNode* call = e.as();