Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL][Matrix] Use KHR cooperative matrix instructions instead of Intel's #13817

Merged
merged 19 commits into from
Aug 15, 2024
131 changes: 131 additions & 0 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

extern __DPCPP_SYCL_EXTERNAL float __spirv_RoundFToTF32INTEL(float a);

#ifndef __SPIRV_USE_COOPERATIVE_MATRIX
template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
Expand Down Expand Up @@ -174,6 +175,136 @@ template <typename Ts, typename T, std::size_t R, std::size_t C,
extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *,
Ts val, size_t i);
#else // __SPIRV_USE_COOPERATIVE_MATRIX
template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
__spirv_CooperativeMatrixLoadKHR(T *Ptr, __spv::MatrixLayout Layout = L,
std::size_t Stride = 0,
int MemOperand = 0);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreKHR(
T *Ptr, __spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *Object,
__spv::MatrixLayout Layout = L, std::size_t Stride = 0, int MemOperand = 0);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL size_t __spirv_CooperativeMatrixLengthKHR(
__spv::__spirv_CooperativeMatrixKHR<T, S, R, C, U> *);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
__spirv_CooperativeMatrixConstructCheckedINTEL(const T Value, size_t Height,
size_t Stride, size_t Width,
size_t CoordX,
size_t CoordY);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
__spirv_CooperativeMatrixLoadCheckedINTEL(T *Ptr, std::size_t Stride,
size_t Height, size_t Width,
size_t CoordX, size_t CoordY,
__spv::MatrixLayout Layout = L,
int MemOperand = 0);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreCheckedINTEL(
dkhaldi marked this conversation as resolved.
Show resolved Hide resolved
T *Ptr, __spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *Object,
std::size_t Stride, size_t Height, size_t Width, size_t CoordX,
size_t CoordY, __spv::MatrixLayout Layout = L, int MemOperand = 0);

template <typename TA, typename TB, typename TC, std::size_t M, std::size_t K,
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
__spv::MatrixUse UC,
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_CooperativeMatrixKHR<TC, S, M, N, UC> *
__spirv_CooperativeMatrixMulAddKHR(
__spv::__spirv_CooperativeMatrixKHR<TA, S, M, K, UA> *A,
__spv::__spirv_CooperativeMatrixKHR<TB, S, K, N, UB> *B,
__spv::__spirv_CooperativeMatrixKHR<TC, S, M, N, UC> *C,
size_t Operands = 0);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
__spirv_CompositeConstruct(const T v);

// TODO: replace with __spirv_CooperativeMatrixGetElementCoordINTEL when ready
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL __ocl_vec_t<uint32_t, 2>
__spirv_JointMatrixGetElementCoordINTEL(
__spv::__spirv_CooperativeMatrixKHR<T, S, R, C, U> *, size_t i);

// AccessChain followed by load/store serves to extract/insert and element
// from/to the matrix
template <typename Ts, typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL Ts *
__spirv_AccessChain(__spv::__spirv_CooperativeMatrixKHR<T, S, R, C, U> **,
size_t i);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
__spirv_CooperativeMatrixConstructCheckedINTEL(int32_t CoordX,
int32_t CoordY,
uint32_t Height,
uint32_t Width,
const T Value);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
__spirv_CooperativeMatrixLoadCheckedINTEL(
T *Ptr, int32_t CoordX, int32_t CoordY, __spv::MatrixLayout Layout = L,
uint32_t Height = 0, uint32_t Width = 0, std::size_t Stride = 0,
int MemOperand = 0);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreCheckedINTEL(
T *Ptr, int32_t CoordX, int32_t CoordY,
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *Object,
__spv::MatrixLayout Layout = L, uint32_t Height = 0, uint32_t Width = 0,
std::size_t Stride = 0, int MemOperand = 0);
#endif // __SPIRV_USE_COOPERATIVE_MATRIX

template <typename T>
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixPrefetchINTEL(
Expand Down
24 changes: 24 additions & 0 deletions sycl/include/CL/__spirv/spirv_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,34 @@ enum class MatrixLayout : uint32_t {

enum class MatrixUse : uint32_t { MatrixA = 0, MatrixB = 1, Accumulator = 2 };

#ifdef __SPIRV_USE_COOPERATIVE_MATRIX
enum class MatrixOperands : uint32_t {
MrSidims marked this conversation as resolved.
Show resolved Hide resolved
// SPV_KHR_cooperative_matrix operands
NoneKHR = 0,
MatrixASignedComponentsKHR = 0x1,
MatrixBSignedComponentsKHR = 0x2,
MatrixCSignedComponentsKHR = 0x4,
MatrixResultSignedComponentsKHR = 0x8,
SaturatingAccumulationKHR = 0x10,
// SPV_INTEL_joint_matrix operands
MatrixAAndBTF32ComponentsINTEL = 0x20,
MatrixAAndBBFloat16ComponentsINTEL = 0x40,
MatrixCBFloat16ComponentsINTEL = 0x80,
MatrixResultBFloat16ComponentsINTEL = 0x100
};
#endif // __SPIRV_USE_COOPERATIVE_MATRIX

#ifndef __SPIRV_USE_COOPERATIVE_MATRIX

template <typename T, std::size_t R, std::size_t C, MatrixLayout L,
Scope::Flag S = Scope::Flag::Subgroup,
MatrixUse U = MatrixUse::MatrixA>
struct __spirv_JointMatrixINTEL;
#else
template <typename T, Scope::Flag S = Scope::Flag::Subgroup, std::size_t R = 1,
std::size_t C = 1, MatrixUse U = MatrixUse::MatrixA>
struct __spirv_CooperativeMatrixKHR;
#endif // __SPIRV_USE_COOPERATIVE_MATRIX

struct __spirv_TaskSequenceINTEL;

Expand Down
Loading
Loading