Skip to content

Commit ef47462

Browse files
authored
[SPIRV] Start adding support for int128 (#170798)
LLVM has pretty thorough support for `int128`, and it has started seeing some use. Even thouth we already have support for the `SPV_ALTERA_arbitrary_precision_integers` extension, the BE was oddly capping integer width to 64-bits. This patch adds partial support for lowering 128-bit integers to `OpTypeInt 128`. Some work remains to be done around legalisation support and validating constant uses (e.g. cases that get lowered to `OpSpecConstantOp`).
1 parent 4f79552 commit ef47462

File tree

10 files changed

+181
-13
lines changed

10 files changed

+181
-13
lines changed

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
5050
unsigned IsBitwidth16 = MI->getFlags() & SPIRV::INST_PRINTER_WIDTH16;
5151
const unsigned NumVarOps = MI->getNumOperands() - StartIndex;
5252

53+
if (MI->getOpcode() == SPIRV::OpConstantI && NumVarOps > 2) {
54+
// SPV_ALTERA_arbitrary_precision_integers allows for integer widths greater
55+
// than 64, which will be encoded via multiple operands.
56+
for (unsigned I = StartIndex; I != MI->getNumOperands(); ++I)
57+
O << ' ' << MI->getOperand(I).getImm();
58+
return;
59+
}
60+
5361
assert((NumVarOps == 1 || NumVarOps == 2) &&
5462
"Unsupported number of bits for literal variable");
5563

llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,14 @@
1313

1414
#include "SPIRVCommandLine.h"
1515
#include "MCTargetDesc/SPIRVBaseInfo.h"
16+
#include "llvm/ADT/STLExtras.h"
1617
#include "llvm/TargetParser/Triple.h"
17-
#include <algorithm>
18+
19+
#include <functional>
1820
#include <map>
21+
#include <string>
22+
#include <utility>
23+
#include <vector>
1924

2025
#define DEBUG_TYPE "spirv-commandline"
2126

@@ -176,7 +181,17 @@ bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
176181
std::set<SPIRV::Extension::Extension> &Vals) {
177182
SmallVector<StringRef, 10> Tokens;
178183
ArgValue.split(Tokens, ",", -1, false);
179-
std::sort(Tokens.begin(), Tokens.end());
184+
llvm::sort(Tokens, [](auto &&LHS, auto &&RHS) {
185+
// We want to ensure that we handle "all" first, to ensure that any
186+
// subsequent disablement actually behaves as expected i.e. given
187+
// --spv-ext=all,-foo, we first enable all and then disable foo; this should
188+
// be revisited and simplified.
189+
if (LHS == "all")
190+
return true;
191+
if (RHS == "all")
192+
return false;
193+
return !(RHS < LHS);
194+
});
180195

181196
std::set<SPIRV::Extension::Extension> EnabledExtensions;
182197

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -151,22 +151,22 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
151151
}
152152

153153
unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
154-
if (Width > 64)
155-
report_fatal_error("Unsupported integer width!");
156154
const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
157155
if (ST.canUseExtension(
158156
SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers) ||
159-
ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4))
157+
(Width == 4 && ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4)))
160158
return Width;
161159
if (Width <= 8)
162-
Width = 8;
160+
return 8;
163161
else if (Width <= 16)
164-
Width = 16;
162+
return 16;
165163
else if (Width <= 32)
166-
Width = 32;
167-
else
168-
Width = 64;
169-
return Width;
164+
return 32;
165+
else if (Width <= 64)
166+
return 64;
167+
else if (Width <= 128)
168+
return 128;
169+
reportFatalUsageError("Unsupported Integer width!");
170170
}
171171

172172
SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
@@ -413,7 +413,7 @@ Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI,
413413
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
414414
.addDef(Res)
415415
.addUse(getSPIRVTypeID(SpvType));
416-
addNumImm(APInt(BitWidth, CI->getZExtValue()), MIB);
416+
addNumImm(CI->getValue(), MIB);
417417
} else {
418418
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
419419
.addDef(Res)

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
4848
const LLT s16 = LLT::scalar(16);
4949
const LLT s32 = LLT::scalar(32);
5050
const LLT s64 = LLT::scalar(64);
51+
const LLT s128 = LLT::scalar(128);
5152

5253
const LLT v16s64 = LLT::fixed_vector(16, 64);
5354
const LLT v16s32 = LLT::fixed_vector(16, 32);
@@ -307,7 +308,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
307308
typeInSet(1, allPtrsScalarsAndVectors)));
308309

309310
getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE})
310-
.legalFor({s1})
311+
.legalFor({s1, s128})
311312
.legalFor(allFloatAndIntScalarsAndPtrs)
312313
.legalFor(allowedVectorTypes)
313314
.moreElementsToNextPow2(0)

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1436,6 +1436,21 @@ void addInstrRequirements(const MachineInstr &MI,
14361436
Reqs.addCapability(SPIRV::Capability::Int16);
14371437
else if (BitWidth == 8)
14381438
Reqs.addCapability(SPIRV::Capability::Int8);
1439+
else if (BitWidth == 4 &&
1440+
ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
1441+
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_int4);
1442+
Reqs.addCapability(SPIRV::Capability::Int4TypeINTEL);
1443+
} else if (BitWidth != 32) {
1444+
if (!ST.canUseExtension(
1445+
SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers))
1446+
reportFatalUsageError(
1447+
"OpTypeInt type with a width other than 8, 16, 32 or 64 bits "
1448+
"requires the following SPIR-V extension: "
1449+
"SPV_ALTERA_arbitrary_precision_integers");
1450+
Reqs.addExtension(
1451+
SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers);
1452+
Reqs.addCapability(SPIRV::Capability::ArbitraryPrecisionIntegersALTERA);
1453+
}
14391454
break;
14401455
}
14411456
case SPIRV::OpDot: {

llvm/lib/Target/SPIRV/SPIRVUtils.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,13 @@ void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) {
171171
// Asm Printer needs this info to print 64-bit operands correctly
172172
MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH64);
173173
return;
174+
} else if (Bitwidth <= 128) {
175+
uint32_t LowBits = Imm.getRawData()[0] & 0xffffffff;
176+
uint32_t MidBits0 = (Imm.getRawData()[0] >> 32) & 0xffffffff;
177+
uint32_t MidBits1 = Imm.getRawData()[1] & 0xffffffff;
178+
uint32_t HighBits = (Imm.getRawData()[1] >> 32) & 0xffffffff;
179+
MIB.addImm(LowBits).addImm(MidBits0).addImm(MidBits1).addImm(HighBits);
180+
return;
174181
}
175182
report_fatal_error("Unsupported constant bitwidth");
176183
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
2+
3+
; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_ALTERA_arbitrary_precision_integers %s -o - | FileCheck %s
4+
; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_ALTERA_arbitrary_precision_integers %s -o - -filetype=obj | spirv-val %}
5+
6+
; CHECK-ERROR: LLVM ERROR: OpTypeInt type with a width other than 8, 16, 32 or 64 bits requires the following SPIR-V extension: SPV_ALTERA_arbitrary_precision_integers
7+
8+
; CHECK: OpCapability ArbitraryPrecisionIntegersALTERA
9+
; CHECK: OpExtension "SPV_ALTERA_arbitrary_precision_integers"
10+
; CHECK: OpName %[[#TestAdd:]] "test_add"
11+
; CHECK: OpName %[[#TestSub:]] "test_sub"
12+
; CHECK: %[[#Int128Ty:]] = OpTypeInt 128 0
13+
; CHECK: %[[#Const64Int128:]] = OpConstant %[[#Int128Ty]] 64 0 0 0
14+
15+
; CHECK: %[[#TestAdd]] = OpFunction
16+
define spir_func void @test_add(i64 %AL, i64 %AH, i64 %BL, i64 %BH, ptr %RL, ptr %RH) {
17+
entry:
18+
; CHECK: {{.*}} = OpUConvert %[[#Int128Ty]]
19+
; CHECK: {{.*}} = OpUConvert %[[#Int128Ty]]
20+
; CHECK: {{.*}} = OpShiftLeftLogical %[[#Int128Ty]] {{%[0-9]+}} %[[#Const64Int128]]
21+
; CHECK: {{.*}} = OpBitwiseOr %[[#Int128Ty]]
22+
; CHECK: {{.*}} = OpUConvert %[[#Int128Ty]]
23+
; CHECK: {{.*}} = OpIAdd %[[#Int128Ty]]
24+
%tmp1 = zext i64 %AL to i128
25+
%tmp23 = zext i64 %AH to i128
26+
%tmp4 = shl i128 %tmp23, 64
27+
%tmp5 = or i128 %tmp4, %tmp1
28+
%tmp67 = zext i64 %BL to i128
29+
%tmp89 = zext i64 %BH to i128
30+
%tmp11 = shl i128 %tmp89, 64
31+
%tmp12 = or i128 %tmp11, %tmp67
32+
%tmp15 = add i128 %tmp12, %tmp5
33+
%tmp1617 = trunc i128 %tmp15 to i64
34+
store i64 %tmp1617, ptr %RL
35+
%tmp21 = lshr i128 %tmp15, 64
36+
%tmp2122 = trunc i128 %tmp21 to i64
37+
store i64 %tmp2122, ptr %RH
38+
ret void
39+
; CHECK: OpFunctionEnd
40+
}
41+
42+
; CHECK: %[[#TestSub]] = OpFunction
43+
define spir_func void @test_sub(i64 %AL, i64 %AH, i64 %BL, i64 %BH, ptr %RL, ptr %RH) {
44+
entry:
45+
; CHECK: {{.*}} = OpUConvert %[[#Int128Ty]]
46+
; CHECK: {{.*}} = OpUConvert %[[#Int128Ty]]
47+
; CHECK: {{.*}} = OpShiftLeftLogical %[[#Int128Ty]] {{%[0-9]+}} %[[#Const64Int128]]
48+
; CHECK: {{.*}} = OpBitwiseOr %[[#Int128Ty]]
49+
; CHECK: {{.*}} = OpUConvert %[[#Int128Ty]]
50+
; CHECK: {{.*}} = OpISub %[[#Int128Ty]]
51+
%tmp1 = zext i64 %AL to i128
52+
%tmp23 = zext i64 %AH to i128
53+
%tmp4 = shl i128 %tmp23, 64
54+
%tmp5 = or i128 %tmp4, %tmp1
55+
%tmp67 = zext i64 %BL to i128
56+
%tmp89 = zext i64 %BH to i128
57+
%tmp11 = shl i128 %tmp89, 64
58+
%tmp12 = or i128 %tmp11, %tmp67
59+
%tmp15 = sub i128 %tmp5, %tmp12
60+
%tmp1617 = trunc i128 %tmp15 to i64
61+
store i64 %tmp1617, ptr %RL
62+
%tmp21 = lshr i128 %tmp15, 64
63+
%tmp2122 = trunc i128 %tmp21 to i64
64+
store i64 %tmp2122, ptr %RH
65+
ret void
66+
; CHECK: OpFunctionEnd
67+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
2+
3+
; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_ALTERA_arbitrary_precision_integers %s -o - | FileCheck %s
4+
; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_ALTERA_arbitrary_precision_integers %s -o - -filetype=obj | spirv-val %}
5+
6+
; CHECK-ERROR: LLVM ERROR: OpTypeInt type with a width other than 8, 16, 32 or 64 bits requires the following SPIR-V extension: SPV_ALTERA_arbitrary_precision_integers
7+
8+
; CHECK: OpCapability ArbitraryPrecisionIntegersALTERA
9+
; CHECK: OpExtension "SPV_ALTERA_arbitrary_precision_integers"
10+
; CHECK: OpName %[[#Foo:]] "foo"
11+
; CHECK: %[[#Int128Ty:]] = OpTypeInt 128 0
12+
13+
; CHECK: %[[#Foo]] = OpFunction
14+
define i64 @foo(i64 %x, i64 %y, i32 %amt) {
15+
; CHECK: {{.*}} = OpUConvert %[[#Int128Ty]]
16+
; CHECK: {{.*}} = OpSConvert %[[#Int128Ty]]
17+
; CHECK: {{.*}} = OpBitwiseOr %[[#Int128Ty]]
18+
; CHECK: {{.*}} = OpUConvert %[[#Int128Ty]]
19+
; CHECK: {{.*}} = OpShiftRightLogical %[[#Int128Ty]]
20+
%tmp0 = zext i64 %x to i128
21+
%tmp1 = sext i64 %y to i128
22+
%tmp2 = or i128 %tmp0, %tmp1
23+
%tmp7 = zext i32 13 to i128
24+
%tmp3 = lshr i128 %tmp2, %tmp7
25+
%tmp4 = trunc i128 %tmp3 to i64
26+
ret i64 %tmp4
27+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
2+
3+
; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_ALTERA_arbitrary_precision_integers %s -o - | FileCheck %s
4+
; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_ALTERA_arbitrary_precision_integers %s -o - -filetype=obj | spirv-val %}
5+
6+
; CHECK-ERROR: LLVM ERROR: OpTypeInt type with a width other than 8, 16, 32 or 64 bits requires the following SPIR-V extension: SPV_ALTERA_arbitrary_precision_integers
7+
8+
; CHECK: OpCapability ArbitraryPrecisionIntegersALTERA
9+
; CHECK: OpExtension "SPV_ALTERA_arbitrary_precision_integers"
10+
; CHECK: OpName %[[#Test:]] "test"
11+
; CHECK: OpName %[[#Exit:]] "exit"
12+
; CHECK: %[[#Int128Ty:]] = OpTypeInt 128 0
13+
; CHECK: %[[#UndefInt128:]] = OpUndef %[[#Int128Ty]]
14+
15+
; CHECK: %[[#Test]] = OpFunction
16+
define void @test() {
17+
entry:
18+
; CHECK: OpSwitch %[[#UndefInt128]] %[[#Exit]] 0 0 3 0 %[[#Exit]] 0 0 5 0 %[[#Exit]] 0 0 4 0 %[[#Exit]] 0 0 8 0 %[[#Exit]]
19+
switch i128 poison, label %exit [
20+
i128 55340232221128654848, label %exit
21+
i128 92233720368547758080, label %exit
22+
i128 73786976294838206464, label %exit
23+
i128 147573952589676412928, label %exit
24+
]
25+
exit:
26+
unreachable
27+
}

llvm/test/CodeGen/SPIRV/extensions/enable-all-extensions-but-one.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=all,-SPV_ALTERA_arbitrary_precision_integers %s -o - | FileCheck %s
2+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=-SPV_ALTERA_arbitrary_precision_integers,all %s -o - | FileCheck %s
23
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=KHR %s -o - | FileCheck %s
34
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=khr %s -o - | FileCheck %s
45

0 commit comments

Comments
 (0)