Skip to content

[NVPTX] Basic support for fp128 as a storage type #136006

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 45 additions & 53 deletions llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,6 @@ MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
return MCOperand::createExpr(Expr);
}

static bool ShouldPassAsArray(Type *Ty) {
return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
Ty->isHalfTy() || Ty->isBFloatTy();
}

void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
const DataLayout &DL = getDataLayout();
const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
Expand All @@ -264,26 +259,21 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
return;
O << " (";

if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) &&
!ShouldPassAsArray(Ty)) {
unsigned size = 0;
if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
size = ITy->getBitWidth();
} else {
assert(Ty->isFloatingPointTy() && "Floating point type expected here");
size = Ty->getPrimitiveSizeInBits();
}
size = promoteScalarArgumentSize(size);
O << ".param .b" << size << " func_retval0";
} else if (isa<PointerType>(Ty)) {
O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
<< " func_retval0";
} else if (ShouldPassAsArray(Ty)) {
unsigned totalsz = DL.getTypeAllocSize(Ty);
Align RetAlignment = TLI->getFunctionArgumentAlignment(
auto PrintScalarRetVal = [&](unsigned Size) {
O << ".param .b" << promoteScalarArgumentSize(Size) << " func_retval0";
};
if (shouldPassAsArray(Ty)) {
const unsigned TotalSize = DL.getTypeAllocSize(Ty);
const Align RetAlignment = TLI->getFunctionArgumentAlignment(
F, Ty, AttributeList::ReturnIndex, DL);
O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
<< totalsz << "]";
<< TotalSize << "]";
} else if (Ty->isFloatingPointTy()) {
PrintScalarRetVal(Ty->getPrimitiveSizeInBits());
} else if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
PrintScalarRetVal(ITy->getBitWidth());
} else if (isa<PointerType>(Ty)) {
PrintScalarRetVal(TLI->getPointerTy(DL).getSizeInBits());
} else
llvm_unreachable("Unknown return type");
O << ") ";
Expand Down Expand Up @@ -975,8 +965,8 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
O << " .align "
<< GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();

if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
(ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
if (ETy->isPointerTy() || ((ETy->isIntegerTy() || ETy->isFloatingPointTy()) &&
ETy->getScalarSizeInBits() <= 64)) {
O << " .";
// Special case: ABI requires that we use .u8 for predicates
if (ETy->isIntegerTy(1))
Expand Down Expand Up @@ -1016,6 +1006,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
// and vectors are lowered into arrays of bytes.
switch (ETy->getTypeID()) {
case Type::IntegerTyID: // Integers larger than 64 bits
case Type::FP128TyID:
case Type::StructTyID:
case Type::ArrayTyID:
case Type::FixedVectorTyID: {
Expand Down Expand Up @@ -1266,8 +1257,8 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
O << " .align "
<< GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();

// Special case for i128
if (ETy->isIntegerTy(128)) {
// Special case for i128/fp128
if (ETy->getScalarSizeInBits() == 128) {
O << " .b8 ";
getSymbol(GVar)->print(O, MAI);
O << "[16]";
Expand Down Expand Up @@ -1383,7 +1374,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
continue;
}

if (ShouldPassAsArray(Ty)) {
if (shouldPassAsArray(Ty)) {
// Just print .param .align <a> .b8 .param[size];
// <a> = optimal alignment for the element type; always multiple of
// PAL.getParamAlignment
Expand Down Expand Up @@ -1682,48 +1673,49 @@ void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
AggBuffer *aggBuffer) {
const DataLayout &DL = getDataLayout();
int Bytes;

auto ExtendBuffer = [](APInt Val, AggBuffer *Buffer) {
for (unsigned I : llvm::seq(Val.getBitWidth() / 8))
Buffer->addByte(Val.extractBitsAsZExtValue(8, I * 8));
};

// Integers of arbitrary width
if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
APInt Val = CI->getValue();
for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
uint8_t Byte = Val.getLoBits(8).getZExtValue();
aggBuffer->addBytes(&Byte, 1, 1);
Val.lshrInPlace(8);
}
ExtendBuffer(CI->getValue(), aggBuffer);
return;
}

// f128
if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
if (CFP->getType()->isFP128Ty()) {
ExtendBuffer(CFP->getValueAPF().bitcastToAPInt(), aggBuffer);
return;
}
}

// Old constants
if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
if (CPV->getNumOperands())
for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
for (const auto &Op : CPV->operands())
bufferLEByte(cast<Constant>(Op), 0, aggBuffer);
return;
}

if (const ConstantDataSequential *CDS =
dyn_cast<ConstantDataSequential>(CPV)) {
if (CDS->getNumElements())
for (unsigned i = 0; i < CDS->getNumElements(); ++i)
bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
aggBuffer);
if (const auto *CDS = dyn_cast<ConstantDataSequential>(CPV)) {
for (unsigned I : llvm::seq(CDS->getNumElements()))
bufferLEByte(cast<Constant>(CDS->getElementAsConstant(I)), 0, aggBuffer);
return;
}

if (isa<ConstantStruct>(CPV)) {
if (CPV->getNumOperands()) {
StructType *ST = cast<StructType>(CPV->getType());
for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
if (i == (e - 1))
Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
DL.getTypeAllocSize(ST) -
DL.getStructLayout(ST)->getElementOffset(i);
else
Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) -
DL.getStructLayout(ST)->getElementOffset(i);
bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
for (unsigned I : llvm::seq(CPV->getNumOperands())) {
int EndOffset = (I + 1 == CPV->getNumOperands())
? DL.getStructLayout(ST)->getElementOffset(0) +
DL.getTypeAllocSize(ST)
: DL.getStructLayout(ST)->getElementOffset(I + 1);
int Bytes = EndOffset - DL.getStructLayout(ST)->getElementOffset(I);
bufferLEByte(cast<Constant>(CPV->getOperand(I)), Bytes, aggBuffer);
}
}
return;
Expand Down
33 changes: 14 additions & 19 deletions llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,27 +111,22 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {

// Copy Num bytes from Ptr.
// if Bytes > Num, zero fill up to Bytes.
unsigned addBytes(unsigned char *Ptr, int Num, int Bytes) {
assert((curpos + Num) <= size);
assert((curpos + Bytes) <= size);
for (int i = 0; i < Num; ++i) {
buffer[curpos] = Ptr[i];
curpos++;
}
for (int i = Num; i < Bytes; ++i) {
buffer[curpos] = 0;
curpos++;
}
return curpos;
void addBytes(const unsigned char *Ptr, unsigned Num, unsigned Bytes) {
for (unsigned I : llvm::seq(Num))
addByte(Ptr[I]);
if (Bytes > Num)
addZeros(Bytes - Num);
}

unsigned addZeros(int Num) {
assert((curpos + Num) <= size);
for (int i = 0; i < Num; ++i) {
buffer[curpos] = 0;
curpos++;
}
return curpos;
void addByte(uint8_t Byte) {
assert(curpos < size);
buffer[curpos] = Byte;
curpos++;
}

void addZeros(unsigned Num) {
for (unsigned _ : llvm::seq(Num))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we need [[maybe_unused]] here to avoid the warning.

addByte(0);
}

void addSymbol(const Value *GVar, const Value *GVarBeforeStripping) {
Expand Down
28 changes: 10 additions & 18 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,11 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
SmallVector<uint64_t, 16> TempOffsets;

// Special case for i128 - decompose to (i64, i64)
if (Ty->isIntegerTy(128)) {
ValueVTs.push_back(EVT(MVT::i64));
ValueVTs.push_back(EVT(MVT::i64));
if (Ty->isIntegerTy(128) || Ty->isFP128Ty()) {
ValueVTs.append({MVT::i64, MVT::i64});

if (Offsets) {
Offsets->push_back(StartingOffset + 0);
Offsets->push_back(StartingOffset + 8);
}
if (Offsets)
Offsets->append({StartingOffset + 0, StartingOffset + 8});

return;
}
Expand Down Expand Up @@ -1165,11 +1162,6 @@ NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
return DAG.getNode(NVPTXISD::Wrapper, dl, PtrVT, Op);
}

static bool IsTypePassedAsArray(const Type *Ty) {
return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
Ty->isHalfTy() || Ty->isBFloatTy();
}

std::string NVPTXTargetLowering::getPrototype(
const DataLayout &DL, Type *retTy, const ArgListTy &Args,
const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign retAlignment,
Expand All @@ -1186,7 +1178,7 @@ std::string NVPTXTargetLowering::getPrototype(
} else {
O << "(";
if ((retTy->isFloatingPointTy() || retTy->isIntegerTy()) &&
!IsTypePassedAsArray(retTy)) {
!shouldPassAsArray(retTy)) {
unsigned size = 0;
if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
size = ITy->getBitWidth();
Expand All @@ -1203,7 +1195,7 @@ std::string NVPTXTargetLowering::getPrototype(
O << ".param .b" << size << " _";
} else if (isa<PointerType>(retTy)) {
O << ".param .b" << PtrVT.getSizeInBits() << " _";
} else if (IsTypePassedAsArray(retTy)) {
} else if (shouldPassAsArray(retTy)) {
O << ".param .align " << (retAlignment ? retAlignment->value() : 0)
<< " .b8 _[" << DL.getTypeAllocSize(retTy) << "]";
} else {
Expand All @@ -1224,7 +1216,7 @@ std::string NVPTXTargetLowering::getPrototype(
first = false;

if (!Outs[OIdx].Flags.isByVal()) {
if (IsTypePassedAsArray(Ty)) {
if (shouldPassAsArray(Ty)) {
Align ParamAlign =
getArgumentAlignment(&CB, Ty, i + AttributeList::FirstArgIndex, DL);
O << ".param .align " << ParamAlign.value() << " .b8 ";
Expand Down Expand Up @@ -1529,7 +1521,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);

bool NeedAlign; // Does argument declaration specify alignment?
bool PassAsArray = IsByVal || IsTypePassedAsArray(Ty);
const bool PassAsArray = IsByVal || shouldPassAsArray(Ty);
if (IsVAArg) {
if (ParamCount == FirstVAArg) {
SDValue DeclareParamOps[] = {
Expand Down Expand Up @@ -1718,7 +1710,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// .param .align N .b8 retval0[<size-in-bytes>], or
// .param .b<size-in-bits> retval0
unsigned resultsz = DL.getTypeAllocSizeInBits(RetTy);
if (!IsTypePassedAsArray(RetTy)) {
if (!shouldPassAsArray(RetTy)) {
resultsz = promoteScalarArgumentSize(resultsz);
SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
Expand Down Expand Up @@ -3344,7 +3336,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(

if (theArgs[i]->use_empty()) {
// argument is dead
if (IsTypePassedAsArray(Ty) && !Ty->isVectorTy()) {
if (shouldPassAsArray(Ty) && !Ty->isVectorTy()) {
SmallVector<EVT, 16> vtparts;

ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts);
Expand Down
4 changes: 0 additions & 4 deletions llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,4 @@ bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM) {
!isKernelFunction(*F);
}

bool Isv2x16VT(EVT VT) {
return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
}

} // namespace llvm
9 changes: 8 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXUtilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,14 @@ inline unsigned promoteScalarArgumentSize(unsigned size) {

bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM);

bool Isv2x16VT(EVT VT);
inline bool Isv2x16VT(EVT VT) {
return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
}

inline bool shouldPassAsArray(Type *Ty) {
return Ty->isAggregateType() || Ty->isVectorTy() ||
Ty->getScalarSizeInBits() == 128 || Ty->isHalfTy() || Ty->isBFloatTy();
}

namespace NVPTX {
inline std::string getValidPTXIdentifier(StringRef Name) {
Expand Down
56 changes: 56 additions & 0 deletions llvm/test/CodeGen/NVPTX/fp128-storage-type.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -mcpu=sm_20 | FileCheck %s
; RUN: %if ptxas %{ llc < %s-mcpu=sm_20 | %ptxas-verify %}

target triple = "nvptx64-unknown-cuda"

define fp128 @identity(fp128 %x) {
; CHECK-LABEL: identity(
; CHECK: {
; CHECK-NEXT: .reg .b64 %rd<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.v2.u64 {%rd1, %rd2}, [identity_param_0];
; CHECK-NEXT: st.param.v2.b64 [func_retval0], {%rd1, %rd2};
; CHECK-NEXT: ret;
ret fp128 %x
}

define void @load_store(ptr %in, ptr %out) {
; CHECK-LABEL: load_store(
; CHECK: {
; CHECK-NEXT: .reg .b64 %rd<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u64 %rd1, [load_store_param_0];
; CHECK-NEXT: ld.u64 %rd2, [%rd1+8];
; CHECK-NEXT: ld.u64 %rd3, [%rd1];
; CHECK-NEXT: ld.param.u64 %rd4, [load_store_param_1];
; CHECK-NEXT: st.u64 [%rd4], %rd3;
; CHECK-NEXT: st.u64 [%rd4+8], %rd2;
; CHECK-NEXT: ret;
%val = load fp128, ptr %in
store fp128 %val, ptr %out
ret void
}

define void @call(fp128 %x) {
; CHECK-LABEL: call(
; CHECK: {
; CHECK-NEXT: .reg .b64 %rd<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.v2.u64 {%rd1, %rd2}, [call_param_0];
; CHECK-NEXT: { // callseq 0, 0
; CHECK-NEXT: .param .align 16 .b8 param0[16];
; CHECK-NEXT: st.param.v2.b64 [param0], {%rd1, %rd2};
; CHECK-NEXT: call.uni
; CHECK-NEXT: call,
; CHECK-NEXT: (
; CHECK-NEXT: param0
; CHECK-NEXT: );
; CHECK-NEXT: } // callseq 0
; CHECK-NEXT: ret;
call void @call(fp128 %x)
ret void
}
5 changes: 4 additions & 1 deletion llvm/test/CodeGen/NVPTX/global-variable-big.ll
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
target datalayout = "e-i64:64-v16:16-v32:32-n16:32:64"
target triple = "nvptx64-nvidia-cuda"

; Check that we can handle global variables of large integer type.
; Check that we can handle global variables of large integer and fp128 type.

; (lsb) 0x0102'0304'0506...0F10 (msb)
@gv = addrspace(1) externally_initialized global i128 21345817372864405881847059188222722561, align 16
; CHECK: .visible .global .align 16 .b8 gv[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};

@gv_fp128 = addrspace(1) externally_initialized global fp128 0xL0807060504030201100F0E0D0C0B0A09, align 16
; CHECK: .visible .global .align 16 .b8 gv_fp128[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};

; Make sure that we do not overflow on large number of elements.
; CHECK: .visible .global .align 1 .b8 large_data[4831838208];
@large_data = global [4831838208 x i8] zeroinitializer
Loading