Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL][Matrix] Add generation of spirv.CooperativeMatrixKHR type #13645

Merged
merged 2 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions clang/lib/CodeGen/CodeGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,22 @@ llvm::Type *getJointMatrixINTELExtType(llvm::Type *CompTy,
"spirv.JointMatrixINTEL", {CompTy}, Params);
}

llvm::Type *
getCooperativeMatrixKHRExtType(llvm::Type *CompTy,
ArrayRef<TemplateArgument> TemplateArgs) {
assert(TemplateArgs.size() == 5 &&
"Wrong CooperativeMatrixKHR template parameters number");
std::vector<unsigned> Params;
for (size_t I = 1; I != TemplateArgs.size(); ++I) {
assert(TemplateArgs[I].getKind() == TemplateArgument::Integral &&
"Wrong CooperativeMatrixKHR template parameter");
Params.push_back(TemplateArgs[I].getAsIntegral().getExtValue());
}

return llvm::TargetExtType::get(
CompTy->getContext(), "spirv.CooperativeMatrixKHR", {CompTy}, Params);
}

/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
/// which is represented as a pointer to a structure to LLVM extension type
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
Expand Down Expand Up @@ -363,6 +379,39 @@ llvm::Type *CodeGenTypes::ConvertSYCLJointMatrixINTELType(RecordDecl *RD) {
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
}

/// ConvertSPVCooperativeMatrixType - Convert SYCL joint_matrix type
/// which is represented as a pointer to a structure to LLVM extension type
/// with the parameters that follow SPIR-V CooperativeMatrixKHR type.
/// The expected representation is:
/// target("spirv.CooperativeMatrixKHR", %element_type, %scope%, %rows%, %cols%,
/// %use%)
llvm::Type *CodeGenTypes::ConvertSPVCooperativeMatrixType(RecordDecl *RD) {
auto *TemplateDecl = cast<ClassTemplateSpecializationDecl>(RD);
ArrayRef<TemplateArgument> TemplateArgs =
TemplateDecl->getTemplateArgs().asArray();
assert(TemplateArgs[0].getKind() == TemplateArgument::Type &&
"1st CooperativeMatrixKHR template parameter must be type");
llvm::Type *CompTy = ConvertType(TemplateArgs[0].getAsType());

if (CompTy->isStructTy()) {
StringRef LlvmTyName = CompTy->getStructName();
// Emit half/int16/float for sycl[::*]::{half,bfloat16,tf32}
if (LlvmTyName.starts_with("class.sycl::") ||
LlvmTyName.starts_with("class.__sycl_internal::"))
LlvmTyName = LlvmTyName.rsplit("::").second;
if (LlvmTyName == "half") {
CompTy = llvm::Type::getHalfTy(getLLVMContext());
} else if (LlvmTyName == "tf32") {
CompTy = llvm::Type::getFloatTy(getLLVMContext());
} else if (LlvmTyName == "bfloat16") {
CompTy = llvm::Type::getInt16Ty(getLLVMContext());
} else {
llvm_unreachable("Wrong matrix base type!");
}
}
return getCooperativeMatrixKHRExtType(CompTy, TemplateArgs);
}

/// ConvertType - Convert the specified type to its LLVM form.
llvm::Type *CodeGenTypes::ConvertType(QualType T) {
T = Context.getCanonicalType(T);
Expand Down Expand Up @@ -654,6 +703,10 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
"__spv::__spirv_JointMatrixINTEL") {
ResultType = ConvertSYCLJointMatrixINTELType(RD);
break;
} else if (RD && RD->getQualifiedNameAsString() ==
"__spv::__spirv_CooperativeMatrixKHR") {
ResultType = ConvertSPVCooperativeMatrixType(RD);
break;
} else if (RD && RD->getQualifiedNameAsString() ==
"__spv::__spirv_TaskSequenceINTEL") {
ResultType = llvm::TargetExtType::get(getLLVMContext(),
Expand Down
9 changes: 9 additions & 0 deletions clang/lib/CodeGen/CodeGenTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,15 @@ class CodeGenTypes {
/// %use%, (optional) %element_type_interpretation%)
llvm::Type *ConvertSYCLJointMatrixINTELType(RecordDecl *RD);

/// ConvertSPVCooperativeMatrixType - Convert SYCL joint_matrix type
/// which is represented as a pointer to a structure to LLVM extension type
/// with the parameters that follow SPIR-V CooperativeMatrixKHR type.
/// The expected representation is:
/// target("spirv.CooperativeMatrixKHR", %element_type, %scope%, %rows%,
/// %cols%, %use%)
///
llvm::Type *ConvertSPVCooperativeMatrixType(RecordDecl *RD);

/// GetFunctionType - Get the LLVM function type for \arg Info.
llvm::FunctionType *GetFunctionType(const CGFunctionInfo &Info);

Expand Down
41 changes: 41 additions & 0 deletions clang/test/CodeGenSYCL/cooperative_matrix.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// RUN: %clang_cc1 -triple spir64-unknown-unknown -disable-llvm-passes -emit-llvm %s -o - | FileCheck %s
// Test that SPIR-V codegen generates the expected LLVM struct name for the
// CooperativeMatrixKHR type.
#include <stddef.h>
#include <stdint.h>

namespace __spv {
template <typename T, uint32_t S, size_t R, size_t C, uint32_t U>
struct __spirv_CooperativeMatrixKHR;
}

// CHECK: @_Z2f1{{.*}}(target("spirv.CooperativeMatrixKHR", float, 3, 5, 10, 0)
void f1(__spv::__spirv_CooperativeMatrixKHR<float, 3, 5, 10, 0> *matrix) {}

// CHECK: @_Z2f2{{.*}}(target("spirv.CooperativeMatrixKHR", i64, 3, 10, 2, 1)
void f2(__spv::__spirv_CooperativeMatrixKHR<uint64_t, 3, 10, 2, 1> *matrix) {}

// CHECK: @_Z2f3{{.*}}(target("spirv.CooperativeMatrixKHR", i8, 3, 10, 2, 2)
void f3(__spv::__spirv_CooperativeMatrixKHR<char, 3, 10, 2, 2> *matrix) {}

namespace sycl {
class half {};
class bfloat16 {};
class tf32 {};
}
typedef sycl::half my_half;

// CHECK: @_Z2f4{{.*}}(target("spirv.CooperativeMatrixKHR", half, 3, 10, 2, 0)
void f4(__spv::__spirv_CooperativeMatrixKHR<my_half, 3, 10, 2, 0> *matrix) {}

// CHECK: @_Z2f5{{.*}}(target("spirv.CooperativeMatrixKHR", i16, 3, 10, 2, 0)
void f5(__spv::__spirv_CooperativeMatrixKHR<sycl::bfloat16, 3, 10, 2, 0> *matrix) {}

// CHECK: @_Z2f6{{.*}}(target("spirv.CooperativeMatrixKHR", i128, 3, 10, 2, 0)
void f6(__spv::__spirv_CooperativeMatrixKHR<_BitInt(128), 3, 10, 2, 0> *matrix) {}

// CHECK: @_Z2f7{{.*}}(target("spirv.CooperativeMatrixKHR", float, 3, 10, 2, 0)
void f7(__spv::__spirv_CooperativeMatrixKHR<sycl::tf32, 3, 10, 2, 0> *matrix) {}

// CHECK: @_Z2f8{{.*}}(target("spirv.CooperativeMatrixKHR", double, 3, 5, 10, 0)
void f8(__spv::__spirv_CooperativeMatrixKHR<double, 3, 5, 10, 0> *matrix) {}
Loading