diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp index f29fb445c31d..712b39e7a877 100644 --- a/clang/lib/CodeGen/CodeGenTypes.cpp +++ b/clang/lib/CodeGen/CodeGenTypes.cpp @@ -57,11 +57,15 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD, if (auto TemplateDecl = dyn_cast(RD)) { ArrayRef TemplateArgs = TemplateDecl->getTemplateArgs().asArray(); + constexpr size_t NumOfMatrixParameters = 6; + const size_t TemplateArgsSize = TemplateArgs.size(); + assert(TemplateArgsSize == NumOfMatrixParameters && + "Incorrect number of template parameters for JointMatrixINTEL"); OS << "spirv.JointMatrixINTEL."; - for (auto &TemplateArg : TemplateArgs) { - OS << "_"; - if (TemplateArg.getKind() == TemplateArgument::Type) { - llvm::Type *TTy = ConvertType(TemplateArg.getAsType()); + for (size_t I = 0; I != TemplateArgsSize; ++I) { + if (TemplateArgs[I].getKind() == TemplateArgument::Type) { + OS << "_"; + llvm::Type *TTy = ConvertType(TemplateArgs[I].getAsType()); if (TTy->isIntegerTy()) { switch (TTy->getIntegerBitWidth()) { case 8: @@ -91,8 +95,16 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD, OS << LlvmTyName; } else TTy->print(OS, false, true); - } else if (TemplateArg.getKind() == TemplateArgument::Integral) - OS << TemplateArg.getAsIntegral(); + } else if (TemplateArgs[I].getKind() == TemplateArgument::Integral) { + const auto IntTemplateParam = TemplateArgs[I].getAsIntegral(); + // Last template parameter of __spirv_JointMatrixINTEL 'Use' is + // optional in SPIR-V, so If it has 'Unnecessary' value - skip it. + // MatrixUse::Unnecessary defined as '3' in spirv_types.hpp. + constexpr size_t Unnecessary = 3; + if (!(I == NumOfMatrixParameters && + IntTemplateParam == Unnecessary)) + OS << "_" << IntTemplateParam; + } } Ty->setName(OS.str()); return; diff --git a/clang/test/CodeGenSYCL/matrix.cpp b/clang/test/CodeGenSYCL/matrix.cpp index a36151859051..1c630b8007af 100644 --- a/clang/test/CodeGenSYCL/matrix.cpp +++ b/clang/test/CodeGenSYCL/matrix.cpp @@ -5,18 +5,18 @@ #include namespace __spv { - template + template struct __spirv_JointMatrixINTEL; } // CHECK: @_Z2f1{{.*}}(%spirv.JointMatrixINTEL._float_5_10_0_1 -void f1(__spv::__spirv_JointMatrixINTEL *matrix) {} +void f1(__spv::__spirv_JointMatrixINTEL *matrix) {} // CHECK: @_Z2f2{{.*}}(%spirv.JointMatrixINTEL._long_10_2_0_0 -void f2(__spv::__spirv_JointMatrixINTEL *matrix) {} +void f2(__spv::__spirv_JointMatrixINTEL *matrix) {} // CHECK: @_Z2f3{{.*}}(%spirv.JointMatrixINTEL._char_10_2_0_0 -void f3(__spv::__spirv_JointMatrixINTEL *matrix) {} +void f3(__spv::__spirv_JointMatrixINTEL *matrix) {} namespace sycl { class half {}; @@ -25,10 +25,13 @@ namespace sycl { typedef sycl::half my_half; // CHECK: @_Z2f4{{.*}}(%spirv.JointMatrixINTEL._half_10_2_0_0 -void f4(__spv::__spirv_JointMatrixINTEL *matrix) {} +void f4(__spv::__spirv_JointMatrixINTEL *matrix) {} // CHECK: @_Z2f5{{.*}}(%spirv.JointMatrixINTEL._bfloat16_10_2_0_0 -void f5(__spv::__spirv_JointMatrixINTEL *matrix) {} +void f5(__spv::__spirv_JointMatrixINTEL *matrix) {} // CHECK: @_Z2f6{{.*}}(%spirv.JointMatrixINTEL._i128_10_2_0_0 -void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0> *matrix) {} +void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0, 3> *matrix) {} + +// CHECK: @_Z2f7{{.*}}(%spirv.JointMatrixINTEL._char_10_2_0_0_1 +void f7(__spv::__spirv_JointMatrixINTEL *matrix) {}