From fdc4c42a4908d8b3af6dcb1817389f335f673db0 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Fri, 5 Aug 2022 12:59:52 +0100 Subject: [PATCH 01/10] Allow joint_matrix to be loaded from const. Signed-off-by: JackAKirk --- .../ext/oneapi/matrix/matrix-tensorcore.hpp | 116 ++++++++++-------- 1 file changed, 64 insertions(+), 52 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp index cf53bec8f943c..8bbd460873ebc 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp @@ -219,9 +219,13 @@ struct joint_matrix_load_impl< S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, multi_ptr src, size_t stride) { if constexpr (std::is_same::value || + std::is_same, uint16_t>::value || std::is_same< - T, sycl::ext::oneapi::experimental::bfloat16>::value) { - auto tileptr = reinterpret_cast(src.get()); + T, sycl::ext::oneapi::experimental::bfloat16>::value || + std::is_same< + std::remove_const_t, + sycl::ext::oneapi::experimental::bfloat16>::value) { + auto tileptr = reinterpret_cast(src.get()); auto destptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (Use == @@ -246,8 +250,9 @@ struct joint_matrix_load_impl< __mma_bf16_m32n8k16_ld_b(destptr, tileptr, stride, get_layout_id()); } - } else if constexpr (std::is_same::value) { - auto tileptr = reinterpret_cast(src.get()); + } else if constexpr (std::is_same::value || + std::is_same, uint8_t>::value) { + auto tileptr = reinterpret_cast(src.get()); auto destptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (Use == @@ -272,8 +277,9 @@ struct joint_matrix_load_impl< __imma_m32n8k16_ld_b_u8(destptr, tileptr, stride, get_layout_id()); } - } else if constexpr (std::is_same::value) { - auto tileptr = reinterpret_cast(src.get()); + } else if constexpr (std::is_same::value || + std::is_same, int8_t>::value) { + auto tileptr = reinterpret_cast(src.get()); auto destptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (Use == @@ -298,8 +304,9 @@ struct joint_matrix_load_impl< __imma_m32n8k16_ld_b_s8(destptr, tileptr, stride, get_layout_id()); } - } else if constexpr (std::is_same::value) { - auto tileptr = reinterpret_cast(src.get()); + } else if constexpr (std::is_same::value || + std::is_same, half>::value) { + auto tileptr = reinterpret_cast(src.get()); auto dstptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (Use == @@ -331,7 +338,8 @@ struct joint_matrix_load_impl< get_layout_id()); } - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value || + std::is_same, int32_t>::value) { auto destptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { __imma_m16n16k16_ld_c(destptr, src.get(), stride, @@ -343,7 +351,8 @@ struct joint_matrix_load_impl< __imma_m32n8k16_ld_c(destptr, src.get(), stride, get_layout_id()); } - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value || + std::is_same, float>::value) { if constexpr (std::is_same::value) { auto dstptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { @@ -359,7 +368,7 @@ struct joint_matrix_load_impl< } else if constexpr (std::is_same::value) { - auto tileptr = reinterpret_cast(src.get()); + auto tileptr = reinterpret_cast(src.get()); auto dstptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 8) { __mma_tf32_m16n16k8_ld_a(dstptr, tileptr, stride, @@ -369,7 +378,8 @@ struct joint_matrix_load_impl< get_layout_id()); } } - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value || + std::is_same, double>::value) { auto dstptr = reinterpret_cast(&res.wi_marray); if constexpr (Use == sycl::ext::oneapi::experimental::matrix::matrix_use::a) { @@ -559,9 +569,9 @@ struct joint_matrix_mad_impl< D; if constexpr (M == 16 && N == 16 && K == 16) { if constexpr (std::is_same::value) { - auto ptrA = reinterpret_cast(&A.wi_marray); - auto ptrB = reinterpret_cast(&B.wi_marray); - auto ptrC = reinterpret_cast(&C.wi_marray); + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrC = reinterpret_cast(&C.wi_marray); auto ptrD = reinterpret_cast(&D.wi_marray); if constexpr (std::is_same::value) { __imma_m16n16k16_mma_s8(ptrD, ptrA, ptrB, ptrC, @@ -571,17 +581,17 @@ struct joint_matrix_mad_impl< get_layout_pair_id(), 0); } } else if constexpr (std::is_same::value) { - auto ptrA = reinterpret_cast(&A.wi_marray); - auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); if constexpr (std::is_same::value) { __hmma_m16n16k16_mma_f32f32( reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } else if constexpr (std::is_same::value) { __hmma_m16n16k16_mma_f16f16( reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } } else if constexpr (std::is_same::value || @@ -589,16 +599,16 @@ struct joint_matrix_mad_impl< bfloat16>::value) { __mma_bf16_m16n16k16_mma_f32( reinterpret_cast(&D.wi_marray), - reinterpret_cast(&A.wi_marray), - reinterpret_cast(&B.wi_marray), - reinterpret_cast(&C.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } } else if constexpr (M == 8 && N == 32 && K == 16) { if constexpr (std::is_same::value) { - auto ptrA = reinterpret_cast(&A.wi_marray); - auto ptrB = reinterpret_cast(&B.wi_marray); - auto ptrC = reinterpret_cast(&C.wi_marray); + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrC = reinterpret_cast(&C.wi_marray); auto ptrD = reinterpret_cast(&D.wi_marray); if constexpr (std::is_same::value) { __imma_m8n32k16_mma_s8(ptrD, ptrA, ptrB, ptrC, @@ -608,17 +618,17 @@ struct joint_matrix_mad_impl< get_layout_pair_id(), 0); } } else if constexpr (std::is_same::value) { - auto ptrA = reinterpret_cast(&A.wi_marray); - auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); if constexpr (std::is_same::value) { __hmma_m8n32k16_mma_f32f32( reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } else if constexpr (std::is_same::value) { __hmma_m8n32k16_mma_f16f16( reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } } else if constexpr (std::is_same::value || @@ -626,16 +636,16 @@ struct joint_matrix_mad_impl< bfloat16>::value) { __mma_bf16_m8n32k16_mma_f32( reinterpret_cast(&D.wi_marray), - reinterpret_cast(&A.wi_marray), - reinterpret_cast(&B.wi_marray), - reinterpret_cast(&C.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } } else if constexpr (M == 32 && N == 8 && K == 16) { if constexpr (std::is_same::value) { - auto ptrA = reinterpret_cast(&A.wi_marray); - auto ptrB = reinterpret_cast(&B.wi_marray); - auto ptrC = reinterpret_cast(&C.wi_marray); + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrC = reinterpret_cast(&C.wi_marray); auto ptrD = reinterpret_cast(&D.wi_marray); if constexpr (std::is_same::value) { __imma_m32n8k16_mma_s8(ptrD, ptrA, ptrB, ptrC, @@ -649,22 +659,22 @@ struct joint_matrix_mad_impl< bfloat16>::value) { __mma_bf16_m32n8k16_mma_f32( reinterpret_cast(&D.wi_marray), - reinterpret_cast(&A.wi_marray), - reinterpret_cast(&B.wi_marray), - reinterpret_cast(&C.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } else if constexpr (std::is_same::value) { - auto ptrA = reinterpret_cast(&A.wi_marray); - auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); if constexpr (std::is_same::value) { __hmma_m32n8k16_mma_f32f32( reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } else if constexpr (std::is_same::value) { __hmma_m32n8k16_mma_f16f16( reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } } @@ -676,9 +686,9 @@ struct joint_matrix_mad_impl< get_layout_pair_id(), 0); } else if constexpr (std::is_same::value) { __dmma_m8n8k4_mma_f64(reinterpret_cast(&D.wi_marray), - reinterpret_cast(&A.wi_marray), - reinterpret_cast(&B.wi_marray), - reinterpret_cast(&C.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } return D; @@ -691,13 +701,15 @@ struct joint_matrix_mad_impl< namespace experimental { namespace matrix { -template ::value || - (std::is_same::value && - std::is_same::value), - bool> = true> +template < + typename Group, typename S, typename T, matrix_use Use, size_t NumRows, + size_t NumCols, matrix_layout Layout, access::address_space Space, + std::enable_if_t::value || + std::is_same>::value || + (std::is_same::value && + (std::is_same::value || + std::is_same, float>::value)), + bool> = true> void joint_matrix_load( Group sg, joint_matrix &res, multi_ptr src, size_t stride) { From 68d3150e708d8506095df24b761be0f24ef25d03 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Fri, 5 Aug 2022 14:17:41 +0100 Subject: [PATCH 02/10] removed duplicates. Signed-off-by: JackAKirk --- .../ext/oneapi/matrix/matrix-tensorcore.hpp | 30 +++++++------------ 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp index 8bbd460873ebc..ee2a898f0f681 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp @@ -218,10 +218,7 @@ struct joint_matrix_load_impl< void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, multi_ptr src, size_t stride) { - if constexpr (std::is_same::value || - std::is_same, uint16_t>::value || - std::is_same< - T, sycl::ext::oneapi::experimental::bfloat16>::value || + if constexpr (std::is_same, uint16_t>::value || std::is_same< std::remove_const_t, sycl::ext::oneapi::experimental::bfloat16>::value) { @@ -250,8 +247,7 @@ struct joint_matrix_load_impl< __mma_bf16_m32n8k16_ld_b(destptr, tileptr, stride, get_layout_id()); } - } else if constexpr (std::is_same::value || - std::is_same, uint8_t>::value) { + } else if constexpr (std::is_same, uint8_t>::value) { auto tileptr = reinterpret_cast(src.get()); auto destptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { @@ -277,8 +273,7 @@ struct joint_matrix_load_impl< __imma_m32n8k16_ld_b_u8(destptr, tileptr, stride, get_layout_id()); } - } else if constexpr (std::is_same::value || - std::is_same, int8_t>::value) { + } else if constexpr (std::is_same, int8_t>::value) { auto tileptr = reinterpret_cast(src.get()); auto destptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { @@ -304,8 +299,7 @@ struct joint_matrix_load_impl< __imma_m32n8k16_ld_b_s8(destptr, tileptr, stride, get_layout_id()); } - } else if constexpr (std::is_same::value || - std::is_same, half>::value) { + } else if constexpr (std::is_same, half>::value) { auto tileptr = reinterpret_cast(src.get()); auto dstptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { @@ -338,8 +332,7 @@ struct joint_matrix_load_impl< get_layout_id()); } - } else if constexpr (std::is_same::value || - std::is_same, int32_t>::value) { + } else if constexpr (std::is_same, int32_t>::value) { auto destptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { __imma_m16n16k16_ld_c(destptr, src.get(), stride, @@ -351,8 +344,7 @@ struct joint_matrix_load_impl< __imma_m32n8k16_ld_c(destptr, src.get(), stride, get_layout_id()); } - } else if constexpr (std::is_same::value || - std::is_same, float>::value) { + } else if constexpr (std::is_same, float>::value) { if constexpr (std::is_same::value) { auto dstptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { @@ -378,8 +370,7 @@ struct joint_matrix_load_impl< get_layout_id()); } } - } else if constexpr (std::is_same::value || - std::is_same, double>::value) { + } else if constexpr (std::is_same, double>::value) { auto dstptr = reinterpret_cast(&res.wi_marray); if constexpr (Use == sycl::ext::oneapi::experimental::matrix::matrix_use::a) { @@ -704,11 +695,10 @@ namespace matrix { template < typename Group, typename S, typename T, matrix_use Use, size_t NumRows, size_t NumCols, matrix_layout Layout, access::address_space Space, - std::enable_if_t::value || - std::is_same>::value || + std::enable_if_t>::value || (std::is_same::value && - (std::is_same::value || - std::is_same, float>::value)), + + std::is_same, float>::value), bool> = true> void joint_matrix_load( Group sg, joint_matrix &res, From 494946444bd3cd8d819d33e46865c08ce46e4d5e Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Mon, 29 Aug 2022 16:12:54 +0100 Subject: [PATCH 03/10] Layout accumulator is specified at load/store. This is a move towards the future looking joint_matrix, joint_matrix_load, joint_matrix_store APIs. Signed-off-by: JackAKirk --- .../ext/oneapi/matrix/matrix-tensorcore.hpp | 510 ++++++++++-------- 1 file changed, 293 insertions(+), 217 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp index fc240e4fa1818..20ee168db9c52 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp @@ -16,25 +16,28 @@ namespace oneapi { namespace experimental { namespace matrix { -enum class matrix_use { a, b, accumulator }; +enum class matrix_use { a, b, accumulator, unnecessary }; -enum class matrix_layout { row_major, col_major, packed_a, packed_b }; +enum class layout { row_major, col_major, packed_a, packed_b, none}; namespace precision { -class tf32 {}; +class tf32 { + tf32() = delete; +}; } // namespace precision -template + matrix_use Use = matrix_use::unnecessary, + layout Layout = layout::none, typename Group = sycl::sub_group, + typename Cond = void> struct joint_matrix; template class wi_data { marray &data; wi_data(marray &wi_data) : data(wi_data){}; - template + template friend struct joint_matrix; public: @@ -58,11 +61,11 @@ template class wi_data { }; #define __SYCL_JOINT_MATRIX_OVERLOAD_ARR(type, use, M, N, size) \ - template \ + template \ struct joint_matrix< \ - type, matrix_use::use, M, N, Layout, sycl::sub_group, \ - typename std::enable_if_t> { \ + type, M, N, matrix_use::use, Layout, sycl::sub_group, \ + typename std::enable_if_t> { \ marray wi_marray; \ inline __SYCL_ALWAYS_INLINE wi_data get_wi_data() { \ return wi_data(wi_marray); \ @@ -74,54 +77,67 @@ __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 8, 16, 4) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, b, 16, 32, 16) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, a, 8, 16, 16) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 16, 32, 16) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, accumulator, 8, 32, 8) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(float, accumulator, 8, 32, 8) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 8, 16, 4) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 32, 16) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 8, 16, 4) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 32, 16) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int32_t, accumulator, 8, 32, 8) // m32n8k16 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 32, 16, 16) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, b, 16, 8, 4) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, a, 32, 16, 16) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 16, 8, 16) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, accumulator, 32, 8, 8) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(float, accumulator, 32, 8, 8) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 32, 16, 16) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 8, 4) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 32, 16, 16) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 8, 4) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int32_t, accumulator, 32, 8, 8) // m16n16k16 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 16, 16, 8) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, b, 16, 16, 8) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, a, 16, 16, 16) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 16, 16, 16) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, accumulator, 16, 16, 8) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(float, accumulator, 16, 16, 8) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 16, 16, 8) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 16, 8) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 16, 16, 8) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 16, 8) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int32_t, accumulator, 16, 16, 8) // m8n8k4 double only __SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, a, 8, 4, 1) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, b, 4, 8, 1) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, accumulator, 8, 8, 2) #undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR +#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(type, M, N, size) \ + template <> \ + struct joint_matrix { \ + marray wi_marray; \ + inline __SYCL_ALWAYS_INLINE wi_data get_wi_data() { \ + return wi_data(wi_marray); \ + }; \ + }; + +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(half, 8, 32, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(float, 8, 32, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 8, 32, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(half, 32, 8, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(float, 32, 8, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 32, 8, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(half, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(float, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(double, 8, 8, 2) + +#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC + #define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION(precision, use, M, N, type, \ size) \ - template \ + template \ struct joint_matrix< \ - precision, matrix_use::use, M, N, Layout, sycl::sub_group, \ - typename std::enable_if_t> { \ + precision, M, N, matrix_use::use, Layout, sycl::sub_group, \ + typename std::enable_if_t> { \ marray wi_marray; \ inline __SYCL_ALWAYS_INLINE wi_data get_wi_data() { \ return wi_data(wi_marray); \ @@ -133,33 +149,11 @@ __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION(precision::tf32, b, 8, 16, float, 4) #undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION -#define __SYCL_JOINT_MATRIX_OVERLOAD(type, use, M, N, frag_type, frag_size) \ - template \ - struct joint_matrix< \ - type, matrix_use::use, M, N, Layout, sycl::sub_group, \ - typename std::enable_if_t> { \ - frag_type wi_marray[frag_size]; \ - }; - -// bf16 data format uint16_t implementation is deprecated -// m8n32k16 -__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 8, 16, int32_t, 2) -__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 32, int32_t, 8) -// m32n8k16 -__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 32, 16, int32_t, 8) -__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 8, int32_t, 2) -// m16n16k16 -__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 16, 16, int32_t, 4) -__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 16, int32_t, 4) - -#undef __SYCL_JOINT_MATRIX_OVERLOAD - -template +template inline __SYCL_ALWAYS_INLINE void joint_matrix_fill(Group sg, - joint_matrix &res, + joint_matrix &res, const T2 v) { // We kept the unused "sg" in joint_matrix_fill to match the other DPC++ // functions @@ -177,46 +171,127 @@ joint_matrix_fill(Group sg, namespace detail { -template struct joint_matrix_load_impl { void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< - S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, + S, NumRows, NumCols, Use, Layout, sycl::sub_group> &res, multi_ptr src, size_t stride); }; -template +template constexpr int get_layout_id(); template <> -constexpr int get_layout_id< - sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major>() { +constexpr int +get_layout_id() { return 0; } template <> -constexpr int get_layout_id< - sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major>() { +constexpr int +get_layout_id() { return 1; } #if __cplusplus >= 201703L // if constexpr usage -template +struct joint_matrix_load_impl< + S, T, NumRows, NumCols, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + sycl::ext::oneapi::experimental::matrix::layout::none, Space> { + template + void loadLayoutT( + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T, NumRows, NumCols, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + sycl::ext::oneapi::experimental::matrix::layout::none, + sycl::sub_group> &res, + multi_ptr src, size_t stride) { + if constexpr (std::is_same::value) { + auto destptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + __imma_m16n16k16_ld_c(destptr, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 32) { + __imma_m8n32k16_ld_c(destptr, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 8) { + __imma_m32n8k16_ld_c(destptr, src.get(), stride, + get_layout_id()); + } + } else if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { + auto dstptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + __hmma_m16n16k16_ld_c_f32(dstptr, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 32) { + __hmma_m8n32k16_ld_c_f32(dstptr, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 8) { + __hmma_m32n8k16_ld_c_f32(dstptr, src.get(), stride, + get_layout_id()); + } + } + } else if constexpr (std::is_same::value) { + auto tileptr = reinterpret_cast(src.get()); + auto dstptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 32 && NumCols == 8) { + __hmma_m32n8k16_ld_c_f16(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 32) { + __hmma_m8n32k16_ld_c_f16(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 16) { + __hmma_m16n16k16_ld_c_f16(dstptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (std::is_same::value) { + auto dstptr = + reinterpret_cast(&res.wi_marray); // todo remove line + + __dmma_m8n8k4_ld_c(dstptr, src.get(), stride, get_layout_id()); + } + }; + void + load(sycl::ext::oneapi::experimental::matrix::joint_matrix< + S, NumRows, NumCols, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + sycl::ext::oneapi::experimental::matrix::layout::none, + sycl::sub_group> &res, + multi_ptr src, size_t stride, + sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { + switch (LayoutAcc) { + case sycl::ext::oneapi::experimental::matrix::layout::row_major: + loadLayoutT( + res, src, stride); + break; + case sycl::ext::oneapi::experimental::matrix::layout::col_major: + loadLayoutT( + res, src, stride); + break; + default: + assert(false && "Invalid layout specified!"); + } + } +}; + +template struct joint_matrix_load_impl< - S, T, Use, NumRows, NumCols, Layout, Space, - typename std::enable_if_t> { + S, T, NumRows, NumCols, Use, Layout, Space, + typename std::enable_if_t< + Layout == sycl::ext::oneapi::experimental::matrix::layout::row_major || + Layout == sycl::ext::oneapi::experimental::matrix::layout::row_major || + Layout == sycl::ext::oneapi::experimental::matrix::layout::col_major>> { void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< - S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, + S, NumRows, NumCols, Use, Layout, sycl::sub_group> &res, multi_ptr src, size_t stride) { if constexpr (std::is_same::value || std::is_same< @@ -310,10 +385,6 @@ struct joint_matrix_load_impl< matrix_use::b) { __hmma_m16n16k16_ld_b(dstptr, tileptr, stride, get_layout_id()); - } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: - matrix_use::accumulator) { - __hmma_m16n16k16_ld_c_f16(dstptr, tileptr, stride, - get_layout_id()); } } else if constexpr (NumRows == 8 && NumCols == 16) { __hmma_m8n32k16_ld_a(dstptr, tileptr, stride, get_layout_id()); @@ -323,51 +394,18 @@ struct joint_matrix_load_impl< __hmma_m32n8k16_ld_a(dstptr, tileptr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 8) { __hmma_m32n8k16_ld_b(dstptr, tileptr, stride, get_layout_id()); - } else if constexpr (NumRows == 32 && NumCols == 8) { - __hmma_m32n8k16_ld_c_f16(dstptr, tileptr, stride, - get_layout_id()); - } else if constexpr (NumRows == 8 && NumCols == 32) { - __hmma_m8n32k16_ld_c_f16(dstptr, tileptr, stride, - get_layout_id()); } - } else if constexpr (std::is_same::value) { - auto destptr = reinterpret_cast(&res.wi_marray); - if constexpr (NumRows == 16 && NumCols == 16) { - __imma_m16n16k16_ld_c(destptr, src.get(), stride, - get_layout_id()); - } else if constexpr (NumRows == 8 && NumCols == 32) { - __imma_m8n32k16_ld_c(destptr, src.get(), stride, - get_layout_id()); - } else if constexpr (NumRows == 32 && NumCols == 8) { - __imma_m32n8k16_ld_c(destptr, src.get(), stride, - get_layout_id()); - } - } else if constexpr (std::is_same::value) { - if constexpr (std::is_same::value) { - auto dstptr = reinterpret_cast(&res.wi_marray); - if constexpr (NumRows == 16 && NumCols == 16) { - __hmma_m16n16k16_ld_c_f32(dstptr, src.get(), stride, - get_layout_id()); - } else if constexpr (NumRows == 8 && NumCols == 32) { - __hmma_m8n32k16_ld_c_f32(dstptr, src.get(), stride, - get_layout_id()); - } else if constexpr (NumRows == 32 && NumCols == 8) { - __hmma_m32n8k16_ld_c_f32(dstptr, src.get(), stride, - get_layout_id()); - } - } else if constexpr (std::is_same::value) { - auto tileptr = reinterpret_cast(src.get()); - auto dstptr = reinterpret_cast(&res.wi_marray); - if constexpr (NumRows == 16 && NumCols == 8) { - __mma_tf32_m16n16k8_ld_a(dstptr, tileptr, stride, - get_layout_id()); - } else if constexpr (NumRows == 8 && NumCols == 16) { - __mma_tf32_m16n16k8_ld_b(dstptr, tileptr, stride, - get_layout_id()); - } + } else if constexpr (std::is_same::value) { + auto tileptr = reinterpret_cast(src.get()); + auto dstptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 8) { + __mma_tf32_m16n16k8_ld_a(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 16) { + __mma_tf32_m16n16k8_ld_b(dstptr, tileptr, stride, + get_layout_id()); } } else if constexpr (std::is_same::value) { auto dstptr = reinterpret_cast(&res.wi_marray); @@ -377,185 +415,190 @@ struct joint_matrix_load_impl< } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: matrix_use::b) { __dmma_m8n8k4_ld_b(dstptr, src.get(), stride, get_layout_id()); - } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: - matrix_use::accumulator) { - __dmma_m8n8k4_ld_c(dstptr, src.get(), stride, get_layout_id()); } } } }; #endif // __cplusplus >= 201703L -template -struct joint_matrix_store_impl { - void - store(sycl::ext::oneapi::experimental::matrix::joint_matrix< - T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, - NumRows, NumCols, Layout, sycl::sub_group> &src, - multi_ptr dst, size_t stride); -}; - #if __cplusplus >= 201703L // if constexpr usage template -struct joint_matrix_store_impl< - T, NumRows, NumCols, Layout, Space, - typename std::enable_if_t> { - void - store(sycl::ext::oneapi::experimental::matrix::joint_matrix< - T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, - NumRows, NumCols, Layout, sycl::sub_group> &src, - multi_ptr dst, size_t stride) { +struct joint_matrix_store_impl { + template + void storeLayoutT( + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T, NumRows, NumCols, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + sycl::ext::oneapi::experimental::matrix::layout::none, + sycl::sub_group> &src, + multi_ptr dst, size_t stride) { if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (std::is_same::value) { __hmma_m16n16k16_st_c_f32(dst.get(), reinterpret_cast(&src.wi_marray), - stride, get_layout_id()); + stride, get_layout_id()); } else if constexpr (std::is_same::value) { __imma_m16n16k16_st_c_i32(dst.get(), reinterpret_cast(&src.wi_marray), - stride, get_layout_id()); + stride, get_layout_id()); } else if constexpr (std::is_same::value) { __hmma_m16n16k16_st_c_f16(reinterpret_cast(dst.get()), reinterpret_cast(&src.wi_marray), - stride, get_layout_id()); + stride, get_layout_id()); } } else if constexpr (NumRows == 8 && NumCols == 32) { if constexpr (std::is_same::value) { __hmma_m8n32k16_st_c_f32(dst.get(), reinterpret_cast(&src.wi_marray), - stride, get_layout_id()); + stride, get_layout_id()); } else if constexpr (std::is_same::value) { __imma_m8n32k16_st_c_i32(dst.get(), reinterpret_cast(&src.wi_marray), - stride, get_layout_id()); + stride, get_layout_id()); } else if constexpr (std::is_same::value) { __hmma_m8n32k16_st_c_f16(reinterpret_cast(dst.get()), reinterpret_cast(&src.wi_marray), - stride, get_layout_id()); + stride, get_layout_id()); } } else if constexpr (NumRows == 32 && NumCols == 8) { if constexpr (std::is_same::value) { __hmma_m32n8k16_st_c_f32(dst.get(), reinterpret_cast(&src.wi_marray), - stride, get_layout_id()); + stride, get_layout_id()); } else if constexpr (std::is_same::value) { __imma_m32n8k16_st_c_i32(dst.get(), reinterpret_cast(&src.wi_marray), - stride, get_layout_id()); + stride, get_layout_id()); } else if constexpr (std::is_same::value) { __hmma_m32n8k16_st_c_f16(reinterpret_cast(dst.get()), reinterpret_cast(&src.wi_marray), - stride, get_layout_id()); + stride, get_layout_id()); } } else if constexpr (std::is_same::value) { __dmma_m8n8k4_st_c_f64(dst.get(), reinterpret_cast(&src.wi_marray), stride, - get_layout_id()); + get_layout_id()); + } + } + void + store(sycl::ext::oneapi::experimental::matrix::joint_matrix< + T, NumRows, NumCols, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + sycl::ext::oneapi::experimental::matrix::layout::none, + sycl::sub_group> &src, + multi_ptr dst, size_t stride, + sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { + switch (LayoutAcc) { + case sycl::ext::oneapi::experimental::matrix::layout::row_major: + storeLayoutT( + src, dst, stride); + break; + case sycl::ext::oneapi::experimental::matrix::layout::col_major: + storeLayoutT( + src, dst, stride); + break; + default: + assert(false && "Invalid layout specified!"); } } }; #endif // __cplusplus >= 201703L template struct joint_matrix_mad_impl { sycl::ext::oneapi::experimental::matrix::joint_matrix< - T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, - N, LayoutC, sycl::sub_group> + T2, M, N, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + sycl::ext::oneapi::experimental::matrix::layout::none, sycl::sub_group> mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< - T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K, + T1, M, K, sycl::ext::oneapi::experimental::matrix::matrix_use::a, LayoutA, sycl::sub_group> A, sycl::ext::oneapi::experimental::matrix::joint_matrix< - T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N, + T1, K, N, sycl::ext::oneapi::experimental::matrix::matrix_use::b, LayoutB, sycl::sub_group> B, sycl::ext::oneapi::experimental::matrix::joint_matrix< - T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, - M, N, LayoutC, sycl::sub_group> + T2, M, N, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + sycl::ext::oneapi::experimental::matrix::layout::none, + sycl::sub_group> C); }; -template +template constexpr int get_layout_pair_id(); template <> constexpr int get_layout_pair_id< - sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major, - sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major>() { + sycl::ext::oneapi::experimental::matrix::layout::row_major, + sycl::ext::oneapi::experimental::matrix::layout::row_major>() { return 0; } template <> constexpr int get_layout_pair_id< - sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major, - sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major>() { + sycl::ext::oneapi::experimental::matrix::layout::row_major, + sycl::ext::oneapi::experimental::matrix::layout::col_major>() { return 1; } template <> constexpr int get_layout_pair_id< - sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major, - sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major>() { + sycl::ext::oneapi::experimental::matrix::layout::col_major, + sycl::ext::oneapi::experimental::matrix::layout::row_major>() { return 2; } template <> constexpr int get_layout_pair_id< - sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major, - sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major>() { + sycl::ext::oneapi::experimental::matrix::layout::col_major, + sycl::ext::oneapi::experimental::matrix::layout::col_major>() { return 3; } - +// layout C unnecessary so long as not constructible as any other type!! #if __cplusplus >= 201703L // if constexpr usage template + sycl::ext::oneapi::experimental::matrix::layout LayoutA, + sycl::ext::oneapi::experimental::matrix::layout LayoutB> struct joint_matrix_mad_impl< - T1, T2, M, K, N, LayoutA, LayoutB, LayoutC, + T1, T2, M, K, N, LayoutA, LayoutB, typename std::enable_if_t< - (LayoutA == sycl::ext::oneapi::experimental::matrix::matrix_layout:: - row_major || - LayoutA == sycl::ext::oneapi::experimental::matrix::matrix_layout:: - col_major) && - (LayoutB == sycl::ext::oneapi::experimental::matrix::matrix_layout:: - row_major || - LayoutB == sycl::ext::oneapi::experimental::matrix::matrix_layout:: - col_major) && - (LayoutC == sycl::ext::oneapi::experimental::matrix::matrix_layout:: - row_major || - LayoutC == sycl::ext::oneapi::experimental::matrix::matrix_layout:: - col_major)>> { + (LayoutA == + sycl::ext::oneapi::experimental::matrix::layout::row_major || + LayoutA == + sycl::ext::oneapi::experimental::matrix::layout::col_major) && + (LayoutB == + sycl::ext::oneapi::experimental::matrix::layout::row_major || + LayoutB == + sycl::ext::oneapi::experimental::matrix::layout::col_major)>> { sycl::ext::oneapi::experimental::matrix::joint_matrix< - T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, - N, LayoutC, sycl::sub_group> + T2, M, N, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + sycl::ext::oneapi::experimental::matrix::layout::none, sycl::sub_group> mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< - T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K, + T1, M, K, sycl::ext::oneapi::experimental::matrix::matrix_use::a, LayoutA, sycl::sub_group> A, sycl::ext::oneapi::experimental::matrix::joint_matrix< - T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N, + T1, K, N, sycl::ext::oneapi::experimental::matrix::matrix_use::b, LayoutB, sycl::sub_group> B, sycl::ext::oneapi::experimental::matrix::joint_matrix< - T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, - M, N, LayoutC, sycl::sub_group> + T2, M, N, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + sycl::ext::oneapi::experimental::matrix::layout::none, + sycl::sub_group> C) { sycl::ext::oneapi::experimental::matrix::joint_matrix< - T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, - N, LayoutC, sycl::sub_group> + T2, M, N, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + sycl::ext::oneapi::experimental::matrix::layout::none, sycl::sub_group> D; if constexpr (M == 16 && N == 16 && K == 16) { if constexpr (std::is_same::value) { @@ -691,20 +734,48 @@ struct joint_matrix_mad_impl< namespace experimental { namespace matrix { -template ::value || + (std::is_same::value && + std::is_same::value), + bool> = true> +void joint_matrix_load( + Group sg, joint_matrix &res, + multi_ptr src, size_t stride, + sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + sycl::ext::oneapi::detail::joint_matrix_load_impl< + S, T, NumRows, NumCols, Use, + sycl::ext::oneapi::experimental::matrix::layout::none, Space>{} + .load(res, src, stride, LayoutAcc); +#else + std::ignore = sg; + std::ignore = res; + std::ignore = src; + std::ignore = stride; + throw runtime_error( + "When using SYCL_EXT_ONEAPI_MATRIX=3 joint_matrix_load is " + "only supported by CUDA devices", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + +template ::value || (std::is_same::value && std::is_same::value), bool> = true> void joint_matrix_load( - Group sg, joint_matrix &res, + Group sg, joint_matrix &res, multi_ptr src, size_t stride) { #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) - sycl::ext::oneapi::detail::joint_matrix_load_impl{} - .load(res, src, stride); + sycl::ext::oneapi::detail::joint_matrix_load_impl{} + .load(res, src, stride); #else std::ignore = sg; std::ignore = res; @@ -718,15 +789,18 @@ void joint_matrix_load( } template -void joint_matrix_store(Group sg, - joint_matrix &src, - multi_ptr dst, size_t stride) { + access::address_space Space> +void joint_matrix_store( + Group sg, + joint_matrix + &src, + multi_ptr dst, size_t stride, + sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) sycl::ext::oneapi::detail::joint_matrix_store_impl{} - .store(src, dst, stride); + Space>{} + .store(src, dst, stride, LayoutAcc); #else std::ignore = sg; std::ignore = src; @@ -740,16 +814,18 @@ void joint_matrix_store(Group sg, } template -joint_matrix + std::size_t K, std::size_t N, layout LayoutA, layout LayoutB> +joint_matrix joint_matrix_mad( - Group sg, joint_matrix A, - joint_matrix B, - joint_matrix C) { + Group sg, joint_matrix A, + joint_matrix B, + joint_matrix + C) { #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) - return sycl::ext::oneapi::detail::joint_matrix_mad_impl< - T1, T2, M, K, N, LayoutA, LayoutB, LayoutC>{} + return sycl::ext::oneapi::detail::joint_matrix_mad_impl{} .mad(A, B, C); #else std::ignore = sg; From 8c0991030e30bfe71f0d2ad1b8eedb1c961e9cd9 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Tue, 30 Aug 2022 17:04:02 +0100 Subject: [PATCH 04/10] joint_matrix_mad takes D matrix as argument. Also updated the impl functions used in the CUDA backend (Some of these functions may be also used in the HIP AMD case when that is implemented, since the interfaces will match). Signed-off-by: JackAKirk --- .../ext/oneapi/matrix/matrix-tensorcore.hpp | 314 +++++++++--------- 1 file changed, 155 insertions(+), 159 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp index 20ee168db9c52..2cf4d4265ad4b 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp @@ -16,20 +16,19 @@ namespace oneapi { namespace experimental { namespace matrix { -enum class matrix_use { a, b, accumulator, unnecessary }; +enum class matrix_use { a, b, accumulator }; -enum class layout { row_major, col_major, packed_a, packed_b, none}; +enum class layout { row_major, col_major, packed_a, packed_b, unused }; namespace precision { class tf32 { tf32() = delete; }; } // namespace precision - +// TODO: how are the default params for Rows/Cols used in Intel backend? template struct joint_matrix; @@ -110,7 +109,7 @@ __SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, b, 4, 8, 1) #define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(type, M, N, size) \ template <> \ - struct joint_matrix { \ marray wi_marray; \ inline __SYCL_ALWAYS_INLINE wi_data get_wi_data() { \ @@ -171,15 +170,17 @@ joint_matrix_fill(Group sg, namespace detail { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) template -struct joint_matrix_load_impl { +struct load_multiplicand { void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< S, NumRows, NumCols, Use, Layout, sycl::sub_group> &res, multi_ptr src, size_t stride); }; +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) template constexpr int get_layout_id(); @@ -197,97 +198,96 @@ get_layout_id() { } #if __cplusplus >= 201703L // if constexpr usage -template -struct joint_matrix_load_impl< - S, T, NumRows, NumCols, - sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, - sycl::ext::oneapi::experimental::matrix::layout::none, Space> { - template - void loadLayoutT( - sycl::ext::oneapi::experimental::matrix::joint_matrix< - T, NumRows, NumCols, - sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, - sycl::ext::oneapi::experimental::matrix::layout::none, - sycl::sub_group> &res, - multi_ptr src, size_t stride) { - if constexpr (std::is_same::value) { - auto destptr = reinterpret_cast(&res.wi_marray); - if constexpr (NumRows == 16 && NumCols == 16) { - __imma_m16n16k16_ld_c(destptr, src.get(), stride, - get_layout_id()); - } else if constexpr (NumRows == 8 && NumCols == 32) { - __imma_m8n32k16_ld_c(destptr, src.get(), stride, - get_layout_id()); - } else if constexpr (NumRows == 32 && NumCols == 8) { - __imma_m32n8k16_ld_c(destptr, src.get(), stride, - get_layout_id()); - } - } else if constexpr (std::is_same::value) { - if constexpr (std::is_same::value) { - auto dstptr = reinterpret_cast(&res.wi_marray); - if constexpr (NumRows == 16 && NumCols == 16) { - __hmma_m16n16k16_ld_c_f32(dstptr, src.get(), stride, - get_layout_id()); - } else if constexpr (NumRows == 8 && NumCols == 32) { - __hmma_m8n32k16_ld_c_f32(dstptr, src.get(), stride, - get_layout_id()); - } else if constexpr (NumRows == 32 && NumCols == 8) { - __hmma_m32n8k16_ld_c_f32(dstptr, src.get(), stride, - get_layout_id()); - } - } - } else if constexpr (std::is_same::value) { - auto tileptr = reinterpret_cast(src.get()); - auto dstptr = reinterpret_cast(&res.wi_marray); - if constexpr (NumRows == 32 && NumCols == 8) { - __hmma_m32n8k16_ld_c_f16(dstptr, tileptr, stride, - get_layout_id()); - } else if constexpr (NumRows == 8 && NumCols == 32) { - __hmma_m8n32k16_ld_c_f16(dstptr, tileptr, stride, - get_layout_id()); - } else if constexpr (NumRows == 16 && NumCols == 16) { - __hmma_m16n16k16_ld_c_f16(dstptr, tileptr, stride, - get_layout_id()); - } - } else if constexpr (std::is_same::value) { - auto dstptr = - reinterpret_cast(&res.wi_marray); // todo remove line - - __dmma_m8n8k4_ld_c(dstptr, src.get(), stride, get_layout_id()); +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +template +void load_accumulator_layoutT( + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T, NumRows, NumCols, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + sycl::ext::oneapi::experimental::matrix::layout::unused, + sycl::sub_group> &res, + multi_ptr src, size_t stride) { + if constexpr (std::is_same::value) { + auto destptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + __imma_m16n16k16_ld_c(destptr, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 32) { + __imma_m8n32k16_ld_c(destptr, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 8) { + __imma_m32n8k16_ld_c(destptr, src.get(), stride, + get_layout_id()); } - }; - void - load(sycl::ext::oneapi::experimental::matrix::joint_matrix< - S, NumRows, NumCols, - sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, - sycl::ext::oneapi::experimental::matrix::layout::none, - sycl::sub_group> &res, - multi_ptr src, size_t stride, - sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { - switch (LayoutAcc) { - case sycl::ext::oneapi::experimental::matrix::layout::row_major: - loadLayoutT( - res, src, stride); - break; - case sycl::ext::oneapi::experimental::matrix::layout::col_major: - loadLayoutT( - res, src, stride); - break; - default: - assert(false && "Invalid layout specified!"); + } else if constexpr (std::is_same::value) { + auto dstptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + __hmma_m16n16k16_ld_c_f32(dstptr, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 32) { + __hmma_m8n32k16_ld_c_f32(dstptr, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 8) { + __hmma_m32n8k16_ld_c_f32(dstptr, src.get(), stride, + get_layout_id()); } + } else if constexpr (std::is_same::value) { + auto tileptr = reinterpret_cast(src.get()); + auto dstptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 32 && NumCols == 8) { + __hmma_m32n8k16_ld_c_f16(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 32) { + __hmma_m8n32k16_ld_c_f16(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 16) { + __hmma_m16n16k16_ld_c_f16(dstptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (std::is_same::value) { + __dmma_m8n8k4_ld_c(reinterpret_cast(&res.wi_marray), src.get(), + stride, get_layout_id()); } }; +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +template +void load_accumulator( + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T, NumRows, NumCols, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + sycl::ext::oneapi::experimental::matrix::layout::unused, + sycl::sub_group> &res, + multi_ptr src, size_t stride, + sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { + switch (LayoutAcc) { + case sycl::ext::oneapi::experimental::matrix::layout::row_major: + load_accumulator_layoutT< + sycl::ext::oneapi::experimental::matrix::layout::row_major>(res, src, + stride); + break; + case sycl::ext::oneapi::experimental::matrix::layout::col_major: + load_accumulator_layoutT< + sycl::ext::oneapi::experimental::matrix::layout::col_major>(res, src, + stride); + break; + default: + assert(false && "Invalid layout specified!"); + } +} +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) template -struct joint_matrix_load_impl< +struct load_multiplicand< S, T, NumRows, NumCols, Use, Layout, Space, typename std::enable_if_t< - Layout == sycl::ext::oneapi::experimental::matrix::layout::row_major || Layout == sycl::ext::oneapi::experimental::matrix::layout::row_major || Layout == sycl::ext::oneapi::experimental::matrix::layout::col_major>> { void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< @@ -419,9 +419,11 @@ struct joint_matrix_load_impl< } } }; +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) #endif // __cplusplus >= 201703L #if __cplusplus >= 201703L // if constexpr usage +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) template struct joint_matrix_store_impl { @@ -430,7 +432,7 @@ struct joint_matrix_store_impl { sycl::ext::oneapi::experimental::matrix::joint_matrix< T, NumRows, NumCols, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, - sycl::ext::oneapi::experimental::matrix::layout::none, + sycl::ext::oneapi::experimental::matrix::layout::unused, sycl::sub_group> &src, multi_ptr dst, size_t stride) { if constexpr (NumRows == 16 && NumCols == 16) { @@ -485,7 +487,7 @@ struct joint_matrix_store_impl { store(sycl::ext::oneapi::experimental::matrix::joint_matrix< T, NumRows, NumCols, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, - sycl::ext::oneapi::experimental::matrix::layout::none, + sycl::ext::oneapi::experimental::matrix::layout::unused, sycl::sub_group> &src, multi_ptr dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { @@ -503,32 +505,33 @@ struct joint_matrix_store_impl { } } }; +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) #endif // __cplusplus >= 201703L +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) template struct joint_matrix_mad_impl { - sycl::ext::oneapi::experimental::matrix::joint_matrix< - T2, M, N, - sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, - sycl::ext::oneapi::experimental::matrix::layout::none, sycl::sub_group> - mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< - T1, M, K, sycl::ext::oneapi::experimental::matrix::matrix_use::a, - LayoutA, sycl::sub_group> - A, - sycl::ext::oneapi::experimental::matrix::joint_matrix< - T1, K, N, sycl::ext::oneapi::experimental::matrix::matrix_use::b, - LayoutB, sycl::sub_group> - B, - sycl::ext::oneapi::experimental::matrix::joint_matrix< - T2, M, N, - sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, - sycl::ext::oneapi::experimental::matrix::layout::none, - sycl::sub_group> - C); + void mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< + T2, M, N, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + sycl::ext::oneapi::experimental::matrix::layout::unused, + sycl::sub_group> &D, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T1, M, K, sycl::ext::oneapi::experimental::matrix::matrix_use::a, + LayoutA, sycl::sub_group> &A, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T1, K, N, sycl::ext::oneapi::experimental::matrix::matrix_use::b, + LayoutB, sycl::sub_group> &B, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T2, M, N, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + sycl::ext::oneapi::experimental::matrix::layout::unused, + sycl::sub_group> &C); }; +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) template @@ -561,8 +564,9 @@ constexpr int get_layout_pair_id< sycl::ext::oneapi::experimental::matrix::layout::col_major>() { return 3; } -// layout C unnecessary so long as not constructible as any other type!! + #if __cplusplus >= 201703L // if constexpr usage +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) template @@ -577,29 +581,22 @@ struct joint_matrix_mad_impl< sycl::ext::oneapi::experimental::matrix::layout::row_major || LayoutB == sycl::ext::oneapi::experimental::matrix::layout::col_major)>> { - sycl::ext::oneapi::experimental::matrix::joint_matrix< - T2, M, N, - sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, - sycl::ext::oneapi::experimental::matrix::layout::none, sycl::sub_group> - mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< - T1, M, K, sycl::ext::oneapi::experimental::matrix::matrix_use::a, - LayoutA, sycl::sub_group> - A, - sycl::ext::oneapi::experimental::matrix::joint_matrix< - T1, K, N, sycl::ext::oneapi::experimental::matrix::matrix_use::b, - LayoutB, sycl::sub_group> - B, - sycl::ext::oneapi::experimental::matrix::joint_matrix< - T2, M, N, - sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, - sycl::ext::oneapi::experimental::matrix::layout::none, - sycl::sub_group> - C) { - sycl::ext::oneapi::experimental::matrix::joint_matrix< - T2, M, N, - sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, - sycl::ext::oneapi::experimental::matrix::layout::none, sycl::sub_group> - D; + void mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< + T2, M, N, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + sycl::ext::oneapi::experimental::matrix::layout::unused, + sycl::sub_group> &D, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T1, M, K, sycl::ext::oneapi::experimental::matrix::matrix_use::a, + LayoutA, sycl::sub_group> &A, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T1, K, N, sycl::ext::oneapi::experimental::matrix::matrix_use::b, + LayoutB, sycl::sub_group> &B, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T2, M, N, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + sycl::ext::oneapi::experimental::matrix::layout::unused, + sycl::sub_group> &C) { if constexpr (M == 16 && N == 16 && K == 16) { if constexpr (std::is_same::value) { auto ptrA = reinterpret_cast(&A.wi_marray); @@ -724,32 +721,29 @@ struct joint_matrix_mad_impl< reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } - return D; } }; +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) #endif // __cplusplus >= 201703L } // namespace detail namespace experimental { namespace matrix { - +// TODO Two typenames (S and T) not required in CUDA backend but included for +// Intel backend requirement. template ::value || - (std::is_same::value && - std::is_same::value), - bool> = true> + size_t NumCols, matrix_use Use, access::address_space Space, + std::enable_if_t::value, bool> = true> void joint_matrix_load( - Group sg, joint_matrix &res, + Group sg, + joint_matrix + &res, multi_ptr src, size_t stride, sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) - sycl::ext::oneapi::detail::joint_matrix_load_impl< - S, T, NumRows, NumCols, Use, - sycl::ext::oneapi::experimental::matrix::layout::none, Space>{} - .load(res, src, stride, LayoutAcc); + sycl::ext::oneapi::detail::load_accumulator(res, src, stride, LayoutAcc); #else std::ignore = sg; std::ignore = res; @@ -773,9 +767,9 @@ void joint_matrix_load( Group sg, joint_matrix &res, multi_ptr src, size_t stride) { #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) - sycl::ext::oneapi::detail::joint_matrix_load_impl{} - .load(res, src, stride); + sycl::ext::oneapi::detail::load_multiplicand{} + .load(res, src, stride); #else std::ignore = sg; std::ignore = res; @@ -793,7 +787,7 @@ template + sycl::ext::oneapi::experimental::matrix::layout::unused, Group> &src, multi_ptr dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { @@ -815,18 +809,20 @@ void joint_matrix_store( template -joint_matrix -joint_matrix_mad( - Group sg, joint_matrix A, - joint_matrix B, +void joint_matrix_mad( + Group sg, + joint_matrix + &D, + joint_matrix &A, + joint_matrix &B, joint_matrix - C) { + sycl::ext::oneapi::experimental::matrix::layout::unused, Group> + &C) { #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) - return sycl::ext::oneapi::detail::joint_matrix_mad_impl{} - .mad(A, B, C); + sycl::ext::oneapi::detail::joint_matrix_mad_impl{} + .mad(D, A, B, C); #else std::ignore = sg; std::ignore = A; From e55e5f0bfe5175d41f61f4b912243a988104e2f7 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Thu, 1 Sep 2022 10:19:41 +0100 Subject: [PATCH 05/10] Add new mma cases enabled by joint_matrix_mad. This is for illustrative purposes: to show the advantage of the proposed change in the joint_matrix_mad interface. Signed-off-by: JackAKirk --- .../ext/oneapi/matrix/matrix-tensorcore.hpp | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp index 2cf4d4265ad4b..551bc914cddb5 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp @@ -509,13 +509,13 @@ struct joint_matrix_store_impl { #endif // __cplusplus >= 201703L #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) -template struct joint_matrix_mad_impl { void mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< - T2, M, N, + T3, M, N, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, sycl::ext::oneapi::experimental::matrix::layout::unused, sycl::sub_group> &D, @@ -567,11 +567,11 @@ constexpr int get_layout_pair_id< #if __cplusplus >= 201703L // if constexpr usage #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) -template struct joint_matrix_mad_impl< - T1, T2, M, K, N, LayoutA, LayoutB, + T1, T2, T3, M, K, N, LayoutA, LayoutB, typename std::enable_if_t< (LayoutA == sycl::ext::oneapi::experimental::matrix::layout::row_major || @@ -582,7 +582,7 @@ struct joint_matrix_mad_impl< LayoutB == sycl::ext::oneapi::experimental::matrix::layout::col_major)>> { void mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< - T2, M, N, + T3, M, N, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, sycl::ext::oneapi::experimental::matrix::layout::unused, sycl::sub_group> &D, @@ -614,15 +614,29 @@ struct joint_matrix_mad_impl< auto ptrA = reinterpret_cast(&A.wi_marray); auto ptrB = reinterpret_cast(&B.wi_marray); if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { __hmma_m16n16k16_mma_f32f32( reinterpret_cast(&D.wi_marray), ptrA, ptrB, reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); + } else { + __hmma_m16n16k16_mma_f16f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } } else if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { + __hmma_m16n16k16_mma_f32f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else { __hmma_m16n16k16_mma_f16f16( reinterpret_cast(&D.wi_marray), ptrA, ptrB, reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); + } } } else if constexpr (std::is_same::value || std::is_same void joint_matrix_mad( Group sg, - joint_matrix &D, joint_matrix &A, @@ -820,7 +834,7 @@ void joint_matrix_mad( sycl::ext::oneapi::experimental::matrix::layout::unused, Group> &C) { #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) - sycl::ext::oneapi::detail::joint_matrix_mad_impl{} .mad(D, A, B, C); #else From a8810552225b4cdb479650011766e9de8f5474b6 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Thu, 1 Sep 2022 12:42:04 +0100 Subject: [PATCH 06/10] packed_a, packed_b -> packed Signed-off-by: JackAKirk --- sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp index 551bc914cddb5..9c99ae56200d9 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp @@ -18,7 +18,7 @@ namespace matrix { enum class matrix_use { a, b, accumulator }; -enum class layout { row_major, col_major, packed_a, packed_b, unused }; +enum class layout { row_major, col_major, packed, unused }; namespace precision { class tf32 { From 5b844349e2f752608fbe1f9905e6c543f2d9d026 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Wed, 14 Sep 2022 16:08:35 +0100 Subject: [PATCH 07/10] Made interface compatible with intel backend. Signed-off-by: JackAKirk --- .../sycl/ext/oneapi/matrix/matrix-tensorcore.hpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp index 9c99ae56200d9..d71766ceb72d3 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp @@ -822,7 +822,7 @@ void joint_matrix_store( } template + std::size_t K, std::size_t N, layout LayoutA = sycl::ext::oneapi::experimental::matrix::layout::unused, layout LayoutB = sycl::ext::oneapi::experimental::matrix::layout::unused> void joint_matrix_mad( Group sg, joint_matrix &C) { -#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) sycl::ext::oneapi::detail::joint_matrix_mad_impl{} .mad(D, A, B, C); +#elif defined(__AMDGCN__) +//rocM wmma joint_matrix_mad_impl +#elif defined(__SPIR__) +//intel joint_matrix_mad_impl +#endif // defined(__NVPTX__) #else std::ignore = sg; std::ignore = A; std::ignore = B; std::ignore = C; - throw runtime_error("When using SYCL_EXT_ONEAPI_MATRIX=3 joint_matrix_mad is " - "only supported by CUDA devices", + throw runtime_error("joint_matrix_mad is " + "not supported on HOST", PI_ERROR_INVALID_DEVICE); -#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +#endif // defined(__SYCL_DEVICE_ONLY__) } // This function rounds the bottom 13 bits up or down, and then zeros out the From ccdb544cef1cd50b48dd6d4808d5dec582d55eb4 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Fri, 7 Oct 2022 06:32:51 -0700 Subject: [PATCH 08/10] added unified header, moved nvptx specific impl. Signed-off-by: JackAKirk --- .../sycl/ext/oneapi/matrix/joint-matrix.hpp | 40 + .../ext/oneapi/matrix/matrix-tensor-cores.hpp | 713 ++++++++++++++++++ .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 173 +++++ .../include/sycl/ext/oneapi/matrix/matrix.hpp | 4 + 4 files changed, 930 insertions(+) create mode 100644 sycl/include/sycl/ext/oneapi/matrix/joint-matrix.hpp create mode 100644 sycl/include/sycl/ext/oneapi/matrix/matrix-tensor-cores.hpp create mode 100644 sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp diff --git a/sycl/include/sycl/ext/oneapi/matrix/joint-matrix.hpp b/sycl/include/sycl/ext/oneapi/matrix/joint-matrix.hpp new file mode 100644 index 0000000000000..c93a6789064d4 --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/matrix/joint-matrix.hpp @@ -0,0 +1,40 @@ +//===---- joint-matrix.hpp - SYCL matrix extension joint_matrix ----*- C++ -*---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ===--------------------------------------------------------------------=== // + +namespace sycl { +__SYCL_INLINE_VER_NAMESPACE(_V1) { +namespace ext { +namespace oneapi { +namespace experimental { +namespace matrix { + +enum class matrix_use { a, b, accumulator }; + +enum class layout { row_major, col_major, packed, unused }; + +namespace precision { +class tf32 { + tf32() = delete; +}; +} // namespace precision + +//TODO forward declare jm or?? + +// TODO: how are the default params for Rows/Cols used in Intel backend? +template +struct joint_matrix; + +} // namespace matrix +} // namespace experimental +} // namespace oneapi +} // namespace ext +} // __SYCL_INLINE_VER_NAMESPACE(_V1) +} // namespace sycl diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensor-cores.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensor-cores.hpp new file mode 100644 index 0000000000000..88c22af666de9 --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensor-cores.hpp @@ -0,0 +1,713 @@ + +//===---- matrix-tensor-cores.hpp - SYCL tensor cores matrix ----*- C++ -*---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ===--------------------------------------------------------------------=== // + +#pragma once +#include +#include + +namespace sycl { +__SYCL_INLINE_VER_NAMESPACE(_V1) { +namespace ext { +namespace oneapi { +namespace experimental { +namespace matrix { +//TODO ifdef this stuff! + +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +template class wi_data { + marray &data; + wi_data(marray &wi_data) : data(wi_data){}; + template + friend struct joint_matrix; + +public: + size_t length() { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + return data.size(); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + }; + + type &operator[](size_t i) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + return data[i]; +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + }; +}; + +#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR(type, use, M, N, size) \ + template \ + struct joint_matrix< \ + type, matrix_use::use, M, N, Layout, sycl::sub_group, \ + typename std::enable_if_t> { \ + marray wi_marray; \ + inline __SYCL_ALWAYS_INLINE wi_data get_wi_data() { \ + return wi_data(wi_marray); \ + }; \ + }; + +// m8n32k16 +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 8, 16, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, b, 16, 32, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, a, 8, 16, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 16, 32, 16) + +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 8, 16, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 32, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 8, 16, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 32, 16) +// m32n8k16 +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 32, 16, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, b, 16, 8, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, a, 32, 16, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 16, 8, 16) + +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 32, 16, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 8, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 32, 16, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 8, 4) +// m16n16k16 +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, b, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, a, 16, 16, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 16, 16, 16) + +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 16, 8) +// m8n8k4 double only +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, a, 8, 4, 1) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, b, 4, 8, 1) + +#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR + +#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(type, M, N, size) \ + template <> \ + struct joint_matrix { \ + marray wi_marray; \ + inline __SYCL_ALWAYS_INLINE wi_data get_wi_data() { \ + return wi_data(wi_marray); \ + }; \ + }; + +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(half, 8, 32, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(float, 8, 32, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 8, 32, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(half, 32, 8, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(float, 32, 8, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 32, 8, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(half, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(float, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(double, 8, 8, 2) + +#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC + +#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION(precision, use, M, N, type, \ + size) \ + template \ + struct joint_matrix< \ + precision, matrix_use::use, M, N, Layout, sycl::sub_group, \ + typename std::enable_if_t> { \ + marray wi_marray; \ + inline __SYCL_ALWAYS_INLINE wi_data get_wi_data() { \ +return wi_data(wi_marray); \ + }; \ + }; +// m16n16k8 tf32 only +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION(precision::tf32, a, 16, 8, float, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION(precision::tf32, b, 8, 16, float, 4) + +#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION +} // namespace matrix +} // namespace experimental + + +namespace detail { + +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +template +struct load_multiplicand { + void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< + S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, + multi_ptr src, size_t stride); +}; +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + +template +constexpr int get_layout_id(); + +template <> +constexpr int +get_layout_id() { + return 0; +} + +template <> +constexpr int +get_layout_id() { + return 1; +} + +#if __cplusplus >= 201703L // if constexpr usage +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +template +void load_accumulator_layoutT( + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, NumRows, NumCols, + sycl::ext::oneapi::experimental::matrix::layout::unused, + sycl::sub_group> &res, + multi_ptr src, size_t stride) { + if constexpr (std::is_same, int32_t>::value) { + auto destptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + __imma_m16n16k16_ld_c(destptr, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 32) { + __imma_m8n32k16_ld_c(destptr, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 8) { + __imma_m32n8k16_ld_c(destptr, src.get(), stride, + get_layout_id()); + } + } else if constexpr (std::is_same, float>::value) { + auto dstptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + __hmma_m16n16k16_ld_c_f32(dstptr, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 32) { + __hmma_m8n32k16_ld_c_f32(dstptr, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 8) { + __hmma_m32n8k16_ld_c_f32(dstptr, src.get(), stride, + get_layout_id()); + } + } else if constexpr (std::is_same::value) { + auto tileptr = reinterpret_cast(src.get()); + auto dstptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 32 && NumCols == 8) { + __hmma_m32n8k16_ld_c_f16(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 32) { + __hmma_m8n32k16_ld_c_f16(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 16) { + __hmma_m16n16k16_ld_c_f16(dstptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (std::is_same::value) { + __dmma_m8n8k4_ld_c(reinterpret_cast(&res.wi_marray), src.get(), + stride, get_layout_id()); + } +}; +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +template +void load_accumulator( + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, NumRows, NumCols, + sycl::ext::oneapi::experimental::matrix::layout::unused, + sycl::sub_group> &res, + multi_ptr src, size_t stride, + sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { + switch (LayoutAcc) { + case sycl::ext::oneapi::experimental::matrix::layout::row_major: + load_accumulator_layoutT< + sycl::ext::oneapi::experimental::matrix::layout::row_major>(res, src, + stride); + break; + case sycl::ext::oneapi::experimental::matrix::layout::col_major: + load_accumulator_layoutT< + sycl::ext::oneapi::experimental::matrix::layout::col_major>(res, src, + stride); + break; + default: + assert(false && "Invalid layout specified!"); + } +} +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +template +struct load_multiplicand< + S, T, NumRows, NumCols, Use, Layout, Space, + typename std::enable_if_t< + Layout == sycl::ext::oneapi::experimental::matrix::layout::row_major || + Layout == sycl::ext::oneapi::experimental::matrix::layout::col_major>> { + void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< + S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, + multi_ptr src, size_t stride) { + if constexpr (std::is_same, uint16_t>::value || + std::is_same< + std::remove_const_t, + sycl::ext::oneapi::experimental::bfloat16>::value) { + auto tileptr = reinterpret_cast(src.get()); + auto destptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == + sycl::ext::oneapi::experimental::matrix::matrix_use::a) { + __mma_bf16_m16n16k16_ld_a(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::b) { + __mma_bf16_m16n16k16_ld_b(destptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 16) { + __mma_bf16_m8n32k16_ld_a(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 32) { + __mma_bf16_m8n32k16_ld_b(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 16) { + __mma_bf16_m32n8k16_ld_a(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 8) { + __mma_bf16_m32n8k16_ld_b(destptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (std::is_same, uint8_t>::value) { + auto tileptr = reinterpret_cast(src.get()); + auto destptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == + sycl::ext::oneapi::experimental::matrix::matrix_use::a) { + __imma_m16n16k16_ld_a_u8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::b) { + __imma_m16n16k16_ld_b_u8(destptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 16) { + __imma_m8n32k16_ld_a_u8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 32) { + __imma_m8n32k16_ld_b_u8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 16) { + __imma_m32n8k16_ld_a_u8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 8) { + __imma_m32n8k16_ld_b_u8(destptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (std::is_same, int8_t>::value) { + auto tileptr = reinterpret_cast(src.get()); + auto destptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == + sycl::ext::oneapi::experimental::matrix::matrix_use::a) { + __imma_m16n16k16_ld_a_s8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::b) { + __imma_m16n16k16_ld_b_s8(destptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 16) { + __imma_m8n32k16_ld_a_s8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 32) { + __imma_m8n32k16_ld_b_s8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 16) { + __imma_m32n8k16_ld_a_s8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 8) { + __imma_m32n8k16_ld_b_s8(destptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (std::is_same, half>::value) { + auto tileptr = reinterpret_cast(src.get()); + auto dstptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == + sycl::ext::oneapi::experimental::matrix::matrix_use::a) { + __hmma_m16n16k16_ld_a(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::b) { + __hmma_m16n16k16_ld_b(dstptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 16) { + __hmma_m8n32k16_ld_a(dstptr, tileptr, stride, get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 32) { + __hmma_m8n32k16_ld_b(dstptr, tileptr, stride, get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 16) { + __hmma_m32n8k16_ld_a(dstptr, tileptr, stride, get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 8) { + __hmma_m32n8k16_ld_b(dstptr, tileptr, stride, get_layout_id()); + } + + } else if constexpr (std::is_same::value) { + auto tileptr = reinterpret_cast(src.get()); + auto dstptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 8) { + __mma_tf32_m16n16k8_ld_a(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 16) { + __mma_tf32_m16n16k8_ld_b(dstptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (std::is_same, double>::value) { + auto dstptr = reinterpret_cast(&res.wi_marray); + if constexpr (Use == + sycl::ext::oneapi::experimental::matrix::matrix_use::a) { + __dmma_m8n8k4_ld_a(dstptr, src.get(), stride, get_layout_id()); + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::b) { + __dmma_m8n8k4_ld_b(dstptr, src.get(), stride, get_layout_id()); + } + } + } +}; +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +#endif // __cplusplus >= 201703L + +#if __cplusplus >= 201703L // if constexpr usage +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +template +struct joint_matrix_store_impl { + template + void storeLayoutT( + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, NumRows, NumCols, + sycl::ext::oneapi::experimental::matrix::layout::unused, + sycl::sub_group> &src, + multi_ptr dst, size_t stride) { + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (std::is_same::value) { + __hmma_m16n16k16_st_c_f32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same::value) { + __imma_m16n16k16_st_c_i32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same::value) { + __hmma_m16n16k16_st_c_f16(reinterpret_cast(dst.get()), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 32) { + if constexpr (std::is_same::value) { + __hmma_m8n32k16_st_c_f32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same::value) { + __imma_m8n32k16_st_c_i32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same::value) { + __hmma_m8n32k16_st_c_f16(reinterpret_cast(dst.get()), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } + } else if constexpr (NumRows == 32 && NumCols == 8) { + if constexpr (std::is_same::value) { + __hmma_m32n8k16_st_c_f32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same::value) { + __imma_m32n8k16_st_c_i32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same::value) { + __hmma_m32n8k16_st_c_f16(reinterpret_cast(dst.get()), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } + } else if constexpr (std::is_same::value) { + __dmma_m8n8k4_st_c_f64(dst.get(), + reinterpret_cast(&src.wi_marray), stride, + get_layout_id()); + } + } + void + store(sycl::ext::oneapi::experimental::matrix::joint_matrix< + T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, NumRows, NumCols, + sycl::ext::oneapi::experimental::matrix::layout::unused, + sycl::sub_group> &src, + multi_ptr dst, size_t stride, + sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { + switch (LayoutAcc) { + case sycl::ext::oneapi::experimental::matrix::layout::row_major: + storeLayoutT( + src, dst, stride); + break; + case sycl::ext::oneapi::experimental::matrix::layout::col_major: + storeLayoutT( + src, dst, stride); + break; + default: + assert(false && "Invalid layout specified!"); + } + } +}; +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +#endif // __cplusplus >= 201703L + +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +template +struct joint_matrix_mad_impl { + void mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< + T3, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, + sycl::ext::oneapi::experimental::matrix::layout::unused, + sycl::sub_group> &D, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K, + LayoutA, sycl::sub_group> &A, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N, + LayoutB, sycl::sub_group> &B, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, + sycl::ext::oneapi::experimental::matrix::layout::unused, + sycl::sub_group> &C); +}; +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + +template +constexpr int get_layout_pair_id(); + +template <> +constexpr int get_layout_pair_id< + sycl::ext::oneapi::experimental::matrix::layout::row_major, + sycl::ext::oneapi::experimental::matrix::layout::row_major>() { + return 0; +} + +template <> +constexpr int get_layout_pair_id< + sycl::ext::oneapi::experimental::matrix::layout::row_major, + sycl::ext::oneapi::experimental::matrix::layout::col_major>() { + return 1; +} + +template <> +constexpr int get_layout_pair_id< + sycl::ext::oneapi::experimental::matrix::layout::col_major, + sycl::ext::oneapi::experimental::matrix::layout::row_major>() { + return 2; +} + +template <> +constexpr int get_layout_pair_id< + sycl::ext::oneapi::experimental::matrix::layout::col_major, + sycl::ext::oneapi::experimental::matrix::layout::col_major>() { + return 3; +} + +#if __cplusplus >= 201703L // if constexpr usage +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +template +struct joint_matrix_mad_impl< + T1, T2, T3, M, K, N, LayoutA, LayoutB, + typename std::enable_if_t< + (LayoutA == + sycl::ext::oneapi::experimental::matrix::layout::row_major || + LayoutA == + sycl::ext::oneapi::experimental::matrix::layout::col_major) && + (LayoutB == + sycl::ext::oneapi::experimental::matrix::layout::row_major || + LayoutB == + sycl::ext::oneapi::experimental::matrix::layout::col_major)>> { + void mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< + T3, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, + sycl::ext::oneapi::experimental::matrix::layout::unused, + sycl::sub_group> &D, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K, + LayoutA, sycl::sub_group> &A, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N, + LayoutB, sycl::sub_group> &B, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, + sycl::ext::oneapi::experimental::matrix::layout::unused, + sycl::sub_group> &C) { + if constexpr (M == 16 && N == 16 && K == 16) { + if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrC = reinterpret_cast(&C.wi_marray); + auto ptrD = reinterpret_cast(&D.wi_marray); + if constexpr (std::is_same::value) { + __imma_m16n16k16_mma_s8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __imma_m16n16k16_mma_u8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } + } else if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { + __hmma_m16n16k16_mma_f32f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else { + __hmma_m16n16k16_mma_f16f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } + } else if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { + __hmma_m16n16k16_mma_f32f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else { + __hmma_m16n16k16_mma_f16f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } + } + } else if constexpr (std::is_same::value || + std::is_same::value) { + __mma_bf16_m16n16k16_mma_f32( + reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } + } else if constexpr (M == 8 && N == 32 && K == 16) { + if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrC = reinterpret_cast(&C.wi_marray); + auto ptrD = reinterpret_cast(&D.wi_marray); + if constexpr (std::is_same::value) { + __imma_m8n32k16_mma_s8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __imma_m8n32k16_mma_u8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } + } else if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + if constexpr (std::is_same::value) { + __hmma_m8n32k16_mma_f32f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __hmma_m8n32k16_mma_f16f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } + } else if constexpr (std::is_same::value || + std::is_same::value) { + __mma_bf16_m8n32k16_mma_f32( + reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } + } else if constexpr (M == 32 && N == 8 && K == 16) { + if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrC = reinterpret_cast(&C.wi_marray); + auto ptrD = reinterpret_cast(&D.wi_marray); + if constexpr (std::is_same::value) { + __imma_m32n8k16_mma_s8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __imma_m32n8k16_mma_u8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } + } else if constexpr (std::is_same::value || + std::is_same::value) { + __mma_bf16_m32n8k16_mma_f32( + reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + if constexpr (std::is_same::value) { + __hmma_m32n8k16_mma_f32f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __hmma_m32n8k16_mma_f16f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } + } + } else if constexpr (M == 16 && N == 16 && K == 8) { + __mma_tf32_m16n16k8_mma_f32(reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __dmma_m8n8k4_mma_f64(reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } + } +}; +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +#endif // __cplusplus >= 201703L + +} // namespace detail +} // namespace oneapi +} // namespace ext +} // __SYCL_INLINE_VER_NAMESPACE(_V1) +} // namespace sycl + diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp new file mode 100644 index 0000000000000..f236e12198986 --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -0,0 +1,173 @@ +//===---- matrix-unified.hpp - SYCL matrix extension ----*- C++ -*---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ===--------------------------------------------------------------------=== // + +//todo is bfloat16 necessary? +#pragma once +//#include +//#include +#include + +namespace sycl { +__SYCL_INLINE_VER_NAMESPACE(_V1) { +namespace ext { +namespace oneapi { +namespace experimental { +namespace matrix { + + +template +inline __SYCL_ALWAYS_INLINE void +joint_matrix_fill(Group sg, + joint_matrix &res, + const T2 v) { + // We kept the unused "sg" in joint_matrix_fill to match the other DPC++ + // functions + std::ignore = sg; +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + res.wi_marray = v; +#else + std::ignore = res; + std::ignore = v; +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + +// TODO Two typenames (S and T) not required in CUDA backend but included for +// Intel backend requirement. +template ::value, bool> = true> +void joint_matrix_load( + Group sg, + joint_matrix + &res, + multi_ptr src, size_t stride, + sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + sycl::ext::oneapi::detail::load_accumulator(res, src, stride, LayoutAcc); +#else + std::ignore = sg; + std::ignore = res; + std::ignore = src; + std::ignore = stride; + throw runtime_error( + "When using SYCL_EXT_ONEAPI_MATRIX=3 joint_matrix_load is " + "only supported by CUDA devices", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + + +template < + typename Group, typename S, typename T, matrix_use Use, size_t NumRows, + size_t NumCols, matrix::layout Layout, access::address_space Space, + std::enable_if_t>::value || + (std::is_same::value && + + std::is_same, float>::value), + bool> = true> +void joint_matrix_load( + Group sg, joint_matrix &res, + multi_ptr src, size_t stride) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + sycl::ext::oneapi::detail::load_multiplicand{} + .load(res, src, stride); +#else + std::ignore = sg; + std::ignore = res; + std::ignore = src; + std::ignore = stride; + throw runtime_error( + "When using SYCL_EXT_ONEAPI_MATRIX=3 joint_matrix_load is " + "only supported by CUDA devices", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + +template +void joint_matrix_store( + Group sg, + joint_matrix + &src, + multi_ptr dst, size_t stride, + sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + sycl::ext::oneapi::detail::joint_matrix_store_impl{} + .store(src, dst, stride, LayoutAcc); +#else + std::ignore = sg; + std::ignore = src; + std::ignore = dst; + std::ignore = stride; + throw runtime_error( + "When using SYCL_EXT_ONEAPI_MATRIX=3 joint_matrix_store is " + "only supported by CUDA devices", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + +template +void joint_matrix_mad( + Group sg, + joint_matrix + &D, + joint_matrix &A, + joint_matrix &B, + joint_matrix + &C) { +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) + sycl::ext::oneapi::detail::joint_matrix_mad_impl{} + .mad(D, A, B, C); +#elif defined(__AMDGCN__) +//rocM wmma joint_matrix_mad_impl +#elif defined(__SPIR__) +//intel joint_matrix_mad_impl +#endif // defined(__NVPTX__) +#else + std::ignore = sg; + std::ignore = A; + std::ignore = B; + std::ignore = C; + throw runtime_error("joint_matrix_mad is " + "not supported on HOST", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) +} + +// This function rounds the bottom 13 bits up or down, and then zeros out the +// bottom bits +float round_to_tf32(float a) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + int32_t tmp_int = __nvvm_f2tf32_rna(a); + return __nvvm_bitcast_i2f(tmp_int); +#else + uint32_t tmp_uint = reinterpret_cast(a); + tmp_uint += 0x1000u; + tmp_uint &= 0xFFFFE000u; + float ret = reinterpret_cast(tmp_uint); + return ret; +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + +} // namespace matrix +} // namespace experimental +} // namespace oneapi +} // namespace ext +} // __SYCL_INLINE_VER_NAMESPACE(_V1) +} // namespace sycl + diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp index ecfad58259cc2..c9e5d46764260 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp @@ -29,3 +29,7 @@ #if (SYCL_EXT_ONEAPI_MATRIX == 3) #include #endif +#if (SYCL_EXT_ONEAPI_MATRIX == 4) +#include +#endif + From 46e87a11a7063af00ac5d82933020751f0fdde7c Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Fri, 7 Oct 2022 22:04:05 +0100 Subject: [PATCH 09/10] (very) draft updated interfaces. Signed-off-by: JackAKirk --- .../sycl/ext/oneapi/matrix/joint-matrix.hpp | 4 +- .../ext/oneapi/matrix/matrix-tensor-cores.hpp | 112 ++++++++---------- .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 40 ++++--- 3 files changed, 74 insertions(+), 82 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/joint-matrix.hpp b/sycl/include/sycl/ext/oneapi/matrix/joint-matrix.hpp index c93a6789064d4..a408579709752 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/joint-matrix.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/joint-matrix.hpp @@ -15,7 +15,7 @@ namespace matrix { enum class matrix_use { a, b, accumulator }; -enum class layout { row_major, col_major, packed, unused }; +enum class layout { row_major, col_major, packed, dynamic }; namespace precision { class tf32 { @@ -28,7 +28,7 @@ class tf32 { // TODO: how are the default params for Rows/Cols used in Intel backend? template struct joint_matrix; diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensor-cores.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensor-cores.hpp index 8e8de92165572..cab158cfc7b68 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensor-cores.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensor-cores.hpp @@ -17,7 +17,7 @@ namespace ext { namespace oneapi { namespace experimental { namespace matrix { -//TODO ifdef this stuff! + template class wi_data { marray &data; @@ -96,7 +96,7 @@ __SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, b, 4, 8, 1) #define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(type, M, N, size) \ template <> \ - struct joint_matrix { \ marray wi_marray; \ inline __SYCL_ALWAYS_INLINE wi_data get_wi_data() { \ @@ -150,7 +150,6 @@ struct load_multiplicand { S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, multi_ptr src, size_t stride); }; -#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) template constexpr int get_layout_id(); @@ -168,13 +167,12 @@ get_layout_id() { } #if __cplusplus >= 201703L // if constexpr usage -#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) template void load_accumulator_layoutT( sycl::ext::oneapi::experimental::matrix::joint_matrix< T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, NumRows, NumCols, - sycl::ext::oneapi::experimental::matrix::layout::unused, + sycl::ext::oneapi::experimental::matrix::layout::dynamic, sycl::sub_group> &res, multi_ptr src, size_t stride) { if constexpr (std::is_same, int32_t>::value) { @@ -219,15 +217,14 @@ void load_accumulator_layoutT( stride, get_layout_id()); } }; -#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +#endif // __cplusplus >= 201703L -#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) template void load_accumulator( sycl::ext::oneapi::experimental::matrix::joint_matrix< T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, NumRows, NumCols, - sycl::ext::oneapi::experimental::matrix::layout::unused, + sycl::ext::oneapi::experimental::matrix::layout::dynamic, sycl::sub_group> &res, multi_ptr src, size_t stride, sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { @@ -246,9 +243,8 @@ void load_accumulator( assert(false && "Invalid layout specified!"); } } -#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) -#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +#if __cplusplus >= 201703L // if constexpr usage template = 201703L #if __cplusplus >= 201703L // if constexpr usage -#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) template struct joint_matrix_store_impl { @@ -401,7 +395,7 @@ struct joint_matrix_store_impl { void storeLayoutT( sycl::ext::oneapi::experimental::matrix::joint_matrix< T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, NumRows, NumCols, - sycl::ext::oneapi::experimental::matrix::layout::unused, + sycl::ext::oneapi::experimental::matrix::layout::dynamic, sycl::sub_group> &src, multi_ptr dst, size_t stride) { if constexpr (NumRows == 16 && NumCols == 16) { @@ -455,7 +449,7 @@ struct joint_matrix_store_impl { void store(sycl::ext::oneapi::experimental::matrix::joint_matrix< T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, NumRows, NumCols, - sycl::ext::oneapi::experimental::matrix::layout::unused, + sycl::ext::oneapi::experimental::matrix::layout::dynamic, sycl::sub_group> &src, multi_ptr dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { @@ -473,31 +467,28 @@ struct joint_matrix_store_impl { } } }; -#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) #endif // __cplusplus >= 201703L -#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) -template struct joint_matrix_mad_impl { void mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< - T3, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, - sycl::ext::oneapi::experimental::matrix::layout::unused, + Td, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, + sycl::ext::oneapi::experimental::matrix::layout::dynamic, sycl::sub_group> &D, sycl::ext::oneapi::experimental::matrix::joint_matrix< - T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K, + Tm, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K, LayoutA, sycl::sub_group> &A, sycl::ext::oneapi::experimental::matrix::joint_matrix< - T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N, + Tm, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N, LayoutB, sycl::sub_group> &B, sycl::ext::oneapi::experimental::matrix::joint_matrix< - T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, - sycl::ext::oneapi::experimental::matrix::layout::unused, + Tc, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, + sycl::ext::oneapi::experimental::matrix::layout::dynamic, sycl::sub_group> &C); }; -#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) template @@ -532,12 +523,11 @@ constexpr int get_layout_pair_id< } #if __cplusplus >= 201703L // if constexpr usage -#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) -template struct joint_matrix_mad_impl< - T1, T2, T3, M, K, N, LayoutA, LayoutB, + Tm, Tc, Td, M, K, N, LayoutA, LayoutB, typename std::enable_if_t< (LayoutA == sycl::ext::oneapi::experimental::matrix::layout::row_major || @@ -548,37 +538,37 @@ struct joint_matrix_mad_impl< LayoutB == sycl::ext::oneapi::experimental::matrix::layout::col_major)>> { void mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< - T3, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, - sycl::ext::oneapi::experimental::matrix::layout::unused, + Td, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, + sycl::ext::oneapi::experimental::matrix::layout::dynamic, sycl::sub_group> &D, sycl::ext::oneapi::experimental::matrix::joint_matrix< - T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K, + Tm, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K, LayoutA, sycl::sub_group> &A, sycl::ext::oneapi::experimental::matrix::joint_matrix< - T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N, + Tm, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N, LayoutB, sycl::sub_group> &B, sycl::ext::oneapi::experimental::matrix::joint_matrix< - T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, - sycl::ext::oneapi::experimental::matrix::layout::unused, + Tc, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, + sycl::ext::oneapi::experimental::matrix::layout::dynamic, sycl::sub_group> &C) { if constexpr (M == 16 && N == 16 && K == 16) { - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { auto ptrA = reinterpret_cast(&A.wi_marray); auto ptrB = reinterpret_cast(&B.wi_marray); auto ptrC = reinterpret_cast(&C.wi_marray); auto ptrD = reinterpret_cast(&D.wi_marray); - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { __imma_m16n16k16_mma_s8(ptrD, ptrA, ptrB, ptrC, get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { __imma_m16n16k16_mma_u8(ptrD, ptrA, ptrB, ptrC, get_layout_pair_id(), 0); } - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { auto ptrA = reinterpret_cast(&A.wi_marray); auto ptrB = reinterpret_cast(&B.wi_marray); - if constexpr (std::is_same::value) { - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { __hmma_m16n16k16_mma_f32f32( reinterpret_cast(&D.wi_marray), ptrA, ptrB, reinterpret_cast(&C.wi_marray), @@ -589,8 +579,8 @@ struct joint_matrix_mad_impl< reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } - } else if constexpr (std::is_same::value) { - if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { __hmma_m16n16k16_mma_f32f16( reinterpret_cast(&D.wi_marray), ptrA, ptrB, reinterpret_cast(&C.wi_marray), @@ -602,8 +592,8 @@ struct joint_matrix_mad_impl< get_layout_pair_id(), 0); } } - } else if constexpr (std::is_same::value || - std::is_same::value || + std::is_same::value) { __mma_bf16_m16n16k16_mma_f32( reinterpret_cast(&D.wi_marray), @@ -613,34 +603,34 @@ struct joint_matrix_mad_impl< get_layout_pair_id(), 0); } } else if constexpr (M == 8 && N == 32 && K == 16) { - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { auto ptrA = reinterpret_cast(&A.wi_marray); auto ptrB = reinterpret_cast(&B.wi_marray); auto ptrC = reinterpret_cast(&C.wi_marray); auto ptrD = reinterpret_cast(&D.wi_marray); - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { __imma_m8n32k16_mma_s8(ptrD, ptrA, ptrB, ptrC, get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { __imma_m8n32k16_mma_u8(ptrD, ptrA, ptrB, ptrC, get_layout_pair_id(), 0); } - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { auto ptrA = reinterpret_cast(&A.wi_marray); auto ptrB = reinterpret_cast(&B.wi_marray); - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { __hmma_m8n32k16_mma_f32f32( reinterpret_cast(&D.wi_marray), ptrA, ptrB, reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { __hmma_m8n32k16_mma_f16f16( reinterpret_cast(&D.wi_marray), ptrA, ptrB, reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } - } else if constexpr (std::is_same::value || - std::is_same::value || + std::is_same::value) { __mma_bf16_m8n32k16_mma_f32( reinterpret_cast(&D.wi_marray), @@ -650,20 +640,20 @@ struct joint_matrix_mad_impl< get_layout_pair_id(), 0); } } else if constexpr (M == 32 && N == 8 && K == 16) { - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { auto ptrA = reinterpret_cast(&A.wi_marray); auto ptrB = reinterpret_cast(&B.wi_marray); auto ptrC = reinterpret_cast(&C.wi_marray); auto ptrD = reinterpret_cast(&D.wi_marray); - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { __imma_m32n8k16_mma_s8(ptrD, ptrA, ptrB, ptrC, get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { __imma_m32n8k16_mma_u8(ptrD, ptrA, ptrB, ptrC, get_layout_pair_id(), 0); } - } else if constexpr (std::is_same::value || - std::is_same::value || + std::is_same::value) { __mma_bf16_m32n8k16_mma_f32( reinterpret_cast(&D.wi_marray), @@ -671,15 +661,15 @@ struct joint_matrix_mad_impl< reinterpret_cast(&B.wi_marray), reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { auto ptrA = reinterpret_cast(&A.wi_marray); auto ptrB = reinterpret_cast(&B.wi_marray); - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { __hmma_m32n8k16_mma_f32f32( reinterpret_cast(&D.wi_marray), ptrA, ptrB, reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { __hmma_m32n8k16_mma_f16f16( reinterpret_cast(&D.wi_marray), ptrA, ptrB, reinterpret_cast(&C.wi_marray), @@ -692,7 +682,7 @@ struct joint_matrix_mad_impl< reinterpret_cast(&B.wi_marray), reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { __dmma_m8n8k4_mma_f64(reinterpret_cast(&D.wi_marray), reinterpret_cast(&A.wi_marray), reinterpret_cast(&B.wi_marray), @@ -701,8 +691,8 @@ struct joint_matrix_mad_impl< } } }; -#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) #endif // __cplusplus >= 201703L +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) } // namespace detail } // namespace oneapi diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index f236e12198986..0db8bfb4d6b6f 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -6,11 +6,12 @@ // // ===--------------------------------------------------------------------=== // -//todo is bfloat16 necessary? #pragma once -//#include -//#include +#if defined(CUDA_MATRIX) #include +#else +//#include +#endif // defined(CUDA_MATRIX) namespace sycl { __SYCL_INLINE_VER_NAMESPACE(_V1) { @@ -19,14 +20,13 @@ namespace oneapi { namespace experimental { namespace matrix { - template inline __SYCL_ALWAYS_INLINE void joint_matrix_fill(Group sg, joint_matrix &res, const T2 v) { - // We kept the unused "sg" in joint_matrix_fill to match the other DPC++ + // We kept the dynamic "sg" in joint_matrix_fill to match the other DPC++ // functions std::ignore = sg; #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) @@ -38,14 +38,14 @@ joint_matrix_fill(Group sg, } // TODO Two typenames (S and T) not required in CUDA backend but included for -// Intel backend requirement. +// Intel backend requirement. TODO: check if this is still the case!! template ::value, bool> = true> void joint_matrix_load( Group sg, joint_matrix + sycl::ext::oneapi::experimental::matrix::layout::dynamic, Group> &res, multi_ptr src, size_t stride, sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { @@ -96,7 +96,7 @@ template + sycl::ext::oneapi::experimental::matrix::layout::dynamic, Group> &src, multi_ptr dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { @@ -116,21 +116,23 @@ void joint_matrix_store( #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) } -template -void joint_matrix_mad( +template +joint_matrix + joint_matrix_mad( Group sg, - joint_matrix - &D, - joint_matrix &A, - joint_matrix &B, - joint_matrix + joint_matrix &A, + joint_matrix &B, + joint_matrix &C) { #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) - sycl::ext::oneapi::detail::joint_matrix_mad_impl + D; + sycl::ext::oneapi::detail::joint_matrix_mad_impl{} .mad(D, A, B, C); #elif defined(__AMDGCN__) From 766fd8c1ba5be4748932f046ebce3bc42bbd0e71 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Mon, 10 Oct 2022 13:32:19 -0700 Subject: [PATCH 10/10] cuda joint_matrix partial specializations in separate file. Signed-off-by: JackAKirk --- .../oneapi/matrix/joint-matrix-cuda-impl.hpp | 141 ++ .../sycl/ext/oneapi/matrix/joint-matrix.hpp | 51 +- .../ext/oneapi/matrix/matrix-tensor-cores.hpp | 1133 ++++++++--------- ...core.hpp => matrix-tensorcores-legacy.hpp} | 0 .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 302 +++-- .../include/sycl/ext/oneapi/matrix/matrix.hpp | 11 +- 6 files changed, 852 insertions(+), 786 deletions(-) create mode 100644 sycl/include/sycl/ext/oneapi/matrix/joint-matrix-cuda-impl.hpp rename sycl/include/sycl/ext/oneapi/matrix/{matrix-tensorcore.hpp => matrix-tensorcores-legacy.hpp} (100%) diff --git a/sycl/include/sycl/ext/oneapi/matrix/joint-matrix-cuda-impl.hpp b/sycl/include/sycl/ext/oneapi/matrix/joint-matrix-cuda-impl.hpp new file mode 100644 index 0000000000000..6d289014e1ada --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/matrix/joint-matrix-cuda-impl.hpp @@ -0,0 +1,141 @@ +// joint-matrix-cuda-impl.hpp - joint_matrix cuda specializations-*- C++ -*---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ===-----------------------------------------------------------------------=== // + +#include + +namespace sycl { +__SYCL_INLINE_VER_NAMESPACE(_V1) { + namespace ext { + namespace oneapi { + namespace experimental { + namespace matrix { + + template class wi_data { + marray &data; + wi_data(marray &wi_data) : data(wi_data){}; + template + friend struct joint_matrix; + + public: + size_t length() { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + return data.size(); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + }; + + type &operator[](size_t i) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + return data[i]; +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + }; + }; + +#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR(type, use, M, N, size) \ + template \ + struct joint_matrix< \ + type, matrix_use::use, M, N, Layout, sycl::sub_group, \ + typename std::enable_if_t> { \ + marray wi_marray; \ + inline __SYCL_ALWAYS_INLINE wi_data get_wi_data() { \ + return wi_data(wi_marray); \ + }; \ + }; + + // m8n32k16 + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 8, 16, 4) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, b, 16, 32, 16) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, a, 8, 16, 16) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 16, 32, 16) + + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 8, 16, 4) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 32, 16) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 8, 16, 4) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 32, 16) + // m32n8k16 + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 32, 16, 16) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, b, 16, 8, 4) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, a, 32, 16, 16) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 16, 8, 16) + + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 32, 16, 16) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 8, 4) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 32, 16, 16) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 8, 4) + // m16n16k16 + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 16, 16, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, b, 16, 16, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, a, 16, 16, 16) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 16, 16, 16) + + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 16, 16, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 16, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 16, 16, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 16, 8) + // m8n8k4 double only + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, a, 8, 4, 1) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, b, 4, 8, 1) + +#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR + +#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(type, M, N, size) \ + template <> \ + struct joint_matrix { \ + marray wi_marray; \ + inline __SYCL_ALWAYS_INLINE wi_data get_wi_data() { \ + return wi_data(wi_marray); \ + }; \ + }; + + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(half, 8, 32, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(float, 8, 32, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 8, 32, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(half, 32, 8, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(float, 32, 8, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 32, 8, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(half, 16, 16, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(float, 16, 16, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 16, 16, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(double, 8, 8, 2) + +#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC + +#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION(precision, use, M, N, type, \ + size) \ + template \ + struct joint_matrix< \ + precision, matrix_use::use, M, N, Layout, sycl::sub_group, \ + typename std::enable_if_t> { \ + marray wi_marray; \ + inline __SYCL_ALWAYS_INLINE wi_data get_wi_data() { \ + return wi_data(wi_marray); \ + }; \ + }; + // m16n16k8 tf32 only + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION(precision::tf32, a, 16, 8, float, + 4) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION(precision::tf32, b, 8, 16, float, + 4) + +#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION + + } // namespace matrix + } // namespace experimental + } // namespace oneapi + } // namespace ext +} // __SYCL_INLINE_VER_NAMESPACE(_V1) +} // namespace sycl diff --git a/sycl/include/sycl/ext/oneapi/matrix/joint-matrix.hpp b/sycl/include/sycl/ext/oneapi/matrix/joint-matrix.hpp index a408579709752..454db696d6e28 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/joint-matrix.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/joint-matrix.hpp @@ -1,40 +1,43 @@ -//===---- joint-matrix.hpp - SYCL matrix extension joint_matrix ----*- C++ -*---===// +//===---- joint-matrix.hpp - SYCL matrix extension joint_matrix ----*- C++ +//-*---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // // ===--------------------------------------------------------------------=== // +#ifndef JOINT_MATRIX +#define JOINT_MATRIX namespace sycl { __SYCL_INLINE_VER_NAMESPACE(_V1) { -namespace ext { -namespace oneapi { -namespace experimental { -namespace matrix { + namespace ext { + namespace oneapi { + namespace experimental { + namespace matrix { -enum class matrix_use { a, b, accumulator }; + enum class matrix_use { a, b, accumulator }; -enum class layout { row_major, col_major, packed, dynamic }; + enum class layout { row_major, col_major, packed, dynamic }; -namespace precision { -class tf32 { - tf32() = delete; -}; -} // namespace precision + namespace precision { + class tf32 { + tf32() = delete; + }; + } // namespace precision -//TODO forward declare jm or?? + // TODO: how are the default params for Rows/Cols used in Intel backend? + // TODO: could we use Cond to distinguish between Intel AMX and tensor cores + // joint_matrix definitions: e.g. Rows == dynamic_extent? + template + struct joint_matrix; -// TODO: how are the default params for Rows/Cols used in Intel backend? -template -struct joint_matrix; - -} // namespace matrix -} // namespace experimental -} // namespace oneapi -} // namespace ext + } // namespace matrix + } // namespace experimental + } // namespace oneapi + } // namespace ext } // __SYCL_INLINE_VER_NAMESPACE(_V1) } // namespace sycl +#endif diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensor-cores.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensor-cores.hpp index cab158cfc7b68..80d0739189eb4 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensor-cores.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensor-cores.hpp @@ -13,355 +13,243 @@ namespace sycl { __SYCL_INLINE_VER_NAMESPACE(_V1) { -namespace ext { -namespace oneapi { -namespace experimental { -namespace matrix { + namespace ext { + namespace oneapi { + namespace detail { -template class wi_data { - marray &data; - wi_data(marray &wi_data) : data(wi_data){}; - template - friend struct joint_matrix; - -public: - size_t length() { -#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) - return data.size(); -#else - throw runtime_error("joint matrix is not supported on host device.", - PI_ERROR_INVALID_DEVICE); -#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) - }; - - type &operator[](size_t i) { #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) - return data[i]; -#else - throw runtime_error("joint matrix is not supported on host device.", - PI_ERROR_INVALID_DEVICE); -#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) - }; -}; - -#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR(type, use, M, N, size) \ - template \ - struct joint_matrix< \ - type, matrix_use::use, M, N, Layout, sycl::sub_group, \ - typename std::enable_if_t> { \ - marray wi_marray; \ - inline __SYCL_ALWAYS_INLINE wi_data get_wi_data() { \ - return wi_data(wi_marray); \ - }; \ - }; - -// m8n32k16 -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 8, 16, 4) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, b, 16, 32, 16) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, a, 8, 16, 16) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 16, 32, 16) - -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 8, 16, 4) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 32, 16) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 8, 16, 4) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 32, 16) -// m32n8k16 -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 32, 16, 16) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, b, 16, 8, 4) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, a, 32, 16, 16) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 16, 8, 16) - -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 32, 16, 16) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 8, 4) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 32, 16, 16) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 8, 4) -// m16n16k16 -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 16, 16, 8) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, b, 16, 16, 8) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, a, 16, 16, 16) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 16, 16, 16) - -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 16, 16, 8) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 16, 8) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 16, 16, 8) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 16, 8) -// m8n8k4 double only -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, a, 8, 4, 1) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, b, 4, 8, 1) - -#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR - -#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(type, M, N, size) \ - template <> \ - struct joint_matrix { \ - marray wi_marray; \ - inline __SYCL_ALWAYS_INLINE wi_data get_wi_data() { \ - return wi_data(wi_marray); \ - }; \ - }; - -__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(half, 8, 32, 8) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(float, 8, 32, 8) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 8, 32, 8) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(half, 32, 8, 8) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(float, 32, 8, 8) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 32, 8, 8) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(half, 16, 16, 8) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(float, 16, 16, 8) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 16, 16, 8) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(double, 8, 8, 2) - -#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC - -#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION(precision, use, M, N, type, \ - size) \ - template \ - struct joint_matrix< \ - precision, matrix_use::use, M, N, Layout, sycl::sub_group, \ - typename std::enable_if_t> { \ - marray wi_marray; \ - inline __SYCL_ALWAYS_INLINE wi_data get_wi_data() { \ -return wi_data(wi_marray); \ - }; \ + template + struct load_multiplicand_cuda { + void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< + S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, + multi_ptr src, size_t stride); }; -// m16n16k8 tf32 only -__SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION(precision::tf32, a, 16, 8, float, 4) -__SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION(precision::tf32, b, 8, 16, float, 4) -#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION -} // namespace matrix -} // namespace experimental + template + constexpr int get_layout_id(); - -namespace detail { - -#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) -template -struct load_multiplicand { - void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< - S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, - multi_ptr src, size_t stride); -}; - -template -constexpr int get_layout_id(); - -template <> -constexpr int -get_layout_id() { - return 0; -} - -template <> -constexpr int -get_layout_id() { - return 1; -} - -#if __cplusplus >= 201703L // if constexpr usage -template -void load_accumulator_layoutT( - sycl::ext::oneapi::experimental::matrix::joint_matrix< - T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, NumRows, NumCols, - sycl::ext::oneapi::experimental::matrix::layout::dynamic, - sycl::sub_group> &res, - multi_ptr src, size_t stride) { - if constexpr (std::is_same, int32_t>::value) { - auto destptr = reinterpret_cast(&res.wi_marray); - if constexpr (NumRows == 16 && NumCols == 16) { - __imma_m16n16k16_ld_c(destptr, src.get(), stride, - get_layout_id()); - } else if constexpr (NumRows == 8 && NumCols == 32) { - __imma_m8n32k16_ld_c(destptr, src.get(), stride, - get_layout_id()); - } else if constexpr (NumRows == 32 && NumCols == 8) { - __imma_m32n8k16_ld_c(destptr, src.get(), stride, - get_layout_id()); - } - } else if constexpr (std::is_same, float>::value) { - auto dstptr = reinterpret_cast(&res.wi_marray); - if constexpr (NumRows == 16 && NumCols == 16) { - __hmma_m16n16k16_ld_c_f32(dstptr, src.get(), stride, - get_layout_id()); - } else if constexpr (NumRows == 8 && NumCols == 32) { - __hmma_m8n32k16_ld_c_f32(dstptr, src.get(), stride, - get_layout_id()); - } else if constexpr (NumRows == 32 && NumCols == 8) { - __hmma_m32n8k16_ld_c_f32(dstptr, src.get(), stride, - get_layout_id()); - } - } else if constexpr (std::is_same::value) { - auto tileptr = reinterpret_cast(src.get()); - auto dstptr = reinterpret_cast(&res.wi_marray); - if constexpr (NumRows == 32 && NumCols == 8) { - __hmma_m32n8k16_ld_c_f16(dstptr, tileptr, stride, - get_layout_id()); - } else if constexpr (NumRows == 8 && NumCols == 32) { - __hmma_m8n32k16_ld_c_f16(dstptr, tileptr, stride, - get_layout_id()); - } else if constexpr (NumRows == 16 && NumCols == 16) { - __hmma_m16n16k16_ld_c_f16(dstptr, tileptr, stride, - get_layout_id()); - } - } else if constexpr (std::is_same::value) { - __dmma_m8n8k4_ld_c(reinterpret_cast(&res.wi_marray), src.get(), - stride, get_layout_id()); + template <> + constexpr int + get_layout_id() { + return 0; } -}; -#endif // __cplusplus >= 201703L -template -void load_accumulator( - sycl::ext::oneapi::experimental::matrix::joint_matrix< - T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, NumRows, NumCols, - sycl::ext::oneapi::experimental::matrix::layout::dynamic, - sycl::sub_group> &res, - multi_ptr src, size_t stride, - sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { - switch (LayoutAcc) { - case sycl::ext::oneapi::experimental::matrix::layout::row_major: - load_accumulator_layoutT< - sycl::ext::oneapi::experimental::matrix::layout::row_major>(res, src, - stride); - break; - case sycl::ext::oneapi::experimental::matrix::layout::col_major: - load_accumulator_layoutT< - sycl::ext::oneapi::experimental::matrix::layout::col_major>(res, src, - stride); - break; - default: - assert(false && "Invalid layout specified!"); + template <> + constexpr int + get_layout_id() { + return 1; } -} #if __cplusplus >= 201703L // if constexpr usage -template -struct load_multiplicand< - S, T, NumRows, NumCols, Use, Layout, Space, - typename std::enable_if_t< - Layout == sycl::ext::oneapi::experimental::matrix::layout::row_major || - Layout == sycl::ext::oneapi::experimental::matrix::layout::col_major>> { - void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< - S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, - multi_ptr src, size_t stride) { - if constexpr (std::is_same, uint16_t>::value || - std::is_same< - std::remove_const_t, - sycl::ext::oneapi::experimental::bfloat16>::value) { - auto tileptr = reinterpret_cast(src.get()); + template + void load_accumulator_layoutT( + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + NumRows, NumCols, + sycl::ext::oneapi::experimental::matrix::layout::dynamic, + sycl::sub_group> &res, + multi_ptr src, size_t stride) { + if constexpr (std::is_same, int32_t>::value) { auto destptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { - if constexpr (Use == - sycl::ext::oneapi::experimental::matrix::matrix_use::a) { - __mma_bf16_m16n16k16_ld_a(destptr, tileptr, stride, - get_layout_id()); - } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: - matrix_use::b) { - __mma_bf16_m16n16k16_ld_b(destptr, tileptr, stride, - get_layout_id()); - } - } else if constexpr (NumRows == 8 && NumCols == 16) { - __mma_bf16_m8n32k16_ld_a(destptr, tileptr, stride, - get_layout_id()); - } else if constexpr (NumRows == 16 && NumCols == 32) { - __mma_bf16_m8n32k16_ld_b(destptr, tileptr, stride, - get_layout_id()); - } else if constexpr (NumRows == 32 && NumCols == 16) { - __mma_bf16_m32n8k16_ld_a(destptr, tileptr, stride, - get_layout_id()); - } else if constexpr (NumRows == 16 && NumCols == 8) { - __mma_bf16_m32n8k16_ld_b(destptr, tileptr, stride, - get_layout_id()); + __imma_m16n16k16_ld_c(destptr, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 32) { + __imma_m8n32k16_ld_c(destptr, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 8) { + __imma_m32n8k16_ld_c(destptr, src.get(), stride, + get_layout_id()); } - } else if constexpr (std::is_same, uint8_t>::value) { - auto tileptr = reinterpret_cast(src.get()); - auto destptr = reinterpret_cast(&res.wi_marray); + } else if constexpr (std::is_same, float>::value) { + auto dstptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { - if constexpr (Use == - sycl::ext::oneapi::experimental::matrix::matrix_use::a) { - __imma_m16n16k16_ld_a_u8(destptr, tileptr, stride, + __hmma_m16n16k16_ld_c_f32(dstptr, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 32) { + __hmma_m8n32k16_ld_c_f32(dstptr, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 8) { + __hmma_m32n8k16_ld_c_f32(dstptr, src.get(), stride, + get_layout_id()); + } + } else if constexpr (std::is_same::value) { + auto tileptr = reinterpret_cast(src.get()); + auto dstptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 32 && NumCols == 8) { + __hmma_m32n8k16_ld_c_f16(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 32) { + __hmma_m8n32k16_ld_c_f16(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 16) { + __hmma_m16n16k16_ld_c_f16(dstptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (std::is_same::value) { + __dmma_m8n8k4_ld_c(reinterpret_cast(&res.wi_marray), src.get(), + stride, get_layout_id()); + } + }; +#endif // __cplusplus >= 201703L + + template + void load_accumulator_cuda( + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + NumRows, NumCols, + sycl::ext::oneapi::experimental::matrix::layout::dynamic, + sycl::sub_group> &res, + multi_ptr src, size_t stride, + sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { + switch (LayoutAcc) { + case sycl::ext::oneapi::experimental::matrix::layout::row_major: + load_accumulator_layoutT< + sycl::ext::oneapi::experimental::matrix::layout::row_major>(res, src, + stride); + break; + case sycl::ext::oneapi::experimental::matrix::layout::col_major: + load_accumulator_layoutT< + sycl::ext::oneapi::experimental::matrix::layout::col_major>(res, src, + stride); + break; + default: + assert(false && "Invalid layout specified!"); + } + } + +#if __cplusplus >= 201703L // if constexpr usage + template + struct load_multiplicand_cuda< + S, T, NumRows, NumCols, Use, Layout, Space, + typename std::enable_if_t< + Layout == + sycl::ext::oneapi::experimental::matrix::layout::row_major || + Layout == + sycl::ext::oneapi::experimental::matrix::layout::col_major>> { + void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< + S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, + multi_ptr src, size_t stride) { + if constexpr (std::is_same, uint16_t>::value || + std::is_same< + std::remove_const_t, + sycl::ext::oneapi::experimental::bfloat16>::value) { + auto tileptr = reinterpret_cast(src.get()); + auto destptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::a) { + __mma_bf16_m16n16k16_ld_a(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::b) { + __mma_bf16_m16n16k16_ld_b(destptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 16) { + __mma_bf16_m8n32k16_ld_a(destptr, tileptr, stride, get_layout_id()); - } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: - matrix_use::b) { - __imma_m16n16k16_ld_b_u8(destptr, tileptr, stride, + } else if constexpr (NumRows == 16 && NumCols == 32) { + __mma_bf16_m8n32k16_ld_b(destptr, tileptr, stride, get_layout_id()); - } - } else if constexpr (NumRows == 8 && NumCols == 16) { - __imma_m8n32k16_ld_a_u8(destptr, tileptr, stride, - get_layout_id()); - } else if constexpr (NumRows == 16 && NumCols == 32) { - __imma_m8n32k16_ld_b_u8(destptr, tileptr, stride, - get_layout_id()); - } else if constexpr (NumRows == 32 && NumCols == 16) { - __imma_m32n8k16_ld_a_u8(destptr, tileptr, stride, - get_layout_id()); - } else if constexpr (NumRows == 16 && NumCols == 8) { - __imma_m32n8k16_ld_b_u8(destptr, tileptr, stride, - get_layout_id()); - } - } else if constexpr (std::is_same, int8_t>::value) { - auto tileptr = reinterpret_cast(src.get()); - auto destptr = reinterpret_cast(&res.wi_marray); - if constexpr (NumRows == 16 && NumCols == 16) { - if constexpr (Use == - sycl::ext::oneapi::experimental::matrix::matrix_use::a) { - __imma_m16n16k16_ld_a_s8(destptr, tileptr, stride, + } else if constexpr (NumRows == 32 && NumCols == 16) { + __mma_bf16_m32n8k16_ld_a(destptr, tileptr, stride, get_layout_id()); - } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: - matrix_use::b) { - __imma_m16n16k16_ld_b_s8(destptr, tileptr, stride, + } else if constexpr (NumRows == 16 && NumCols == 8) { + __mma_bf16_m32n8k16_ld_b(destptr, tileptr, stride, get_layout_id()); } - } else if constexpr (NumRows == 8 && NumCols == 16) { - __imma_m8n32k16_ld_a_s8(destptr, tileptr, stride, - get_layout_id()); - } else if constexpr (NumRows == 16 && NumCols == 32) { - __imma_m8n32k16_ld_b_s8(destptr, tileptr, stride, - get_layout_id()); - } else if constexpr (NumRows == 32 && NumCols == 16) { - __imma_m32n8k16_ld_a_s8(destptr, tileptr, stride, - get_layout_id()); - } else if constexpr (NumRows == 16 && NumCols == 8) { - __imma_m32n8k16_ld_b_s8(destptr, tileptr, stride, - get_layout_id()); - } - } else if constexpr (std::is_same, half>::value) { - auto tileptr = reinterpret_cast(src.get()); - auto dstptr = reinterpret_cast(&res.wi_marray); - if constexpr (NumRows == 16 && NumCols == 16) { - if constexpr (Use == - sycl::ext::oneapi::experimental::matrix::matrix_use::a) { - __hmma_m16n16k16_ld_a(dstptr, tileptr, stride, - get_layout_id()); - } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: - matrix_use::b) { - __hmma_m16n16k16_ld_b(dstptr, tileptr, stride, - get_layout_id()); + } else if constexpr (std::is_same, + uint8_t>::value) { + auto tileptr = reinterpret_cast(src.get()); + auto destptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::a) { + __imma_m16n16k16_ld_a_u8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::b) { + __imma_m16n16k16_ld_b_u8(destptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 16) { + __imma_m8n32k16_ld_a_u8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 32) { + __imma_m8n32k16_ld_b_u8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 16) { + __imma_m32n8k16_ld_a_u8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 8) { + __imma_m32n8k16_ld_b_u8(destptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (std::is_same, + int8_t>::value) { + auto tileptr = reinterpret_cast(src.get()); + auto destptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::a) { + __imma_m16n16k16_ld_a_s8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::b) { + __imma_m16n16k16_ld_b_s8(destptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 16) { + __imma_m8n32k16_ld_a_s8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 32) { + __imma_m8n32k16_ld_b_s8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 16) { + __imma_m32n8k16_ld_a_s8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 8) { + __imma_m32n8k16_ld_b_s8(destptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (std::is_same, half>::value) { + auto tileptr = reinterpret_cast(src.get()); + auto dstptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::a) { + __hmma_m16n16k16_ld_a(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::b) { + __hmma_m16n16k16_ld_b(dstptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 16) { + __hmma_m8n32k16_ld_a(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 32) { + __hmma_m8n32k16_ld_b(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 16) { + __hmma_m32n8k16_ld_a(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 8) { + __hmma_m32n8k16_ld_b(dstptr, tileptr, stride, + get_layout_id()); } - } else if constexpr (NumRows == 8 && NumCols == 16) { - __hmma_m8n32k16_ld_a(dstptr, tileptr, stride, get_layout_id()); - } else if constexpr (NumRows == 16 && NumCols == 32) { - __hmma_m8n32k16_ld_b(dstptr, tileptr, stride, get_layout_id()); - } else if constexpr (NumRows == 32 && NumCols == 16) { - __hmma_m32n8k16_ld_a(dstptr, tileptr, stride, get_layout_id()); - } else if constexpr (NumRows == 16 && NumCols == 8) { - __hmma_m32n8k16_ld_b(dstptr, tileptr, stride, get_layout_id()); - } - } else if constexpr (std::is_same::value) { auto tileptr = reinterpret_cast(src.get()); @@ -373,330 +261,345 @@ struct load_multiplicand< __mma_tf32_m16n16k8_ld_b(dstptr, tileptr, stride, get_layout_id()); } - } else if constexpr (std::is_same, double>::value) { - auto dstptr = reinterpret_cast(&res.wi_marray); - if constexpr (Use == - sycl::ext::oneapi::experimental::matrix::matrix_use::a) { - __dmma_m8n8k4_ld_a(dstptr, src.get(), stride, get_layout_id()); - } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: - matrix_use::b) { - __dmma_m8n8k4_ld_b(dstptr, src.get(), stride, get_layout_id()); + } else if constexpr (std::is_same, + double>::value) { + auto dstptr = reinterpret_cast(&res.wi_marray); + if constexpr (Use == + sycl::ext::oneapi::experimental::matrix::matrix_use::a) { + __dmma_m8n8k4_ld_a(dstptr, src.get(), stride, + get_layout_id()); + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::b) { + __dmma_m8n8k4_ld_b(dstptr, src.get(), stride, + get_layout_id()); + } } } - } -}; + }; #endif // __cplusplus >= 201703L #if __cplusplus >= 201703L // if constexpr usage -template -struct joint_matrix_store_impl { - template - void storeLayoutT( - sycl::ext::oneapi::experimental::matrix::joint_matrix< - T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, NumRows, NumCols, - sycl::ext::oneapi::experimental::matrix::layout::dynamic, - sycl::sub_group> &src, - multi_ptr dst, size_t stride) { - if constexpr (NumRows == 16 && NumCols == 16) { - if constexpr (std::is_same::value) { - __hmma_m16n16k16_st_c_f32(dst.get(), - reinterpret_cast(&src.wi_marray), - stride, get_layout_id()); - } else if constexpr (std::is_same::value) { - __imma_m16n16k16_st_c_i32(dst.get(), - reinterpret_cast(&src.wi_marray), - stride, get_layout_id()); - } else if constexpr (std::is_same::value) { - __hmma_m16n16k16_st_c_f16(reinterpret_cast(dst.get()), - reinterpret_cast(&src.wi_marray), - stride, get_layout_id()); - } - } else if constexpr (NumRows == 8 && NumCols == 32) { - if constexpr (std::is_same::value) { - __hmma_m8n32k16_st_c_f32(dst.get(), - reinterpret_cast(&src.wi_marray), - stride, get_layout_id()); - } else if constexpr (std::is_same::value) { - __imma_m8n32k16_st_c_i32(dst.get(), - reinterpret_cast(&src.wi_marray), - stride, get_layout_id()); - } else if constexpr (std::is_same::value) { - __hmma_m8n32k16_st_c_f16(reinterpret_cast(dst.get()), - reinterpret_cast(&src.wi_marray), - stride, get_layout_id()); - } - } else if constexpr (NumRows == 32 && NumCols == 8) { - if constexpr (std::is_same::value) { - __hmma_m32n8k16_st_c_f32(dst.get(), - reinterpret_cast(&src.wi_marray), - stride, get_layout_id()); - } else if constexpr (std::is_same::value) { - __imma_m32n8k16_st_c_i32(dst.get(), - reinterpret_cast(&src.wi_marray), - stride, get_layout_id()); - } else if constexpr (std::is_same::value) { - __hmma_m32n8k16_st_c_f16(reinterpret_cast(dst.get()), - reinterpret_cast(&src.wi_marray), - stride, get_layout_id()); + template + struct joint_matrix_store_cuda_impl { + template + void storeLayoutT( + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + NumRows, NumCols, + sycl::ext::oneapi::experimental::matrix::layout::dynamic, + sycl::sub_group> &src, + multi_ptr dst, size_t stride) { + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (std::is_same::value) { + __hmma_m16n16k16_st_c_f32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same::value) { + __imma_m16n16k16_st_c_i32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same::value) { + __hmma_m16n16k16_st_c_f16(reinterpret_cast(dst.get()), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 32) { + if constexpr (std::is_same::value) { + __hmma_m8n32k16_st_c_f32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same::value) { + __imma_m8n32k16_st_c_i32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same::value) { + __hmma_m8n32k16_st_c_f16(reinterpret_cast(dst.get()), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } + } else if constexpr (NumRows == 32 && NumCols == 8) { + if constexpr (std::is_same::value) { + __hmma_m32n8k16_st_c_f32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same::value) { + __imma_m32n8k16_st_c_i32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same::value) { + __hmma_m32n8k16_st_c_f16(reinterpret_cast(dst.get()), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } + } else if constexpr (std::is_same::value) { + __dmma_m8n8k4_st_c_f64(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); } - } else if constexpr (std::is_same::value) { - __dmma_m8n8k4_st_c_f64(dst.get(), - reinterpret_cast(&src.wi_marray), stride, - get_layout_id()); } - } - void - store(sycl::ext::oneapi::experimental::matrix::joint_matrix< - T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, NumRows, NumCols, + void store( + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + NumRows, NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic, sycl::sub_group> &src, multi_ptr dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { - switch (LayoutAcc) { - case sycl::ext::oneapi::experimental::matrix::layout::row_major: - storeLayoutT( - src, dst, stride); - break; - case sycl::ext::oneapi::experimental::matrix::layout::col_major: - storeLayoutT( - src, dst, stride); - break; - default: - assert(false && "Invalid layout specified!"); + switch (LayoutAcc) { + case sycl::ext::oneapi::experimental::matrix::layout::row_major: + storeLayoutT< + sycl::ext::oneapi::experimental::matrix::layout::row_major>( + src, dst, stride); + break; + case sycl::ext::oneapi::experimental::matrix::layout::col_major: + storeLayoutT< + sycl::ext::oneapi::experimental::matrix::layout::col_major>( + src, dst, stride); + break; + default: + assert(false && "Invalid layout specified!"); + } } - } -}; + }; #endif // __cplusplus >= 201703L -template -struct joint_matrix_mad_impl { - void mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< - Td, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, - sycl::ext::oneapi::experimental::matrix::layout::dynamic, - sycl::sub_group> &D, - sycl::ext::oneapi::experimental::matrix::joint_matrix< - Tm, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K, - LayoutA, sycl::sub_group> &A, - sycl::ext::oneapi::experimental::matrix::joint_matrix< - Tm, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N, - LayoutB, sycl::sub_group> &B, - sycl::ext::oneapi::experimental::matrix::joint_matrix< - Tc, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, - sycl::ext::oneapi::experimental::matrix::layout::dynamic, - sycl::sub_group> &C); -}; + template + struct joint_matrix_mad_cuda_impl { + void + mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< + Td, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, + N, sycl::ext::oneapi::experimental::matrix::layout::dynamic, + sycl::sub_group> &D, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + Tm, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K, + LayoutA, sycl::sub_group> &A, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + Tm, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N, + LayoutB, sycl::sub_group> &B, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + Tc, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, + N, sycl::ext::oneapi::experimental::matrix::layout::dynamic, + sycl::sub_group> &C); + }; -template -constexpr int get_layout_pair_id(); + template + constexpr int get_layout_pair_id(); -template <> -constexpr int get_layout_pair_id< - sycl::ext::oneapi::experimental::matrix::layout::row_major, - sycl::ext::oneapi::experimental::matrix::layout::row_major>() { - return 0; -} + template <> + constexpr int get_layout_pair_id< + sycl::ext::oneapi::experimental::matrix::layout::row_major, + sycl::ext::oneapi::experimental::matrix::layout::row_major>() { + return 0; + } -template <> -constexpr int get_layout_pair_id< - sycl::ext::oneapi::experimental::matrix::layout::row_major, - sycl::ext::oneapi::experimental::matrix::layout::col_major>() { - return 1; -} + template <> + constexpr int get_layout_pair_id< + sycl::ext::oneapi::experimental::matrix::layout::row_major, + sycl::ext::oneapi::experimental::matrix::layout::col_major>() { + return 1; + } -template <> -constexpr int get_layout_pair_id< - sycl::ext::oneapi::experimental::matrix::layout::col_major, - sycl::ext::oneapi::experimental::matrix::layout::row_major>() { - return 2; -} + template <> + constexpr int get_layout_pair_id< + sycl::ext::oneapi::experimental::matrix::layout::col_major, + sycl::ext::oneapi::experimental::matrix::layout::row_major>() { + return 2; + } -template <> -constexpr int get_layout_pair_id< - sycl::ext::oneapi::experimental::matrix::layout::col_major, - sycl::ext::oneapi::experimental::matrix::layout::col_major>() { - return 3; -} + template <> + constexpr int get_layout_pair_id< + sycl::ext::oneapi::experimental::matrix::layout::col_major, + sycl::ext::oneapi::experimental::matrix::layout::col_major>() { + return 3; + } #if __cplusplus >= 201703L // if constexpr usage -template -struct joint_matrix_mad_impl< - Tm, Tc, Td, M, K, N, LayoutA, LayoutB, - typename std::enable_if_t< - (LayoutA == - sycl::ext::oneapi::experimental::matrix::layout::row_major || - LayoutA == - sycl::ext::oneapi::experimental::matrix::layout::col_major) && - (LayoutB == - sycl::ext::oneapi::experimental::matrix::layout::row_major || - LayoutB == - sycl::ext::oneapi::experimental::matrix::layout::col_major)>> { - void mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< - Td, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, - sycl::ext::oneapi::experimental::matrix::layout::dynamic, - sycl::sub_group> &D, - sycl::ext::oneapi::experimental::matrix::joint_matrix< - Tm, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K, - LayoutA, sycl::sub_group> &A, - sycl::ext::oneapi::experimental::matrix::joint_matrix< - Tm, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N, - LayoutB, sycl::sub_group> &B, - sycl::ext::oneapi::experimental::matrix::joint_matrix< - Tc, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, - sycl::ext::oneapi::experimental::matrix::layout::dynamic, - sycl::sub_group> &C) { - if constexpr (M == 16 && N == 16 && K == 16) { - if constexpr (std::is_same::value) { - auto ptrA = reinterpret_cast(&A.wi_marray); - auto ptrB = reinterpret_cast(&B.wi_marray); - auto ptrC = reinterpret_cast(&C.wi_marray); - auto ptrD = reinterpret_cast(&D.wi_marray); - if constexpr (std::is_same::value) { - __imma_m16n16k16_mma_s8(ptrD, ptrA, ptrB, ptrC, - get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { - __imma_m16n16k16_mma_u8(ptrD, ptrA, ptrB, ptrC, - get_layout_pair_id(), 0); - } - } else if constexpr (std::is_same::value) { - auto ptrA = reinterpret_cast(&A.wi_marray); - auto ptrB = reinterpret_cast(&B.wi_marray); - if constexpr (std::is_same::value) { - if constexpr (std::is_same::value) { - __hmma_m16n16k16_mma_f32f32( - reinterpret_cast(&D.wi_marray), ptrA, ptrB, + template + struct joint_matrix_mad_cuda_impl< + Tm, Tc, Td, M, K, N, LayoutA, LayoutB, + typename std::enable_if_t< + (LayoutA == + sycl::ext::oneapi::experimental::matrix::layout::row_major || + LayoutA == + sycl::ext::oneapi::experimental::matrix::layout::col_major) && + (LayoutB == + sycl::ext::oneapi::experimental::matrix::layout::row_major || + LayoutB == + sycl::ext::oneapi::experimental::matrix::layout::col_major)>> { + void + mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< + Td, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, + N, sycl::ext::oneapi::experimental::matrix::layout::dynamic, + sycl::sub_group> &D, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + Tm, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K, + LayoutA, sycl::sub_group> &A, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + Tm, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N, + LayoutB, sycl::sub_group> &B, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + Tc, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, + N, sycl::ext::oneapi::experimental::matrix::layout::dynamic, + sycl::sub_group> &C) { + if constexpr (M == 16 && N == 16 && K == 16) { + if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrC = reinterpret_cast(&C.wi_marray); + auto ptrD = reinterpret_cast(&D.wi_marray); + if constexpr (std::is_same::value) { + __imma_m16n16k16_mma_s8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __imma_m16n16k16_mma_u8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } + } else if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { + __hmma_m16n16k16_mma_f32f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else { + __hmma_m16n16k16_mma_f16f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } + } else if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { + __hmma_m16n16k16_mma_f32f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else { + __hmma_m16n16k16_mma_f16f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } + } + } else if constexpr (std::is_same::value || + std::is_same::value) { + __mma_bf16_m16n16k16_mma_f32( + reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); - } else { - __hmma_m16n16k16_mma_f16f32( - reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), - get_layout_pair_id(), 0); + } + } else if constexpr (M == 8 && N == 32 && K == 16) { + if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrC = reinterpret_cast(&C.wi_marray); + auto ptrD = reinterpret_cast(&D.wi_marray); + if constexpr (std::is_same::value) { + __imma_m8n32k16_mma_s8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __imma_m8n32k16_mma_u8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); } - } else if constexpr (std::is_same::value) { - if constexpr (std::is_same::value) { - __hmma_m16n16k16_mma_f32f16( - reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), - get_layout_pair_id(), 0); - } else { - __hmma_m16n16k16_mma_f16f16( - reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), - get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + if constexpr (std::is_same::value) { + __hmma_m8n32k16_mma_f32f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __hmma_m8n32k16_mma_f16f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); } - } - } else if constexpr (std::is_same::value || - std::is_same::value) { - __mma_bf16_m16n16k16_mma_f32( - reinterpret_cast(&D.wi_marray), - reinterpret_cast(&A.wi_marray), - reinterpret_cast(&B.wi_marray), - reinterpret_cast(&C.wi_marray), - get_layout_pair_id(), 0); - } - } else if constexpr (M == 8 && N == 32 && K == 16) { - if constexpr (std::is_same::value) { - auto ptrA = reinterpret_cast(&A.wi_marray); - auto ptrB = reinterpret_cast(&B.wi_marray); - auto ptrC = reinterpret_cast(&C.wi_marray); - auto ptrD = reinterpret_cast(&D.wi_marray); - if constexpr (std::is_same::value) { - __imma_m8n32k16_mma_s8(ptrD, ptrA, ptrB, ptrC, - get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { - __imma_m8n32k16_mma_u8(ptrD, ptrA, ptrB, ptrC, - get_layout_pair_id(), 0); - } - } else if constexpr (std::is_same::value) { - auto ptrA = reinterpret_cast(&A.wi_marray); - auto ptrB = reinterpret_cast(&B.wi_marray); - if constexpr (std::is_same::value) { - __hmma_m8n32k16_mma_f32f32( - reinterpret_cast(&D.wi_marray), ptrA, ptrB, + } else if constexpr (std::is_same::value || + std::is_same::value) { + __mma_bf16_m8n32k16_mma_f32( + reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { - __hmma_m8n32k16_mma_f16f16( - reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), - get_layout_pair_id(), 0); } - } else if constexpr (std::is_same::value || - std::is_same::value) { - __mma_bf16_m8n32k16_mma_f32( - reinterpret_cast(&D.wi_marray), - reinterpret_cast(&A.wi_marray), - reinterpret_cast(&B.wi_marray), - reinterpret_cast(&C.wi_marray), - get_layout_pair_id(), 0); - } - } else if constexpr (M == 32 && N == 8 && K == 16) { - if constexpr (std::is_same::value) { - auto ptrA = reinterpret_cast(&A.wi_marray); - auto ptrB = reinterpret_cast(&B.wi_marray); - auto ptrC = reinterpret_cast(&C.wi_marray); - auto ptrD = reinterpret_cast(&D.wi_marray); - if constexpr (std::is_same::value) { - __imma_m32n8k16_mma_s8(ptrD, ptrA, ptrB, ptrC, - get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { - __imma_m32n8k16_mma_u8(ptrD, ptrA, ptrB, ptrC, - get_layout_pair_id(), 0); - } - } else if constexpr (std::is_same::value || - std::is_same::value) { - __mma_bf16_m32n8k16_mma_f32( - reinterpret_cast(&D.wi_marray), - reinterpret_cast(&A.wi_marray), - reinterpret_cast(&B.wi_marray), - reinterpret_cast(&C.wi_marray), - get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { - auto ptrA = reinterpret_cast(&A.wi_marray); - auto ptrB = reinterpret_cast(&B.wi_marray); - if constexpr (std::is_same::value) { - __hmma_m32n8k16_mma_f32f32( - reinterpret_cast(&D.wi_marray), ptrA, ptrB, + } else if constexpr (M == 32 && N == 8 && K == 16) { + if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrC = reinterpret_cast(&C.wi_marray); + auto ptrD = reinterpret_cast(&D.wi_marray); + if constexpr (std::is_same::value) { + __imma_m32n8k16_mma_s8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __imma_m32n8k16_mma_u8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } + } else if constexpr (std::is_same::value || + std::is_same::value) { + __mma_bf16_m32n8k16_mma_f32( + reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { - __hmma_m32n8k16_mma_f16f16( - reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), - get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + if constexpr (std::is_same::value) { + __hmma_m32n8k16_mma_f32f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __hmma_m32n8k16_mma_f16f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } } + } else if constexpr (M == 16 && N == 16 && K == 8) { + __mma_tf32_m16n16k8_mma_f32(reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __dmma_m8n8k4_mma_f64(reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); } - } else if constexpr (M == 16 && N == 16 && K == 8) { - __mma_tf32_m16n16k8_mma_f32(reinterpret_cast(&D.wi_marray), - reinterpret_cast(&A.wi_marray), - reinterpret_cast(&B.wi_marray), - reinterpret_cast(&C.wi_marray), - get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { - __dmma_m8n8k4_mma_f64(reinterpret_cast(&D.wi_marray), - reinterpret_cast(&A.wi_marray), - reinterpret_cast(&B.wi_marray), - reinterpret_cast(&C.wi_marray), - get_layout_pair_id(), 0); } - } -}; + }; #endif // __cplusplus >= 201703L #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) -} // namespace detail -} // namespace oneapi -} // namespace ext + } // namespace detail + } // namespace oneapi + } // namespace ext } // __SYCL_INLINE_VER_NAMESPACE(_V1) } // namespace sycl diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores-legacy.hpp similarity index 100% rename from sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp rename to sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores-legacy.hpp diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 0db8bfb4d6b6f..1310cbf92ca2d 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -1,4 +1,4 @@ -//===---- matrix-unified.hpp - SYCL matrix extension ----*- C++ -*---===// +//===------- matrix-unified.hpp - SYCL matrix extension ----*- C++ -*------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,169 +7,191 @@ // ===--------------------------------------------------------------------=== // #pragma once -#if defined(CUDA_MATRIX) +#include #include -#else -//#include -#endif // defined(CUDA_MATRIX) +// #include namespace sycl { __SYCL_INLINE_VER_NAMESPACE(_V1) { -namespace ext { -namespace oneapi { -namespace experimental { -namespace matrix { + namespace ext { + namespace oneapi { + namespace experimental { + namespace matrix { -template -inline __SYCL_ALWAYS_INLINE void -joint_matrix_fill(Group sg, - joint_matrix &res, - const T2 v) { - // We kept the dynamic "sg" in joint_matrix_fill to match the other DPC++ - // functions - std::ignore = sg; -#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) - res.wi_marray = v; + template + inline __SYCL_ALWAYS_INLINE void + joint_matrix_fill(Group sg, + joint_matrix &res, + const T2 v) { + // We kept the dynamic "sg" in joint_matrix_fill to match the other DPC++ + // functions + std::ignore = sg; +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) + res.wi_marray = v; +#elif defined(__SPIR__) +// intel joint_matrix_fill_intel_impl +#endif // defined(__NVPTX__) #else - std::ignore = res; - std::ignore = v; -#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) -} + std::ignore = res; + std::ignore = v; + throw runtime_error("The matrix extension is only currently supported on " + "Intel and Nvidia devices", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) + } -// TODO Two typenames (S and T) not required in CUDA backend but included for -// Intel backend requirement. TODO: check if this is still the case!! -template ::value, bool> = true> -void joint_matrix_load( - Group sg, - joint_matrix - &res, - multi_ptr src, size_t stride, - sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { -#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) - sycl::ext::oneapi::detail::load_accumulator(res, src, stride, LayoutAcc); + // TODO Two typenames (S and T) not required in CUDA backend but included for + // Intel backend requirement. TODO: check if this is still the case!! + template ::value, bool> = true> + void joint_matrix_load( + Group sg, + joint_matrix &res, + multi_ptr src, size_t stride, + sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) + sycl::ext::oneapi::detail::load_accumulator_cuda(res, src, stride, + LayoutAcc); +#elif defined(__SPIR__) +// load_accumulator_intel +#endif // defined(__NVPTX__) #else - std::ignore = sg; - std::ignore = res; - std::ignore = src; - std::ignore = stride; - throw runtime_error( - "When using SYCL_EXT_ONEAPI_MATRIX=3 joint_matrix_load is " - "only supported by CUDA devices", - PI_ERROR_INVALID_DEVICE); -#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) -} - + std::ignore = sg; + std::ignore = res; + std::ignore = src; + std::ignore = stride; + throw runtime_error("The matrix extension is only currently supported on " + "Intel and Nvidia devices", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) + } -template < - typename Group, typename S, typename T, matrix_use Use, size_t NumRows, - size_t NumCols, matrix::layout Layout, access::address_space Space, - std::enable_if_t>::value || - (std::is_same::value && + template < + typename Group, typename S, typename T, matrix_use Use, size_t NumRows, + size_t NumCols, matrix::layout Layout, access::address_space Space, + std::enable_if_t>::value || + (std::is_same::value && - std::is_same, float>::value), - bool> = true> -void joint_matrix_load( - Group sg, joint_matrix &res, - multi_ptr src, size_t stride) { -#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) - sycl::ext::oneapi::detail::load_multiplicand{} - .load(res, src, stride); + std::is_same, float>::value), + bool> = true> + void + joint_matrix_load(Group sg, + joint_matrix &res, + multi_ptr src, size_t stride) { +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) + sycl::ext::oneapi::detail::load_multiplicand_cuda{} + .load(res, src, stride); +#elif defined(__SPIR__) +// load_multiplicand_intel +#endif // defined(__NVPTX__) #else - std::ignore = sg; - std::ignore = res; - std::ignore = src; - std::ignore = stride; - throw runtime_error( - "When using SYCL_EXT_ONEAPI_MATRIX=3 joint_matrix_load is " - "only supported by CUDA devices", - PI_ERROR_INVALID_DEVICE); -#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) -} + std::ignore = sg; + std::ignore = res; + std::ignore = src; + std::ignore = stride; + throw runtime_error("The matrix extension is only currently supported on " + "Intel and Nvidia devices", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) + } -template -void joint_matrix_store( - Group sg, - joint_matrix - &src, - multi_ptr dst, size_t stride, - sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { -#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) - sycl::ext::oneapi::detail::joint_matrix_store_impl{} - .store(src, dst, stride, LayoutAcc); + template + void joint_matrix_store( + Group sg, + joint_matrix &src, + multi_ptr dst, size_t stride, + sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) + sycl::ext::oneapi::detail::joint_matrix_store_cuda_impl{} + .store(src, dst, stride, LayoutAcc); +#elif defined(__SPIR__) +// joint_matrix_store_intel_impl +#endif // defined(__NVPTX__) #else - std::ignore = sg; - std::ignore = src; - std::ignore = dst; - std::ignore = stride; - throw runtime_error( - "When using SYCL_EXT_ONEAPI_MATRIX=3 joint_matrix_store is " - "only supported by CUDA devices", - PI_ERROR_INVALID_DEVICE); -#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) -} + std::ignore = sg; + std::ignore = src; + std::ignore = dst; + std::ignore = stride; + throw runtime_error("The matrix extension is only currently supported on " + "Intel and Nvidia devices", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) + } -template -joint_matrix - joint_matrix_mad( - Group sg, - joint_matrix &A, - joint_matrix &B, - joint_matrix - &C) { -#if defined(__SYCL_DEVICE_ONLY__) -#if defined(__NVPTX__) + template < + typename Group, typename Ta, typename Tb, typename Tc, std::size_t M, + std::size_t K, std::size_t N, + layout LayoutA = sycl::ext::oneapi::experimental::matrix::layout::dynamic, + layout LayoutB = sycl::ext::oneapi::experimental::matrix::layout::dynamic> joint_matrix - D; - sycl::ext::oneapi::detail::joint_matrix_mad_impl{} - .mad(D, A, B, C); -#elif defined(__AMDGCN__) -//rocM wmma joint_matrix_mad_impl + sycl::ext::oneapi::experimental::matrix::layout::dynamic, Group> + joint_matrix_mad( + Group sg, joint_matrix &A, + joint_matrix &B, + joint_matrix &C) { +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) + if constexpr (std::is_same::value) { + joint_matrix + D; + sycl::ext::oneapi::detail::joint_matrix_mad_cuda_impl{} + .mad(D, A, B, C); + return D; + } else { + assert(false && "Ta != Tb : In the CUDA backend joint_matrix_mad " + "requires that joint_matrix data types match"); + } #elif defined(__SPIR__) -//intel joint_matrix_mad_impl +// joint_matrix_mad_intel_impl #endif // defined(__NVPTX__) #else - std::ignore = sg; - std::ignore = A; - std::ignore = B; - std::ignore = C; - throw runtime_error("joint_matrix_mad is " - "not supported on HOST", - PI_ERROR_INVALID_DEVICE); + std::ignore = sg; + std::ignore = A; + std::ignore = B; + std::ignore = C; + throw runtime_error("The matrix extension is only currently supported on " + "Intel and Nvidia devices", + PI_ERROR_INVALID_DEVICE); #endif // defined(__SYCL_DEVICE_ONLY__) -} + } -// This function rounds the bottom 13 bits up or down, and then zeros out the -// bottom bits -float round_to_tf32(float a) { + // This function rounds the bottom 13 bits up or down, and then zeros out the + // bottom bits + float round_to_tf32(float a) { #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) - int32_t tmp_int = __nvvm_f2tf32_rna(a); - return __nvvm_bitcast_i2f(tmp_int); + int32_t tmp_int = __nvvm_f2tf32_rna(a); + return __nvvm_bitcast_i2f(tmp_int); #else - uint32_t tmp_uint = reinterpret_cast(a); - tmp_uint += 0x1000u; - tmp_uint &= 0xFFFFE000u; - float ret = reinterpret_cast(tmp_uint); - return ret; + uint32_t tmp_uint = reinterpret_cast(a); + tmp_uint += 0x1000u; + tmp_uint &= 0xFFFFE000u; + float ret = reinterpret_cast(tmp_uint); + return ret; #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) -} + } -} // namespace matrix -} // namespace experimental -} // namespace oneapi -} // namespace ext + } // namespace matrix + } // namespace experimental + } // namespace oneapi + } // namespace ext } // __SYCL_INLINE_VER_NAMESPACE(_V1) } // namespace sycl diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp index c9e5d46764260..53f2a7db132aa 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp @@ -21,15 +21,12 @@ #if (SYCL_EXT_ONEAPI_MATRIX == 1) #include #include -#endif -#if (SYCL_EXT_ONEAPI_MATRIX == 2) +#elif (SYCL_EXT_ONEAPI_MATRIX == 2) #include #include -#endif -#if (SYCL_EXT_ONEAPI_MATRIX == 3) -#include -#endif -#if (SYCL_EXT_ONEAPI_MATRIX == 4) +#elif (SYCL_EXT_ONEAPI_MATRIX == 3) +#include +#elif (SYCL_EXT_ONEAPI_MATRIX == 4) #include #endif