Skip to content

Commit b44bb84

Browse files
committed
Add ComponentTypeInterpretation for joint matrix type
It specifies how to interpret 'Component Type' when components of a joint matrix are storages for values of different types, for example float for TF32, unsigned short for bfloat16. Spec update: intel/llvm#8175 Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com>
1 parent 35f8b4d commit b44bb84

File tree

5 files changed

+77
-39
lines changed

5 files changed

+77
-39
lines changed

lib/SPIRV/SPIRVReader.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,9 +445,28 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool UseTPT) {
445445
(unsigned)S};
446446
if (auto *Use = MT->getUse())
447447
Params.push_back(static_cast<SPIRVConstant *>(Use)->getZExtIntValue());
448+
std::string ComponentTypeName;
449+
auto *CTI = MT->getCTI();
450+
if (!CTI) {
451+
ComponentTypeName = transTypeToOCLTypeName(MT->getCompType());
452+
} else {
453+
switch (static_cast<SPIRVConstant *>(CTI)->getZExtIntValue()) {
454+
case internal::InternalJointMatrixCTI::TF32:
455+
ComponentTypeName = "tf32";
456+
break;
457+
case internal::InternalJointMatrixCTI::Bfloat16:
458+
ComponentTypeName = "bfloat16";
459+
break;
460+
case internal::InternalJointMatrixCTI::PackedInt2:
461+
case internal::InternalJointMatrixCTI::PackedInt4:
462+
// Do nothing just now
463+
break;
464+
default:
465+
llvm_unreachable("Unexpected joint matrix component type");
466+
}
467+
}
448468
return mapType(T, getSPIRVType(internal::OpTypeJointMatrixINTEL,
449-
transTypeToOCLTypeName(MT->getCompType()),
450-
Params, !UseTPT));
469+
ComponentTypeName, Params, !UseTPT));
451470
}
452471
case OpTypeForwardPointer: {
453472
SPIRVTypeForwardPointer *FP =

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,20 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) {
626626
// simply not true.
627627
SPIRVType *LLVMToSPIRVBase::transSPIRVJointMatrixINTELType(
628628
SmallVector<std::string, 8> Postfixes) {
629+
auto ParseInteger = [this](StringRef Postfix) -> ConstantInt * {
630+
unsigned long long N = 0;
631+
if (consumeUnsignedInteger(Postfix, 10, N)) {
632+
BM->getErrorLog().checkError(
633+
false, SPIRVEC_InvalidLlvmModule,
634+
"TypeJointMatrixINTEL expects integer parameters");
635+
return 0;
636+
}
637+
return getUInt32(M, N);
638+
};
639+
std::vector<SPIRVValue *> Args;
640+
for (size_t I = 1; I != Postfixes.size(); ++I)
641+
Args.emplace_back(transConstant(ParseInteger(Postfixes[I])));
642+
629643
Type *ElemTy = nullptr;
630644
StringRef Ty{Postfixes[0]};
631645
auto NumBits = llvm::StringSwitch<unsigned>(Ty)
@@ -634,32 +648,27 @@ SPIRVType *LLVMToSPIRVBase::transSPIRVJointMatrixINTELType(
634648
.Case("int", 32)
635649
.Case("long", 64)
636650
.Default(0);
637-
if (NumBits)
651+
if (NumBits) {
638652
ElemTy = IntegerType::get(M->getContext(), NumBits);
639-
else if (Ty == "half")
653+
} else if (Ty == "half") {
640654
ElemTy = Type::getHalfTy(M->getContext());
641-
else if (Ty == "float")
655+
} else if (Ty == "float") {
642656
ElemTy = Type::getFloatTy(M->getContext());
643-
else if (Ty == "double")
657+
} else if (Ty == "double") {
644658
ElemTy = Type::getDoubleTy(M->getContext());
645-
else if (Ty == "bfloat16")
659+
} else if (Ty == "bfloat16") {
646660
ElemTy = Type::getInt16Ty(M->getContext());
647-
else
661+
auto *CTI = transConstant(getUInt32(M, static_cast<uint64_t>(
662+
internal::InternalJointMatrixCTI::Bfloat16)));
663+
Args.push_back(CTI);
664+
} else if (Ty == "tf32") {
665+
ElemTy = Type::getFloatTy(M->getContext());
666+
auto *CTI = transConstant(getUInt32(M, static_cast<uint64_t>(
667+
internal::InternalJointMatrixCTI::TF32)));
668+
Args.push_back(CTI);
669+
} else {
648670
llvm_unreachable("Unexpected type for matrix!");
649-
650-
auto ParseInteger = [this](StringRef Postfix) -> ConstantInt * {
651-
unsigned long long N = 0;
652-
if (consumeUnsignedInteger(Postfix, 10, N)) {
653-
BM->getErrorLog().checkError(
654-
false, SPIRVEC_InvalidLlvmModule,
655-
"TypeJointMatrixINTEL expects integer parameters");
656-
return 0;
657-
}
658-
return getUInt32(M, N);
659-
};
660-
std::vector<SPIRVValue *> Args;
661-
for (size_t I = 1; I != Postfixes.size(); ++I)
662-
Args.emplace_back(transConstant(ParseInteger(Postfixes[I])));
671+
}
663672
return BM->addJointMatrixINTELType(transType(ElemTy), Args);
664673
}
665674

lib/SPIRV/libSPIRV/SPIRVType.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,6 +1088,7 @@ class SPIRVTypeJointMatrixINTEL : public SPIRVType {
10881088
SPIRVValue *getLayout() const { return Args[2]; }
10891089
SPIRVValue *getScope() const { return Args[3]; }
10901090
SPIRVValue *getUse() const { return Args.size() > 4 ? Args[4] : nullptr; }
1091+
SPIRVValue *getCTI() const { return Args.size() > 5 ? Args[5] : nullptr; }
10911092
};
10921093

10931094
} // namespace SPIRV

lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,13 @@ enum InternalJointMatrixLayout {
9494

9595
enum InternalJointMatrixUse { MatrixA = 0, MatrixB = 1, Accumulator = 2 };
9696

97+
enum InternalJointMatrixCTI {
98+
TF32 = 0,
99+
Bfloat16 = 1,
100+
PackedInt2 = 2,
101+
PackedInt4 = 3
102+
};
103+
97104
enum InternalBuiltIn {
98105
IBuiltInSubDeviceIDINTEL = 6135,
99106
IBuiltInGlobalHWThreadIDINTEL = 6136,

0 commit comments

Comments
 (0)