From d2db69f8f88b416f97abfde198f088a2c8a3e3d4 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Wed, 17 May 2023 06:45:24 -0700 Subject: [PATCH 1/2] Fix non decorated address space for cuda joint_matrix. Signed-off-by: JackAKirk --- .../ext/oneapi/matrix/matrix-tensorcores.hpp | 132 +++++++++--------- 1 file changed, 69 insertions(+), 63 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp index 018c66cf5213d..f4a941f5a4261 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp @@ -149,43 +149,45 @@ void load_accumulator_layoutT( S, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows, NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res, multi_ptr src, size_t stride) { + using DecorT = typename sycl::detail::DecoratedType::type; + DecorT *srcPtr = sycl::detail::getDecorated(src); if constexpr (std::is_same_v) { - auto destptr = reinterpret_cast(&res.wi_marray); + auto dstPtr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { - __imma_m16n16k16_ld_c(destptr, src.get(), stride, + __imma_m16n16k16_ld_c(dstPtr, srcPtr, stride, get_layout_id()); } else if constexpr (NumRows == 8 && NumCols == 32) { - __imma_m8n32k16_ld_c(destptr, src.get(), stride, get_layout_id()); + __imma_m8n32k16_ld_c(dstPtr, srcPtr, stride, get_layout_id()); } else if constexpr (NumRows == 32 && NumCols == 8) { - __imma_m32n8k16_ld_c(destptr, src.get(), stride, get_layout_id()); + __imma_m32n8k16_ld_c(dstPtr, srcPtr, stride, get_layout_id()); } } else if constexpr (std::is_same_v) { - auto dstptr = reinterpret_cast(&res.wi_marray); + auto dstPtr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { - __hmma_m16n16k16_ld_c_f32(dstptr, src.get(), stride, + __hmma_m16n16k16_ld_c_f32(dstPtr, srcPtr, stride, get_layout_id()); } else if constexpr (NumRows == 8 && NumCols == 32) { - __hmma_m8n32k16_ld_c_f32(dstptr, src.get(), stride, + __hmma_m8n32k16_ld_c_f32(dstPtr, srcPtr, stride, get_layout_id()); } else if constexpr (NumRows == 32 && NumCols == 8) { - __hmma_m32n8k16_ld_c_f32(dstptr, src.get(), stride, + __hmma_m32n8k16_ld_c_f32(dstPtr, srcPtr, stride, get_layout_id()); } } else if constexpr (std::is_same_v) { - auto tileptr = reinterpret_cast(src.get()); - auto dstptr = reinterpret_cast(&res.wi_marray); + auto tilePtr = reinterpret_cast(srcPtr); + auto dstPtr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 32 && NumCols == 8) { - __hmma_m32n8k16_ld_c_f16(dstptr, tileptr, stride, + __hmma_m32n8k16_ld_c_f16(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 8 && NumCols == 32) { - __hmma_m8n32k16_ld_c_f16(dstptr, tileptr, stride, + __hmma_m8n32k16_ld_c_f16(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 16) { - __hmma_m16n16k16_ld_c_f16(dstptr, tileptr, stride, + __hmma_m16n16k16_ld_c_f16(dstPtr, tilePtr, stride, get_layout_id()); } } else if constexpr (std::is_same_v) { - __dmma_m8n8k4_ld_c(reinterpret_cast(&res.wi_marray), src.get(), + __dmma_m8n8k4_ld_c(reinterpret_cast(&res.wi_marray), srcPtr, stride, get_layout_id()); } }; @@ -227,119 +229,121 @@ template < void load_multiplicand_cuda( joint_matrix_cuda &res, multi_ptr src, size_t stride) { + using DecorT = typename sycl::detail::DecoratedType::type; + DecorT *srcPtr = sycl::detail::getDecorated(src); if constexpr (std::is_same_v) { - auto tileptr = reinterpret_cast(src.get()); - auto destptr = reinterpret_cast(&res.wi_marray); + auto tilePtr = reinterpret_cast(srcPtr); + auto dstPtr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::a) { - __mma_bf16_m16n16k16_ld_a(destptr, tileptr, stride, + __mma_bf16_m16n16k16_ld_a(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::b) { - __mma_bf16_m16n16k16_ld_b(destptr, tileptr, stride, + __mma_bf16_m16n16k16_ld_b(dstPtr, tilePtr, stride, get_layout_id()); } } else if constexpr (NumRows == 8 && NumCols == 16) { - __mma_bf16_m8n32k16_ld_a(destptr, tileptr, stride, + __mma_bf16_m8n32k16_ld_a(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 32) { - __mma_bf16_m8n32k16_ld_b(destptr, tileptr, stride, + __mma_bf16_m8n32k16_ld_b(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 32 && NumCols == 16) { - __mma_bf16_m32n8k16_ld_a(destptr, tileptr, stride, + __mma_bf16_m32n8k16_ld_a(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 8) { - __mma_bf16_m32n8k16_ld_b(destptr, tileptr, stride, + __mma_bf16_m32n8k16_ld_b(dstPtr, tilePtr, stride, get_layout_id()); } } else if constexpr (std::is_same_v) { - auto tileptr = reinterpret_cast(src.get()); - auto destptr = reinterpret_cast(&res.wi_marray); + auto tilePtr = reinterpret_cast(srcPtr); + auto dstPtr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::a) { - __imma_m16n16k16_ld_a_u8(destptr, tileptr, stride, + __imma_m16n16k16_ld_a_u8(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::b) { - __imma_m16n16k16_ld_b_u8(destptr, tileptr, stride, + __imma_m16n16k16_ld_b_u8(dstPtr, tilePtr, stride, get_layout_id()); } } else if constexpr (NumRows == 8 && NumCols == 16) { - __imma_m8n32k16_ld_a_u8(destptr, tileptr, stride, + __imma_m8n32k16_ld_a_u8(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 32) { - __imma_m8n32k16_ld_b_u8(destptr, tileptr, stride, + __imma_m8n32k16_ld_b_u8(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 32 && NumCols == 16) { - __imma_m32n8k16_ld_a_u8(destptr, tileptr, stride, + __imma_m32n8k16_ld_a_u8(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 8) { - __imma_m32n8k16_ld_b_u8(destptr, tileptr, stride, + __imma_m32n8k16_ld_b_u8(dstPtr, tilePtr, stride, get_layout_id()); } } else if constexpr (std::is_same_v) { - auto tileptr = reinterpret_cast(src.get()); - auto destptr = reinterpret_cast(&res.wi_marray); + auto tilePtr = reinterpret_cast(srcPtr); + auto dstPtr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::a) { - __imma_m16n16k16_ld_a_s8(destptr, tileptr, stride, + __imma_m16n16k16_ld_a_s8(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::b) { - __imma_m16n16k16_ld_b_s8(destptr, tileptr, stride, + __imma_m16n16k16_ld_b_s8(dstPtr, tilePtr, stride, get_layout_id()); } } else if constexpr (NumRows == 8 && NumCols == 16) { - __imma_m8n32k16_ld_a_s8(destptr, tileptr, stride, + __imma_m8n32k16_ld_a_s8(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 32) { - __imma_m8n32k16_ld_b_s8(destptr, tileptr, stride, + __imma_m8n32k16_ld_b_s8(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 32 && NumCols == 16) { - __imma_m32n8k16_ld_a_s8(destptr, tileptr, stride, + __imma_m32n8k16_ld_a_s8(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 8) { - __imma_m32n8k16_ld_b_s8(destptr, tileptr, stride, + __imma_m32n8k16_ld_b_s8(dstPtr, tilePtr, stride, get_layout_id()); } } else if constexpr (std::is_same_v) { - auto tileptr = reinterpret_cast(src.get()); - auto dstptr = reinterpret_cast(&res.wi_marray); + auto tilePtr = reinterpret_cast(srcPtr); + auto dstPtr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::a) { - __hmma_m16n16k16_ld_a(dstptr, tileptr, stride, get_layout_id()); + __hmma_m16n16k16_ld_a(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::b) { - __hmma_m16n16k16_ld_b(dstptr, tileptr, stride, get_layout_id()); + __hmma_m16n16k16_ld_b(dstPtr, tilePtr, stride, get_layout_id()); } } else if constexpr (NumRows == 8 && NumCols == 16) { - __hmma_m8n32k16_ld_a(dstptr, tileptr, stride, get_layout_id()); + __hmma_m8n32k16_ld_a(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 32) { - __hmma_m8n32k16_ld_b(dstptr, tileptr, stride, get_layout_id()); + __hmma_m8n32k16_ld_b(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 32 && NumCols == 16) { - __hmma_m32n8k16_ld_a(dstptr, tileptr, stride, get_layout_id()); + __hmma_m32n8k16_ld_a(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 8) { - __hmma_m32n8k16_ld_b(dstptr, tileptr, stride, get_layout_id()); + __hmma_m32n8k16_ld_b(dstPtr, tilePtr, stride, get_layout_id()); } } else if constexpr (std::is_same_v) { - auto tileptr = reinterpret_cast(src.get()); - auto dstptr = reinterpret_cast(&res.wi_marray); + auto tilePtr = reinterpret_cast(srcPtr); + auto dstPtr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 8) { - __mma_tf32_m16n16k8_ld_a(dstptr, tileptr, stride, + __mma_tf32_m16n16k8_ld_a(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 8 && NumCols == 16) { - __mma_tf32_m16n16k8_ld_b(dstptr, tileptr, stride, + __mma_tf32_m16n16k8_ld_b(dstPtr, tilePtr, stride, get_layout_id()); } } else if constexpr (std::is_same_v) { - auto dstptr = reinterpret_cast(&res.wi_marray); + auto dstPtr = reinterpret_cast(&res.wi_marray); if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::a) { - __dmma_m8n8k4_ld_a(dstptr, src.get(), stride, get_layout_id()); + __dmma_m8n8k4_ld_a(dstPtr, srcPtr, stride, get_layout_id()); } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::b) { - __dmma_m8n8k4_ld_b(dstptr, src.get(), stride, get_layout_id()); + __dmma_m8n8k4_ld_b(dstPtr, srcPtr, stride, get_layout_id()); } } } @@ -352,50 +356,52 @@ void store_layoutT( T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows, NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src, multi_ptr dst, size_t stride) { + using DecorT = typename sycl::detail::DecoratedType::type; + DecorT *dstPtr = sycl::detail::getDecorated(dst); if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (std::is_same_v) { - __hmma_m16n16k16_st_c_f32(dst.get(), + __hmma_m16n16k16_st_c_f32(dstPtr, reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } else if constexpr (std::is_same_v) { - __imma_m16n16k16_st_c_i32(dst.get(), + __imma_m16n16k16_st_c_i32(dstPtr, reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } else if constexpr (std::is_same_v) { - __hmma_m16n16k16_st_c_f16(reinterpret_cast(dst.get()), + __hmma_m16n16k16_st_c_f16(reinterpret_cast(dstPtr), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } } else if constexpr (NumRows == 8 && NumCols == 32) { if constexpr (std::is_same_v) { - __hmma_m8n32k16_st_c_f32(dst.get(), + __hmma_m8n32k16_st_c_f32(dstPtr, reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } else if constexpr (std::is_same_v) { - __imma_m8n32k16_st_c_i32(dst.get(), + __imma_m8n32k16_st_c_i32(dstPtr, reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } else if constexpr (std::is_same_v) { - __hmma_m8n32k16_st_c_f16(reinterpret_cast(dst.get()), + __hmma_m8n32k16_st_c_f16(reinterpret_cast(dstPtr), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } } else if constexpr (NumRows == 32 && NumCols == 8) { if constexpr (std::is_same_v) { - __hmma_m32n8k16_st_c_f32(dst.get(), + __hmma_m32n8k16_st_c_f32(dstPtr, reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } else if constexpr (std::is_same_v) { - __imma_m32n8k16_st_c_i32(dst.get(), + __imma_m32n8k16_st_c_i32(dstPtr, reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } else if constexpr (std::is_same_v) { - __hmma_m32n8k16_st_c_f16(reinterpret_cast(dst.get()), + __hmma_m32n8k16_st_c_f16(reinterpret_cast(dstPtr), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } } else if constexpr (std::is_same_v) { - __dmma_m8n8k4_st_c_f64(dst.get(), + __dmma_m8n8k4_st_c_f64(dstPtr, reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } From 797596a19fc5a2192bd700564880160925f98fcc Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Wed, 17 May 2023 07:22:07 -0700 Subject: [PATCH 2/2] Added missing header. Signed-off-by: JackAKirk --- sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp index f4a941f5a4261..23cb99f064619 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp @@ -8,6 +8,7 @@ // ===-------------------------------------------------------------------=== // #pragma once +#include "utils.hpp" #include "matrix-unified-utils.hpp" #include