Skip to content

Commit

Permalink
[SYCL][Matrix] Use KHR cooperative matrix instructions instead of Int…
Browse files Browse the repository at this point in the history
…el's (intel#13817)

The usage is currently guarded by __SPIRV_USE_COOPERATIVE_MATRIX macro.

It's a split from intel#13316

---------

Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com>
  • Loading branch information
MrSidims authored Aug 15, 2024
1 parent 8407960 commit 15fbefc
Show file tree
Hide file tree
Showing 85 changed files with 2,107 additions and 2 deletions.
131 changes: 131 additions & 0 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

extern __DPCPP_SYCL_EXTERNAL float __spirv_RoundFToTF32INTEL(float a);

#ifndef __SPIRV_USE_COOPERATIVE_MATRIX
template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
Expand Down Expand Up @@ -174,6 +175,136 @@ template <typename Ts, typename T, std::size_t R, std::size_t C,
extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *,
Ts val, size_t i);
#else // __SPIRV_USE_COOPERATIVE_MATRIX
template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
__spirv_CooperativeMatrixLoadKHR(T *Ptr, __spv::MatrixLayout Layout = L,
std::size_t Stride = 0,
int MemOperand = 0);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreKHR(
T *Ptr, __spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *Object,
__spv::MatrixLayout Layout = L, std::size_t Stride = 0, int MemOperand = 0);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL size_t __spirv_CooperativeMatrixLengthKHR(
__spv::__spirv_CooperativeMatrixKHR<T, S, R, C, U> *);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
__spirv_CooperativeMatrixConstructCheckedINTEL(const T Value, size_t Height,
size_t Stride, size_t Width,
size_t CoordX,
size_t CoordY);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
__spirv_CooperativeMatrixLoadCheckedINTEL(T *Ptr, std::size_t Stride,
size_t Height, size_t Width,
size_t CoordX, size_t CoordY,
__spv::MatrixLayout Layout = L,
int MemOperand = 0);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreCheckedINTEL(
T *Ptr, __spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *Object,
std::size_t Stride, size_t Height, size_t Width, size_t CoordX,
size_t CoordY, __spv::MatrixLayout Layout = L, int MemOperand = 0);

template <typename TA, typename TB, typename TC, std::size_t M, std::size_t K,
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
__spv::MatrixUse UC,
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_CooperativeMatrixKHR<TC, S, M, N, UC> *
__spirv_CooperativeMatrixMulAddKHR(
__spv::__spirv_CooperativeMatrixKHR<TA, S, M, K, UA> *A,
__spv::__spirv_CooperativeMatrixKHR<TB, S, K, N, UB> *B,
__spv::__spirv_CooperativeMatrixKHR<TC, S, M, N, UC> *C,
size_t Operands = 0);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
__spirv_CompositeConstruct(const T v);

// TODO: replace with __spirv_CooperativeMatrixGetElementCoordINTEL when ready
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL __ocl_vec_t<uint32_t, 2>
__spirv_JointMatrixGetElementCoordINTEL(
__spv::__spirv_CooperativeMatrixKHR<T, S, R, C, U> *, size_t i);

// AccessChain followed by load/store serves to extract/insert and element
// from/to the matrix
template <typename Ts, typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL Ts *
__spirv_AccessChain(__spv::__spirv_CooperativeMatrixKHR<T, S, R, C, U> **,
size_t i);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
__spirv_CooperativeMatrixConstructCheckedINTEL(int32_t CoordX,
int32_t CoordY,
uint32_t Height,
uint32_t Width,
const T Value);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
__spirv_CooperativeMatrixLoadCheckedINTEL(
T *Ptr, int32_t CoordX, int32_t CoordY, __spv::MatrixLayout Layout = L,
uint32_t Height = 0, uint32_t Width = 0, std::size_t Stride = 0,
int MemOperand = 0);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreCheckedINTEL(
T *Ptr, int32_t CoordX, int32_t CoordY,
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *Object,
__spv::MatrixLayout Layout = L, uint32_t Height = 0, uint32_t Width = 0,
std::size_t Stride = 0, int MemOperand = 0);
#endif // __SPIRV_USE_COOPERATIVE_MATRIX

template <typename T>
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixPrefetchINTEL(
Expand Down
24 changes: 24 additions & 0 deletions sycl/include/CL/__spirv/spirv_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,34 @@ enum class MatrixLayout : uint32_t {

enum class MatrixUse : uint32_t { MatrixA = 0, MatrixB = 1, Accumulator = 2 };

#ifdef __SPIRV_USE_COOPERATIVE_MATRIX
enum class MatrixOperands : uint32_t {
// SPV_KHR_cooperative_matrix operands
NoneKHR = 0,
MatrixASignedComponentsKHR = 0x1,
MatrixBSignedComponentsKHR = 0x2,
MatrixCSignedComponentsKHR = 0x4,
MatrixResultSignedComponentsKHR = 0x8,
SaturatingAccumulationKHR = 0x10,
// SPV_INTEL_joint_matrix operands
MatrixAAndBTF32ComponentsINTEL = 0x20,
MatrixAAndBBFloat16ComponentsINTEL = 0x40,
MatrixCBFloat16ComponentsINTEL = 0x80,
MatrixResultBFloat16ComponentsINTEL = 0x100
};
#endif // __SPIRV_USE_COOPERATIVE_MATRIX

#ifndef __SPIRV_USE_COOPERATIVE_MATRIX

template <typename T, std::size_t R, std::size_t C, MatrixLayout L,
Scope::Flag S = Scope::Flag::Subgroup,
MatrixUse U = MatrixUse::MatrixA>
struct __spirv_JointMatrixINTEL;
#else
template <typename T, Scope::Flag S = Scope::Flag::Subgroup, std::size_t R = 1,
std::size_t C = 1, MatrixUse U = MatrixUse::MatrixA>
struct __spirv_CooperativeMatrixKHR;
#endif // __SPIRV_USE_COOPERATIVE_MATRIX

struct __spirv_TaskSequenceINTEL;

Expand Down
Loading

0 comments on commit 15fbefc

Please sign in to comment.