From fc5d7b8a9d1116f7e145631e941263f7e3cd04d9 Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Wed, 7 Jun 2017 01:01:41 -0400 Subject: [PATCH] Add a custom LLVM pass to allow contraction of fast fadd/fsub with a normal fmul. Fix #18654, Fix #22217 --- src/Makefile | 2 +- src/intrinsics.cpp | 20 ++++++- src/jitlayers.cpp | 1 + src/jitlayers.h | 1 + src/llvm-muladd.cpp | 123 ++++++++++++++++++++++++++++++++++++++ test/llvmpasses/muladd.ll | 28 +++++++++ 6 files changed, 171 insertions(+), 4 deletions(-) create mode 100644 src/llvm-muladd.cpp create mode 100644 test/llvmpasses/muladd.ll diff --git a/src/Makefile b/src/Makefile index cfaf825c25f46..9e2bf245d8c43 100644 --- a/src/Makefile +++ b/src/Makefile @@ -50,7 +50,7 @@ endif LLVMLINK := ifeq ($(JULIACODEGEN),LLVM) -SRCS += codegen jitlayers disasm debuginfo llvm-simdloop llvm-ptls llvm-late-gc-lowering llvm-lower-handlers llvm-gc-invariant-verifier llvm-propagate-addrspaces cgmemmgr +SRCS += codegen jitlayers disasm debuginfo llvm-simdloop llvm-ptls llvm-muladd llvm-late-gc-lowering llvm-lower-handlers llvm-gc-invariant-verifier llvm-propagate-addrspaces cgmemmgr FLAGS += -I$(shell $(LLVM_CONFIG_HOST) --includedir) LLVM_LIBS := all ifeq ($(USE_POLLY),1) diff --git a/src/intrinsics.cpp b/src/intrinsics.cpp index d044decf25bf9..1f197b615492f 100644 --- a/src/intrinsics.cpp +++ b/src/intrinsics.cpp @@ -723,17 +723,23 @@ static Value *emit_checked_srem_int(jl_codectx_t &ctx, Value *x, Value *den) struct math_builder { IRBuilder<> &ctxbuilder; FastMathFlags old_fmf; - math_builder(jl_codectx_t &ctx, bool always_fast = false) + math_builder(jl_codectx_t &ctx, bool always_fast = false, bool contract = false) : ctxbuilder(ctx.builder), old_fmf(ctxbuilder.getFastMathFlags()) { + FastMathFlags fmf; if (jl_options.fast_math != JL_OPTIONS_FAST_MATH_OFF && (always_fast || jl_options.fast_math == JL_OPTIONS_FAST_MATH_ON)) { - FastMathFlags fmf; fmf.setUnsafeAlgebra(); - ctxbuilder.setFastMathFlags(fmf); } +#if JL_LLVM_VERSION >= 50000 + if (contract) + fmf.setAllowContract(true); +#else + assert(!contract); +#endif + ctxbuilder.setFastMathFlags(fmf); } IRBuilder<>& operator()() const { return ctxbuilder; } ~math_builder() { @@ -936,10 +942,18 @@ static Value *emit_untyped_intrinsic(jl_codectx_t &ctx, intrinsic f, Value **arg return ctx.builder.CreateCall(fmaintr, {x, y, z}); } case muladd_float: { +#if JL_LLVM_VERSION >= 50000 + // LLVM 5.0 can create FMA in the backend for contractable fmul and fadd + // Emitting fmul and fadd here since they are easier for other LLVM passes to + // optimize. + auto mathb = math_builder(ctx, false, true); + return mathb().CreateFAdd(mathb().CreateFMul(x, y), z); +#else assert(y->getType() == x->getType()); assert(z->getType() == y->getType()); Value *muladdintr = Intrinsic::getDeclaration(jl_Module, Intrinsic::fmuladd, makeArrayRef(t)); return ctx.builder.CreateCall(muladdintr, {x, y, z}); +#endif } case checked_sadd_int: diff --git a/src/jitlayers.cpp b/src/jitlayers.cpp index 06e4b02545339..10f8a4c290c21 100644 --- a/src/jitlayers.cpp +++ b/src/jitlayers.cpp @@ -242,6 +242,7 @@ void addOptimizationPasses(legacy::PassManagerBase *PM, int opt_level) PM->add(createDeadCodeEliminationPass()); PM->add(createLowerPTLSPass(imaging_mode)); #endif + PM->add(createCombineMulAddPass()); } extern "C" JL_DLLEXPORT diff --git a/src/jitlayers.h b/src/jitlayers.h index 22418b6fc2beb..aaa11ec3847c3 100644 --- a/src/jitlayers.h +++ b/src/jitlayers.h @@ -202,6 +202,7 @@ extern JuliaOJIT *jl_ExecutionEngine; JL_DLLEXPORT extern LLVMContext jl_LLVMContext; Pass *createLowerPTLSPass(bool imaging_mode); +Pass *createCombineMulAddPass(); Pass *createLateLowerGCFramePass(); Pass *createLowerExcHandlersPass(); Pass *createGCInvariantVerifierPass(bool Strong); diff --git a/src/llvm-muladd.cpp b/src/llvm-muladd.cpp new file mode 100644 index 0000000000000..267efa0749dc7 --- /dev/null +++ b/src/llvm-muladd.cpp @@ -0,0 +1,123 @@ +// This file is a part of Julia. License is MIT: https://julialang.org/license + +#define DEBUG_TYPE "combine_muladd" +#undef DEBUG +#include "llvm-version.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "fix_llvm_assert.h" + +#include "julia.h" + +using namespace llvm; + +/** + * Combine + * ``` + * %v0 = fmul ... %a, %b + * %v = fadd fast ... %v0, %c + * ``` + * to + * `%v = call fast @llvm.fmuladd.<...>(... %a, ... %b, ... %c)` + * when `%v0` has no other use + */ + +struct CombineMulAdd : public FunctionPass { + static char ID; + CombineMulAdd() : FunctionPass(ID) + {} + +private: + bool runOnFunction(Function &F) override; +}; + +// Return true if this function shouldn't be called again on the other operand +// This will always return false on LLVM 5.0+ +static bool checkCombine(Module *m, Instruction *addOp, Value *maybeMul, Value *addend, + bool negadd, bool negres) +{ + auto mulOp = dyn_cast(maybeMul); + if (!mulOp || mulOp->getOpcode() != Instruction::FMul) + return false; + if (!mulOp->hasOneUse()) + return false; +#if JL_LLVM_VERSION >= 50000 + // On 5.0+ we only need to mark the mulOp as contract and the backend will do the work for us. + auto fmf = mulOp->getFastMathFlags(); + fmf.setAllowContract(true); + mulOp->copyFastMathFlags(fmf); + return false; +#else + IRBuilder<> builder(m->getContext()); + builder.SetInsertPoint(addOp); + auto mul1 = mulOp->getOperand(0); + auto mul2 = mulOp->getOperand(1); + Value *muladdf = Intrinsic::getDeclaration(m, Intrinsic::fmuladd, addOp->getType()); + if (negadd) { + auto newaddend = builder.CreateFNeg(addend); + // Might be a const + if (auto neginst = dyn_cast(newaddend)) + neginst->setHasUnsafeAlgebra(true); + addend = newaddend; + } + Instruction *newv = builder.CreateCall(muladdf, {mul1, mul2, addend}); + newv->setHasUnsafeAlgebra(true); + if (negres) { + // Shouldn't be a constant + newv = cast(builder.CreateFNeg(newv)); + newv->setHasUnsafeAlgebra(true); + } + addOp->replaceAllUsesWith(newv); + addOp->eraseFromParent(); + mulOp->eraseFromParent(); + return true; +#endif +} + +bool CombineMulAdd::runOnFunction(Function &F) +{ + Module *m = F.getParent(); + for (auto &BB: F) { + for (auto it = BB.begin(); it != BB.end();) { + auto &I = *it; + it++; + switch (I.getOpcode()) { + case Instruction::FAdd: { + if (!I.hasUnsafeAlgebra()) + continue; + checkCombine(m, &I, I.getOperand(0), I.getOperand(1), false, false) || + checkCombine(m, &I, I.getOperand(1), I.getOperand(0), false, false); + break; + } + case Instruction::FSub: { + if (!I.hasUnsafeAlgebra()) + continue; + checkCombine(m, &I, I.getOperand(0), I.getOperand(1), true, false) || + checkCombine(m, &I, I.getOperand(1), I.getOperand(0), true, true); + break; + } + default: + break; + } + } + } + return true; +} + +char CombineMulAdd::ID = 0; +static RegisterPass X("CombineMulAdd", "Combine mul and add to muladd", + false /* Only looks at CFG */, + false /* Analysis Pass */); + +Pass *createCombineMulAddPass() +{ + return new CombineMulAdd(); +} diff --git a/test/llvmpasses/muladd.ll b/test/llvmpasses/muladd.ll new file mode 100644 index 0000000000000..0a2ac931d9cde --- /dev/null +++ b/test/llvmpasses/muladd.ll @@ -0,0 +1,28 @@ +; RUN: opt -load libjulia.so -CombineMulAdd -S %s | FileCheck %s + +define double @fast_muladd1(double %a, double %b, double %c) { +top: +; CHECK: {{contract|fmuladd}} + %v1 = fmul double %a, %b + %v2 = fadd fast double %v1, %c +; CHECK: ret double + ret double %v2 +} + +define double @fast_mulsub1(double %a, double %b, double %c) { +top: +; CHECK: {{contract|fmuladd}} + %v1 = fmul double %a, %b + %v2 = fsub fast double %v1, %c +; CHECK: ret double + ret double %v2 +} + +define <2 x double> @fast_mulsub_vec1(<2 x double> %a, <2 x double> %b, <2 x double> %c) { +top: +; CHECK: {{contract|fmuladd}} + %v1 = fmul <2 x double> %a, %b + %v2 = fsub fast <2 x double> %c, %v1 +; CHECK: ret <2 x double> + ret <2 x double> %v2 +}