From 494d5c95d7a56116989d0a7f2788361cfdb11e3a Mon Sep 17 00:00:00 2001 From: Peter Yeh Date: Wed, 11 Dec 2019 01:14:36 -0800 Subject: [PATCH] update rocm intrin rule (#4499) --- src/codegen/llvm/intrin_rule_rocm.cc | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/src/codegen/llvm/intrin_rule_rocm.cc b/src/codegen/llvm/intrin_rule_rocm.cc index aaaf0ca99325..5ad5261c81bf 100644 --- a/src/codegen/llvm/intrin_rule_rocm.cc +++ b/src/codegen/llvm/intrin_rule_rocm.cc @@ -22,7 +22,6 @@ */ #ifdef TVM_LLVM_VERSION -#include "intrin_rule_llvm.h" #include #include #include @@ -45,27 +44,28 @@ inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) { namespace llvm { TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>); +.set_body(DispatchExternOCML); TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>); +.set_body(DispatchExternOCML); TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); +.set_body(DispatchExternOCML); TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); +.set_body(DispatchExternOCML); TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>); +.set_body(DispatchExternOCML); TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp") .set_body(DispatchExternOCML); -// On AMD GPU, fma is slower than mac -// removing fma dispatch allows backend to generate faster mac instruction +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.erf") +.set_body(DispatchExternOCML); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 1>); +.set_body(DispatchExternOCML); TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log") .set_body(DispatchExternOCML); @@ -78,6 +78,16 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow") TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tanh") .set_body(DispatchExternOCML); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos") +.set_body(DispatchExternOCML); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sin") +.set_body(DispatchExternOCML); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.atan") +.set_body(DispatchExternOCML); + } // namespace llvm } // namespace codegen } // namespace tvm