diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp index 018c66cf5213d..23cb99f064619 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp @@ -8,6 +8,7 @@ // ===-------------------------------------------------------------------=== // #pragma once +#include "utils.hpp" #include "matrix-unified-utils.hpp" #include @@ -149,43 +150,45 @@ void load_accumulator_layoutT( S, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows, NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res, multi_ptr src, size_t stride) { + using DecorT = typename sycl::detail::DecoratedType::type; + DecorT *srcPtr = sycl::detail::getDecorated(src); if constexpr (std::is_same_v) { - auto destptr = reinterpret_cast(&res.wi_marray); + auto dstPtr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { - __imma_m16n16k16_ld_c(destptr, src.get(), stride, + __imma_m16n16k16_ld_c(dstPtr, srcPtr, stride, get_layout_id()); } else if constexpr (NumRows == 8 && NumCols == 32) { - __imma_m8n32k16_ld_c(destptr, src.get(), stride, get_layout_id()); + __imma_m8n32k16_ld_c(dstPtr, srcPtr, stride, get_layout_id()); } else if constexpr (NumRows == 32 && NumCols == 8) { - __imma_m32n8k16_ld_c(destptr, src.get(), stride, get_layout_id()); + __imma_m32n8k16_ld_c(dstPtr, srcPtr, stride, get_layout_id()); } } else if constexpr (std::is_same_v) { - auto dstptr = reinterpret_cast(&res.wi_marray); + auto dstPtr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { - __hmma_m16n16k16_ld_c_f32(dstptr, src.get(), stride, + __hmma_m16n16k16_ld_c_f32(dstPtr, srcPtr, stride, get_layout_id()); } else if constexpr (NumRows == 8 && NumCols == 32) { - __hmma_m8n32k16_ld_c_f32(dstptr, src.get(), stride, + __hmma_m8n32k16_ld_c_f32(dstPtr, srcPtr, stride, get_layout_id()); } else if constexpr (NumRows == 32 && NumCols == 8) { - __hmma_m32n8k16_ld_c_f32(dstptr, src.get(), stride, + __hmma_m32n8k16_ld_c_f32(dstPtr, srcPtr, stride, get_layout_id()); } } else if constexpr (std::is_same_v) { - auto tileptr = reinterpret_cast(src.get()); - auto dstptr = reinterpret_cast(&res.wi_marray); + auto tilePtr = reinterpret_cast(srcPtr); + auto dstPtr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 32 && NumCols == 8) { - __hmma_m32n8k16_ld_c_f16(dstptr, tileptr, stride, + __hmma_m32n8k16_ld_c_f16(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 8 && NumCols == 32) { - __hmma_m8n32k16_ld_c_f16(dstptr, tileptr, stride, + __hmma_m8n32k16_ld_c_f16(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 16) { - __hmma_m16n16k16_ld_c_f16(dstptr, tileptr, stride, + __hmma_m16n16k16_ld_c_f16(dstPtr, tilePtr, stride, get_layout_id()); } } else if constexpr (std::is_same_v) { - __dmma_m8n8k4_ld_c(reinterpret_cast(&res.wi_marray), src.get(), + __dmma_m8n8k4_ld_c(reinterpret_cast(&res.wi_marray), srcPtr, stride, get_layout_id()); } }; @@ -227,119 +230,121 @@ template < void load_multiplicand_cuda( joint_matrix_cuda &res, multi_ptr src, size_t stride) { + using DecorT = typename sycl::detail::DecoratedType::type; + DecorT *srcPtr = sycl::detail::getDecorated(src); if constexpr (std::is_same_v) { - auto tileptr = reinterpret_cast(src.get()); - auto destptr = reinterpret_cast(&res.wi_marray); + auto tilePtr = reinterpret_cast(srcPtr); + auto dstPtr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::a) { - __mma_bf16_m16n16k16_ld_a(destptr, tileptr, stride, + __mma_bf16_m16n16k16_ld_a(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::b) { - __mma_bf16_m16n16k16_ld_b(destptr, tileptr, stride, + __mma_bf16_m16n16k16_ld_b(dstPtr, tilePtr, stride, get_layout_id()); } } else if constexpr (NumRows == 8 && NumCols == 16) { - __mma_bf16_m8n32k16_ld_a(destptr, tileptr, stride, + __mma_bf16_m8n32k16_ld_a(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 32) { - __mma_bf16_m8n32k16_ld_b(destptr, tileptr, stride, + __mma_bf16_m8n32k16_ld_b(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 32 && NumCols == 16) { - __mma_bf16_m32n8k16_ld_a(destptr, tileptr, stride, + __mma_bf16_m32n8k16_ld_a(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 8) { - __mma_bf16_m32n8k16_ld_b(destptr, tileptr, stride, + __mma_bf16_m32n8k16_ld_b(dstPtr, tilePtr, stride, get_layout_id()); } } else if constexpr (std::is_same_v) { - auto tileptr = reinterpret_cast(src.get()); - auto destptr = reinterpret_cast(&res.wi_marray); + auto tilePtr = reinterpret_cast(srcPtr); + auto dstPtr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::a) { - __imma_m16n16k16_ld_a_u8(destptr, tileptr, stride, + __imma_m16n16k16_ld_a_u8(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::b) { - __imma_m16n16k16_ld_b_u8(destptr, tileptr, stride, + __imma_m16n16k16_ld_b_u8(dstPtr, tilePtr, stride, get_layout_id()); } } else if constexpr (NumRows == 8 && NumCols == 16) { - __imma_m8n32k16_ld_a_u8(destptr, tileptr, stride, + __imma_m8n32k16_ld_a_u8(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 32) { - __imma_m8n32k16_ld_b_u8(destptr, tileptr, stride, + __imma_m8n32k16_ld_b_u8(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 32 && NumCols == 16) { - __imma_m32n8k16_ld_a_u8(destptr, tileptr, stride, + __imma_m32n8k16_ld_a_u8(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 8) { - __imma_m32n8k16_ld_b_u8(destptr, tileptr, stride, + __imma_m32n8k16_ld_b_u8(dstPtr, tilePtr, stride, get_layout_id()); } } else if constexpr (std::is_same_v) { - auto tileptr = reinterpret_cast(src.get()); - auto destptr = reinterpret_cast(&res.wi_marray); + auto tilePtr = reinterpret_cast(srcPtr); + auto dstPtr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::a) { - __imma_m16n16k16_ld_a_s8(destptr, tileptr, stride, + __imma_m16n16k16_ld_a_s8(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::b) { - __imma_m16n16k16_ld_b_s8(destptr, tileptr, stride, + __imma_m16n16k16_ld_b_s8(dstPtr, tilePtr, stride, get_layout_id()); } } else if constexpr (NumRows == 8 && NumCols == 16) { - __imma_m8n32k16_ld_a_s8(destptr, tileptr, stride, + __imma_m8n32k16_ld_a_s8(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 32) { - __imma_m8n32k16_ld_b_s8(destptr, tileptr, stride, + __imma_m8n32k16_ld_b_s8(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 32 && NumCols == 16) { - __imma_m32n8k16_ld_a_s8(destptr, tileptr, stride, + __imma_m32n8k16_ld_a_s8(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 8) { - __imma_m32n8k16_ld_b_s8(destptr, tileptr, stride, + __imma_m32n8k16_ld_b_s8(dstPtr, tilePtr, stride, get_layout_id()); } } else if constexpr (std::is_same_v) { - auto tileptr = reinterpret_cast(src.get()); - auto dstptr = reinterpret_cast(&res.wi_marray); + auto tilePtr = reinterpret_cast(srcPtr); + auto dstPtr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::a) { - __hmma_m16n16k16_ld_a(dstptr, tileptr, stride, get_layout_id()); + __hmma_m16n16k16_ld_a(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::b) { - __hmma_m16n16k16_ld_b(dstptr, tileptr, stride, get_layout_id()); + __hmma_m16n16k16_ld_b(dstPtr, tilePtr, stride, get_layout_id()); } } else if constexpr (NumRows == 8 && NumCols == 16) { - __hmma_m8n32k16_ld_a(dstptr, tileptr, stride, get_layout_id()); + __hmma_m8n32k16_ld_a(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 32) { - __hmma_m8n32k16_ld_b(dstptr, tileptr, stride, get_layout_id()); + __hmma_m8n32k16_ld_b(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 32 && NumCols == 16) { - __hmma_m32n8k16_ld_a(dstptr, tileptr, stride, get_layout_id()); + __hmma_m32n8k16_ld_a(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 8) { - __hmma_m32n8k16_ld_b(dstptr, tileptr, stride, get_layout_id()); + __hmma_m32n8k16_ld_b(dstPtr, tilePtr, stride, get_layout_id()); } } else if constexpr (std::is_same_v) { - auto tileptr = reinterpret_cast(src.get()); - auto dstptr = reinterpret_cast(&res.wi_marray); + auto tilePtr = reinterpret_cast(srcPtr); + auto dstPtr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 8) { - __mma_tf32_m16n16k8_ld_a(dstptr, tileptr, stride, + __mma_tf32_m16n16k8_ld_a(dstPtr, tilePtr, stride, get_layout_id()); } else if constexpr (NumRows == 8 && NumCols == 16) { - __mma_tf32_m16n16k8_ld_b(dstptr, tileptr, stride, + __mma_tf32_m16n16k8_ld_b(dstPtr, tilePtr, stride, get_layout_id()); } } else if constexpr (std::is_same_v) { - auto dstptr = reinterpret_cast(&res.wi_marray); + auto dstPtr = reinterpret_cast(&res.wi_marray); if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::a) { - __dmma_m8n8k4_ld_a(dstptr, src.get(), stride, get_layout_id()); + __dmma_m8n8k4_ld_a(dstPtr, srcPtr, stride, get_layout_id()); } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::use::b) { - __dmma_m8n8k4_ld_b(dstptr, src.get(), stride, get_layout_id()); + __dmma_m8n8k4_ld_b(dstPtr, srcPtr, stride, get_layout_id()); } } } @@ -352,50 +357,52 @@ void store_layoutT( T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows, NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src, multi_ptr dst, size_t stride) { + using DecorT = typename sycl::detail::DecoratedType::type; + DecorT *dstPtr = sycl::detail::getDecorated(dst); if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (std::is_same_v) { - __hmma_m16n16k16_st_c_f32(dst.get(), + __hmma_m16n16k16_st_c_f32(dstPtr, reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } else if constexpr (std::is_same_v) { - __imma_m16n16k16_st_c_i32(dst.get(), + __imma_m16n16k16_st_c_i32(dstPtr, reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } else if constexpr (std::is_same_v) { - __hmma_m16n16k16_st_c_f16(reinterpret_cast(dst.get()), + __hmma_m16n16k16_st_c_f16(reinterpret_cast(dstPtr), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } } else if constexpr (NumRows == 8 && NumCols == 32) { if constexpr (std::is_same_v) { - __hmma_m8n32k16_st_c_f32(dst.get(), + __hmma_m8n32k16_st_c_f32(dstPtr, reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } else if constexpr (std::is_same_v) { - __imma_m8n32k16_st_c_i32(dst.get(), + __imma_m8n32k16_st_c_i32(dstPtr, reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } else if constexpr (std::is_same_v) { - __hmma_m8n32k16_st_c_f16(reinterpret_cast(dst.get()), + __hmma_m8n32k16_st_c_f16(reinterpret_cast(dstPtr), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } } else if constexpr (NumRows == 32 && NumCols == 8) { if constexpr (std::is_same_v) { - __hmma_m32n8k16_st_c_f32(dst.get(), + __hmma_m32n8k16_st_c_f32(dstPtr, reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } else if constexpr (std::is_same_v) { - __imma_m32n8k16_st_c_i32(dst.get(), + __imma_m32n8k16_st_c_i32(dstPtr, reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } else if constexpr (std::is_same_v) { - __hmma_m32n8k16_st_c_f16(reinterpret_cast(dst.get()), + __hmma_m32n8k16_st_c_f16(reinterpret_cast(dstPtr), reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } } else if constexpr (std::is_same_v) { - __dmma_m8n8k4_st_c_f64(dst.get(), + __dmma_m8n8k4_st_c_f64(dstPtr, reinterpret_cast(&src.wi_marray), stride, get_layout_id()); }