From 8db7dd5eded5f266361bc11058cf7da4dd943e6b Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Tue, 5 Jul 2022 18:02:35 +0100 Subject: [PATCH 1/7] Add exp builtins header to sycl.hpp and c++17 guards. Signed-off-by: JackAKirk --- sycl/include/CL/sycl.hpp | 1 + .../sycl/ext/oneapi/experimental/builtins.hpp | 29 +++++++++++++++---- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/sycl/include/CL/sycl.hpp b/sycl/include/CL/sycl.hpp index 10dcf00b94d22..9b517ced4b659 100644 --- a/sycl/include/CL/sycl.hpp +++ b/sycl/include/CL/sycl.hpp @@ -61,6 +61,7 @@ #include #endif #include +#include #include #include #include diff --git a/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp b/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp index 9e93e688f181a..6fd371e1d7e4c 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp @@ -143,8 +143,12 @@ sycl::marray fabs(sycl::marray x) { auto partial_res = __clc_fabs(detail::to_uint32_t(x, i * 2)); std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); } - - if constexpr (N % 2) { +#if __cplusplus >= 201703L + if constexpr (N % 2) +#else + if (N % 2) +#endif // __cplusplus >= 201703L + { res[N - 1] = bfloat16::from_bits(__clc_fabs(x[N - 1].raw())); } return res; @@ -179,7 +183,12 @@ sycl::marray fmin(sycl::marray x, std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); } - if constexpr (N % 2) { +#if __cplusplus >= 201703L + if constexpr (N % 2) +#else + if (N % 2) +#endif // __cplusplus >= 201703L + { res[N - 1] = bfloat16::from_bits(__clc_fmin(x[N - 1].raw(), y[N - 1].raw())); } @@ -217,7 +226,12 @@ sycl::marray fmax(sycl::marray x, std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); } - if constexpr (N % 2) { +#if __cplusplus >= 201703L + if constexpr (N % 2) +#else + if (N % 2) +#endif // __cplusplus >= 201703L + { res[N - 1] = bfloat16::from_bits(__clc_fmax(x[N - 1].raw(), y[N - 1].raw())); } @@ -257,7 +271,12 @@ sycl::marray fma(sycl::marray x, std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); } - if constexpr (N % 2) { +#if __cplusplus >= 201703L + if constexpr (N % 2) +#else + if (N % 2) +#endif // __cplusplus >= 201703L + { res[N - 1] = bfloat16::from_bits( __clc_fma(x[N - 1].raw(), y[N - 1].raw(), z[N - 1].raw())); } From 4a7b44e6967d39580c52ba62b402fbfe754cb2c8 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Tue, 5 Jul 2022 18:05:14 +0100 Subject: [PATCH 2/7] remove exp builtins header include from tests. Signed-off-by: JackAKirk --- sycl/test/basic_tests/built-ins.cpp | 1 - sycl/test/extensions/experimental-printf.cpp | 1 - 2 files changed, 2 deletions(-) diff --git a/sycl/test/basic_tests/built-ins.cpp b/sycl/test/basic_tests/built-ins.cpp index 720313c2c5ade..b3d3d6437963d 100644 --- a/sycl/test/basic_tests/built-ins.cpp +++ b/sycl/test/basic_tests/built-ins.cpp @@ -7,7 +7,6 @@ // Hits an assertion with AMD: // XFAIL: hip_amd -#include #include #include diff --git a/sycl/test/extensions/experimental-printf.cpp b/sycl/test/extensions/experimental-printf.cpp index d64e78ee4883e..4c9269e54495a 100644 --- a/sycl/test/extensions/experimental-printf.cpp +++ b/sycl/test/extensions/experimental-printf.cpp @@ -16,7 +16,6 @@ // CHECK: Constant [[#TYPE]] [[#CONST:]] // CHECK: ExtInst [[#]] [[#]] [[#]] printf [[#]] [[#CONST]] -#include #include #ifdef __SYCL_DEVICE_ONLY__ From dd5bc26d18f157b945578221dbce6822420809f4 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Mon, 11 Jul 2022 10:57:27 +0100 Subject: [PATCH 3/7] Removed `if constexpr` usage. Signed-off-by: JackAKirk --- .../sycl/ext/oneapi/experimental/builtins.hpp | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp b/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp index 6fd371e1d7e4c..3c00cd4eb5698 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp @@ -143,11 +143,8 @@ sycl::marray fabs(sycl::marray x) { auto partial_res = __clc_fabs(detail::to_uint32_t(x, i * 2)); std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); } -#if __cplusplus >= 201703L - if constexpr (N % 2) -#else + if (N % 2) -#endif // __cplusplus >= 201703L { res[N - 1] = bfloat16::from_bits(__clc_fabs(x[N - 1].raw())); } @@ -183,11 +180,8 @@ sycl::marray fmin(sycl::marray x, std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); } -#if __cplusplus >= 201703L - if constexpr (N % 2) -#else + if (N % 2) -#endif // __cplusplus >= 201703L { res[N - 1] = bfloat16::from_bits(__clc_fmin(x[N - 1].raw(), y[N - 1].raw())); @@ -226,11 +220,7 @@ sycl::marray fmax(sycl::marray x, std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); } -#if __cplusplus >= 201703L - if constexpr (N % 2) -#else if (N % 2) -#endif // __cplusplus >= 201703L { res[N - 1] = bfloat16::from_bits(__clc_fmax(x[N - 1].raw(), y[N - 1].raw())); @@ -271,11 +261,7 @@ sycl::marray fma(sycl::marray x, std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); } -#if __cplusplus >= 201703L - if constexpr (N % 2) -#else if (N % 2) -#endif // __cplusplus >= 201703L { res[N - 1] = bfloat16::from_bits( __clc_fma(x[N - 1].raw(), y[N - 1].raw(), z[N - 1].raw())); From 1a9f631a907c6d10585993eb83205e2e36b74dcb Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Mon, 11 Jul 2022 11:01:49 +0100 Subject: [PATCH 4/7] Format. Signed-off-by: JackAKirk --- .../sycl/ext/oneapi/experimental/builtins.hpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp b/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp index 3c00cd4eb5698..5752bedac17e9 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp @@ -144,8 +144,7 @@ sycl::marray fabs(sycl::marray x) { std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); } - if (N % 2) - { + if (N % 2) { res[N - 1] = bfloat16::from_bits(__clc_fabs(x[N - 1].raw())); } return res; @@ -181,8 +180,7 @@ sycl::marray fmin(sycl::marray x, } - if (N % 2) - { + if (N % 2) { res[N - 1] = bfloat16::from_bits(__clc_fmin(x[N - 1].raw(), y[N - 1].raw())); } @@ -220,8 +218,7 @@ sycl::marray fmax(sycl::marray x, std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); } - if (N % 2) - { + if (N % 2) { res[N - 1] = bfloat16::from_bits(__clc_fmax(x[N - 1].raw(), y[N - 1].raw())); } @@ -261,8 +258,7 @@ sycl::marray fma(sycl::marray x, std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); } - if (N % 2) - { + if (N % 2) { res[N - 1] = bfloat16::from_bits( __clc_fma(x[N - 1].raw(), y[N - 1].raw(), z[N - 1].raw())); } From 2c6814f193d5360b6b17b66cba192c8fbdae0658 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Mon, 11 Jul 2022 11:09:06 +0100 Subject: [PATCH 5/7] Format. Signed-off-by: JackAKirk --- sycl/include/sycl/ext/oneapi/experimental/builtins.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp b/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp index 5752bedac17e9..5fdd8dc6e015e 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp @@ -179,7 +179,6 @@ sycl::marray fmin(sycl::marray x, std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); } - if (N % 2) { res[N - 1] = bfloat16::from_bits(__clc_fmin(x[N - 1].raw(), y[N - 1].raw())); From b417939cd24250cb40a9f4c636aa75be78756ed8 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Mon, 11 Jul 2022 11:57:42 +0100 Subject: [PATCH 6/7] Remove c++17 usage from exper headers. Signed-off-by: JackAKirk --- .../sycl/ext/oneapi/experimental/builtins.hpp | 10 +- .../ext/oneapi/matrix/matrix-tensorcore.hpp | 194 +++++++++--------- 2 files changed, 108 insertions(+), 96 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp b/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp index 5fdd8dc6e015e..57c1b0c58fc2c 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp @@ -27,7 +27,10 @@ #endif __SYCL_INLINE_NAMESPACE(cl) { -namespace sycl::ext::oneapi::experimental { +namespace sycl { +namespace ext { +namespace oneapi { +namespace experimental { namespace detail { template uint32_t to_uint32_t(sycl::marray x, size_t start) { @@ -271,7 +274,10 @@ sycl::marray fma(sycl::marray x, #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) } -} // namespace sycl::ext::oneapi::experimental +} // namespace experimental +} // namespace oneapi +} // namespace ext +} // namespace sycl } // __SYCL_INLINE_NAMESPACE(cl) #undef __SYCL_CONSTANT_AS diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp index 6de66d0f8590c..c8417204ec9b5 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp @@ -10,8 +10,11 @@ #include __SYCL_INLINE_NAMESPACE(cl) { -namespace sycl::ext::oneapi { -namespace experimental::matrix { +namespace sycl { +namespace ext { +namespace oneapi { +namespace experimental { +namespace matrix { enum class matrix_use { a, b, accumulator }; @@ -213,166 +216,166 @@ struct joint_matrix_load_impl< void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, multi_ptr src, size_t stride) { - if constexpr (std::is_same::value || + if (std::is_same::value || std::is_same< T, sycl::ext::oneapi::experimental::bfloat16>::value) { auto tileptr = reinterpret_cast(src.get()); auto destptr = reinterpret_cast(&res.wi_marray); - if constexpr (NumRows == 16 && NumCols == 16) { - if constexpr (Use == + if (NumRows == 16 && NumCols == 16) { + if (Use == sycl::ext::oneapi::experimental::matrix::matrix_use::a) { __mma_bf16_m16n16k16_ld_a(destptr, tileptr, stride, get_layout_id()); - } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + } else if (Use == sycl::ext::oneapi::experimental::matrix:: matrix_use::b) { __mma_bf16_m16n16k16_ld_b(destptr, tileptr, stride, get_layout_id()); } - } else if constexpr (NumRows == 8 && NumCols == 16) { + } else if (NumRows == 8 && NumCols == 16) { __mma_bf16_m8n32k16_ld_a(destptr, tileptr, stride, get_layout_id()); - } else if constexpr (NumRows == 16 && NumCols == 32) { + } else if (NumRows == 16 && NumCols == 32) { __mma_bf16_m8n32k16_ld_b(destptr, tileptr, stride, get_layout_id()); - } else if constexpr (NumRows == 32 && NumCols == 16) { + } else if (NumRows == 32 && NumCols == 16) { __mma_bf16_m32n8k16_ld_a(destptr, tileptr, stride, get_layout_id()); - } else if constexpr (NumRows == 16 && NumCols == 8) { + } else if (NumRows == 16 && NumCols == 8) { __mma_bf16_m32n8k16_ld_b(destptr, tileptr, stride, get_layout_id()); } - } else if constexpr (std::is_same::value) { + } else if (std::is_same::value) { auto tileptr = reinterpret_cast(src.get()); auto destptr = reinterpret_cast(&res.wi_marray); - if constexpr (NumRows == 16 && NumCols == 16) { - if constexpr (Use == + if (NumRows == 16 && NumCols == 16) { + if (Use == sycl::ext::oneapi::experimental::matrix::matrix_use::a) { __imma_m16n16k16_ld_a_u8(destptr, tileptr, stride, get_layout_id()); - } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + } else if (Use == sycl::ext::oneapi::experimental::matrix:: matrix_use::b) { __imma_m16n16k16_ld_b_u8(destptr, tileptr, stride, get_layout_id()); } - } else if constexpr (NumRows == 8 && NumCols == 16) { + } else if (NumRows == 8 && NumCols == 16) { __imma_m8n32k16_ld_a_u8(destptr, tileptr, stride, get_layout_id()); - } else if constexpr (NumRows == 16 && NumCols == 32) { + } else if (NumRows == 16 && NumCols == 32) { __imma_m8n32k16_ld_b_u8(destptr, tileptr, stride, get_layout_id()); - } else if constexpr (NumRows == 32 && NumCols == 16) { + } else if (NumRows == 32 && NumCols == 16) { __imma_m32n8k16_ld_a_u8(destptr, tileptr, stride, get_layout_id()); - } else if constexpr (NumRows == 16 && NumCols == 8) { + } else if (NumRows == 16 && NumCols == 8) { __imma_m32n8k16_ld_b_u8(destptr, tileptr, stride, get_layout_id()); } - } else if constexpr (std::is_same::value) { + } else if (std::is_same::value) { auto tileptr = reinterpret_cast(src.get()); auto destptr = reinterpret_cast(&res.wi_marray); - if constexpr (NumRows == 16 && NumCols == 16) { - if constexpr (Use == + if (NumRows == 16 && NumCols == 16) { + if (Use == sycl::ext::oneapi::experimental::matrix::matrix_use::a) { __imma_m16n16k16_ld_a_s8(destptr, tileptr, stride, get_layout_id()); - } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + } else if (Use == sycl::ext::oneapi::experimental::matrix:: matrix_use::b) { __imma_m16n16k16_ld_b_s8(destptr, tileptr, stride, get_layout_id()); } - } else if constexpr (NumRows == 8 && NumCols == 16) { + } else if (NumRows == 8 && NumCols == 16) { __imma_m8n32k16_ld_a_s8(destptr, tileptr, stride, get_layout_id()); - } else if constexpr (NumRows == 16 && NumCols == 32) { + } else if (NumRows == 16 && NumCols == 32) { __imma_m8n32k16_ld_b_s8(destptr, tileptr, stride, get_layout_id()); - } else if constexpr (NumRows == 32 && NumCols == 16) { + } else if (NumRows == 32 && NumCols == 16) { __imma_m32n8k16_ld_a_s8(destptr, tileptr, stride, get_layout_id()); - } else if constexpr (NumRows == 16 && NumCols == 8) { + } else if (NumRows == 16 && NumCols == 8) { __imma_m32n8k16_ld_b_s8(destptr, tileptr, stride, get_layout_id()); } - } else if constexpr (std::is_same::value) { + } else if (std::is_same::value) { auto tileptr = reinterpret_cast(src.get()); auto dstptr = reinterpret_cast(&res.wi_marray); - if constexpr (NumRows == 16 && NumCols == 16) { - if constexpr (Use == + if (NumRows == 16 && NumCols == 16) { + if (Use == sycl::ext::oneapi::experimental::matrix::matrix_use::a) { __hmma_m16n16k16_ld_a(dstptr, tileptr, stride, get_layout_id()); - } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + } else if (Use == sycl::ext::oneapi::experimental::matrix:: matrix_use::b) { __hmma_m16n16k16_ld_b(dstptr, tileptr, stride, get_layout_id()); - } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + } else if (Use == sycl::ext::oneapi::experimental::matrix:: matrix_use::accumulator) { __hmma_m16n16k16_ld_c_f16(dstptr, tileptr, stride, get_layout_id()); } - } else if constexpr (NumRows == 8 && NumCols == 16) { + } else if (NumRows == 8 && NumCols == 16) { __hmma_m8n32k16_ld_a(dstptr, tileptr, stride, get_layout_id()); - } else if constexpr (NumRows == 16 && NumCols == 32) { + } else if (NumRows == 16 && NumCols == 32) { __hmma_m8n32k16_ld_b(dstptr, tileptr, stride, get_layout_id()); - } else if constexpr (NumRows == 32 && NumCols == 16) { + } else if (NumRows == 32 && NumCols == 16) { __hmma_m32n8k16_ld_a(dstptr, tileptr, stride, get_layout_id()); - } else if constexpr (NumRows == 16 && NumCols == 8) { + } else if (NumRows == 16 && NumCols == 8) { __hmma_m32n8k16_ld_b(dstptr, tileptr, stride, get_layout_id()); - } else if constexpr (NumRows == 32 && NumCols == 8) { + } else if (NumRows == 32 && NumCols == 8) { __hmma_m32n8k16_ld_c_f16(dstptr, tileptr, stride, get_layout_id()); - } else if constexpr (NumRows == 8 && NumCols == 32) { + } else if (NumRows == 8 && NumCols == 32) { __hmma_m8n32k16_ld_c_f16(dstptr, tileptr, stride, get_layout_id()); } - } else if constexpr (std::is_same::value) { + } else if (std::is_same::value) { auto destptr = reinterpret_cast(&res.wi_marray); - if constexpr (NumRows == 16 && NumCols == 16) { + if (NumRows == 16 && NumCols == 16) { __imma_m16n16k16_ld_c(destptr, src.get(), stride, get_layout_id()); - } else if constexpr (NumRows == 8 && NumCols == 32) { + } else if (NumRows == 8 && NumCols == 32) { __imma_m8n32k16_ld_c(destptr, src.get(), stride, get_layout_id()); - } else if constexpr (NumRows == 32 && NumCols == 8) { + } else if (NumRows == 32 && NumCols == 8) { __imma_m32n8k16_ld_c(destptr, src.get(), stride, get_layout_id()); } - } else if constexpr (std::is_same::value) { - if constexpr (std::is_same::value) { + } else if (std::is_same::value) { + if (std::is_same::value) { auto dstptr = reinterpret_cast(&res.wi_marray); - if constexpr (NumRows == 16 && NumCols == 16) { + if (NumRows == 16 && NumCols == 16) { __hmma_m16n16k16_ld_c_f32(dstptr, src.get(), stride, get_layout_id()); - } else if constexpr (NumRows == 8 && NumCols == 32) { + } else if (NumRows == 8 && NumCols == 32) { __hmma_m8n32k16_ld_c_f32(dstptr, src.get(), stride, get_layout_id()); - } else if constexpr (NumRows == 32 && NumCols == 8) { + } else if (NumRows == 32 && NumCols == 8) { __hmma_m32n8k16_ld_c_f32(dstptr, src.get(), stride, get_layout_id()); } - } else if constexpr (std::is_same::value) { auto tileptr = reinterpret_cast(src.get()); auto dstptr = reinterpret_cast(&res.wi_marray); - if constexpr (NumRows == 16 && NumCols == 8) { + if (NumRows == 16 && NumCols == 8) { __mma_tf32_m16n16k8_ld_a(dstptr, tileptr, stride, get_layout_id()); - } else if constexpr (NumRows == 8 && NumCols == 16) { + } else if (NumRows == 8 && NumCols == 16) { __mma_tf32_m16n16k8_ld_b(dstptr, tileptr, stride, get_layout_id()); } } - } else if constexpr (std::is_same::value) { + } else if (std::is_same::value) { auto dstptr = reinterpret_cast(&res.wi_marray); - if constexpr (Use == + if (Use == sycl::ext::oneapi::experimental::matrix::matrix_use::a) { __dmma_m8n8k4_ld_a(dstptr, src.get(), stride, get_layout_id()); - } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + } else if (Use == sycl::ext::oneapi::experimental::matrix:: matrix_use::b) { __dmma_m8n8k4_ld_b(dstptr, src.get(), stride, get_layout_id()); - } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + } else if (Use == sycl::ext::oneapi::experimental::matrix:: matrix_use::accumulator) { __dmma_m8n8k4_ld_c(dstptr, src.get(), stride, get_layout_id()); } @@ -405,49 +408,49 @@ struct joint_matrix_store_impl< T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, NumRows, NumCols, Layout, sycl::sub_group> &src, multi_ptr dst, size_t stride) { - if constexpr (NumRows == 16 && NumCols == 16) { - if constexpr (std::is_same::value) { + if (NumRows == 16 && NumCols == 16) { + if (std::is_same::value) { __hmma_m16n16k16_st_c_f32(dst.get(), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); - } else if constexpr (std::is_same::value) { + } else if (std::is_same::value) { __imma_m16n16k16_st_c_i32(dst.get(), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); - } else if constexpr (std::is_same::value) { + } else if (std::is_same::value) { __hmma_m16n16k16_st_c_f16(reinterpret_cast(dst.get()), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } - } else if constexpr (NumRows == 8 && NumCols == 32) { - if constexpr (std::is_same::value) { + } else if (NumRows == 8 && NumCols == 32) { + if (std::is_same::value) { __hmma_m8n32k16_st_c_f32(dst.get(), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); - } else if constexpr (std::is_same::value) { + } else if (std::is_same::value) { __imma_m8n32k16_st_c_i32(dst.get(), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); - } else if constexpr (std::is_same::value) { + } else if (std::is_same::value) { __hmma_m8n32k16_st_c_f16(reinterpret_cast(dst.get()), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } - } else if constexpr (NumRows == 32 && NumCols == 8) { - if constexpr (std::is_same::value) { + } else if (NumRows == 32 && NumCols == 8) { + if (std::is_same::value) { __hmma_m32n8k16_st_c_f32(dst.get(), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); - } else if constexpr (std::is_same::value) { + } else if (std::is_same::value) { __imma_m32n8k16_st_c_i32(dst.get(), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); - } else if constexpr (std::is_same::value) { + } else if (std::is_same::value) { __hmma_m32n8k16_st_c_f16(reinterpret_cast(dst.get()), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } - } else if constexpr (std::is_same::value) { + } else if (std::is_same::value) { __dmma_m8n8k4_st_c_f64(dst.get(), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); @@ -548,34 +551,34 @@ struct joint_matrix_mad_impl< T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, LayoutC, sycl::sub_group> D; - if constexpr (M == 16 && N == 16 && K == 16) { - if constexpr (std::is_same::value) { + if (M == 16 && N == 16 && K == 16) { + if (std::is_same::value) { auto ptrA = reinterpret_cast(&A.wi_marray); auto ptrB = reinterpret_cast(&B.wi_marray); auto ptrC = reinterpret_cast(&C.wi_marray); auto ptrD = reinterpret_cast(&D.wi_marray); - if constexpr (std::is_same::value) { + if (std::is_same::value) { __imma_m16n16k16_mma_s8(ptrD, ptrA, ptrB, ptrC, get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { + } else if (std::is_same::value) { __imma_m16n16k16_mma_u8(ptrD, ptrA, ptrB, ptrC, get_layout_pair_id(), 0); } - } else if constexpr (std::is_same::value) { + } else if (std::is_same::value) { auto ptrA = reinterpret_cast(&A.wi_marray); auto ptrB = reinterpret_cast(&B.wi_marray); - if constexpr (std::is_same::value) { + if (std::is_same::value) { __hmma_m16n16k16_mma_f32f32( reinterpret_cast(&D.wi_marray), ptrA, ptrB, reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { + } else if (std::is_same::value) { __hmma_m16n16k16_mma_f16f16( reinterpret_cast(&D.wi_marray), ptrA, ptrB, reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } - } else if constexpr (std::is_same::value || + } else if (std::is_same::value || std::is_same::value) { __mma_bf16_m16n16k16_mma_f32( @@ -585,34 +588,34 @@ struct joint_matrix_mad_impl< reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } - } else if constexpr (M == 8 && N == 32 && K == 16) { - if constexpr (std::is_same::value) { + } else if (M == 8 && N == 32 && K == 16) { + if (std::is_same::value) { auto ptrA = reinterpret_cast(&A.wi_marray); auto ptrB = reinterpret_cast(&B.wi_marray); auto ptrC = reinterpret_cast(&C.wi_marray); auto ptrD = reinterpret_cast(&D.wi_marray); - if constexpr (std::is_same::value) { + if (std::is_same::value) { __imma_m8n32k16_mma_s8(ptrD, ptrA, ptrB, ptrC, get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { + } else if (std::is_same::value) { __imma_m8n32k16_mma_u8(ptrD, ptrA, ptrB, ptrC, get_layout_pair_id(), 0); } - } else if constexpr (std::is_same::value) { + } else if (std::is_same::value) { auto ptrA = reinterpret_cast(&A.wi_marray); auto ptrB = reinterpret_cast(&B.wi_marray); - if constexpr (std::is_same::value) { + if (std::is_same::value) { __hmma_m8n32k16_mma_f32f32( reinterpret_cast(&D.wi_marray), ptrA, ptrB, reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { + } else if (std::is_same::value) { __hmma_m8n32k16_mma_f16f16( reinterpret_cast(&D.wi_marray), ptrA, ptrB, reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } - } else if constexpr (std::is_same::value || + } else if (std::is_same::value || std::is_same::value) { __mma_bf16_m8n32k16_mma_f32( @@ -622,20 +625,20 @@ struct joint_matrix_mad_impl< reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } - } else if constexpr (M == 32 && N == 8 && K == 16) { - if constexpr (std::is_same::value) { + } else if (M == 32 && N == 8 && K == 16) { + if (std::is_same::value) { auto ptrA = reinterpret_cast(&A.wi_marray); auto ptrB = reinterpret_cast(&B.wi_marray); auto ptrC = reinterpret_cast(&C.wi_marray); auto ptrD = reinterpret_cast(&D.wi_marray); - if constexpr (std::is_same::value) { + if (std::is_same::value) { __imma_m32n8k16_mma_s8(ptrD, ptrA, ptrB, ptrC, get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { + } else if (std::is_same::value) { __imma_m32n8k16_mma_u8(ptrD, ptrA, ptrB, ptrC, get_layout_pair_id(), 0); } - } else if constexpr (std::is_same::value || + } else if (std::is_same::value || std::is_same::value) { __mma_bf16_m32n8k16_mma_f32( @@ -644,28 +647,28 @@ struct joint_matrix_mad_impl< reinterpret_cast(&B.wi_marray), reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { + } else if (std::is_same::value) { auto ptrA = reinterpret_cast(&A.wi_marray); auto ptrB = reinterpret_cast(&B.wi_marray); - if constexpr (std::is_same::value) { + if (std::is_same::value) { __hmma_m32n8k16_mma_f32f32( reinterpret_cast(&D.wi_marray), ptrA, ptrB, reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { + } else if (std::is_same::value) { __hmma_m32n8k16_mma_f16f16( reinterpret_cast(&D.wi_marray), ptrA, ptrB, reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } } - } else if constexpr (M == 16 && N == 16 && K == 8) { + } else if (M == 16 && N == 16 && K == 8) { __mma_tf32_m16n16k8_mma_f32(reinterpret_cast(&D.wi_marray), reinterpret_cast(&A.wi_marray), reinterpret_cast(&B.wi_marray), reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { + } else if (std::is_same::value) { __dmma_m8n8k4_mma_f64(reinterpret_cast(&D.wi_marray), reinterpret_cast(&A.wi_marray), reinterpret_cast(&B.wi_marray), @@ -766,6 +769,9 @@ float round_to_tf32(float a) { #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) } -} // namespace experimental::matrix -} // namespace sycl::ext::oneapi +} // namespace matrix +} // namespace experimental +} // namespace oneapi +} // namespace ext +} // namespace sycl } // __SYCL_INLINE_NAMESPACE(cl) From e924e1e64d135a0962aa6da770fd93c0853bec27 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Mon, 11 Jul 2022 13:00:45 +0100 Subject: [PATCH 7/7] if constexpr is required for joint_matrix* functions. Signed-off-by: JackAKirk --- .../ext/oneapi/matrix/matrix-tensorcore.hpp | 192 +++++++++--------- 1 file changed, 100 insertions(+), 92 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp index c8417204ec9b5..cf53bec8f943c 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp @@ -172,7 +172,8 @@ joint_matrix_fill(Group sg, #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) } -} // namespace experimental::matrix +} // namespace matrix +} // namespace experimental namespace detail { @@ -202,6 +203,7 @@ constexpr int get_layout_id< return 1; } +#if __cplusplus >= 201703L // if constexpr usage template &res, multi_ptr src, size_t stride) { - if (std::is_same::value || + if constexpr (std::is_same::value || std::is_same< T, sycl::ext::oneapi::experimental::bfloat16>::value) { auto tileptr = reinterpret_cast(src.get()); auto destptr = reinterpret_cast(&res.wi_marray); - if (NumRows == 16 && NumCols == 16) { - if (Use == + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == sycl::ext::oneapi::experimental::matrix::matrix_use::a) { __mma_bf16_m16n16k16_ld_a(destptr, tileptr, stride, get_layout_id()); - } else if (Use == sycl::ext::oneapi::experimental::matrix:: + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: matrix_use::b) { __mma_bf16_m16n16k16_ld_b(destptr, tileptr, stride, get_layout_id()); } - } else if (NumRows == 8 && NumCols == 16) { + } else if constexpr (NumRows == 8 && NumCols == 16) { __mma_bf16_m8n32k16_ld_a(destptr, tileptr, stride, get_layout_id()); - } else if (NumRows == 16 && NumCols == 32) { + } else if constexpr (NumRows == 16 && NumCols == 32) { __mma_bf16_m8n32k16_ld_b(destptr, tileptr, stride, get_layout_id()); - } else if (NumRows == 32 && NumCols == 16) { + } else if constexpr (NumRows == 32 && NumCols == 16) { __mma_bf16_m32n8k16_ld_a(destptr, tileptr, stride, get_layout_id()); - } else if (NumRows == 16 && NumCols == 8) { + } else if constexpr (NumRows == 16 && NumCols == 8) { __mma_bf16_m32n8k16_ld_b(destptr, tileptr, stride, get_layout_id()); } - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { auto tileptr = reinterpret_cast(src.get()); auto destptr = reinterpret_cast(&res.wi_marray); - if (NumRows == 16 && NumCols == 16) { - if (Use == + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == sycl::ext::oneapi::experimental::matrix::matrix_use::a) { __imma_m16n16k16_ld_a_u8(destptr, tileptr, stride, get_layout_id()); - } else if (Use == sycl::ext::oneapi::experimental::matrix:: + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: matrix_use::b) { __imma_m16n16k16_ld_b_u8(destptr, tileptr, stride, get_layout_id()); } - } else if (NumRows == 8 && NumCols == 16) { + } else if constexpr (NumRows == 8 && NumCols == 16) { __imma_m8n32k16_ld_a_u8(destptr, tileptr, stride, get_layout_id()); - } else if (NumRows == 16 && NumCols == 32) { + } else if constexpr (NumRows == 16 && NumCols == 32) { __imma_m8n32k16_ld_b_u8(destptr, tileptr, stride, get_layout_id()); - } else if (NumRows == 32 && NumCols == 16) { + } else if constexpr (NumRows == 32 && NumCols == 16) { __imma_m32n8k16_ld_a_u8(destptr, tileptr, stride, get_layout_id()); - } else if (NumRows == 16 && NumCols == 8) { + } else if constexpr (NumRows == 16 && NumCols == 8) { __imma_m32n8k16_ld_b_u8(destptr, tileptr, stride, get_layout_id()); } - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { auto tileptr = reinterpret_cast(src.get()); auto destptr = reinterpret_cast(&res.wi_marray); - if (NumRows == 16 && NumCols == 16) { - if (Use == + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == sycl::ext::oneapi::experimental::matrix::matrix_use::a) { __imma_m16n16k16_ld_a_s8(destptr, tileptr, stride, get_layout_id()); - } else if (Use == sycl::ext::oneapi::experimental::matrix:: + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: matrix_use::b) { __imma_m16n16k16_ld_b_s8(destptr, tileptr, stride, get_layout_id()); } - } else if (NumRows == 8 && NumCols == 16) { + } else if constexpr (NumRows == 8 && NumCols == 16) { __imma_m8n32k16_ld_a_s8(destptr, tileptr, stride, get_layout_id()); - } else if (NumRows == 16 && NumCols == 32) { + } else if constexpr (NumRows == 16 && NumCols == 32) { __imma_m8n32k16_ld_b_s8(destptr, tileptr, stride, get_layout_id()); - } else if (NumRows == 32 && NumCols == 16) { + } else if constexpr (NumRows == 32 && NumCols == 16) { __imma_m32n8k16_ld_a_s8(destptr, tileptr, stride, get_layout_id()); - } else if (NumRows == 16 && NumCols == 8) { + } else if constexpr (NumRows == 16 && NumCols == 8) { __imma_m32n8k16_ld_b_s8(destptr, tileptr, stride, get_layout_id()); } - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { auto tileptr = reinterpret_cast(src.get()); auto dstptr = reinterpret_cast(&res.wi_marray); - if (NumRows == 16 && NumCols == 16) { - if (Use == + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == sycl::ext::oneapi::experimental::matrix::matrix_use::a) { __hmma_m16n16k16_ld_a(dstptr, tileptr, stride, get_layout_id()); - } else if (Use == sycl::ext::oneapi::experimental::matrix:: + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: matrix_use::b) { __hmma_m16n16k16_ld_b(dstptr, tileptr, stride, get_layout_id()); - } else if (Use == sycl::ext::oneapi::experimental::matrix:: + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: matrix_use::accumulator) { __hmma_m16n16k16_ld_c_f16(dstptr, tileptr, stride, get_layout_id()); } - } else if (NumRows == 8 && NumCols == 16) { + } else if constexpr (NumRows == 8 && NumCols == 16) { __hmma_m8n32k16_ld_a(dstptr, tileptr, stride, get_layout_id()); - } else if (NumRows == 16 && NumCols == 32) { + } else if constexpr (NumRows == 16 && NumCols == 32) { __hmma_m8n32k16_ld_b(dstptr, tileptr, stride, get_layout_id()); - } else if (NumRows == 32 && NumCols == 16) { + } else if constexpr (NumRows == 32 && NumCols == 16) { __hmma_m32n8k16_ld_a(dstptr, tileptr, stride, get_layout_id()); - } else if (NumRows == 16 && NumCols == 8) { + } else if constexpr (NumRows == 16 && NumCols == 8) { __hmma_m32n8k16_ld_b(dstptr, tileptr, stride, get_layout_id()); - } else if (NumRows == 32 && NumCols == 8) { + } else if constexpr (NumRows == 32 && NumCols == 8) { __hmma_m32n8k16_ld_c_f16(dstptr, tileptr, stride, get_layout_id()); - } else if (NumRows == 8 && NumCols == 32) { + } else if constexpr (NumRows == 8 && NumCols == 32) { __hmma_m8n32k16_ld_c_f16(dstptr, tileptr, stride, get_layout_id()); } - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { auto destptr = reinterpret_cast(&res.wi_marray); - if (NumRows == 16 && NumCols == 16) { + if constexpr (NumRows == 16 && NumCols == 16) { __imma_m16n16k16_ld_c(destptr, src.get(), stride, get_layout_id()); - } else if (NumRows == 8 && NumCols == 32) { + } else if constexpr (NumRows == 8 && NumCols == 32) { __imma_m8n32k16_ld_c(destptr, src.get(), stride, get_layout_id()); - } else if (NumRows == 32 && NumCols == 8) { + } else if constexpr (NumRows == 32 && NumCols == 8) { __imma_m32n8k16_ld_c(destptr, src.get(), stride, get_layout_id()); } - } else if (std::is_same::value) { - if (std::is_same::value) { + } else if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { auto dstptr = reinterpret_cast(&res.wi_marray); - if (NumRows == 16 && NumCols == 16) { + if constexpr (NumRows == 16 && NumCols == 16) { __hmma_m16n16k16_ld_c_f32(dstptr, src.get(), stride, get_layout_id()); - } else if (NumRows == 8 && NumCols == 32) { + } else if constexpr (NumRows == 8 && NumCols == 32) { __hmma_m8n32k16_ld_c_f32(dstptr, src.get(), stride, get_layout_id()); - } else if (NumRows == 32 && NumCols == 8) { + } else if constexpr (NumRows == 32 && NumCols == 8) { __hmma_m32n8k16_ld_c_f32(dstptr, src.get(), stride, get_layout_id()); } - } else if (std::is_same::value) { auto tileptr = reinterpret_cast(src.get()); auto dstptr = reinterpret_cast(&res.wi_marray); - if (NumRows == 16 && NumCols == 8) { + if constexpr (NumRows == 16 && NumCols == 8) { __mma_tf32_m16n16k8_ld_a(dstptr, tileptr, stride, get_layout_id()); - } else if (NumRows == 8 && NumCols == 16) { + } else if constexpr (NumRows == 8 && NumCols == 16) { __mma_tf32_m16n16k8_ld_b(dstptr, tileptr, stride, get_layout_id()); } } - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { auto dstptr = reinterpret_cast(&res.wi_marray); - if (Use == + if constexpr (Use == sycl::ext::oneapi::experimental::matrix::matrix_use::a) { __dmma_m8n8k4_ld_a(dstptr, src.get(), stride, get_layout_id()); - } else if (Use == sycl::ext::oneapi::experimental::matrix:: + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: matrix_use::b) { __dmma_m8n8k4_ld_b(dstptr, src.get(), stride, get_layout_id()); - } else if (Use == sycl::ext::oneapi::experimental::matrix:: + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: matrix_use::accumulator) { __dmma_m8n8k4_ld_c(dstptr, src.get(), stride, get_layout_id()); } } } }; +#endif // __cplusplus >= 201703L template dst, size_t stride); }; +#if __cplusplus >= 201703L // if constexpr usage template @@ -408,55 +412,56 @@ struct joint_matrix_store_impl< T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, NumRows, NumCols, Layout, sycl::sub_group> &src, multi_ptr dst, size_t stride) { - if (NumRows == 16 && NumCols == 16) { - if (std::is_same::value) { + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (std::is_same::value) { __hmma_m16n16k16_st_c_f32(dst.get(), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { __imma_m16n16k16_st_c_i32(dst.get(), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { __hmma_m16n16k16_st_c_f16(reinterpret_cast(dst.get()), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } - } else if (NumRows == 8 && NumCols == 32) { - if (std::is_same::value) { + } else if constexpr (NumRows == 8 && NumCols == 32) { + if constexpr (std::is_same::value) { __hmma_m8n32k16_st_c_f32(dst.get(), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { __imma_m8n32k16_st_c_i32(dst.get(), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { __hmma_m8n32k16_st_c_f16(reinterpret_cast(dst.get()), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } - } else if (NumRows == 32 && NumCols == 8) { - if (std::is_same::value) { + } else if constexpr (NumRows == 32 && NumCols == 8) { + if constexpr (std::is_same::value) { __hmma_m32n8k16_st_c_f32(dst.get(), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { __imma_m32n8k16_st_c_i32(dst.get(), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { __hmma_m32n8k16_st_c_f16(reinterpret_cast(dst.get()), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { __dmma_m8n8k4_st_c_f64(dst.get(), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } } }; +#endif // __cplusplus >= 201703L template = 201703L // if constexpr usage template D; - if (M == 16 && N == 16 && K == 16) { - if (std::is_same::value) { + if constexpr (M == 16 && N == 16 && K == 16) { + if constexpr (std::is_same::value) { auto ptrA = reinterpret_cast(&A.wi_marray); auto ptrB = reinterpret_cast(&B.wi_marray); auto ptrC = reinterpret_cast(&C.wi_marray); auto ptrD = reinterpret_cast(&D.wi_marray); - if (std::is_same::value) { + if constexpr (std::is_same::value) { __imma_m16n16k16_mma_s8(ptrD, ptrA, ptrB, ptrC, get_layout_pair_id(), 0); - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { __imma_m16n16k16_mma_u8(ptrD, ptrA, ptrB, ptrC, get_layout_pair_id(), 0); } - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { auto ptrA = reinterpret_cast(&A.wi_marray); auto ptrB = reinterpret_cast(&B.wi_marray); - if (std::is_same::value) { + if constexpr (std::is_same::value) { __hmma_m16n16k16_mma_f32f32( reinterpret_cast(&D.wi_marray), ptrA, ptrB, reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { __hmma_m16n16k16_mma_f16f16( reinterpret_cast(&D.wi_marray), ptrA, ptrB, reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } - } else if (std::is_same::value || + } else if constexpr (std::is_same::value || std::is_same::value) { __mma_bf16_m16n16k16_mma_f32( @@ -588,34 +594,34 @@ struct joint_matrix_mad_impl< reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } - } else if (M == 8 && N == 32 && K == 16) { - if (std::is_same::value) { + } else if constexpr (M == 8 && N == 32 && K == 16) { + if constexpr (std::is_same::value) { auto ptrA = reinterpret_cast(&A.wi_marray); auto ptrB = reinterpret_cast(&B.wi_marray); auto ptrC = reinterpret_cast(&C.wi_marray); auto ptrD = reinterpret_cast(&D.wi_marray); - if (std::is_same::value) { + if constexpr (std::is_same::value) { __imma_m8n32k16_mma_s8(ptrD, ptrA, ptrB, ptrC, get_layout_pair_id(), 0); - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { __imma_m8n32k16_mma_u8(ptrD, ptrA, ptrB, ptrC, get_layout_pair_id(), 0); } - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { auto ptrA = reinterpret_cast(&A.wi_marray); auto ptrB = reinterpret_cast(&B.wi_marray); - if (std::is_same::value) { + if constexpr (std::is_same::value) { __hmma_m8n32k16_mma_f32f32( reinterpret_cast(&D.wi_marray), ptrA, ptrB, reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { __hmma_m8n32k16_mma_f16f16( reinterpret_cast(&D.wi_marray), ptrA, ptrB, reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } - } else if (std::is_same::value || + } else if constexpr (std::is_same::value || std::is_same::value) { __mma_bf16_m8n32k16_mma_f32( @@ -625,20 +631,20 @@ struct joint_matrix_mad_impl< reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } - } else if (M == 32 && N == 8 && K == 16) { - if (std::is_same::value) { + } else if constexpr (M == 32 && N == 8 && K == 16) { + if constexpr (std::is_same::value) { auto ptrA = reinterpret_cast(&A.wi_marray); auto ptrB = reinterpret_cast(&B.wi_marray); auto ptrC = reinterpret_cast(&C.wi_marray); auto ptrD = reinterpret_cast(&D.wi_marray); - if (std::is_same::value) { + if constexpr (std::is_same::value) { __imma_m32n8k16_mma_s8(ptrD, ptrA, ptrB, ptrC, get_layout_pair_id(), 0); - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { __imma_m32n8k16_mma_u8(ptrD, ptrA, ptrB, ptrC, get_layout_pair_id(), 0); } - } else if (std::is_same::value || + } else if constexpr (std::is_same::value || std::is_same::value) { __mma_bf16_m32n8k16_mma_f32( @@ -647,28 +653,28 @@ struct joint_matrix_mad_impl< reinterpret_cast(&B.wi_marray), reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { auto ptrA = reinterpret_cast(&A.wi_marray); auto ptrB = reinterpret_cast(&B.wi_marray); - if (std::is_same::value) { + if constexpr (std::is_same::value) { __hmma_m32n8k16_mma_f32f32( reinterpret_cast(&D.wi_marray), ptrA, ptrB, reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { __hmma_m32n8k16_mma_f16f16( reinterpret_cast(&D.wi_marray), ptrA, ptrB, reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } } - } else if (M == 16 && N == 16 && K == 8) { + } else if constexpr (M == 16 && N == 16 && K == 8) { __mma_tf32_m16n16k8_mma_f32(reinterpret_cast(&D.wi_marray), reinterpret_cast(&A.wi_marray), reinterpret_cast(&B.wi_marray), reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { __dmma_m8n8k4_mma_f64(reinterpret_cast(&D.wi_marray), reinterpret_cast(&A.wi_marray), reinterpret_cast(&B.wi_marray), @@ -678,10 +684,12 @@ struct joint_matrix_mad_impl< return D; } }; +#endif // __cplusplus >= 201703L } // namespace detail -namespace experimental::matrix { +namespace experimental { +namespace matrix { template