-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[NVPTX] Basic support for fp128 as a storage type #136006
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVPTX] Basic support for fp128 as a storage type #136006
Conversation
@llvm/pr-subscribers-backend-nvptx Author: Alex MacLean (AlexMaclean) ChangesWhile fp128 operations are not natively supported in hardware, emulation for them is supported by Fixes: #95471 Full diff: https://github.com/llvm/llvm-project/pull/136006.diff 6 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 65cfeadc21a3b..e0f9a1ada3bc4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -249,11 +249,6 @@ MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
return MCOperand::createExpr(Expr);
}
-static bool ShouldPassAsArray(Type *Ty) {
- return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
- Ty->isHalfTy() || Ty->isBFloatTy();
-}
-
void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
const DataLayout &DL = getDataLayout();
const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
@@ -264,26 +259,21 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
return;
O << " (";
- if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) &&
- !ShouldPassAsArray(Ty)) {
- unsigned size = 0;
- if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
- size = ITy->getBitWidth();
- } else {
- assert(Ty->isFloatingPointTy() && "Floating point type expected here");
- size = Ty->getPrimitiveSizeInBits();
- }
- size = promoteScalarArgumentSize(size);
- O << ".param .b" << size << " func_retval0";
- } else if (isa<PointerType>(Ty)) {
- O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
- << " func_retval0";
- } else if (ShouldPassAsArray(Ty)) {
- unsigned totalsz = DL.getTypeAllocSize(Ty);
- Align RetAlignment = TLI->getFunctionArgumentAlignment(
+ auto PrintScalarParam = [&](unsigned Size) {
+ O << ".param .b" << promoteScalarArgumentSize(Size) << " func_retval0";
+ };
+ if (shouldPassAsArray(Ty)) {
+ const unsigned TotalSize = DL.getTypeAllocSize(Ty);
+ const Align RetAlignment = TLI->getFunctionArgumentAlignment(
F, Ty, AttributeList::ReturnIndex, DL);
O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
- << totalsz << "]";
+ << TotalSize << "]";
+ } else if (Ty->isFloatingPointTy()) {
+ PrintScalarParam(Ty->getPrimitiveSizeInBits());
+ } else if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
+ PrintScalarParam(ITy->getBitWidth());
+ } else if (isa<PointerType>(Ty)) {
+ PrintScalarParam(TLI->getPointerTy(DL).getSizeInBits());
} else
llvm_unreachable("Unknown return type");
O << ") ";
@@ -975,8 +965,8 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
O << " .align "
<< GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();
- if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
- (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
+ if (ETy->isPointerTy() || ((ETy->isIntegerTy() || ETy->isFloatingPointTy()) &&
+ ETy->getScalarSizeInBits() <= 64)) {
O << " .";
// Special case: ABI requires that we use .u8 for predicates
if (ETy->isIntegerTy(1))
@@ -1016,6 +1006,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
// and vectors are lowered into arrays of bytes.
switch (ETy->getTypeID()) {
case Type::IntegerTyID: // Integers larger than 64 bits
+ case Type::FP128TyID:
case Type::StructTyID:
case Type::ArrayTyID:
case Type::FixedVectorTyID: {
@@ -1266,8 +1257,8 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
O << " .align "
<< GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();
- // Special case for i128
- if (ETy->isIntegerTy(128)) {
+ // Special case for i128/fp128
+ if (ETy->getScalarSizeInBits() == 128) {
O << " .b8 ";
getSymbol(GVar)->print(O, MAI);
O << "[16]";
@@ -1383,7 +1374,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
continue;
}
- if (ShouldPassAsArray(Ty)) {
+ if (shouldPassAsArray(Ty)) {
// Just print .param .align <a> .b8 .param[size];
// <a> = optimal alignment for the element type; always multiple of
// PAL.getParamAlignment
@@ -1682,29 +1673,37 @@ void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
AggBuffer *aggBuffer) {
const DataLayout &DL = getDataLayout();
- int Bytes;
- // Integers of arbitrary width
- if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
- APInt Val = CI->getValue();
+ auto BufferConstant = [&](APInt Val) {
for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
uint8_t Byte = Val.getLoBits(8).getZExtValue();
aggBuffer->addBytes(&Byte, 1, 1);
Val.lshrInPlace(8);
}
+ };
+
+ // Integers of arbitrary width
+ if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
+ BufferConstant(CI->getValue());
return;
}
+ // f128
+ if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
+ if (CFP->getType()->isFP128Ty()) {
+ BufferConstant(CFP->getValueAPF().bitcastToAPInt());
+ return;
+ }
+ }
+
// Old constants
if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
- if (CPV->getNumOperands())
- for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
- bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
+ for (const auto &Op : CPV->operands())
+ bufferLEByte(cast<Constant>(Op), 0, aggBuffer);
return;
}
- if (const ConstantDataSequential *CDS =
- dyn_cast<ConstantDataSequential>(CPV)) {
+ if (const auto *CDS = dyn_cast<ConstantDataSequential>(CPV)) {
if (CDS->getNumElements())
for (unsigned i = 0; i < CDS->getNumElements(); ++i)
bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
@@ -1716,6 +1715,7 @@ void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
if (CPV->getNumOperands()) {
StructType *ST = cast<StructType>(CPV->getType());
for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
+ int Bytes;
if (i == (e - 1))
Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
DL.getTypeAllocSize(ST) -
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 9bde2a976e164..8d26785b898f3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -246,14 +246,11 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
SmallVector<uint64_t, 16> TempOffsets;
// Special case for i128 - decompose to (i64, i64)
- if (Ty->isIntegerTy(128)) {
- ValueVTs.push_back(EVT(MVT::i64));
- ValueVTs.push_back(EVT(MVT::i64));
+ if (Ty->isIntegerTy(128) || Ty->isFP128Ty()) {
+ ValueVTs.append({MVT::i64, MVT::i64});
- if (Offsets) {
- Offsets->push_back(StartingOffset + 0);
- Offsets->push_back(StartingOffset + 8);
- }
+ if (Offsets)
+ Offsets->append({StartingOffset + 0, StartingOffset + 8});
return;
}
@@ -1165,11 +1162,6 @@ NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
return DAG.getNode(NVPTXISD::Wrapper, dl, PtrVT, Op);
}
-static bool IsTypePassedAsArray(const Type *Ty) {
- return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
- Ty->isHalfTy() || Ty->isBFloatTy();
-}
-
std::string NVPTXTargetLowering::getPrototype(
const DataLayout &DL, Type *retTy, const ArgListTy &Args,
const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign retAlignment,
@@ -1186,7 +1178,7 @@ std::string NVPTXTargetLowering::getPrototype(
} else {
O << "(";
if ((retTy->isFloatingPointTy() || retTy->isIntegerTy()) &&
- !IsTypePassedAsArray(retTy)) {
+ !shouldPassAsArray(retTy)) {
unsigned size = 0;
if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
size = ITy->getBitWidth();
@@ -1203,7 +1195,7 @@ std::string NVPTXTargetLowering::getPrototype(
O << ".param .b" << size << " _";
} else if (isa<PointerType>(retTy)) {
O << ".param .b" << PtrVT.getSizeInBits() << " _";
- } else if (IsTypePassedAsArray(retTy)) {
+ } else if (shouldPassAsArray(retTy)) {
O << ".param .align " << (retAlignment ? retAlignment->value() : 0)
<< " .b8 _[" << DL.getTypeAllocSize(retTy) << "]";
} else {
@@ -1224,7 +1216,7 @@ std::string NVPTXTargetLowering::getPrototype(
first = false;
if (!Outs[OIdx].Flags.isByVal()) {
- if (IsTypePassedAsArray(Ty)) {
+ if (shouldPassAsArray(Ty)) {
Align ParamAlign =
getArgumentAlignment(&CB, Ty, i + AttributeList::FirstArgIndex, DL);
O << ".param .align " << ParamAlign.value() << " .b8 ";
@@ -1529,7 +1521,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
bool NeedAlign; // Does argument declaration specify alignment?
- bool PassAsArray = IsByVal || IsTypePassedAsArray(Ty);
+ const bool PassAsArray = IsByVal || shouldPassAsArray(Ty);
if (IsVAArg) {
if (ParamCount == FirstVAArg) {
SDValue DeclareParamOps[] = {
@@ -1718,7 +1710,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// .param .align N .b8 retval0[<size-in-bytes>], or
// .param .b<size-in-bits> retval0
unsigned resultsz = DL.getTypeAllocSizeInBits(RetTy);
- if (!IsTypePassedAsArray(RetTy)) {
+ if (!shouldPassAsArray(RetTy)) {
resultsz = promoteScalarArgumentSize(resultsz);
SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
@@ -3344,7 +3336,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
if (theArgs[i]->use_empty()) {
// argument is dead
- if (IsTypePassedAsArray(Ty) && !Ty->isVectorTy()) {
+ if (shouldPassAsArray(Ty) && !Ty->isVectorTy()) {
SmallVector<EVT, 16> vtparts;
ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts);
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 3d9d2ae372080..b800445a3b19c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -372,8 +372,4 @@ bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM) {
!isKernelFunction(*F);
}
-bool Isv2x16VT(EVT VT) {
- return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
-}
-
} // namespace llvm
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index 9283b398a9c14..2288241ec0178 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -84,7 +84,14 @@ inline unsigned promoteScalarArgumentSize(unsigned size) {
bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM);
-bool Isv2x16VT(EVT VT);
+inline bool Isv2x16VT(EVT VT) {
+ return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
+}
+
+inline bool shouldPassAsArray(Type *Ty) {
+ return Ty->isAggregateType() || Ty->isVectorTy() ||
+ Ty->getScalarSizeInBits() == 128 || Ty->isHalfTy() || Ty->isBFloatTy();
+}
namespace NVPTX {
inline std::string getValidPTXIdentifier(StringRef Name) {
diff --git a/llvm/test/CodeGen/NVPTX/fp128-storage-type.ll b/llvm/test/CodeGen/NVPTX/fp128-storage-type.ll
new file mode 100644
index 0000000000000..5b96f4978a7cb
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/fp128-storage-type.ll
@@ -0,0 +1,56 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mcpu=sm_20 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s-mcpu=sm_20 | %ptxas-verify %}
+
+target triple = "nvptx64-unknown-cuda"
+
+define fp128 @identity(fp128 %x) {
+; CHECK-LABEL: identity(
+; CHECK: {
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.v2.u64 {%rd1, %rd2}, [identity_param_0];
+; CHECK-NEXT: st.param.v2.b64 [func_retval0], {%rd1, %rd2};
+; CHECK-NEXT: ret;
+ ret fp128 %x
+}
+
+define void @load_store(ptr %in, ptr %out) {
+; CHECK-LABEL: load_store(
+; CHECK: {
+; CHECK-NEXT: .reg .b64 %rd<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u64 %rd1, [load_store_param_0];
+; CHECK-NEXT: ld.u64 %rd2, [%rd1+8];
+; CHECK-NEXT: ld.u64 %rd3, [%rd1];
+; CHECK-NEXT: ld.param.u64 %rd4, [load_store_param_1];
+; CHECK-NEXT: st.u64 [%rd4], %rd3;
+; CHECK-NEXT: st.u64 [%rd4+8], %rd2;
+; CHECK-NEXT: ret;
+ %val = load fp128, ptr %in
+ store fp128 %val, ptr %out
+ ret void
+}
+
+define void @call(fp128 %x) {
+; CHECK-LABEL: call(
+; CHECK: {
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.v2.u64 {%rd1, %rd2}, [call_param_0];
+; CHECK-NEXT: { // callseq 0, 0
+; CHECK-NEXT: .param .align 16 .b8 param0[16];
+; CHECK-NEXT: st.param.v2.b64 [param0], {%rd1, %rd2};
+; CHECK-NEXT: call.uni
+; CHECK-NEXT: call,
+; CHECK-NEXT: (
+; CHECK-NEXT: param0
+; CHECK-NEXT: );
+; CHECK-NEXT: } // callseq 0
+; CHECK-NEXT: ret;
+ call void @call(fp128 %x)
+ ret void
+}
diff --git a/llvm/test/CodeGen/NVPTX/global-variable-big.ll b/llvm/test/CodeGen/NVPTX/global-variable-big.ll
index e8d7fb3815b79..422f721d934e0 100644
--- a/llvm/test/CodeGen/NVPTX/global-variable-big.ll
+++ b/llvm/test/CodeGen/NVPTX/global-variable-big.ll
@@ -4,12 +4,15 @@
target datalayout = "e-i64:64-v16:16-v32:32-n16:32:64"
target triple = "nvptx64-nvidia-cuda"
-; Check that we can handle global variables of large integer type.
+; Check that we can handle global variables of large integer and fp128 type.
; (lsb) 0x0102'0304'0506...0F10 (msb)
@gv = addrspace(1) externally_initialized global i128 21345817372864405881847059188222722561, align 16
; CHECK: .visible .global .align 16 .b8 gv[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+@gv_fp128 = addrspace(1) externally_initialized global fp128 0xL33333333333333334004033333333333, align 16
+; CHECK: .visible .global .align 16 .b8 gv_fp128[16] = {51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 3, 4, 64};
+
; Make sure that we do not overflow on large number of elements.
; CHECK: .visible .global .align 1 .b8 large_data[4831838208];
@large_data = global [4831838208 x i8] zeroinitializer
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM in principle with few minor nits.
5851002
to
ab7d81f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with one more suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ship it.
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/160/builds/16504 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/180/builds/16493 Here is the relevant piece of the build log for the reference
|
} | ||
|
||
void addZeros(unsigned Num) { | ||
for (unsigned _ : llvm::seq(Num)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we need [[maybe_unused]]
here to avoid the warning.
ptxas appears to be complaining about PTX it got from llc. I suspect compilation may have failed somehow, as it claims that ".version" is missing. |
@AlexMaclean @Artem-B I've landed 94aa4bf to fix a warning from this PR. Now that I see the discussion here, |
@kazutakahirata, thanks for the quick fix! I'll use |
Great! Thanks for the follow-up! |
While fp128 operations are not natively supported in hardware, emulation for them is supported by nvcc. This change adds basic support for fp128 as a storage type allowing for lowering of IR containing these types. Fixes: llvm#95471
While fp128 operations are not natively supported in hardware, emulation for them is supported by nvcc. This change adds basic support for fp128 as a storage type allowing for lowering of IR containing these types. Fixes: llvm#95471
While fp128 operations are not natively supported in hardware, emulation for them is supported by
nvcc
. This change adds basic support for fp128 as a storage type allowing for lowering of IR containing these types.Fixes: #95471