@@ -1302,8 +1302,8 @@ std::string NVPTXTargetLowering::getPrototype(
13021302
13031303 bool first = true ;
13041304
1305- unsigned OIdx = 0 ;
1306- for (unsigned i = 0 , e = Args.size (); i != e; ++i, ++OIdx) {
1305+ const Function *F = CB. getFunction () ;
1306+ for (unsigned i = 0 , e = Args.size (), OIdx = 0 ; i != e; ++i, ++OIdx) {
13071307 Type *Ty = Args[i].Ty ;
13081308 if (!first) {
13091309 O << " , " ;
@@ -1312,15 +1312,14 @@ std::string NVPTXTargetLowering::getPrototype(
13121312
13131313 if (!Outs[OIdx].Flags .isByVal ()) {
13141314 if (Ty->isAggregateType () || Ty->isVectorTy () || Ty->isIntegerTy (128 )) {
1315- unsigned align = 0 ;
1315+ unsigned ParamAlign = 0 ;
13161316 const CallInst *CallI = cast<CallInst>(&CB);
13171317 // +1 because index 0 is reserved for return type alignment
1318- if (!getAlign (*CallI, i + 1 , align))
1319- align = DL.getABITypeAlignment (Ty);
1320- unsigned sz = DL.getTypeAllocSize (Ty);
1321- O << " .param .align " << align << " .b8 " ;
1318+ if (!getAlign (*CallI, i + 1 , ParamAlign))
1319+ ParamAlign = getFunctionParamOptimizedAlign (F, Ty, DL).value ();
1320+ O << " .param .align " << ParamAlign << " .b8 " ;
13221321 O << " _" ;
1323- O << " [" << sz << " ]" ;
1322+ O << " [" << DL. getTypeAllocSize (Ty) << " ]" ;
13241323 // update the index for Outs
13251324 SmallVector<EVT, 16 > vtparts;
13261325 ComputeValueVTs (*this , DL, Ty, vtparts);
@@ -1352,11 +1351,17 @@ std::string NVPTXTargetLowering::getPrototype(
13521351 continue ;
13531352 }
13541353
1355- Align align = Outs[OIdx].Flags .getNonZeroByValAlign ();
1356- unsigned sz = Outs[OIdx].Flags .getByValSize ();
1357- O << " .param .align " << align.value () << " .b8 " ;
1354+ Align ParamByValAlign = Outs[OIdx].Flags .getNonZeroByValAlign ();
1355+
1356+ // Try to increase alignment. This code matches logic in LowerCall when
1357+ // alignment increase is performed to increase vectorization options.
1358+ Type *ETy = Args[i].IndirectType ;
1359+ Align AlignCandidate = getFunctionParamOptimizedAlign (F, ETy, DL);
1360+ ParamByValAlign = std::max (ParamByValAlign, AlignCandidate);
1361+
1362+ O << " .param .align " << ParamByValAlign.value () << " .b8 " ;
13581363 O << " _" ;
1359- O << " [" << sz << " ]" ;
1364+ O << " [" << Outs[OIdx]. Flags . getByValSize () << " ]" ;
13601365 }
13611366 O << " );" ;
13621367 return O.str ();
@@ -1403,12 +1408,15 @@ Align NVPTXTargetLowering::getArgumentAlignment(SDValue Callee,
14031408
14041409 // Check for function alignment information if we found that the
14051410 // ultimate target is a Function
1406- if (DirectCallee)
1411+ if (DirectCallee) {
14071412 if (getAlign (*DirectCallee, Idx, Alignment))
14081413 return Align (Alignment);
1414+ // If alignment information is not available, fall back to the
1415+ // default function param optimized type alignment
1416+ return getFunctionParamOptimizedAlign (DirectCallee, Ty, DL);
1417+ }
14091418
1410- // Call is indirect or alignment information is not available, fall back to
1411- // the ABI type alignment
1419+ // Call is indirect, fall back to the ABI type alignment
14121420 return DL.getABITypeAlign (Ty);
14131421}
14141422
@@ -1569,18 +1577,26 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
15691577 }
15701578
15711579 // ByVal arguments
1580+ // TODO: remove code duplication when handling byval and non-byval cases.
15721581 SmallVector<EVT, 16 > VTs;
15731582 SmallVector<uint64_t , 16 > Offsets;
1574- assert (Args[i].IndirectType && " byval arg must have indirect type" );
1575- ComputePTXValueVTs (*this , DL, Args[i].IndirectType , VTs, &Offsets, 0 );
1583+ Type *ETy = Args[i].IndirectType ;
1584+ assert (ETy && " byval arg must have indirect type" );
1585+ ComputePTXValueVTs (*this , DL, ETy, VTs, &Offsets, 0 );
15761586
15771587 // declare .param .align <align> .b8 .param<n>[<size>];
15781588 unsigned sz = Outs[OIdx].Flags .getByValSize ();
15791589 SDVTList DeclareParamVTs = DAG.getVTList (MVT::Other, MVT::Glue);
1580- Align ArgAlign = Outs[OIdx]. Flags . getNonZeroByValAlign ();
1590+
15811591 // The ByValAlign in the Outs[OIdx].Flags is alway set at this point,
15821592 // so we don't need to worry about natural alignment or not.
15831593 // See TargetLowering::LowerCallTo().
1594+ Align ArgAlign = Outs[OIdx].Flags .getNonZeroByValAlign ();
1595+
1596+ // Try to increase alignment to enhance vectorization options.
1597+ const Function *F = CB->getCalledFunction ();
1598+ Align AlignCandidate = getFunctionParamOptimizedAlign (F, ETy, DL);
1599+ ArgAlign = std::max (ArgAlign, AlignCandidate);
15841600
15851601 // Enforce minumum alignment of 4 to work around ptxas miscompile
15861602 // for sm_50+. See corresponding alignment adjustment in
@@ -1594,29 +1610,67 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
15941610 Chain = DAG.getNode (NVPTXISD::DeclareParam, dl, DeclareParamVTs,
15951611 DeclareParamOps);
15961612 InFlag = Chain.getValue (1 );
1613+
1614+ auto VectorInfo = VectorizePTXValueVTs (VTs, Offsets, ArgAlign);
1615+ SmallVector<SDValue, 6 > StoreOperands;
15971616 for (unsigned j = 0 , je = VTs.size (); j != je; ++j) {
15981617 EVT elemtype = VTs[j];
15991618 int curOffset = Offsets[j];
1600- unsigned PartAlign = GreatestCommonDivisor64 (ArgAlign.value (), curOffset);
1619+ Align PartAlign = commonAlignment (ArgAlign, curOffset);
1620+
1621+ // New store.
1622+ if (VectorInfo[j] & PVF_FIRST) {
1623+ assert (StoreOperands.empty () && " Unfinished preceding store." );
1624+ StoreOperands.push_back (Chain);
1625+ StoreOperands.push_back (DAG.getConstant (paramCount, dl, MVT::i32 ));
1626+ StoreOperands.push_back (DAG.getConstant (curOffset, dl, MVT::i32 ));
1627+ }
1628+
16011629 auto PtrVT = getPointerTy (DL);
16021630 SDValue srcAddr = DAG.getNode (ISD::ADD, dl, PtrVT, OutVals[OIdx],
16031631 DAG.getConstant (curOffset, dl, PtrVT));
16041632 SDValue theVal = DAG.getLoad (elemtype, dl, tempChain, srcAddr,
16051633 MachinePointerInfo (), PartAlign);
1634+
16061635 if (elemtype.getSizeInBits () < 16 ) {
1636+ // Use 16-bit registers for small stores as it's the
1637+ // smallest general purpose register size supported by NVPTX.
16071638 theVal = DAG.getNode (ISD::ANY_EXTEND, dl, MVT::i16 , theVal);
16081639 }
1609- SDVTList CopyParamVTs = DAG.getVTList (MVT::Other, MVT::Glue);
1610- SDValue CopyParamOps[] = { Chain,
1611- DAG.getConstant (paramCount, dl, MVT::i32 ),
1612- DAG.getConstant (curOffset, dl, MVT::i32 ),
1613- theVal, InFlag };
1614- Chain = DAG.getMemIntrinsicNode (
1615- NVPTXISD::StoreParam, dl, CopyParamVTs, CopyParamOps, elemtype,
1616- MachinePointerInfo (), /* Align */ None, MachineMemOperand::MOStore);
16171640
1618- InFlag = Chain.getValue (1 );
1641+ // Record the value to store.
1642+ StoreOperands.push_back (theVal);
1643+
1644+ if (VectorInfo[j] & PVF_LAST) {
1645+ unsigned NumElts = StoreOperands.size () - 3 ;
1646+ NVPTXISD::NodeType Op;
1647+ switch (NumElts) {
1648+ case 1 :
1649+ Op = NVPTXISD::StoreParam;
1650+ break ;
1651+ case 2 :
1652+ Op = NVPTXISD::StoreParamV2;
1653+ break ;
1654+ case 4 :
1655+ Op = NVPTXISD::StoreParamV4;
1656+ break ;
1657+ default :
1658+ llvm_unreachable (" Invalid vector info." );
1659+ }
1660+
1661+ StoreOperands.push_back (InFlag);
1662+
1663+ Chain = DAG.getMemIntrinsicNode (
1664+ Op, dl, DAG.getVTList (MVT::Other, MVT::Glue), StoreOperands,
1665+ elemtype, MachinePointerInfo (), PartAlign,
1666+ MachineMemOperand::MOStore);
1667+ InFlag = Chain.getValue (1 );
1668+
1669+ // Cleanup.
1670+ StoreOperands.clear ();
1671+ }
16191672 }
1673+ assert (StoreOperands.empty () && " Unfinished parameter store." );
16201674 ++paramCount;
16211675 }
16221676
@@ -2617,7 +2671,8 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
26172671 const SmallVectorImpl<ISD::OutputArg> &Outs,
26182672 const SmallVectorImpl<SDValue> &OutVals,
26192673 const SDLoc &dl, SelectionDAG &DAG) const {
2620- MachineFunction &MF = DAG.getMachineFunction ();
2674+ const MachineFunction &MF = DAG.getMachineFunction ();
2675+ const Function &F = MF.getFunction ();
26212676 Type *RetTy = MF.getFunction ().getReturnType ();
26222677
26232678 bool isABI = (STI.getSmVersion () >= 20 );
@@ -2632,7 +2687,9 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
26322687 assert (VTs.size () == OutVals.size () && " Bad return value decomposition" );
26332688
26342689 auto VectorInfo = VectorizePTXValueVTs (
2635- VTs, Offsets, RetTy->isSized () ? DL.getABITypeAlign (RetTy) : Align (1 ));
2690+ VTs, Offsets,
2691+ RetTy->isSized () ? getFunctionParamOptimizedAlign (&F, RetTy, DL)
2692+ : Align (1 ));
26362693
26372694 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
26382695 // 32-bits are sign extended or zero extended, depending on whether
@@ -4252,6 +4309,55 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
42524309 return false ;
42534310}
42544311
4312+ // / getFunctionParamOptimizedAlign - since function arguments are passed via
4313+ // / .param space, we may want to increase their alignment in a way that
4314+ // / ensures that we can effectively vectorize their loads & stores. We can
4315+ // / increase alignment only if the function has internal or has private
4316+ // / linkage as for other linkage types callers may already rely on default
4317+ // / alignment. To allow using 128-bit vectorized loads/stores, this function
4318+ // / ensures that alignment is 16 or greater.
4319+ Align NVPTXTargetLowering::getFunctionParamOptimizedAlign (
4320+ const Function *F, Type *ArgTy, const DataLayout &DL) const {
4321+ const uint64_t ABITypeAlign = DL.getABITypeAlign (ArgTy).value ();
4322+
4323+ // If a function has linkage different from internal or private, we
4324+ // must use default ABI alignment as external users rely on it.
4325+ switch (F->getLinkage ()) {
4326+ case GlobalValue::InternalLinkage:
4327+ case GlobalValue::PrivateLinkage: {
4328+ // Check that if a function has internal or private linkage
4329+ // it is not a kernel.
4330+ #ifndef NDEBUG
4331+ const NamedMDNode *NMDN =
4332+ F->getParent ()->getNamedMetadata (" nvvm.annotations" );
4333+ if (NMDN) {
4334+ for (const MDNode *MDN : NMDN->operands ()) {
4335+ assert (MDN->getNumOperands () == 3 );
4336+
4337+ const Metadata *MD0 = MDN->getOperand (0 ).get ();
4338+ const auto *MDV0 = cast<ConstantAsMetadata>(MD0)->getValue ();
4339+ const auto *MDFn = cast<Function>(MDV0);
4340+ if (MDFn != F)
4341+ continue ;
4342+
4343+ const Metadata *MD1 = MDN->getOperand (1 ).get ();
4344+ const MDString *MDStr = cast<MDString>(MD1);
4345+ if (MDStr->getString () != " kernel" )
4346+ continue ;
4347+
4348+ const Metadata *MD2 = MDN->getOperand (2 ).get ();
4349+ const auto *MDV2 = cast<ConstantAsMetadata>(MD2)->getValue ();
4350+ assert (!cast<ConstantInt>(MDV2)->isZero ());
4351+ }
4352+ }
4353+ #endif
4354+ return Align (std::max (uint64_t (16 ), ABITypeAlign));
4355+ }
4356+ default :
4357+ return Align (ABITypeAlign);
4358+ }
4359+ }
4360+
42554361// / isLegalAddressingMode - Return true if the addressing mode represented
42564362// / by AM is legal for this target, for a load/store of the specified type.
42574363// / Used to guide target specific optimizations, like loop strength reduction
0 commit comments