Skip to content

Commit

Permalink
Add a custom LLVM pass to allow contraction of fast fadd/fsub with a …
Browse files Browse the repository at this point in the history
…normal fmul.

Fix #18654, Fix #22217
  • Loading branch information
yuyichao committed Jul 3, 2017
1 parent 91e8a53 commit fc5d7b8
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ endif
LLVMLINK :=

ifeq ($(JULIACODEGEN),LLVM)
SRCS += codegen jitlayers disasm debuginfo llvm-simdloop llvm-ptls llvm-late-gc-lowering llvm-lower-handlers llvm-gc-invariant-verifier llvm-propagate-addrspaces cgmemmgr
SRCS += codegen jitlayers disasm debuginfo llvm-simdloop llvm-ptls llvm-muladd llvm-late-gc-lowering llvm-lower-handlers llvm-gc-invariant-verifier llvm-propagate-addrspaces cgmemmgr
FLAGS += -I$(shell $(LLVM_CONFIG_HOST) --includedir)
LLVM_LIBS := all
ifeq ($(USE_POLLY),1)
Expand Down
20 changes: 17 additions & 3 deletions src/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -723,17 +723,23 @@ static Value *emit_checked_srem_int(jl_codectx_t &ctx, Value *x, Value *den)
struct math_builder {
IRBuilder<> &ctxbuilder;
FastMathFlags old_fmf;
math_builder(jl_codectx_t &ctx, bool always_fast = false)
math_builder(jl_codectx_t &ctx, bool always_fast = false, bool contract = false)
: ctxbuilder(ctx.builder),
old_fmf(ctxbuilder.getFastMathFlags())
{
FastMathFlags fmf;
if (jl_options.fast_math != JL_OPTIONS_FAST_MATH_OFF &&
(always_fast ||
jl_options.fast_math == JL_OPTIONS_FAST_MATH_ON)) {
FastMathFlags fmf;
fmf.setUnsafeAlgebra();
ctxbuilder.setFastMathFlags(fmf);
}
#if JL_LLVM_VERSION >= 50000
if (contract)
fmf.setAllowContract(true);
#else
assert(!contract);
#endif
ctxbuilder.setFastMathFlags(fmf);
}
IRBuilder<>& operator()() const { return ctxbuilder; }
~math_builder() {
Expand Down Expand Up @@ -936,10 +942,18 @@ static Value *emit_untyped_intrinsic(jl_codectx_t &ctx, intrinsic f, Value **arg
return ctx.builder.CreateCall(fmaintr, {x, y, z});
}
case muladd_float: {
#if JL_LLVM_VERSION >= 50000
// LLVM 5.0 can create FMA in the backend for contractable fmul and fadd
// Emitting fmul and fadd here since they are easier for other LLVM passes to
// optimize.
auto mathb = math_builder(ctx, false, true);
return mathb().CreateFAdd(mathb().CreateFMul(x, y), z);
#else
assert(y->getType() == x->getType());
assert(z->getType() == y->getType());
Value *muladdintr = Intrinsic::getDeclaration(jl_Module, Intrinsic::fmuladd, makeArrayRef(t));
return ctx.builder.CreateCall(muladdintr, {x, y, z});
#endif
}

case checked_sadd_int:
Expand Down
1 change: 1 addition & 0 deletions src/jitlayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ void addOptimizationPasses(legacy::PassManagerBase *PM, int opt_level)
PM->add(createDeadCodeEliminationPass());
PM->add(createLowerPTLSPass(imaging_mode));
#endif
PM->add(createCombineMulAddPass());
}

extern "C" JL_DLLEXPORT
Expand Down
1 change: 1 addition & 0 deletions src/jitlayers.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ extern JuliaOJIT *jl_ExecutionEngine;
JL_DLLEXPORT extern LLVMContext jl_LLVMContext;

Pass *createLowerPTLSPass(bool imaging_mode);
Pass *createCombineMulAddPass();
Pass *createLateLowerGCFramePass();
Pass *createLowerExcHandlersPass();
Pass *createGCInvariantVerifierPass(bool Strong);
Expand Down
123 changes: 123 additions & 0 deletions src/llvm-muladd.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// This file is a part of Julia. License is MIT: https://julialang.org/license

#define DEBUG_TYPE "combine_muladd"
#undef DEBUG
#include "llvm-version.h"

#include <llvm/IR/Value.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/IntrinsicInst.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/Operator.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/Pass.h>
#include <llvm/Support/Debug.h>
#include "fix_llvm_assert.h"

#include "julia.h"

using namespace llvm;

/**
* Combine
* ```
* %v0 = fmul ... %a, %b
* %v = fadd fast ... %v0, %c
* ```
* to
* `%v = call fast @llvm.fmuladd.<...>(... %a, ... %b, ... %c)`
* when `%v0` has no other use
*/

struct CombineMulAdd : public FunctionPass {
static char ID;
CombineMulAdd() : FunctionPass(ID)
{}

private:
bool runOnFunction(Function &F) override;
};

// Return true if this function shouldn't be called again on the other operand
// This will always return false on LLVM 5.0+
static bool checkCombine(Module *m, Instruction *addOp, Value *maybeMul, Value *addend,
bool negadd, bool negres)
{
auto mulOp = dyn_cast<Instruction>(maybeMul);
if (!mulOp || mulOp->getOpcode() != Instruction::FMul)
return false;
if (!mulOp->hasOneUse())
return false;
#if JL_LLVM_VERSION >= 50000
// On 5.0+ we only need to mark the mulOp as contract and the backend will do the work for us.
auto fmf = mulOp->getFastMathFlags();
fmf.setAllowContract(true);
mulOp->copyFastMathFlags(fmf);
return false;
#else
IRBuilder<> builder(m->getContext());
builder.SetInsertPoint(addOp);
auto mul1 = mulOp->getOperand(0);
auto mul2 = mulOp->getOperand(1);
Value *muladdf = Intrinsic::getDeclaration(m, Intrinsic::fmuladd, addOp->getType());
if (negadd) {
auto newaddend = builder.CreateFNeg(addend);
// Might be a const
if (auto neginst = dyn_cast<Instruction>(newaddend))
neginst->setHasUnsafeAlgebra(true);
addend = newaddend;
}
Instruction *newv = builder.CreateCall(muladdf, {mul1, mul2, addend});
newv->setHasUnsafeAlgebra(true);
if (negres) {
// Shouldn't be a constant
newv = cast<Instruction>(builder.CreateFNeg(newv));
newv->setHasUnsafeAlgebra(true);
}
addOp->replaceAllUsesWith(newv);
addOp->eraseFromParent();
mulOp->eraseFromParent();
return true;
#endif
}

bool CombineMulAdd::runOnFunction(Function &F)
{
Module *m = F.getParent();
for (auto &BB: F) {
for (auto it = BB.begin(); it != BB.end();) {
auto &I = *it;
it++;
switch (I.getOpcode()) {
case Instruction::FAdd: {
if (!I.hasUnsafeAlgebra())
continue;
checkCombine(m, &I, I.getOperand(0), I.getOperand(1), false, false) ||
checkCombine(m, &I, I.getOperand(1), I.getOperand(0), false, false);
break;
}
case Instruction::FSub: {
if (!I.hasUnsafeAlgebra())
continue;
checkCombine(m, &I, I.getOperand(0), I.getOperand(1), true, false) ||
checkCombine(m, &I, I.getOperand(1), I.getOperand(0), true, true);
break;
}
default:
break;
}
}
}
return true;
}

char CombineMulAdd::ID = 0;
static RegisterPass<CombineMulAdd> X("CombineMulAdd", "Combine mul and add to muladd",
false /* Only looks at CFG */,
false /* Analysis Pass */);

Pass *createCombineMulAddPass()
{
return new CombineMulAdd();
}
28 changes: 28 additions & 0 deletions test/llvmpasses/muladd.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
; RUN: opt -load libjulia.so -CombineMulAdd -S %s | FileCheck %s

define double @fast_muladd1(double %a, double %b, double %c) {
top:
; CHECK: {{contract|fmuladd}}
%v1 = fmul double %a, %b
%v2 = fadd fast double %v1, %c
; CHECK: ret double
ret double %v2
}

define double @fast_mulsub1(double %a, double %b, double %c) {
top:
; CHECK: {{contract|fmuladd}}
%v1 = fmul double %a, %b
%v2 = fsub fast double %v1, %c
; CHECK: ret double
ret double %v2
}

define <2 x double> @fast_mulsub_vec1(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
top:
; CHECK: {{contract|fmuladd}}
%v1 = fmul <2 x double> %a, %b
%v2 = fsub fast <2 x double> %c, %v1
; CHECK: ret <2 x double>
ret <2 x double> %v2
}

0 comments on commit fc5d7b8

Please sign in to comment.