@@ -155,6 +155,14 @@ class SPIRVTypeInt : public SPIRVType {
155155 SPIRVCapVec getRequiredCapability () const override {
156156 SPIRVCapVec CV;
157157 switch (BitWidth) {
158+ case 4 : {
159+ if (Module->isAllowedToUseExtension (ExtensionID::SPV_INTEL_int4)) {
160+ CV.push_back (CapabilityInt4TypeINTEL);
161+ return CV;
162+ }
163+ CV.push_back (CapabilityArbitraryPrecisionIntegersINTEL);
164+ return CV;
165+ }
158166 case 8 :
159167 CV.push_back (CapabilityInt8);
160168 break ;
@@ -175,6 +183,11 @@ class SPIRVTypeInt : public SPIRVType {
175183 }
176184 std::optional<ExtensionID> getRequiredExtension () const override {
177185 switch (BitWidth) {
186+ case 4 : {
187+ if (Module->isAllowedToUseExtension (ExtensionID::SPV_INTEL_int4))
188+ return ExtensionID::SPV_INTEL_int4;
189+ return ExtensionID::SPV_INTEL_arbitrary_precision_integers;
190+ }
178191 case 8 :
179192 case 16 :
180193 case 32 :
@@ -189,7 +202,9 @@ class SPIRVTypeInt : public SPIRVType {
189202 _SPIRV_DEF_ENCDEC3 (Id, BitWidth, IsSigned)
190203 void validate() const override {
191204 SPIRVEntry::validate ();
192- assert ((BitWidth == 8 || BitWidth == 16 || BitWidth == 32 ||
205+ assert (((BitWidth == 4 &&
206+ Module->isAllowedToUseExtension (ExtensionID::SPV_INTEL_int4)) ||
207+ BitWidth == 8 || BitWidth == 16 || BitWidth == 32 ||
193208 BitWidth == 64 ||
194209 Module->isAllowedToUseExtension (
195210 ExtensionID::SPV_INTEL_arbitrary_precision_integers)) &&
@@ -1219,12 +1234,18 @@ class SPIRVTypeCooperativeMatrixKHR : public SPIRVType {
12191234 SPIRVTypeCooperativeMatrixKHR ();
12201235 _SPIRV_DCL_ENCDEC
12211236 std::optional<ExtensionID> getRequiredExtension () const override {
1237+ SPIRVType *Ty = this ->getCompType ();
1238+ if (Ty->isTypeInt () && static_cast <SPIRVTypeInt *>(Ty)->getBitWidth () == 4 )
1239+ this ->getModule ()->addExtension (ExtensionID::SPV_INTEL_int4);
12221240 return ExtensionID::SPV_KHR_cooperative_matrix;
12231241 }
12241242 SPIRVCapVec getRequiredCapability () const override {
12251243 auto CV = getVec (CapabilityCooperativeMatrixKHR);
12261244 if (CompType->isTypeFloat (16 , FPEncodingBFloat16KHR))
12271245 CV.push_back (CapabilityBFloat16CooperativeMatrixKHR);
1246+ else if (CompType->isTypeInt () &&
1247+ static_cast <SPIRVTypeInt *>(CompType)->getBitWidth () == 4 )
1248+ CV.push_back (CapabilityInt4CooperativeMatrixINTEL);
12281249 return CV;
12291250 }
12301251
0 commit comments