Skip to content

Commit 675a6ea

Browse files
MrSidimsagainull
authored andcommitted
Start preparing for TypeJointMatrixINTEL switch (intel#1935)
The patch adds TypeJointMatrixINTELv2 which maps to new type OpCode 6184. Under new OpCode matrix type no longer has Layout parameter. The patch also moved 'scope' to optional matrix muladd instruction. The changes are done only in the consumer part to prepare the switch and make E2E switch backward compatible by preparing consumers ahead of time. Unfortunately there is no way to add a test foe this unless it's binary test, but it seems to be a bit unsafe to add this, so the patch was tested locally. Spec change: intel#8175 Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com> Original commit: KhronosGroup/SPIRV-LLVM-Translator@a6fcade
1 parent 5ffcd16 commit 675a6ea

9 files changed

+61
-19
lines changed

llvm-spirv/lib/SPIRV/OCLUtil.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,7 @@ SPIRAddressSpace getOCLOpaqueTypeAddrSpace(Op OpCode) {
890890
case OpTypeSampler:
891891
return SPIRV_SAMPLER_T_ADDR_SPACE;
892892
case internal::OpTypeJointMatrixINTEL:
893+
case internal::OpTypeJointMatrixINTELv2:
893894
return SPIRAS_Global;
894895
default:
895896
if (isSubgroupAvcINTELTypeOpCode(OpCode))

llvm-spirv/lib/SPIRV/SPIRVReader.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -439,10 +439,11 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool UseTPT) {
439439
auto *MT = static_cast<SPIRVTypeJointMatrixINTEL *>(T);
440440
auto R = static_cast<SPIRVConstant *>(MT->getRows())->getZExtIntValue();
441441
auto C = static_cast<SPIRVConstant *>(MT->getColumns())->getZExtIntValue();
442-
auto L = static_cast<SPIRVConstant *>(MT->getLayout())->getZExtIntValue();
443-
auto S = static_cast<SPIRVConstant *>(MT->getScope())->getZExtIntValue();
444-
SmallVector<unsigned, 5> Params = {(unsigned)R, (unsigned)C, (unsigned)L,
445-
(unsigned)S};
442+
std::vector<unsigned> Params = {(unsigned)R, (unsigned)C};
443+
if (auto *Layout = MT->getLayout())
444+
Params.push_back(static_cast<SPIRVConstant *>(Layout)->getZExtIntValue());
445+
Params.push_back(
446+
static_cast<SPIRVConstant *>(MT->getScope())->getZExtIntValue());
446447
if (auto *Use = MT->getUse())
447448
Params.push_back(static_cast<SPIRVConstant *>(Use)->getZExtIntValue());
448449
auto *CTI = MT->getComponentTypeInterpretation();

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVEntry.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ SPIRVEntry *SPIRVEntry::create(Op OpCode) {
8484
static const OpToFactoryMapTy OpToFactoryMap(std::begin(Table),
8585
std::end(Table));
8686

87+
// TODO: To remove this when we make a switch to new version
88+
if (OpCode == internal::OpTypeJointMatrixINTELv2)
89+
OpCode = internal::OpTypeJointMatrixINTEL;
90+
8791
OpToFactoryMapTy::const_iterator Loc = OpToFactoryMap.find(OpCode);
8892
if (Loc != OpToFactoryMap.end())
8993
return Loc->second();

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,6 +1905,7 @@ class SPIRVCompositeConstruct : public SPIRVInstruction {
19051905
case OpTypeArray:
19061906
case OpTypeStruct:
19071907
case internal::OpTypeJointMatrixINTEL:
1908+
case internal::OpTypeJointMatrixINTELv2:
19081909
break;
19091910
default:
19101911
assert(false && "Invalid type");
@@ -3344,10 +3345,10 @@ class SPIRVJointMatrixINTELInst : public SPIRVJointMatrixINTELInstBase {
33443345
SPIRV##x##INTEL;
33453346
_SPIRV_OP(JointMatrixLoad, true, 6, true)
33463347
_SPIRV_OP(JointMatrixStore, false, 5, true)
3347-
_SPIRV_OP(JointMatrixMad, true, 7)
3348-
_SPIRV_OP(JointMatrixSUMad, true, 7)
3349-
_SPIRV_OP(JointMatrixUSMad, true, 7)
3350-
_SPIRV_OP(JointMatrixUUMad, true, 7)
3348+
_SPIRV_OP(JointMatrixMad, true, 6, true)
3349+
_SPIRV_OP(JointMatrixSUMad, true, 6, true)
3350+
_SPIRV_OP(JointMatrixUSMad, true, 6, true)
3351+
_SPIRV_OP(JointMatrixUUMad, true, 6, true)
33513352
// TODO: move to SPIRVJointMatrixINTELWorkItemInst
33523353
_SPIRV_OP(JointMatrixWorkItemLength, true, 4)
33533354
#undef _SPIRV_OP

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVOpCode.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ inline bool isTypeOpCode(Op OpCode) {
219219
return (OpTypeVoid <= OC && OC <= OpTypePipe) || OC == OpTypePipeStorage ||
220220
isSubgroupAvcINTELTypeOpCode(OpCode) || OC == OpTypeVmeImageINTEL ||
221221
isVCOpCode(OpCode) || OC == internal::OpTypeTokenINTEL ||
222-
OC == internal::OpTypeJointMatrixINTEL;
222+
OC == internal::OpTypeJointMatrixINTEL ||
223+
OC == internal::OpTypeJointMatrixINTELv2;
223224
}
224225

225226
inline bool isSpecConstantOpCode(Op OpCode) {

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ _SPIRV_OP_INTERNAL(ArithmeticFenceINTEL, internal::OpArithmeticFenceINTEL)
66
_SPIRV_OP_INTERNAL(ConvertFToBF16INTEL, internal::OpConvertFToBF16INTEL)
77
_SPIRV_OP_INTERNAL(ConvertBF16ToFINTEL, internal::OpConvertBF16ToFINTEL)
88
_SPIRV_OP_INTERNAL(TypeJointMatrixINTEL, internal::OpTypeJointMatrixINTEL)
9+
_SPIRV_OP_INTERNAL(TypeJointMatrixINTEL, internal::OpTypeJointMatrixINTEL)
910
_SPIRV_OP_INTERNAL(JointMatrixLoadINTEL, internal::OpJointMatrixLoadINTEL)
1011
_SPIRV_OP_INTERNAL(JointMatrixStoreINTEL, internal::OpJointMatrixStoreINTEL)
1112
_SPIRV_OP_INTERNAL(JointMatrixMadINTEL, internal::OpJointMatrixMadINTEL)

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVType.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,8 @@ bool SPIRVType::isTypeStruct() const { return OpCode == OpTypeStruct; }
199199
bool SPIRVType::isTypeVector() const { return OpCode == OpTypeVector; }
200200

201201
bool SPIRVType::isTypeJointMatrixINTEL() const {
202-
return OpCode == internal::OpTypeJointMatrixINTEL;
202+
return OpCode == internal::OpTypeJointMatrixINTEL ||
203+
OpCode == internal::OpTypeJointMatrixINTELv2;
203204
}
204205

205206
bool SPIRVType::isTypeVectorBool() const {
@@ -279,13 +280,20 @@ void SPIRVTypeForwardPointer::decode(std::istream &I) {
279280
}
280281

281282
SPIRVTypeJointMatrixINTEL::SPIRVTypeJointMatrixINTEL(
282-
SPIRVModule *M, SPIRVId TheId, SPIRVType *CompType,
283+
SPIRVModule *M, SPIRVId TheId, Op OC, SPIRVType *CompType,
283284
std::vector<SPIRVValue *> Args)
284285
: SPIRVType(M, FixedWC + Args.size(), OC, TheId), CompType(CompType),
285-
Args(Args) {}
286+
Args(std::move(Args)) {}
287+
288+
SPIRVTypeJointMatrixINTEL::SPIRVTypeJointMatrixINTEL(
289+
SPIRVModule *M, SPIRVId TheId, SPIRVType *CompType,
290+
std::vector<SPIRVValue *> Args)
291+
: SPIRVType(M, FixedWC + Args.size(), internal::OpTypeJointMatrixINTEL,
292+
TheId),
293+
CompType(CompType), Args(std::move(Args)) {}
286294

287295
SPIRVTypeJointMatrixINTEL::SPIRVTypeJointMatrixINTEL()
288-
: SPIRVType(OC), CompType(nullptr),
296+
: SPIRVType(internal::OpTypeJointMatrixINTEL), CompType(nullptr),
289297
Args({nullptr, nullptr, nullptr, nullptr}) {}
290298

291299
void SPIRVTypeJointMatrixINTEL::encode(spv_ostream &O) const {

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVType.h

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,13 +1060,18 @@ class SPIRVTypeTokenINTEL : public SPIRVType {
10601060
};
10611061

10621062
class SPIRVTypeJointMatrixINTEL : public SPIRVType {
1063+
Op OC;
10631064
SPIRVType *CompType;
10641065
std::vector<SPIRVValue *> Args;
10651066

10661067
public:
1067-
const static Op OC = internal::OpTypeJointMatrixINTEL;
10681068
const static SPIRVWord FixedWC = 3;
1069-
// Complete constructor
1069+
// Complete constructor with non-default OC
1070+
SPIRVTypeJointMatrixINTEL(SPIRVModule *M, SPIRVId TheId, Op OC,
1071+
SPIRVType *CompType,
1072+
std::vector<SPIRVValue *> Args);
1073+
1074+
// Incomplete constructor for default OC
10701075
SPIRVTypeJointMatrixINTEL(SPIRVModule *M, SPIRVId TheId, SPIRVType *CompType,
10711076
std::vector<SPIRVValue *> Args);
10721077
// Incomplete constructor
@@ -1085,11 +1090,29 @@ class SPIRVTypeJointMatrixINTEL : public SPIRVType {
10851090
SPIRVType *getCompType() const { return CompType; }
10861091
SPIRVValue *getRows() const { return Args[0]; }
10871092
SPIRVValue *getColumns() const { return Args[1]; }
1088-
SPIRVValue *getLayout() const { return Args[2]; }
1089-
SPIRVValue *getScope() const { return Args[3]; }
1090-
SPIRVValue *getUse() const { return Args.size() > 4 ? Args[4] : nullptr; }
1093+
1094+
SPIRVValue *getLayout() const {
1095+
if (this->getOpCode() == internal::OpTypeJointMatrixINTEL)
1096+
return Args[2];
1097+
return nullptr;
1098+
}
1099+
1100+
SPIRVValue *getScope() const {
1101+
if (this->getOpCode() == internal::OpTypeJointMatrixINTEL)
1102+
return Args[3];
1103+
return Args[2];
1104+
}
1105+
1106+
SPIRVValue *getUse() const {
1107+
if (this->getOpCode() == internal::OpTypeJointMatrixINTEL)
1108+
return Args.size() > 4 ? Args[4] : nullptr;
1109+
return Args[3];
1110+
}
1111+
10911112
SPIRVValue *getComponentTypeInterpretation() const {
1092-
return Args.size() > 5 ? Args[5] : nullptr;
1113+
if (this->getOpCode() == internal::OpTypeJointMatrixINTEL)
1114+
return Args.size() > 5 ? Args[5] : nullptr;
1115+
return Args.size() > 4 ? Args[4] : nullptr;
10931116
}
10941117
};
10951118

llvm-spirv/lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ enum InternalOp {
6969
IOpJointMatrixUSMadINTEL = 6129,
7070
IOpJointMatrixUUMadINTEL = 6130,
7171
IOpArithmeticFenceINTEL = 6145,
72+
IOpTypeJointMatrixINTELv2 = 6184,
7273
IOpJointMatrixWorkItemLengthINTEL = 6410,
7374
IOpComplexFMulINTEL = 6415,
7475
IOpComplexFDivINTEL = 6416,
@@ -149,6 +150,7 @@ _SPIRV_OP(Capability, JointMatrixBF16ComponentTypeINTEL)
149150
_SPIRV_OP(Capability, JointMatrixPackedInt2ComponentTypeINTEL)
150151
_SPIRV_OP(Capability, JointMatrixPackedInt4ComponentTypeINTEL)
151152
_SPIRV_OP(Op, TypeJointMatrixINTEL)
153+
_SPIRV_OP(Op, TypeJointMatrixINTELv2)
152154
_SPIRV_OP(Op, JointMatrixLoadINTEL)
153155
_SPIRV_OP(Op, JointMatrixStoreINTEL)
154156
_SPIRV_OP(Op, JointMatrixMadINTEL)

0 commit comments

Comments
 (0)