Skip to content

Commit a6fcade

Browse files
authored
Start preparing for TypeJointMatrixINTEL switch (#1935)
The patch adds TypeJointMatrixINTELv2 which maps to new type OpCode 6184. Under new OpCode matrix type no longer has Layout parameter. The patch also moved 'scope' to optional matrix muladd instruction. The changes are done only in the consumer part to prepare the switch and make E2E switch backward compatible by preparing consumers ahead of time. Unfortunately there is no way to add a test foe this unless it's binary test, but it seems to be a bit unsafe to add this, so the patch was tested locally. Spec change: intel/llvm#8175 Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com>
1 parent f729c49 commit a6fcade

9 files changed

+61
-19
lines changed

lib/SPIRV/OCLUtil.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,7 @@ SPIRAddressSpace getOCLOpaqueTypeAddrSpace(Op OpCode) {
898898
case OpTypeSampler:
899899
return SPIRV_SAMPLER_T_ADDR_SPACE;
900900
case internal::OpTypeJointMatrixINTEL:
901+
case internal::OpTypeJointMatrixINTELv2:
901902
return SPIRAS_Global;
902903
default:
903904
if (isSubgroupAvcINTELTypeOpCode(OpCode))

lib/SPIRV/SPIRVReader.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -439,10 +439,11 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool UseTPT) {
439439
auto *MT = static_cast<SPIRVTypeJointMatrixINTEL *>(T);
440440
auto R = static_cast<SPIRVConstant *>(MT->getRows())->getZExtIntValue();
441441
auto C = static_cast<SPIRVConstant *>(MT->getColumns())->getZExtIntValue();
442-
auto L = static_cast<SPIRVConstant *>(MT->getLayout())->getZExtIntValue();
443-
auto S = static_cast<SPIRVConstant *>(MT->getScope())->getZExtIntValue();
444-
SmallVector<unsigned, 5> Params = {(unsigned)R, (unsigned)C, (unsigned)L,
445-
(unsigned)S};
442+
std::vector<unsigned> Params = {(unsigned)R, (unsigned)C};
443+
if (auto *Layout = MT->getLayout())
444+
Params.push_back(static_cast<SPIRVConstant *>(Layout)->getZExtIntValue());
445+
Params.push_back(
446+
static_cast<SPIRVConstant *>(MT->getScope())->getZExtIntValue());
446447
if (auto *Use = MT->getUse())
447448
Params.push_back(static_cast<SPIRVConstant *>(Use)->getZExtIntValue());
448449
auto *CTI = MT->getComponentTypeInterpretation();

lib/SPIRV/libSPIRV/SPIRVEntry.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ SPIRVEntry *SPIRVEntry::create(Op OpCode) {
8484
static const OpToFactoryMapTy OpToFactoryMap(std::begin(Table),
8585
std::end(Table));
8686

87+
// TODO: To remove this when we make a switch to new version
88+
if (OpCode == internal::OpTypeJointMatrixINTELv2)
89+
OpCode = internal::OpTypeJointMatrixINTEL;
90+
8791
OpToFactoryMapTy::const_iterator Loc = OpToFactoryMap.find(OpCode);
8892
if (Loc != OpToFactoryMap.end())
8993
return Loc->second();

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,6 +1905,7 @@ class SPIRVCompositeConstruct : public SPIRVInstruction {
19051905
case OpTypeArray:
19061906
case OpTypeStruct:
19071907
case internal::OpTypeJointMatrixINTEL:
1908+
case internal::OpTypeJointMatrixINTELv2:
19081909
break;
19091910
default:
19101911
assert(false && "Invalid type");
@@ -3344,10 +3345,10 @@ class SPIRVJointMatrixINTELInst : public SPIRVJointMatrixINTELInstBase {
33443345
SPIRV##x##INTEL;
33453346
_SPIRV_OP(JointMatrixLoad, true, 6, true)
33463347
_SPIRV_OP(JointMatrixStore, false, 5, true)
3347-
_SPIRV_OP(JointMatrixMad, true, 7)
3348-
_SPIRV_OP(JointMatrixSUMad, true, 7)
3349-
_SPIRV_OP(JointMatrixUSMad, true, 7)
3350-
_SPIRV_OP(JointMatrixUUMad, true, 7)
3348+
_SPIRV_OP(JointMatrixMad, true, 6, true)
3349+
_SPIRV_OP(JointMatrixSUMad, true, 6, true)
3350+
_SPIRV_OP(JointMatrixUSMad, true, 6, true)
3351+
_SPIRV_OP(JointMatrixUUMad, true, 6, true)
33513352
// TODO: move to SPIRVJointMatrixINTELWorkItemInst
33523353
_SPIRV_OP(JointMatrixWorkItemLength, true, 4)
33533354
#undef _SPIRV_OP

lib/SPIRV/libSPIRV/SPIRVOpCode.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,8 @@ inline bool isTypeOpCode(Op OpCode) {
224224
return (OpTypeVoid <= OC && OC <= OpTypePipe) || OC == OpTypePipeStorage ||
225225
isSubgroupAvcINTELTypeOpCode(OpCode) || OC == OpTypeVmeImageINTEL ||
226226
isVCOpCode(OpCode) || OC == internal::OpTypeTokenINTEL ||
227-
OC == internal::OpTypeJointMatrixINTEL;
227+
OC == internal::OpTypeJointMatrixINTEL ||
228+
OC == internal::OpTypeJointMatrixINTELv2;
228229
}
229230

230231
inline bool isSpecConstantOpCode(Op OpCode) {

lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ _SPIRV_OP_INTERNAL(ArithmeticFenceINTEL, internal::OpArithmeticFenceINTEL)
66
_SPIRV_OP_INTERNAL(ConvertFToBF16INTEL, internal::OpConvertFToBF16INTEL)
77
_SPIRV_OP_INTERNAL(ConvertBF16ToFINTEL, internal::OpConvertBF16ToFINTEL)
88
_SPIRV_OP_INTERNAL(TypeJointMatrixINTEL, internal::OpTypeJointMatrixINTEL)
9+
_SPIRV_OP_INTERNAL(TypeJointMatrixINTEL, internal::OpTypeJointMatrixINTEL)
910
_SPIRV_OP_INTERNAL(JointMatrixLoadINTEL, internal::OpJointMatrixLoadINTEL)
1011
_SPIRV_OP_INTERNAL(JointMatrixStoreINTEL, internal::OpJointMatrixStoreINTEL)
1112
_SPIRV_OP_INTERNAL(JointMatrixMadINTEL, internal::OpJointMatrixMadINTEL)

lib/SPIRV/libSPIRV/SPIRVType.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,8 @@ bool SPIRVType::isTypeStruct() const { return OpCode == OpTypeStruct; }
199199
bool SPIRVType::isTypeVector() const { return OpCode == OpTypeVector; }
200200

201201
bool SPIRVType::isTypeJointMatrixINTEL() const {
202-
return OpCode == internal::OpTypeJointMatrixINTEL;
202+
return OpCode == internal::OpTypeJointMatrixINTEL ||
203+
OpCode == internal::OpTypeJointMatrixINTELv2;
203204
}
204205

205206
bool SPIRVType::isTypeVectorBool() const {
@@ -279,13 +280,20 @@ void SPIRVTypeForwardPointer::decode(std::istream &I) {
279280
}
280281

281282
SPIRVTypeJointMatrixINTEL::SPIRVTypeJointMatrixINTEL(
282-
SPIRVModule *M, SPIRVId TheId, SPIRVType *CompType,
283+
SPIRVModule *M, SPIRVId TheId, Op OC, SPIRVType *CompType,
283284
std::vector<SPIRVValue *> Args)
284285
: SPIRVType(M, FixedWC + Args.size(), OC, TheId), CompType(CompType),
285-
Args(Args) {}
286+
Args(std::move(Args)) {}
287+
288+
SPIRVTypeJointMatrixINTEL::SPIRVTypeJointMatrixINTEL(
289+
SPIRVModule *M, SPIRVId TheId, SPIRVType *CompType,
290+
std::vector<SPIRVValue *> Args)
291+
: SPIRVType(M, FixedWC + Args.size(), internal::OpTypeJointMatrixINTEL,
292+
TheId),
293+
CompType(CompType), Args(std::move(Args)) {}
286294

287295
SPIRVTypeJointMatrixINTEL::SPIRVTypeJointMatrixINTEL()
288-
: SPIRVType(OC), CompType(nullptr),
296+
: SPIRVType(internal::OpTypeJointMatrixINTEL), CompType(nullptr),
289297
Args({nullptr, nullptr, nullptr, nullptr}) {}
290298

291299
void SPIRVTypeJointMatrixINTEL::encode(spv_ostream &O) const {

lib/SPIRV/libSPIRV/SPIRVType.h

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,13 +1060,18 @@ class SPIRVTypeTokenINTEL : public SPIRVType {
10601060
};
10611061

10621062
class SPIRVTypeJointMatrixINTEL : public SPIRVType {
1063+
Op OC;
10631064
SPIRVType *CompType;
10641065
std::vector<SPIRVValue *> Args;
10651066

10661067
public:
1067-
const static Op OC = internal::OpTypeJointMatrixINTEL;
10681068
const static SPIRVWord FixedWC = 3;
1069-
// Complete constructor
1069+
// Complete constructor with non-default OC
1070+
SPIRVTypeJointMatrixINTEL(SPIRVModule *M, SPIRVId TheId, Op OC,
1071+
SPIRVType *CompType,
1072+
std::vector<SPIRVValue *> Args);
1073+
1074+
// Incomplete constructor for default OC
10701075
SPIRVTypeJointMatrixINTEL(SPIRVModule *M, SPIRVId TheId, SPIRVType *CompType,
10711076
std::vector<SPIRVValue *> Args);
10721077
// Incomplete constructor
@@ -1085,11 +1090,29 @@ class SPIRVTypeJointMatrixINTEL : public SPIRVType {
10851090
SPIRVType *getCompType() const { return CompType; }
10861091
SPIRVValue *getRows() const { return Args[0]; }
10871092
SPIRVValue *getColumns() const { return Args[1]; }
1088-
SPIRVValue *getLayout() const { return Args[2]; }
1089-
SPIRVValue *getScope() const { return Args[3]; }
1090-
SPIRVValue *getUse() const { return Args.size() > 4 ? Args[4] : nullptr; }
1093+
1094+
SPIRVValue *getLayout() const {
1095+
if (this->getOpCode() == internal::OpTypeJointMatrixINTEL)
1096+
return Args[2];
1097+
return nullptr;
1098+
}
1099+
1100+
SPIRVValue *getScope() const {
1101+
if (this->getOpCode() == internal::OpTypeJointMatrixINTEL)
1102+
return Args[3];
1103+
return Args[2];
1104+
}
1105+
1106+
SPIRVValue *getUse() const {
1107+
if (this->getOpCode() == internal::OpTypeJointMatrixINTEL)
1108+
return Args.size() > 4 ? Args[4] : nullptr;
1109+
return Args[3];
1110+
}
1111+
10911112
SPIRVValue *getComponentTypeInterpretation() const {
1092-
return Args.size() > 5 ? Args[5] : nullptr;
1113+
if (this->getOpCode() == internal::OpTypeJointMatrixINTEL)
1114+
return Args.size() > 5 ? Args[5] : nullptr;
1115+
return Args.size() > 4 ? Args[4] : nullptr;
10931116
}
10941117
};
10951118

lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ enum InternalOp {
6969
IOpJointMatrixUSMadINTEL = 6129,
7070
IOpJointMatrixUUMadINTEL = 6130,
7171
IOpArithmeticFenceINTEL = 6145,
72+
IOpTypeJointMatrixINTELv2 = 6184,
7273
IOpJointMatrixWorkItemLengthINTEL = 6410,
7374
IOpComplexFMulINTEL = 6415,
7475
IOpComplexFDivINTEL = 6416,
@@ -147,6 +148,7 @@ _SPIRV_OP(Capability, JointMatrixBF16ComponentTypeINTEL)
147148
_SPIRV_OP(Capability, JointMatrixPackedInt2ComponentTypeINTEL)
148149
_SPIRV_OP(Capability, JointMatrixPackedInt4ComponentTypeINTEL)
149150
_SPIRV_OP(Op, TypeJointMatrixINTEL)
151+
_SPIRV_OP(Op, TypeJointMatrixINTELv2)
150152
_SPIRV_OP(Op, JointMatrixLoadINTEL)
151153
_SPIRV_OP(Op, JointMatrixStoreINTEL)
152154
_SPIRV_OP(Op, JointMatrixMadINTEL)

0 commit comments

Comments
 (0)