Skip to content

Commit 1bada0a

Browse files
authored
[NVPTX] Add IR pass for FMA transformation in the llc pipeline (#154735)
This change introduces a new IR pass in the llc pipeline for NVPTX that transforms sequences of FMUL followed by FADD or FSUB into a single FMA instruction. Currently, all FMA folding for NVPTX occurs at the DAGCombine stage, which is too late for any IR-level passes that might want to optimize or analyze FMAs. By moving this transformation earlier into the IR phase, we enable more opportunities for FMA folding, including across basic blocks. Additionally, this new pass relies on the contract instruction level fast-math flag to perform these transformations, rather than depending on the -fp-contract=fast or -enable-unsafe-fp-math options passed to llc.
1 parent 9b12f8f commit 1bada0a

File tree

6 files changed

+432
-0
lines changed

6 files changed

+432
-0
lines changed

llvm/lib/Target/NVPTX/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ set(NVPTXCodeGen_sources
1818
NVPTXAssignValidGlobalNames.cpp
1919
NVPTXAtomicLower.cpp
2020
NVPTXCtorDtorLowering.cpp
21+
NVPTXIRPeephole.cpp
2122
NVPTXForwardParams.cpp
2223
NVPTXFrameLowering.cpp
2324
NVPTXGenericToNVVM.cpp

llvm/lib/Target/NVPTX/NVPTX.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ FunctionPass *createNVPTXLowerAllocaPass();
5252
FunctionPass *createNVPTXLowerUnreachablePass(bool TrapUnreachable,
5353
bool NoTrapAfterNoreturn);
5454
FunctionPass *createNVPTXTagInvariantLoadsPass();
55+
FunctionPass *createNVPTXIRPeepholePass();
5556
MachineFunctionPass *createNVPTXPeephole();
5657
MachineFunctionPass *createNVPTXProxyRegErasurePass();
5758
MachineFunctionPass *createNVPTXForwardParamsPass();
@@ -75,12 +76,17 @@ void initializeNVPTXAAWrapperPassPass(PassRegistry &);
7576
void initializeNVPTXExternalAAWrapperPass(PassRegistry &);
7677
void initializeNVPTXPeepholePass(PassRegistry &);
7778
void initializeNVPTXTagInvariantLoadLegacyPassPass(PassRegistry &);
79+
void initializeNVPTXIRPeepholePass(PassRegistry &);
7880
void initializeNVPTXPrologEpilogPassPass(PassRegistry &);
7981

8082
struct NVVMIntrRangePass : PassInfoMixin<NVVMIntrRangePass> {
8183
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
8284
};
8385

86+
struct NVPTXIRPeepholePass : PassInfoMixin<NVPTXIRPeepholePass> {
87+
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
88+
};
89+
8490
struct NVVMReflectPass : PassInfoMixin<NVVMReflectPass> {
8591
NVVMReflectPass() : SmVersion(0) {}
8692
NVVMReflectPass(unsigned SmVersion) : SmVersion(SmVersion) {}
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
//===------ NVPTXIRPeephole.cpp - NVPTX IR Peephole --------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements IR-level peephole optimizations. These transformations
10+
// run late in the NVPTX IR pass pipeline just before the instruction selection.
11+
//
12+
// Currently, it implements the following transformation(s):
13+
// 1. FMA folding (float/double types):
14+
// Transforms FMUL+FADD/FSUB sequences into FMA intrinsics when the
15+
// 'contract' fast-math flag is present. Supported patterns:
16+
// - fadd(fmul(a, b), c) => fma(a, b, c)
17+
// - fadd(c, fmul(a, b)) => fma(a, b, c)
18+
// - fadd(fmul(a, b), fmul(c, d)) => fma(a, b, fmul(c, d))
19+
// - fsub(fmul(a, b), c) => fma(a, b, fneg(c))
20+
// - fsub(a, fmul(b, c)) => fma(fneg(b), c, a)
21+
// - fsub(fmul(a, b), fmul(c, d)) => fma(a, b, fneg(fmul(c, d)))
22+
//
23+
//===----------------------------------------------------------------------===//
24+
25+
#include "NVPTXUtilities.h"
26+
#include "llvm/IR/IRBuilder.h"
27+
#include "llvm/IR/InstIterator.h"
28+
#include "llvm/IR/Instructions.h"
29+
#include "llvm/IR/Intrinsics.h"
30+
31+
#define DEBUG_TYPE "nvptx-ir-peephole"
32+
33+
using namespace llvm;
34+
35+
static bool tryFoldBinaryFMul(BinaryOperator *BI) {
36+
Value *Op0 = BI->getOperand(0);
37+
Value *Op1 = BI->getOperand(1);
38+
39+
auto *FMul0 = dyn_cast<BinaryOperator>(Op0);
40+
auto *FMul1 = dyn_cast<BinaryOperator>(Op1);
41+
42+
BinaryOperator *FMul = nullptr;
43+
Value *OtherOperand = nullptr;
44+
bool IsFirstOperand = false;
45+
46+
// Either Op0 or Op1 should be a valid FMul
47+
if (FMul0 && FMul0->getOpcode() == Instruction::FMul && FMul0->hasOneUse() &&
48+
FMul0->hasAllowContract()) {
49+
FMul = FMul0;
50+
OtherOperand = Op1;
51+
IsFirstOperand = true;
52+
} else if (FMul1 && FMul1->getOpcode() == Instruction::FMul &&
53+
FMul1->hasOneUse() && FMul1->hasAllowContract()) {
54+
FMul = FMul1;
55+
OtherOperand = Op0;
56+
IsFirstOperand = false;
57+
} else {
58+
return false;
59+
}
60+
61+
bool IsFSub = BI->getOpcode() == Instruction::FSub;
62+
LLVM_DEBUG({
63+
const char *OpName = IsFSub ? "FSub" : "FAdd";
64+
dbgs() << "Found " << OpName << " with FMul (single use) as "
65+
<< (IsFirstOperand ? "first" : "second") << " operand: " << *BI
66+
<< "\n";
67+
});
68+
69+
Value *MulOp0 = FMul->getOperand(0);
70+
Value *MulOp1 = FMul->getOperand(1);
71+
IRBuilder<> Builder(BI);
72+
Value *FMA = nullptr;
73+
74+
if (!IsFSub) {
75+
// fadd(fmul(a, b), c) => fma(a, b, c)
76+
// fadd(c, fmul(a, b)) => fma(a, b, c)
77+
FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
78+
{MulOp0, MulOp1, OtherOperand});
79+
} else {
80+
if (IsFirstOperand) {
81+
// fsub(fmul(a, b), c) => fma(a, b, fneg(c))
82+
Value *NegOtherOp =
83+
Builder.CreateFNegFMF(OtherOperand, BI->getFastMathFlags());
84+
FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
85+
{MulOp0, MulOp1, NegOtherOp});
86+
} else {
87+
// fsub(a, fmul(b, c)) => fma(fneg(b), c, a)
88+
Value *NegMulOp0 =
89+
Builder.CreateFNegFMF(MulOp0, FMul->getFastMathFlags());
90+
FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
91+
{NegMulOp0, MulOp1, OtherOperand});
92+
}
93+
}
94+
95+
// Combine fast-math flags from the original instructions
96+
auto *FMAInst = cast<Instruction>(FMA);
97+
FastMathFlags BinaryFMF = BI->getFastMathFlags();
98+
FastMathFlags FMulFMF = FMul->getFastMathFlags();
99+
FastMathFlags NewFMF = FastMathFlags::intersectRewrite(BinaryFMF, FMulFMF) |
100+
FastMathFlags::unionValue(BinaryFMF, FMulFMF);
101+
FMAInst->setFastMathFlags(NewFMF);
102+
103+
LLVM_DEBUG({
104+
const char *OpName = IsFSub ? "FSub" : "FAdd";
105+
dbgs() << "Replacing " << OpName << " with FMA: " << *FMA << "\n";
106+
});
107+
BI->replaceAllUsesWith(FMA);
108+
BI->eraseFromParent();
109+
FMul->eraseFromParent();
110+
return true;
111+
}
112+
113+
static bool foldFMA(Function &F) {
114+
bool Changed = false;
115+
116+
// Iterate and process float/double FAdd/FSub instructions with allow-contract
117+
for (auto &I : llvm::make_early_inc_range(instructions(F))) {
118+
if (auto *BI = dyn_cast<BinaryOperator>(&I)) {
119+
// Only FAdd and FSub are supported.
120+
if (BI->getOpcode() != Instruction::FAdd &&
121+
BI->getOpcode() != Instruction::FSub)
122+
continue;
123+
124+
// At minimum, the instruction should have allow-contract.
125+
if (!BI->hasAllowContract())
126+
continue;
127+
128+
// Only float and double are supported.
129+
if (!BI->getType()->isFloatTy() && !BI->getType()->isDoubleTy())
130+
continue;
131+
132+
if (tryFoldBinaryFMul(BI))
133+
Changed = true;
134+
}
135+
}
136+
return Changed;
137+
}
138+
139+
namespace {
140+
141+
struct NVPTXIRPeephole : public FunctionPass {
142+
static char ID;
143+
NVPTXIRPeephole() : FunctionPass(ID) {}
144+
bool runOnFunction(Function &F) override;
145+
};
146+
147+
} // namespace
148+
149+
char NVPTXIRPeephole::ID = 0;
150+
INITIALIZE_PASS(NVPTXIRPeephole, "nvptx-ir-peephole", "NVPTX IR Peephole",
151+
false, false)
152+
153+
bool NVPTXIRPeephole::runOnFunction(Function &F) { return foldFMA(F); }
154+
155+
FunctionPass *llvm::createNVPTXIRPeepholePass() {
156+
return new NVPTXIRPeephole();
157+
}
158+
159+
PreservedAnalyses NVPTXIRPeepholePass::run(Function &F,
160+
FunctionAnalysisManager &) {
161+
if (!foldFMA(F))
162+
return PreservedAnalyses::all();
163+
164+
PreservedAnalyses PA;
165+
PA.preserveSet<CFGAnalyses>();
166+
return PA;
167+
}

llvm/lib/Target/NVPTX/NVPTXPassRegistry.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,5 @@ FUNCTION_PASS("nvvm-intr-range", NVVMIntrRangePass())
4040
FUNCTION_PASS("nvptx-copy-byval-args", NVPTXCopyByValArgsPass())
4141
FUNCTION_PASS("nvptx-lower-args", NVPTXLowerArgsPass(*this))
4242
FUNCTION_PASS("nvptx-tag-invariant-loads", NVPTXTagInvariantLoadsPass())
43+
FUNCTION_PASS("nvptx-ir-peephole", NVPTXIRPeepholePass())
4344
#undef FUNCTION_PASS

llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ static cl::opt<bool>
5151
cl::desc("Disable load/store vectorizer"),
5252
cl::init(false), cl::Hidden);
5353

54+
// NVPTX IR Peephole is a new pass; this option will lets us turn it off in case
55+
// we encounter some issues.
56+
static cl::opt<bool>
57+
DisableNVPTXIRPeephole("disable-nvptx-ir-peephole",
58+
cl::desc("Disable NVPTX IR Peephole"),
59+
cl::init(false), cl::Hidden);
60+
5461
// TODO: Remove this flag when we are confident with no regressions.
5562
static cl::opt<bool> DisableRequireStructuredCFG(
5663
"disable-nvptx-require-structured-cfg",
@@ -115,6 +122,7 @@ extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXTarget() {
115122
initializeNVPTXExternalAAWrapperPass(PR);
116123
initializeNVPTXPeepholePass(PR);
117124
initializeNVPTXTagInvariantLoadLegacyPassPass(PR);
125+
initializeNVPTXIRPeepholePass(PR);
118126
initializeNVPTXPrologEpilogPassPass(PR);
119127
}
120128

@@ -379,6 +387,8 @@ void NVPTXPassConfig::addIRPasses() {
379387
addPass(createLoadStoreVectorizerPass());
380388
addPass(createSROAPass());
381389
addPass(createNVPTXTagInvariantLoadsPass());
390+
if (!DisableNVPTXIRPeephole)
391+
addPass(createNVPTXIRPeepholePass());
382392
}
383393

384394
if (ST.hasPTXASUnreachableBug()) {

0 commit comments

Comments
 (0)