Skip to content

Commit 49adcdf

Browse files
committed
Addressed review comments
1. Removed extra arguments passed to tryFoldBinaryFMul. 2. Removed temporary storage to collect the binary instructions. 3. Made guarding condition little easier to read. 4. Added one more test scenario.
1 parent f1eff5c commit 49adcdf

File tree

6 files changed

+76
-45
lines changed

6 files changed

+76
-45
lines changed

llvm/lib/Target/NVPTX/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ set(NVPTXCodeGen_sources
1717
NVPTXAssignValidGlobalNames.cpp
1818
NVPTXAtomicLower.cpp
1919
NVPTXCtorDtorLowering.cpp
20-
NVPTXFoldFMA.cpp
20+
NVPTXIRPeephole.cpp
2121
NVPTXForwardParams.cpp
2222
NVPTXFrameLowering.cpp
2323
NVPTXGenericToNVVM.cpp

llvm/lib/Target/NVPTX/NVPTX.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ FunctionPass *createNVPTXLowerAllocaPass();
5252
FunctionPass *createNVPTXLowerUnreachablePass(bool TrapUnreachable,
5353
bool NoTrapAfterNoreturn);
5454
FunctionPass *createNVPTXTagInvariantLoadsPass();
55-
FunctionPass *createNVPTXFoldFMAPass();
55+
FunctionPass *createNVPTXIRPeepholePass();
5656
MachineFunctionPass *createNVPTXPeephole();
5757
MachineFunctionPass *createNVPTXProxyRegErasurePass();
5858
MachineFunctionPass *createNVPTXForwardParamsPass();
@@ -77,14 +77,14 @@ void initializeNVPTXAAWrapperPassPass(PassRegistry &);
7777
void initializeNVPTXExternalAAWrapperPass(PassRegistry &);
7878
void initializeNVPTXPeepholePass(PassRegistry &);
7979
void initializeNVPTXTagInvariantLoadLegacyPassPass(PassRegistry &);
80-
void initializeNVPTXFoldFMAPass(PassRegistry &);
80+
void initializeNVPTXIRPeepholePass(PassRegistry &);
8181
void initializeNVPTXPrologEpilogPassPass(PassRegistry &);
8282

8383
struct NVVMIntrRangePass : PassInfoMixin<NVVMIntrRangePass> {
8484
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
8585
};
8686

87-
struct NVPTXFoldFMAPass : PassInfoMixin<NVPTXFoldFMAPass> {
87+
struct NVPTXIRPeepholePass : PassInfoMixin<NVPTXIRPeepholePass> {
8888
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
8989
};
9090

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===------ NVPTXFoldFMA.cpp - Fold FMA --------------===//
1+
//===------ NVPTXIRPeephole.cpp - NVPTX IR Peephole --------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -22,18 +22,37 @@
2222
#include "llvm/IR/Instructions.h"
2323
#include "llvm/IR/Intrinsics.h"
2424

25-
#define DEBUG_TYPE "nvptx-fold-fma"
25+
#define DEBUG_TYPE "nvptx-ir-peephole"
2626

2727
using namespace llvm;
2828

29-
static bool tryFoldBinaryFMul(BinaryOperator *BI, Value *MulOperand,
30-
Value *OtherOperand, bool IsFirstOperand,
31-
bool IsFSub) {
32-
auto *FMul = dyn_cast<BinaryOperator>(MulOperand);
33-
if (!FMul || FMul->getOpcode() != Instruction::FMul || !FMul->hasOneUse() ||
34-
!FMul->hasAllowContract())
29+
static bool tryFoldBinaryFMul(BinaryOperator *BI) {
30+
Value *Op0 = BI->getOperand(0);
31+
Value *Op1 = BI->getOperand(1);
32+
33+
auto *FMul0 = dyn_cast<BinaryOperator>(Op0);
34+
auto *FMul1 = dyn_cast<BinaryOperator>(Op1);
35+
36+
BinaryOperator *FMul = nullptr;
37+
Value *OtherOperand = nullptr;
38+
bool IsFirstOperand = false;
39+
40+
// Either Op0 or Op1 should be a valid FMul
41+
if (FMul0 && FMul0->getOpcode() == Instruction::FMul && FMul0->hasOneUse() &&
42+
FMul0->hasAllowContract()) {
43+
FMul = FMul0;
44+
OtherOperand = Op1;
45+
IsFirstOperand = true;
46+
} else if (FMul1 && FMul1->getOpcode() == Instruction::FMul &&
47+
FMul1->hasOneUse() && FMul1->hasAllowContract()) {
48+
FMul = FMul1;
49+
OtherOperand = Op0;
50+
IsFirstOperand = false;
51+
} else {
3552
return false;
53+
}
3654

55+
bool IsFSub = BI->getOpcode() == Instruction::FSub;
3756
LLVM_DEBUG({
3857
const char *OpName = IsFSub ? "FSub" : "FAdd";
3958
dbgs() << "Found " << OpName << " with FMul (single use) as "
@@ -87,10 +106,9 @@ static bool tryFoldBinaryFMul(BinaryOperator *BI, Value *MulOperand,
87106

88107
static bool foldFMA(Function &F) {
89108
bool Changed = false;
90-
SmallVector<BinaryOperator *, 16> FAddFSubInsts;
91109

92-
// Collect all float/double FAdd/FSub instructions with allow-contract
93-
for (auto &I : instructions(F)) {
110+
// Iterate and process float/double FAdd/FSub instructions with allow-contract
111+
for (auto &I : llvm::make_early_inc_range(instructions(F))) {
94112
if (auto *BI = dyn_cast<BinaryOperator>(&I)) {
95113
// Only FAdd and FSub are supported.
96114
if (BI->getOpcode() != Instruction::FAdd &&
@@ -105,42 +123,35 @@ static bool foldFMA(Function &F) {
105123
if (!BI->getType()->isFloatTy() && !BI->getType()->isDoubleTy())
106124
continue;
107125

108-
FAddFSubInsts.push_back(BI);
126+
if (tryFoldBinaryFMul(BI))
127+
Changed = true;
109128
}
110129
}
111-
112-
for (auto *BI : FAddFSubInsts) {
113-
Value *Op0 = BI->getOperand(0);
114-
Value *Op1 = BI->getOperand(1);
115-
bool IsFSub = BI->getOpcode() == Instruction::FSub;
116-
117-
if (tryFoldBinaryFMul(BI, Op0, Op1, true /*IsFirstOperand*/, IsFSub) ||
118-
tryFoldBinaryFMul(BI, Op1, Op0, false /*IsFirstOperand*/, IsFSub))
119-
Changed = true;
120-
}
121-
122130
return Changed;
123131
}
124132

125133
namespace {
126134

127-
struct NVPTXFoldFMA : public FunctionPass {
135+
struct NVPTXIRPeephole : public FunctionPass {
128136
static char ID;
129-
NVPTXFoldFMA() : FunctionPass(ID) {}
137+
NVPTXIRPeephole() : FunctionPass(ID) {}
130138
bool runOnFunction(Function &F) override;
131139
};
132140

133141
} // namespace
134142

135-
char NVPTXFoldFMA::ID = 0;
136-
INITIALIZE_PASS(NVPTXFoldFMA, "nvptx-fold-fma", "NVPTX Fold FMA", false, false)
143+
char NVPTXIRPeephole::ID = 0;
144+
INITIALIZE_PASS(NVPTXIRPeephole, "nvptx-ir-peephole", "NVPTX IR Peephole",
145+
false, false)
137146

138-
bool NVPTXFoldFMA::runOnFunction(Function &F) { return foldFMA(F); }
147+
bool NVPTXIRPeephole::runOnFunction(Function &F) { return foldFMA(F); }
139148

140-
FunctionPass *llvm::createNVPTXFoldFMAPass() { return new NVPTXFoldFMA(); }
149+
FunctionPass *llvm::createNVPTXIRPeepholePass() {
150+
return new NVPTXIRPeephole();
151+
}
141152

142-
PreservedAnalyses NVPTXFoldFMAPass::run(Function &F,
143-
FunctionAnalysisManager &) {
153+
PreservedAnalyses NVPTXIRPeepholePass::run(Function &F,
154+
FunctionAnalysisManager &) {
144155
if (!foldFMA(F))
145156
return PreservedAnalyses::all();
146157

llvm/lib/Target/NVPTX/NVPTXPassRegistry.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,5 @@ FUNCTION_PASS("nvvm-intr-range", NVVMIntrRangePass())
4040
FUNCTION_PASS("nvptx-copy-byval-args", NVPTXCopyByValArgsPass())
4141
FUNCTION_PASS("nvptx-lower-args", NVPTXLowerArgsPass(*this))
4242
FUNCTION_PASS("nvptx-tag-invariant-loads", NVPTXTagInvariantLoadsPass())
43-
FUNCTION_PASS("nvptx-fold-fma", NVPTXFoldFMAPass())
43+
FUNCTION_PASS("nvptx-ir-peephole", NVPTXIRPeepholePass())
4444
#undef FUNCTION_PASS

llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,12 @@ static cl::opt<bool>
5151
cl::desc("Disable load/store vectorizer"),
5252
cl::init(false), cl::Hidden);
5353

54-
// FoldFMA is a new pass; this option will lets us turn it off in case we
55-
// encounter some issues.
56-
static cl::opt<bool> DisableFoldFMA("disable-nvptx-fold-fma",
57-
cl::desc("Disable NVPTX Fold FMA"),
58-
cl::init(false), cl::Hidden);
54+
// NVPTX IR Peephole is a new pass; this option will lets us turn it off in case
55+
// we encounter some issues.
56+
static cl::opt<bool>
57+
DisableNVPTXIRPeephole("disable-nvptx-ir-peephole",
58+
cl::desc("Disable NVPTX IR Peephole"),
59+
cl::init(false), cl::Hidden);
5960

6061
// TODO: Remove this flag when we are confident with no regressions.
6162
static cl::opt<bool> DisableRequireStructuredCFG(
@@ -121,7 +122,7 @@ extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXTarget() {
121122
initializeNVPTXExternalAAWrapperPass(PR);
122123
initializeNVPTXPeepholePass(PR);
123124
initializeNVPTXTagInvariantLoadLegacyPassPass(PR);
124-
initializeNVPTXFoldFMAPass(PR);
125+
initializeNVPTXIRPeepholePass(PR);
125126
initializeNVPTXPrologEpilogPassPass(PR);
126127
}
127128

@@ -404,8 +405,8 @@ void NVPTXPassConfig::addIRPasses() {
404405
addPass(createLoadStoreVectorizerPass());
405406
addPass(createSROAPass());
406407
addPass(createNVPTXTagInvariantLoadsPass());
407-
if (!DisableFoldFMA)
408-
addPass(createNVPTXFoldFMAPass());
408+
if (!DisableNVPTXIRPeephole)
409+
addPass(createNVPTXIRPeepholePass());
409410
}
410411

411412
if (ST.hasPTXASUnreachableBug()) {

llvm/test/CodeGen/NVPTX/nvptx-fold-fma.ll

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2-
; RUN: opt < %s -passes=nvptx-fold-fma -S | FileCheck %s
2+
; RUN: opt < %s -passes=nvptx-ir-peephole -S | FileCheck %s
33

44
target triple = "nvptx64-nvidia-cuda"
55

@@ -47,6 +47,25 @@ define float @test_fsub_fmul_fmul(float %a, float %b, float %c, float %d) {
4747
}
4848

4949

50+
; fsub(fmul(a, b), fmul(c, d)) => fma(fneg(c), d, fmul(a, b)))
51+
; fmul(a, b) has multiple uses.
52+
define float @test_fsub_fmul_fmul_multiple_use(float %a, float %b, float %c, float %d) {
53+
; CHECK-LABEL: define float @test_fsub_fmul_fmul_multiple_use(
54+
; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]], float [[D:%.*]]) {
55+
; CHECK-NEXT: [[MUL1:%.*]] = fmul contract float [[A]], [[B]]
56+
; CHECK-NEXT: [[TMP1:%.*]] = fneg contract float [[C]]
57+
; CHECK-NEXT: [[TMP2:%.*]] = call contract float @llvm.fma.f32(float [[TMP1]], float [[D]], float [[MUL1]])
58+
; CHECK-NEXT: [[ADD:%.*]] = fadd float [[TMP2]], [[MUL1]]
59+
; CHECK-NEXT: ret float [[ADD]]
60+
;
61+
%mul1 = fmul contract float %a, %b
62+
%mul2 = fmul contract float %c, %d
63+
%sub = fsub contract float %mul1, %mul2
64+
%add = fadd float %sub, %mul1
65+
ret float %add
66+
}
67+
68+
5069
; fsub(fmul(a, b), c) => fma(a, b, fneg(c)) where fsub and fmul are in different BBs
5170
define float @test_fsub_fmul_different_BB(float %a, float %b, float %c, i32 %n) {
5271
; CHECK-LABEL: define float @test_fsub_fmul_different_BB(

0 commit comments

Comments
 (0)