From 9e481da081eb9d5ce2055afbf775376e3689dece Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Wed, 7 Jun 2017 01:01:41 -0400 Subject: [PATCH] Add a custom LLVM pass to allow contraction of fast fadd/fsub with a normal fmul. Fix #18654, Fix #22217 --- src/Makefile | 3 +- src/intrinsics.cpp | 12 +++- src/jitlayers.cpp | 1 + src/jitlayers.h | 1 + src/llvm-muladd.cpp | 121 ++++++++++++++++++++++++++++++++++++++ test/llvmpasses/muladd.ll | 28 +++++++++ 6 files changed, 164 insertions(+), 2 deletions(-) create mode 100644 src/llvm-muladd.cpp create mode 100644 test/llvmpasses/muladd.ll diff --git a/src/Makefile b/src/Makefile index ce5e8b63a5679b..85f6e4f8f5c27e 100644 --- a/src/Makefile +++ b/src/Makefile @@ -53,7 +53,8 @@ endif LLVMLINK := ifeq ($(JULIACODEGEN),LLVM) -SRCS += codegen jitlayers disasm debuginfo llvm-simdloop llvm-ptls llvm-gcroot llvm-lower-handlers cgmemmgr +SRCS += codegen jitlayers disasm debuginfo llvm-simdloop llvm-ptls llvm-gcroot \ + llvm-lower-handlers llvm-muladd cgmemmgr FLAGS += -I$(shell $(LLVM_CONFIG_HOST) --includedir) LLVM_LIBS := all ifeq ($(USE_POLLY),1) diff --git a/src/intrinsics.cpp b/src/intrinsics.cpp index dc1f8ed89f0c08..aabbcc14be2ebb 100644 --- a/src/intrinsics.cpp +++ b/src/intrinsics.cpp @@ -955,7 +955,17 @@ static Value *emit_untyped_intrinsic(intrinsic f, Value **argvalues, size_t narg #endif } case muladd_float: { -#if JL_LLVM_VERSION >= 30400 +#if JL_LLVM_VERSION >= 50000 + // LLVM 5.0 can create FMA in the backend for contractable fmul and fadd + // Emitting fmul and fadd here since they are easier for other LLVM passes to + // optimize. + auto mathb = math_builder(ctx); + auto mul = mathb().CreateFMul(x, y); + auto add = mathb().CreateFAdd(mul, z); + mul->setHasAllowContract(true); + add->setHasAllowContract(true); + return add; +#elif JL_LLVM_VERSION >= 30400 assert(y->getType() == x->getType()); assert(z->getType() == y->getType()); Value *muladdintr = Intrinsic::getDeclaration(jl_Module, Intrinsic::fmuladd, makeArrayRef(t)); diff --git a/src/jitlayers.cpp b/src/jitlayers.cpp index 08d47d66a4156d..d1f59f57c78949 100644 --- a/src/jitlayers.cpp +++ b/src/jitlayers.cpp @@ -266,6 +266,7 @@ void addOptimizationPasses(PassManager *PM) PM->add(createLoopVectorizePass()); // Vectorize loops PM->add(createInstructionCombiningPass()); // Clean up after loop vectorizer #endif + PM->add(createCombineMulAddPass()); } #ifdef USE_ORCJIT diff --git a/src/jitlayers.h b/src/jitlayers.h index 972867f0ebda67..e5b02d9270c0e4 100644 --- a/src/jitlayers.h +++ b/src/jitlayers.h @@ -249,6 +249,7 @@ JL_DLLEXPORT extern LLVMContext &jl_LLVMContext; Pass *createLowerPTLSPass(bool imaging_mode); Pass *createLowerGCFramePass(); Pass *createLowerExcHandlersPass(); +Pass *createCombineMulAddPass(); // Whether the Function is an llvm or julia intrinsic. static inline bool isIntrinsicFunction(Function *F) { diff --git a/src/llvm-muladd.cpp b/src/llvm-muladd.cpp new file mode 100644 index 00000000000000..5137c2d5ca8870 --- /dev/null +++ b/src/llvm-muladd.cpp @@ -0,0 +1,121 @@ +// This file is a part of Julia. License is MIT: https://julialang.org/license + +#define DEBUG_TYPE "combine_muladd" +#undef DEBUG +#include "llvm-version.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "fix_llvm_assert.h" + +#include "julia.h" + +using namespace llvm; + +/** + * Combine + * ``` + * %v0 = fmul ... %a, %b + * %v = fadd fast ... %v0, %c + * ``` + * to + * `%v = call fast @llvm.fmuladd.<...>(... %a, ... %b, ... %c)` + * when `%v0` has no other use + */ + +struct CombineMulAdd : public FunctionPass { + static char ID; + CombineMulAdd() : FunctionPass(ID) + {} + +private: + bool runOnFunction(Function &F) override; +}; + +// Return true if this function shouldn't be called again on the other operand +// This will always return false on LLVM 5.0+ +static bool checkCombine(Module *m, Instruction *addOp, Value *maybeMul, Value *addend, + bool negadd, bool negres) +{ + auto mulOp = dyn_cast(maybeMul); + if (!mulOp || mulOp->getOpcode() != Instruction::FMul) + return false; + if (!mulOp->hasOneUse()) + return false; +#if JL_LLVM_VERSION >= 50000 + // On 5.0+ we only need to mark the mulOp as contract and the backend will do the work for us. + addOp->setHasAllowContract(true); + return false; +#else + IRBuilder<> builder(m->getContext()); + builder.SetInsertPoint(addOp); + auto mul1 = mulOp->getOperand(0); + auto mul2 = mulOp->getOperand(1); + Value *muladdf = Intrinsic::getDeclaration(m, Intrinsic::fmuladd, addOp->getType()); + if (negadd) { + auto newaddend = builder.CreateFNeg(addend); + // Might be a const + if (auto neginst = dyn_cast(newaddend)) + neginst->setHasUnsafeAlgebra(true); + addend = newaddend; + } + Instruction *newv = builder.CreateCall(muladdf, {mul1, mul2, addend}); + newv->setHasUnsafeAlgebra(true); + if (negres) { + // Shouldn't be a constant + newv = cast(builder.CreateFNeg(newv)); + newv->setHasUnsafeAlgebra(true); + } + addOp->replaceAllUsesWith(newv); + addOp->eraseFromParent(); + mulOp->eraseFromParent(); + return true; +#endif +} + +bool CombineMulAdd::runOnFunction(Function &F) +{ + Module *m = F.getParent(); + for (auto &BB: F) { + for (auto it = BB.begin(); it != BB.end();) { + auto &I = *it; + it++; + switch (I.getOpcode()) { + case Instruction::FAdd: { + if (!I.hasUnsafeAlgebra()) + continue; + checkCombine(m, &I, I.getOperand(0), I.getOperand(1), false, false) || + checkCombine(m, &I, I.getOperand(1), I.getOperand(0), false, false); + break; + } + case Instruction::FSub: { + if (!I.hasUnsafeAlgebra()) + continue; + checkCombine(m, &I, I.getOperand(0), I.getOperand(1), true, false) || + checkCombine(m, &I, I.getOperand(1), I.getOperand(0), true, true); + break; + } + default: + break; + } + } + } + return true; +} + +char CombineMulAdd::ID = 0; +static RegisterPass X("CombineMulAdd", "Combine mul and add to muladd", + false /* Only looks at CFG */, + false /* Analysis Pass */); + +Pass *createCombineMulAddPass() +{ + return new CombineMulAdd(); +} diff --git a/test/llvmpasses/muladd.ll b/test/llvmpasses/muladd.ll new file mode 100644 index 00000000000000..0a2ac931d9cdeb --- /dev/null +++ b/test/llvmpasses/muladd.ll @@ -0,0 +1,28 @@ +; RUN: opt -load libjulia.so -CombineMulAdd -S %s | FileCheck %s + +define double @fast_muladd1(double %a, double %b, double %c) { +top: +; CHECK: {{contract|fmuladd}} + %v1 = fmul double %a, %b + %v2 = fadd fast double %v1, %c +; CHECK: ret double + ret double %v2 +} + +define double @fast_mulsub1(double %a, double %b, double %c) { +top: +; CHECK: {{contract|fmuladd}} + %v1 = fmul double %a, %b + %v2 = fsub fast double %v1, %c +; CHECK: ret double + ret double %v2 +} + +define <2 x double> @fast_mulsub_vec1(<2 x double> %a, <2 x double> %b, <2 x double> %c) { +top: +; CHECK: {{contract|fmuladd}} + %v1 = fmul <2 x double> %a, %b + %v2 = fsub fast <2 x double> %c, %v1 +; CHECK: ret <2 x double> + ret <2 x double> %v2 +}