Skip to content

Commit

Permalink
[BACKEND] Explicitly allow specialization of FMA in llvm (#407)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Sep 1, 2017
1 parent a45d3b0 commit f73c461
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
3 changes: 3 additions & 0 deletions src/codegen/llvm/intrin_rule_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log>);

Expand Down
26 changes: 22 additions & 4 deletions src/pass/lower_intrin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ namespace ir {
class IntrinInjecter : public IRMutator {
public:
explicit IntrinInjecter(std::string target) {
patterns_.push_back("tvm.intrin.rule." + target + ".");
if (!strncmp(target.c_str(), "llvm", 4) && target != "llvm") {
patterns_.push_back("tvm.intrin.rule.llvm.");
}
std::istringstream is(target);
std::string starget;
is >> starget;
patterns_.push_back("tvm.intrin.rule." + starget + ".");
patterns_.push_back("tvm.intrin.rule.default.");
fma_ = runtime::Registry::Get(patterns_[0] + "fma");
}

Expr Mutate_(const Call* op, const Expr& e) final {
Expand All @@ -32,6 +33,22 @@ class IntrinInjecter : public IRMutator {
return IRMutator::Mutate_(op, e);
}

Expr Mutate_(const Add* op, const Expr& e) final {
if (fma_ == nullptr || !op->type.is_float()) {
return IRMutator::Mutate_(op, e);
}
if (const Mul* mb = op->b.as<Mul>()) {
Expr r = (*fma_)(Call::make(
op->type, "fma", {mb->a, mb->b, op->a}, Call::PureIntrinsic));
if (r.defined()) return r;
} else if (const Mul* ma = op->a.as<Mul>()) {
Expr r = (*fma_)(Call::make(
op->type, "fma", {ma->a, ma->b, op->b}, Call::PureIntrinsic));
if (r.defined()) return r;
}
return IRMutator::Mutate_(op, e);
}

private:
Expr ApplyPattern(const std::string& name, const Expr& e) {
for (size_t i = 0; i < patterns_.size(); ++i) {
Expand All @@ -54,6 +71,7 @@ class IntrinInjecter : public IRMutator {
}
// patterns
std::vector<std::string> patterns_;
const PackedFunc* fma_{nullptr};
};

LoweredFunc
Expand Down

0 comments on commit f73c461

Please sign in to comment.