From f73c461f50adda0777338cb7a2af2f8a35676681 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 1 Sep 2017 11:46:45 -0700 Subject: [PATCH] [BACKEND] Explicitly allow specialization of FMA in llvm (#407) --- src/codegen/llvm/intrin_rule_llvm.cc | 3 +++ src/pass/lower_intrin.cc | 26 ++++++++++++++++++++++---- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/codegen/llvm/intrin_rule_llvm.cc b/src/codegen/llvm/intrin_rule_llvm.cc index 569d06bb81b1..ab6a52be5093 100644 --- a/src/codegen/llvm/intrin_rule_llvm.cc +++ b/src/codegen/llvm/intrin_rule_llvm.cc @@ -67,6 +67,9 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp>); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma") +.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd>); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log>); diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index 4a496063daec..43989a6cd1ac 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -16,11 +16,12 @@ namespace ir { class IntrinInjecter : public IRMutator { public: explicit IntrinInjecter(std::string target) { - patterns_.push_back("tvm.intrin.rule." + target + "."); - if (!strncmp(target.c_str(), "llvm", 4) && target != "llvm") { - patterns_.push_back("tvm.intrin.rule.llvm."); - } + std::istringstream is(target); + std::string starget; + is >> starget; + patterns_.push_back("tvm.intrin.rule." + starget + "."); patterns_.push_back("tvm.intrin.rule.default."); + fma_ = runtime::Registry::Get(patterns_[0] + "fma"); } Expr Mutate_(const Call* op, const Expr& e) final { @@ -32,6 +33,22 @@ class IntrinInjecter : public IRMutator { return IRMutator::Mutate_(op, e); } + Expr Mutate_(const Add* op, const Expr& e) final { + if (fma_ == nullptr || !op->type.is_float()) { + return IRMutator::Mutate_(op, e); + } + if (const Mul* mb = op->b.as()) { + Expr r = (*fma_)(Call::make( + op->type, "fma", {mb->a, mb->b, op->a}, Call::PureIntrinsic)); + if (r.defined()) return r; + } else if (const Mul* ma = op->a.as()) { + Expr r = (*fma_)(Call::make( + op->type, "fma", {ma->a, ma->b, op->b}, Call::PureIntrinsic)); + if (r.defined()) return r; + } + return IRMutator::Mutate_(op, e); + } + private: Expr ApplyPattern(const std::string& name, const Expr& e) { for (size_t i = 0; i < patterns_.size(); ++i) { @@ -54,6 +71,7 @@ class IntrinInjecter : public IRMutator { } // patterns std::vector patterns_; + const PackedFunc* fma_{nullptr}; }; LoweredFunc