Skip to content

Commit f854434

Browse files
committed
[NVPTX] Enhance vectorization of ld.param & st.param
Since function parameters and return values are passed via param space, we can force special alignment for values hold in it which will add vectorization options. This change may be done if the function has private or internal linkage. Special alignment is forced during 2 phases. 1) Instruction selection lowering. Here we use special alignment for function prototypes (changing both own return value and parameters alignment), call lowering (changing both callee's return value and parameters alignment). 2) IR pass nvptx-lower-args. Here we change alignment of byval parameters that belong to param space (or are casted to it). We only handle cases when all uses of such parameters are loads from it. For such loads, we can change the alignment according to special type alignment and the load offset. Then, load-store-vectorizer IR pass will perform vectorization where alignment allows it. Special alignment calculated as maximum from default ABI type alignment and alignment 16. Alignment 16 is chosen because it's the maximum size of vectorized ld.param & st.param. Before specifying such special alignment, we should check if it is a multiple of the alignment that the type already has. For example, if a value has an enforced alignment of 64, default ABI alignment of 4 and special alignment of 16, we should preserve 64. This patch will be followed by a refactoring patch that removes duplicating code in handling byval and non-byval arguments. Differential Revision: https://reviews.llvm.org/D121549
1 parent be5c3ca commit f854434

File tree

7 files changed

+1544
-55
lines changed

7 files changed

+1544
-55
lines changed
Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,32 @@
1-
// RUN: %clang_cc1 -triple nvptx -fcuda-is-device \
2-
// RUN: -emit-llvm -o - %s \
1+
// RUN: %clang_cc1 -triple nvptx -fcuda-is-device -emit-llvm -o - %s \
32
// RUN: | FileCheck -check-prefix=NORDC %s
4-
// RUN: %clang_cc1 -triple nvptx -fcuda-is-device \
5-
// RUN: -fgpu-rdc -emit-llvm -o - %s \
3+
// RUN: %clang_cc1 -triple nvptx -fcuda-is-device -emit-llvm -o - %s \
4+
// RUN: | FileCheck -check-prefix=NORDC-NEG %s
5+
// RUN: %clang_cc1 -triple nvptx -fcuda-is-device -fgpu-rdc -emit-llvm -o - %s \
66
// RUN: | FileCheck -check-prefix=RDC %s
7+
// RUN: %clang_cc1 -triple nvptx -fcuda-is-device -fgpu-rdc -emit-llvm -o - %s \
8+
// RUN: | FileCheck -check-prefix=RDC-NEG %s
79

810
#include "Inputs/cuda.h"
911

10-
// NORDC: define internal void @_Z4funcIiEvv()
11-
// NORDC: define{{.*}} void @_Z6kernelIiEvv()
12-
// RDC: define weak_odr void @_Z4funcIiEvv()
13-
// RDC: define weak_odr void @_Z6kernelIiEvv()
14-
1512
template <typename T> __device__ void func() {}
1613
template <typename T> __global__ void kernel() {}
1714

1815
template __device__ void func<int>();
16+
// NORDC: define internal void @_Z4funcIiEvv()
17+
// RDC: define weak_odr void @_Z4funcIiEvv()
18+
1919
template __global__ void kernel<int>();
20+
// NORDC: define void @_Z6kernelIiEvv()
21+
// RDC: define weak_odr void @_Z6kernelIiEvv()
22+
23+
// Ensure that unused static device function is eliminated
24+
static __device__ void static_func() {}
25+
// NORDC-NEG-NOT: define{{.*}} void @_ZL13static_funcv()
26+
// RDC-NEG-NOT: define{{.*}} void @_ZL13static_funcv()
27+
28+
// Ensure that kernel function has external or weak_odr
29+
// linkage regardless static specifier
30+
static __global__ void static_kernel() {}
31+
// NORDC: define void @_ZL13static_kernelv()
32+
// RDC: define weak_odr void @_ZL13static_kernelv()

llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
329329
void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
330330
const DataLayout &DL = getDataLayout();
331331
const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
332-
const TargetLowering *TLI = STI.getTargetLowering();
332+
const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
333333

334334
Type *Ty = F->getReturnType();
335335

@@ -363,7 +363,7 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
363363
unsigned totalsz = DL.getTypeAllocSize(Ty);
364364
unsigned retAlignment = 0;
365365
if (!getAlign(*F, 0, retAlignment))
366-
retAlignment = DL.getABITypeAlignment(Ty);
366+
retAlignment = TLI->getFunctionParamOptimizedAlign(F, Ty, DL).value();
367367
O << ".param .align " << retAlignment << " .b8 func_retval0[" << totalsz
368368
<< "]";
369369
} else
@@ -1348,7 +1348,8 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
13481348
const DataLayout &DL = getDataLayout();
13491349
const AttributeList &PAL = F->getAttributes();
13501350
const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
1351-
const TargetLowering *TLI = STI.getTargetLowering();
1351+
const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
1352+
13521353
Function::const_arg_iterator I, E;
13531354
unsigned paramIndex = 0;
13541355
bool first = true;
@@ -1405,18 +1406,24 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
14051406
}
14061407
}
14071408

1409+
auto getOptimalAlignForParam = [TLI, &DL, &PAL, F,
1410+
paramIndex](Type *Ty) -> Align {
1411+
Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL);
1412+
MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex);
1413+
return max(TypeAlign, ParamAlign);
1414+
};
1415+
14081416
if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) {
14091417
if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
14101418
// Just print .param .align <a> .b8 .param[size];
1411-
// <a> = PAL.getparamalignment
1419+
// <a> = optimal alignment for the element type; always multiple of
1420+
// PAL.getParamAlignment
14121421
// size = typeallocsize of element type
1413-
const Align align = DL.getValueOrABITypeAlignment(
1414-
PAL.getParamAlignment(paramIndex), Ty);
1422+
Align OptimalAlign = getOptimalAlignForParam(Ty);
14151423

1416-
unsigned sz = DL.getTypeAllocSize(Ty);
1417-
O << "\t.param .align " << align.value() << " .b8 ";
1424+
O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
14181425
printParamName(I, paramIndex, O);
1419-
O << "[" << sz << "]";
1426+
O << "[" << DL.getTypeAllocSize(Ty) << "]";
14201427

14211428
continue;
14221429
}
@@ -1492,10 +1499,11 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
14921499

14931500
if (isABI || isKernelFunc) {
14941501
// Just print .param .align <a> .b8 .param[size];
1495-
// <a> = PAL.getparamalignment
1502+
// <a> = optimal alignment for the element type; always multiple of
1503+
// PAL.getParamAlignment
14961504
// size = typeallocsize of element type
1497-
Align align =
1498-
DL.getValueOrABITypeAlignment(PAL.getParamAlignment(paramIndex), ETy);
1505+
Align OptimalAlign = getOptimalAlignForParam(ETy);
1506+
14991507
// Work around a bug in ptxas. When PTX code takes address of
15001508
// byval parameter with alignment < 4, ptxas generates code to
15011509
// spill argument into memory. Alas on sm_50+ ptxas generates
@@ -1507,10 +1515,10 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
15071515
// TODO: this will need to be undone when we get to support multi-TU
15081516
// device-side compilation as it breaks ABI compatibility with nvcc.
15091517
// Hopefully ptxas bug is fixed by then.
1510-
if (!isKernelFunc && align < Align(4))
1511-
align = Align(4);
1518+
if (!isKernelFunc && OptimalAlign < Align(4))
1519+
OptimalAlign = Align(4);
15121520
unsigned sz = DL.getTypeAllocSize(ETy);
1513-
O << "\t.param .align " << align.value() << " .b8 ";
1521+
O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
15141522
printParamName(I, paramIndex, O);
15151523
O << "[" << sz << "]";
15161524
continue;

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 136 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,8 +1302,8 @@ std::string NVPTXTargetLowering::getPrototype(
13021302

13031303
bool first = true;
13041304

1305-
unsigned OIdx = 0;
1306-
for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
1305+
const Function *F = CB.getFunction();
1306+
for (unsigned i = 0, e = Args.size(), OIdx = 0; i != e; ++i, ++OIdx) {
13071307
Type *Ty = Args[i].Ty;
13081308
if (!first) {
13091309
O << ", ";
@@ -1312,15 +1312,14 @@ std::string NVPTXTargetLowering::getPrototype(
13121312

13131313
if (!Outs[OIdx].Flags.isByVal()) {
13141314
if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
1315-
unsigned align = 0;
1315+
unsigned ParamAlign = 0;
13161316
const CallInst *CallI = cast<CallInst>(&CB);
13171317
// +1 because index 0 is reserved for return type alignment
1318-
if (!getAlign(*CallI, i + 1, align))
1319-
align = DL.getABITypeAlignment(Ty);
1320-
unsigned sz = DL.getTypeAllocSize(Ty);
1321-
O << ".param .align " << align << " .b8 ";
1318+
if (!getAlign(*CallI, i + 1, ParamAlign))
1319+
ParamAlign = getFunctionParamOptimizedAlign(F, Ty, DL).value();
1320+
O << ".param .align " << ParamAlign << " .b8 ";
13221321
O << "_";
1323-
O << "[" << sz << "]";
1322+
O << "[" << DL.getTypeAllocSize(Ty) << "]";
13241323
// update the index for Outs
13251324
SmallVector<EVT, 16> vtparts;
13261325
ComputeValueVTs(*this, DL, Ty, vtparts);
@@ -1352,11 +1351,17 @@ std::string NVPTXTargetLowering::getPrototype(
13521351
continue;
13531352
}
13541353

1355-
Align align = Outs[OIdx].Flags.getNonZeroByValAlign();
1356-
unsigned sz = Outs[OIdx].Flags.getByValSize();
1357-
O << ".param .align " << align.value() << " .b8 ";
1354+
Align ParamByValAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
1355+
1356+
// Try to increase alignment. This code matches logic in LowerCall when
1357+
// alignment increase is performed to increase vectorization options.
1358+
Type *ETy = Args[i].IndirectType;
1359+
Align AlignCandidate = getFunctionParamOptimizedAlign(F, ETy, DL);
1360+
ParamByValAlign = std::max(ParamByValAlign, AlignCandidate);
1361+
1362+
O << ".param .align " << ParamByValAlign.value() << " .b8 ";
13581363
O << "_";
1359-
O << "[" << sz << "]";
1364+
O << "[" << Outs[OIdx].Flags.getByValSize() << "]";
13601365
}
13611366
O << ");";
13621367
return O.str();
@@ -1403,12 +1408,15 @@ Align NVPTXTargetLowering::getArgumentAlignment(SDValue Callee,
14031408

14041409
// Check for function alignment information if we found that the
14051410
// ultimate target is a Function
1406-
if (DirectCallee)
1411+
if (DirectCallee) {
14071412
if (getAlign(*DirectCallee, Idx, Alignment))
14081413
return Align(Alignment);
1414+
// If alignment information is not available, fall back to the
1415+
// default function param optimized type alignment
1416+
return getFunctionParamOptimizedAlign(DirectCallee, Ty, DL);
1417+
}
14091418

1410-
// Call is indirect or alignment information is not available, fall back to
1411-
// the ABI type alignment
1419+
// Call is indirect, fall back to the ABI type alignment
14121420
return DL.getABITypeAlign(Ty);
14131421
}
14141422

@@ -1569,18 +1577,26 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
15691577
}
15701578

15711579
// ByVal arguments
1580+
// TODO: remove code duplication when handling byval and non-byval cases.
15721581
SmallVector<EVT, 16> VTs;
15731582
SmallVector<uint64_t, 16> Offsets;
1574-
assert(Args[i].IndirectType && "byval arg must have indirect type");
1575-
ComputePTXValueVTs(*this, DL, Args[i].IndirectType, VTs, &Offsets, 0);
1583+
Type *ETy = Args[i].IndirectType;
1584+
assert(ETy && "byval arg must have indirect type");
1585+
ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, 0);
15761586

15771587
// declare .param .align <align> .b8 .param<n>[<size>];
15781588
unsigned sz = Outs[OIdx].Flags.getByValSize();
15791589
SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1580-
Align ArgAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
1590+
15811591
// The ByValAlign in the Outs[OIdx].Flags is alway set at this point,
15821592
// so we don't need to worry about natural alignment or not.
15831593
// See TargetLowering::LowerCallTo().
1594+
Align ArgAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
1595+
1596+
// Try to increase alignment to enhance vectorization options.
1597+
const Function *F = CB->getCalledFunction();
1598+
Align AlignCandidate = getFunctionParamOptimizedAlign(F, ETy, DL);
1599+
ArgAlign = std::max(ArgAlign, AlignCandidate);
15841600

15851601
// Enforce minumum alignment of 4 to work around ptxas miscompile
15861602
// for sm_50+. See corresponding alignment adjustment in
@@ -1594,29 +1610,67 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
15941610
Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
15951611
DeclareParamOps);
15961612
InFlag = Chain.getValue(1);
1613+
1614+
auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
1615+
SmallVector<SDValue, 6> StoreOperands;
15971616
for (unsigned j = 0, je = VTs.size(); j != je; ++j) {
15981617
EVT elemtype = VTs[j];
15991618
int curOffset = Offsets[j];
1600-
unsigned PartAlign = GreatestCommonDivisor64(ArgAlign.value(), curOffset);
1619+
Align PartAlign = commonAlignment(ArgAlign, curOffset);
1620+
1621+
// New store.
1622+
if (VectorInfo[j] & PVF_FIRST) {
1623+
assert(StoreOperands.empty() && "Unfinished preceding store.");
1624+
StoreOperands.push_back(Chain);
1625+
StoreOperands.push_back(DAG.getConstant(paramCount, dl, MVT::i32));
1626+
StoreOperands.push_back(DAG.getConstant(curOffset, dl, MVT::i32));
1627+
}
1628+
16011629
auto PtrVT = getPointerTy(DL);
16021630
SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, OutVals[OIdx],
16031631
DAG.getConstant(curOffset, dl, PtrVT));
16041632
SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
16051633
MachinePointerInfo(), PartAlign);
1634+
16061635
if (elemtype.getSizeInBits() < 16) {
1636+
// Use 16-bit registers for small stores as it's the
1637+
// smallest general purpose register size supported by NVPTX.
16071638
theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal);
16081639
}
1609-
SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1610-
SDValue CopyParamOps[] = { Chain,
1611-
DAG.getConstant(paramCount, dl, MVT::i32),
1612-
DAG.getConstant(curOffset, dl, MVT::i32),
1613-
theVal, InFlag };
1614-
Chain = DAG.getMemIntrinsicNode(
1615-
NVPTXISD::StoreParam, dl, CopyParamVTs, CopyParamOps, elemtype,
1616-
MachinePointerInfo(), /* Align */ None, MachineMemOperand::MOStore);
16171640

1618-
InFlag = Chain.getValue(1);
1641+
// Record the value to store.
1642+
StoreOperands.push_back(theVal);
1643+
1644+
if (VectorInfo[j] & PVF_LAST) {
1645+
unsigned NumElts = StoreOperands.size() - 3;
1646+
NVPTXISD::NodeType Op;
1647+
switch (NumElts) {
1648+
case 1:
1649+
Op = NVPTXISD::StoreParam;
1650+
break;
1651+
case 2:
1652+
Op = NVPTXISD::StoreParamV2;
1653+
break;
1654+
case 4:
1655+
Op = NVPTXISD::StoreParamV4;
1656+
break;
1657+
default:
1658+
llvm_unreachable("Invalid vector info.");
1659+
}
1660+
1661+
StoreOperands.push_back(InFlag);
1662+
1663+
Chain = DAG.getMemIntrinsicNode(
1664+
Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands,
1665+
elemtype, MachinePointerInfo(), PartAlign,
1666+
MachineMemOperand::MOStore);
1667+
InFlag = Chain.getValue(1);
1668+
1669+
// Cleanup.
1670+
StoreOperands.clear();
1671+
}
16191672
}
1673+
assert(StoreOperands.empty() && "Unfinished parameter store.");
16201674
++paramCount;
16211675
}
16221676

@@ -2617,7 +2671,8 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
26172671
const SmallVectorImpl<ISD::OutputArg> &Outs,
26182672
const SmallVectorImpl<SDValue> &OutVals,
26192673
const SDLoc &dl, SelectionDAG &DAG) const {
2620-
MachineFunction &MF = DAG.getMachineFunction();
2674+
const MachineFunction &MF = DAG.getMachineFunction();
2675+
const Function &F = MF.getFunction();
26212676
Type *RetTy = MF.getFunction().getReturnType();
26222677

26232678
bool isABI = (STI.getSmVersion() >= 20);
@@ -2632,7 +2687,9 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
26322687
assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
26332688

26342689
auto VectorInfo = VectorizePTXValueVTs(
2635-
VTs, Offsets, RetTy->isSized() ? DL.getABITypeAlign(RetTy) : Align(1));
2690+
VTs, Offsets,
2691+
RetTy->isSized() ? getFunctionParamOptimizedAlign(&F, RetTy, DL)
2692+
: Align(1));
26362693

26372694
// PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
26382695
// 32-bits are sign extended or zero extended, depending on whether
@@ -4252,6 +4309,55 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
42524309
return false;
42534310
}
42544311

4312+
/// getFunctionParamOptimizedAlign - since function arguments are passed via
4313+
/// .param space, we may want to increase their alignment in a way that
4314+
/// ensures that we can effectively vectorize their loads & stores. We can
4315+
/// increase alignment only if the function has internal or has private
4316+
/// linkage as for other linkage types callers may already rely on default
4317+
/// alignment. To allow using 128-bit vectorized loads/stores, this function
4318+
/// ensures that alignment is 16 or greater.
4319+
Align NVPTXTargetLowering::getFunctionParamOptimizedAlign(
4320+
const Function *F, Type *ArgTy, const DataLayout &DL) const {
4321+
const uint64_t ABITypeAlign = DL.getABITypeAlign(ArgTy).value();
4322+
4323+
// If a function has linkage different from internal or private, we
4324+
// must use default ABI alignment as external users rely on it.
4325+
switch (F->getLinkage()) {
4326+
case GlobalValue::InternalLinkage:
4327+
case GlobalValue::PrivateLinkage: {
4328+
// Check that if a function has internal or private linkage
4329+
// it is not a kernel.
4330+
#ifndef NDEBUG
4331+
const NamedMDNode *NMDN =
4332+
F->getParent()->getNamedMetadata("nvvm.annotations");
4333+
if (NMDN) {
4334+
for (const MDNode *MDN : NMDN->operands()) {
4335+
assert(MDN->getNumOperands() == 3);
4336+
4337+
const Metadata *MD0 = MDN->getOperand(0).get();
4338+
const auto *MDV0 = cast<ConstantAsMetadata>(MD0)->getValue();
4339+
const auto *MDFn = cast<Function>(MDV0);
4340+
if (MDFn != F)
4341+
continue;
4342+
4343+
const Metadata *MD1 = MDN->getOperand(1).get();
4344+
const MDString *MDStr = cast<MDString>(MD1);
4345+
if (MDStr->getString() != "kernel")
4346+
continue;
4347+
4348+
const Metadata *MD2 = MDN->getOperand(2).get();
4349+
const auto *MDV2 = cast<ConstantAsMetadata>(MD2)->getValue();
4350+
assert(!cast<ConstantInt>(MDV2)->isZero());
4351+
}
4352+
}
4353+
#endif
4354+
return Align(std::max(uint64_t(16), ABITypeAlign));
4355+
}
4356+
default:
4357+
return Align(ABITypeAlign);
4358+
}
4359+
}
4360+
42554361
/// isLegalAddressingMode - Return true if the addressing mode represented
42564362
/// by AM is legal for this target, for a load/store of the specified type.
42574363
/// Used to guide target specific optimizations, like loop strength reduction

0 commit comments

Comments
 (0)