Skip to content

Commit 3a136cb

Browse files
committed
[SYCL][SPIR-V] Drop Unnecessary Matrix Use parameter
Last parameter 'Use' is optional in SPIR-V, but is not optional in DPCPP headers. If it has Unnecessary value - skip it Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com>
1 parent 9f89247 commit 3a136cb

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

clang/lib/CodeGen/CodeGenTypes.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,14 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD,
5757
if (auto TemplateDecl = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
5858
ArrayRef<TemplateArgument> TemplateArgs =
5959
TemplateDecl->getTemplateArgs().asArray();
60+
constexpr size_t MaxMatrixParameter = 6;
61+
assert(TemplateArgs.size() <= MaxMatrixParameter &&
62+
"Too many template parameters for JointMatrixINTEL type");
6063
OS << "spirv.JointMatrixINTEL.";
61-
for (auto &TemplateArg : TemplateArgs) {
62-
OS << "_";
63-
if (TemplateArg.getKind() == TemplateArgument::Type) {
64-
llvm::Type *TTy = ConvertType(TemplateArg.getAsType());
64+
for (size_t I = 0; I != TemplateArgs.size(); ++I) {
65+
if (TemplateArgs[I].getKind() == TemplateArgument::Type) {
66+
OS << "_";
67+
llvm::Type *TTy = ConvertType(TemplateArgs[I].getAsType());
6568
if (TTy->isIntegerTy()) {
6669
switch (TTy->getIntegerBitWidth()) {
6770
case 8:
@@ -91,8 +94,14 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD,
9194
OS << LlvmTyName;
9295
} else
9396
TTy->print(OS, false, true);
94-
} else if (TemplateArg.getKind() == TemplateArgument::Integral)
95-
OS << TemplateArg.getAsIntegral();
97+
} else if (TemplateArgs[I].getKind() == TemplateArgument::Integral) {
98+
const auto IntTemplateParam = TemplateArgs[I].getAsIntegral();
99+
// Last parameter 'Use' is optional in SPIR-V, but is not optional
100+
// in DPCPP headers. If it has Unnecessary value - skip it
101+
constexpr size_t Unnecessary = 3;
102+
if (!(I == MaxMatrixParameter && IntTemplateParam == Unnecessary))
103+
OS << "_" << IntTemplateParam;
104+
}
96105
}
97106
Ty->setName(OS.str());
98107
return;

clang/test/CodeGenSYCL/matrix.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@
55
#include <stdint.h>
66

77
namespace __spv {
8-
template <typename T, size_t R, size_t C, uint32_t U, uint32_t S>
8+
template <typename T, size_t R, size_t C, uint32_t L, uint32_t S, uint32_t U>
99
struct __spirv_JointMatrixINTEL;
1010
}
1111

1212
// CHECK: @_Z2f1{{.*}}(%spirv.JointMatrixINTEL._float_5_10_0_1
13-
void f1(__spv::__spirv_JointMatrixINTEL<float, 5, 10, 0, 1> *matrix) {}
13+
void f1(__spv::__spirv_JointMatrixINTEL<float, 5, 10, 0, 1, 3> *matrix) {}
1414

1515
// CHECK: @_Z2f2{{.*}}(%spirv.JointMatrixINTEL._long_10_2_0_0
16-
void f2(__spv::__spirv_JointMatrixINTEL<uint64_t, 10, 2, 0, 0> *matrix) {}
16+
void f2(__spv::__spirv_JointMatrixINTEL<uint64_t, 10, 2, 0, 0, 3> *matrix) {}
1717

1818
// CHECK: @_Z2f3{{.*}}(%spirv.JointMatrixINTEL._char_10_2_0_0
19-
void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0> *matrix) {}
19+
void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0, 3> *matrix) {}
2020

2121
namespace sycl {
2222
class half {};
@@ -25,10 +25,13 @@ namespace sycl {
2525
typedef sycl::half my_half;
2626

2727
// CHECK: @_Z2f4{{.*}}(%spirv.JointMatrixINTEL._half_10_2_0_0
28-
void f4(__spv::__spirv_JointMatrixINTEL<my_half, 10, 2, 0, 0> *matrix) {}
28+
void f4(__spv::__spirv_JointMatrixINTEL<my_half, 10, 2, 0, 0, 3> *matrix) {}
2929

3030
// CHECK: @_Z2f5{{.*}}(%spirv.JointMatrixINTEL._bfloat16_10_2_0_0
31-
void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0> *matrix) {}
31+
void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0, 3> *matrix) {}
3232

3333
// CHECK: @_Z2f6{{.*}}(%spirv.JointMatrixINTEL._i128_10_2_0_0
34-
void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0> *matrix) {}
34+
void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0, 3> *matrix) {}
35+
36+
// CHECK: @_Z2f7{{.*}}(%spirv.JointMatrixINTEL._char_10_2_0_0_1
37+
void f7(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0, 1> *matrix) {}

0 commit comments

Comments
 (0)