diff --git a/sycl/include/sycl/ext/oneapi/matrix/joint-matrix-cuda-impl.hpp b/sycl/include/sycl/ext/oneapi/matrix/joint-matrix-cuda-impl.hpp new file mode 100644 index 0000000000000..6d289014e1ada --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/matrix/joint-matrix-cuda-impl.hpp @@ -0,0 +1,141 @@ +// joint-matrix-cuda-impl.hpp - joint_matrix cuda specializations-*- C++ -*---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ===-----------------------------------------------------------------------=== // + +#include + +namespace sycl { +__SYCL_INLINE_VER_NAMESPACE(_V1) { + namespace ext { + namespace oneapi { + namespace experimental { + namespace matrix { + + template class wi_data { + marray &data; + wi_data(marray &wi_data) : data(wi_data){}; + template + friend struct joint_matrix; + + public: + size_t length() { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + return data.size(); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + }; + + type &operator[](size_t i) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + return data[i]; +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + }; + }; + +#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR(type, use, M, N, size) \ + template \ + struct joint_matrix< \ + type, matrix_use::use, M, N, Layout, sycl::sub_group, \ + typename std::enable_if_t> { \ + marray wi_marray; \ + inline __SYCL_ALWAYS_INLINE wi_data get_wi_data() { \ + return wi_data(wi_marray); \ + }; \ + }; + + // m8n32k16 + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 8, 16, 4) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, b, 16, 32, 16) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, a, 8, 16, 16) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 16, 32, 16) + + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 8, 16, 4) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 32, 16) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 8, 16, 4) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 32, 16) + // m32n8k16 + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 32, 16, 16) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, b, 16, 8, 4) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, a, 32, 16, 16) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 16, 8, 16) + + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 32, 16, 16) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 8, 4) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 32, 16, 16) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 8, 4) + // m16n16k16 + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 16, 16, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, b, 16, 16, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, a, 16, 16, 16) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 16, 16, 16) + + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 16, 16, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 16, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 16, 16, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 16, 8) + // m8n8k4 double only + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, a, 8, 4, 1) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, b, 4, 8, 1) + +#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR + +#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(type, M, N, size) \ + template <> \ + struct joint_matrix { \ + marray wi_marray; \ + inline __SYCL_ALWAYS_INLINE wi_data get_wi_data() { \ + return wi_data(wi_marray); \ + }; \ + }; + + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(half, 8, 32, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(float, 8, 32, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 8, 32, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(half, 32, 8, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(float, 32, 8, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 32, 8, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(half, 16, 16, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(float, 16, 16, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 16, 16, 8) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(double, 8, 8, 2) + +#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC + +#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION(precision, use, M, N, type, \ + size) \ + template \ + struct joint_matrix< \ + precision, matrix_use::use, M, N, Layout, sycl::sub_group, \ + typename std::enable_if_t> { \ + marray wi_marray; \ + inline __SYCL_ALWAYS_INLINE wi_data get_wi_data() { \ + return wi_data(wi_marray); \ + }; \ + }; + // m16n16k8 tf32 only + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION(precision::tf32, a, 16, 8, float, + 4) + __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION(precision::tf32, b, 8, 16, float, + 4) + +#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION + + } // namespace matrix + } // namespace experimental + } // namespace oneapi + } // namespace ext +} // __SYCL_INLINE_VER_NAMESPACE(_V1) +} // namespace sycl diff --git a/sycl/include/sycl/ext/oneapi/matrix/joint-matrix.hpp b/sycl/include/sycl/ext/oneapi/matrix/joint-matrix.hpp new file mode 100644 index 0000000000000..454db696d6e28 --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/matrix/joint-matrix.hpp @@ -0,0 +1,43 @@ +//===---- joint-matrix.hpp - SYCL matrix extension joint_matrix ----*- C++ +//-*---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ===--------------------------------------------------------------------=== // +#ifndef JOINT_MATRIX +#define JOINT_MATRIX + +namespace sycl { +__SYCL_INLINE_VER_NAMESPACE(_V1) { + namespace ext { + namespace oneapi { + namespace experimental { + namespace matrix { + + enum class matrix_use { a, b, accumulator }; + + enum class layout { row_major, col_major, packed, dynamic }; + + namespace precision { + class tf32 { + tf32() = delete; + }; + } // namespace precision + + // TODO: how are the default params for Rows/Cols used in Intel backend? + // TODO: could we use Cond to distinguish between Intel AMX and tensor cores + // joint_matrix definitions: e.g. Rows == dynamic_extent? + template + struct joint_matrix; + + } // namespace matrix + } // namespace experimental + } // namespace oneapi + } // namespace ext +} // __SYCL_INLINE_VER_NAMESPACE(_V1) +} // namespace sycl +#endif diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensor-cores.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensor-cores.hpp new file mode 100644 index 0000000000000..80d0739189eb4 --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensor-cores.hpp @@ -0,0 +1,605 @@ + +//===---- matrix-tensor-cores.hpp - SYCL tensor cores matrix ----*- C++ -*---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ===--------------------------------------------------------------------=== // + +#pragma once +#include +#include + +namespace sycl { +__SYCL_INLINE_VER_NAMESPACE(_V1) { + namespace ext { + namespace oneapi { + + namespace detail { + +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + template + struct load_multiplicand_cuda { + void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< + S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, + multi_ptr src, size_t stride); + }; + + template + constexpr int get_layout_id(); + + template <> + constexpr int + get_layout_id() { + return 0; + } + + template <> + constexpr int + get_layout_id() { + return 1; + } + +#if __cplusplus >= 201703L // if constexpr usage + template + void load_accumulator_layoutT( + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + NumRows, NumCols, + sycl::ext::oneapi::experimental::matrix::layout::dynamic, + sycl::sub_group> &res, + multi_ptr src, size_t stride) { + if constexpr (std::is_same, int32_t>::value) { + auto destptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + __imma_m16n16k16_ld_c(destptr, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 32) { + __imma_m8n32k16_ld_c(destptr, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 8) { + __imma_m32n8k16_ld_c(destptr, src.get(), stride, + get_layout_id()); + } + } else if constexpr (std::is_same, float>::value) { + auto dstptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + __hmma_m16n16k16_ld_c_f32(dstptr, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 32) { + __hmma_m8n32k16_ld_c_f32(dstptr, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 8) { + __hmma_m32n8k16_ld_c_f32(dstptr, src.get(), stride, + get_layout_id()); + } + } else if constexpr (std::is_same::value) { + auto tileptr = reinterpret_cast(src.get()); + auto dstptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 32 && NumCols == 8) { + __hmma_m32n8k16_ld_c_f16(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 32) { + __hmma_m8n32k16_ld_c_f16(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 16) { + __hmma_m16n16k16_ld_c_f16(dstptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (std::is_same::value) { + __dmma_m8n8k4_ld_c(reinterpret_cast(&res.wi_marray), src.get(), + stride, get_layout_id()); + } + }; +#endif // __cplusplus >= 201703L + + template + void load_accumulator_cuda( + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + NumRows, NumCols, + sycl::ext::oneapi::experimental::matrix::layout::dynamic, + sycl::sub_group> &res, + multi_ptr src, size_t stride, + sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { + switch (LayoutAcc) { + case sycl::ext::oneapi::experimental::matrix::layout::row_major: + load_accumulator_layoutT< + sycl::ext::oneapi::experimental::matrix::layout::row_major>(res, src, + stride); + break; + case sycl::ext::oneapi::experimental::matrix::layout::col_major: + load_accumulator_layoutT< + sycl::ext::oneapi::experimental::matrix::layout::col_major>(res, src, + stride); + break; + default: + assert(false && "Invalid layout specified!"); + } + } + +#if __cplusplus >= 201703L // if constexpr usage + template + struct load_multiplicand_cuda< + S, T, NumRows, NumCols, Use, Layout, Space, + typename std::enable_if_t< + Layout == + sycl::ext::oneapi::experimental::matrix::layout::row_major || + Layout == + sycl::ext::oneapi::experimental::matrix::layout::col_major>> { + void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< + S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, + multi_ptr src, size_t stride) { + if constexpr (std::is_same, uint16_t>::value || + std::is_same< + std::remove_const_t, + sycl::ext::oneapi::experimental::bfloat16>::value) { + auto tileptr = reinterpret_cast(src.get()); + auto destptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::a) { + __mma_bf16_m16n16k16_ld_a(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::b) { + __mma_bf16_m16n16k16_ld_b(destptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 16) { + __mma_bf16_m8n32k16_ld_a(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 32) { + __mma_bf16_m8n32k16_ld_b(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 16) { + __mma_bf16_m32n8k16_ld_a(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 8) { + __mma_bf16_m32n8k16_ld_b(destptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (std::is_same, + uint8_t>::value) { + auto tileptr = reinterpret_cast(src.get()); + auto destptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::a) { + __imma_m16n16k16_ld_a_u8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::b) { + __imma_m16n16k16_ld_b_u8(destptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 16) { + __imma_m8n32k16_ld_a_u8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 32) { + __imma_m8n32k16_ld_b_u8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 16) { + __imma_m32n8k16_ld_a_u8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 8) { + __imma_m32n8k16_ld_b_u8(destptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (std::is_same, + int8_t>::value) { + auto tileptr = reinterpret_cast(src.get()); + auto destptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::a) { + __imma_m16n16k16_ld_a_s8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::b) { + __imma_m16n16k16_ld_b_s8(destptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 16) { + __imma_m8n32k16_ld_a_s8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 32) { + __imma_m8n32k16_ld_b_s8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 16) { + __imma_m32n8k16_ld_a_s8(destptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 8) { + __imma_m32n8k16_ld_b_s8(destptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (std::is_same, half>::value) { + auto tileptr = reinterpret_cast(src.get()); + auto dstptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::a) { + __hmma_m16n16k16_ld_a(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::b) { + __hmma_m16n16k16_ld_b(dstptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 16) { + __hmma_m8n32k16_ld_a(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 32) { + __hmma_m8n32k16_ld_b(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 16) { + __hmma_m32n8k16_ld_a(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 8) { + __hmma_m32n8k16_ld_b(dstptr, tileptr, stride, + get_layout_id()); + } + + } else if constexpr (std::is_same::value) { + auto tileptr = reinterpret_cast(src.get()); + auto dstptr = reinterpret_cast(&res.wi_marray); + if constexpr (NumRows == 16 && NumCols == 8) { + __mma_tf32_m16n16k8_ld_a(dstptr, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 16) { + __mma_tf32_m16n16k8_ld_b(dstptr, tileptr, stride, + get_layout_id()); + } + } else if constexpr (std::is_same, + double>::value) { + auto dstptr = reinterpret_cast(&res.wi_marray); + if constexpr (Use == + sycl::ext::oneapi::experimental::matrix::matrix_use::a) { + __dmma_m8n8k4_ld_a(dstptr, src.get(), stride, + get_layout_id()); + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::b) { + __dmma_m8n8k4_ld_b(dstptr, src.get(), stride, + get_layout_id()); + } + } + } + }; +#endif // __cplusplus >= 201703L + +#if __cplusplus >= 201703L // if constexpr usage + template + struct joint_matrix_store_cuda_impl { + template + void storeLayoutT( + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + NumRows, NumCols, + sycl::ext::oneapi::experimental::matrix::layout::dynamic, + sycl::sub_group> &src, + multi_ptr dst, size_t stride) { + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (std::is_same::value) { + __hmma_m16n16k16_st_c_f32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same::value) { + __imma_m16n16k16_st_c_i32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same::value) { + __hmma_m16n16k16_st_c_f16(reinterpret_cast(dst.get()), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 32) { + if constexpr (std::is_same::value) { + __hmma_m8n32k16_st_c_f32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same::value) { + __imma_m8n32k16_st_c_i32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same::value) { + __hmma_m8n32k16_st_c_f16(reinterpret_cast(dst.get()), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } + } else if constexpr (NumRows == 32 && NumCols == 8) { + if constexpr (std::is_same::value) { + __hmma_m32n8k16_st_c_f32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same::value) { + __imma_m32n8k16_st_c_i32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } else if constexpr (std::is_same::value) { + __hmma_m32n8k16_st_c_f16(reinterpret_cast(dst.get()), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } + } else if constexpr (std::is_same::value) { + __dmma_m8n8k4_st_c_f64(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); + } + } + void store( + sycl::ext::oneapi::experimental::matrix::joint_matrix< + T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + NumRows, NumCols, + sycl::ext::oneapi::experimental::matrix::layout::dynamic, + sycl::sub_group> &src, + multi_ptr dst, size_t stride, + sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { + switch (LayoutAcc) { + case sycl::ext::oneapi::experimental::matrix::layout::row_major: + storeLayoutT< + sycl::ext::oneapi::experimental::matrix::layout::row_major>( + src, dst, stride); + break; + case sycl::ext::oneapi::experimental::matrix::layout::col_major: + storeLayoutT< + sycl::ext::oneapi::experimental::matrix::layout::col_major>( + src, dst, stride); + break; + default: + assert(false && "Invalid layout specified!"); + } + } + }; +#endif // __cplusplus >= 201703L + + template + struct joint_matrix_mad_cuda_impl { + void + mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< + Td, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, + N, sycl::ext::oneapi::experimental::matrix::layout::dynamic, + sycl::sub_group> &D, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + Tm, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K, + LayoutA, sycl::sub_group> &A, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + Tm, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N, + LayoutB, sycl::sub_group> &B, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + Tc, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, + N, sycl::ext::oneapi::experimental::matrix::layout::dynamic, + sycl::sub_group> &C); + }; + + template + constexpr int get_layout_pair_id(); + + template <> + constexpr int get_layout_pair_id< + sycl::ext::oneapi::experimental::matrix::layout::row_major, + sycl::ext::oneapi::experimental::matrix::layout::row_major>() { + return 0; + } + + template <> + constexpr int get_layout_pair_id< + sycl::ext::oneapi::experimental::matrix::layout::row_major, + sycl::ext::oneapi::experimental::matrix::layout::col_major>() { + return 1; + } + + template <> + constexpr int get_layout_pair_id< + sycl::ext::oneapi::experimental::matrix::layout::col_major, + sycl::ext::oneapi::experimental::matrix::layout::row_major>() { + return 2; + } + + template <> + constexpr int get_layout_pair_id< + sycl::ext::oneapi::experimental::matrix::layout::col_major, + sycl::ext::oneapi::experimental::matrix::layout::col_major>() { + return 3; + } + +#if __cplusplus >= 201703L // if constexpr usage + template + struct joint_matrix_mad_cuda_impl< + Tm, Tc, Td, M, K, N, LayoutA, LayoutB, + typename std::enable_if_t< + (LayoutA == + sycl::ext::oneapi::experimental::matrix::layout::row_major || + LayoutA == + sycl::ext::oneapi::experimental::matrix::layout::col_major) && + (LayoutB == + sycl::ext::oneapi::experimental::matrix::layout::row_major || + LayoutB == + sycl::ext::oneapi::experimental::matrix::layout::col_major)>> { + void + mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< + Td, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, + N, sycl::ext::oneapi::experimental::matrix::layout::dynamic, + sycl::sub_group> &D, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + Tm, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K, + LayoutA, sycl::sub_group> &A, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + Tm, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N, + LayoutB, sycl::sub_group> &B, + sycl::ext::oneapi::experimental::matrix::joint_matrix< + Tc, + sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, + N, sycl::ext::oneapi::experimental::matrix::layout::dynamic, + sycl::sub_group> &C) { + if constexpr (M == 16 && N == 16 && K == 16) { + if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrC = reinterpret_cast(&C.wi_marray); + auto ptrD = reinterpret_cast(&D.wi_marray); + if constexpr (std::is_same::value) { + __imma_m16n16k16_mma_s8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __imma_m16n16k16_mma_u8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } + } else if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { + __hmma_m16n16k16_mma_f32f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else { + __hmma_m16n16k16_mma_f16f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } + } else if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { + __hmma_m16n16k16_mma_f32f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else { + __hmma_m16n16k16_mma_f16f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } + } + } else if constexpr (std::is_same::value || + std::is_same::value) { + __mma_bf16_m16n16k16_mma_f32( + reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } + } else if constexpr (M == 8 && N == 32 && K == 16) { + if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrC = reinterpret_cast(&C.wi_marray); + auto ptrD = reinterpret_cast(&D.wi_marray); + if constexpr (std::is_same::value) { + __imma_m8n32k16_mma_s8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __imma_m8n32k16_mma_u8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } + } else if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + if constexpr (std::is_same::value) { + __hmma_m8n32k16_mma_f32f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __hmma_m8n32k16_mma_f16f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } + } else if constexpr (std::is_same::value || + std::is_same::value) { + __mma_bf16_m8n32k16_mma_f32( + reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } + } else if constexpr (M == 32 && N == 8 && K == 16) { + if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrC = reinterpret_cast(&C.wi_marray); + auto ptrD = reinterpret_cast(&D.wi_marray); + if constexpr (std::is_same::value) { + __imma_m32n8k16_mma_s8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __imma_m32n8k16_mma_u8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } + } else if constexpr (std::is_same::value || + std::is_same::value) { + __mma_bf16_m32n8k16_mma_f32( + reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + if constexpr (std::is_same::value) { + __hmma_m32n8k16_mma_f32f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __hmma_m32n8k16_mma_f16f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } + } + } else if constexpr (M == 16 && N == 16 && K == 8) { + __mma_tf32_m16n16k8_mma_f32(reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __dmma_m8n8k4_mma_f64(reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } + } + }; +#endif // __cplusplus >= 201703L +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + + } // namespace detail + } // namespace oneapi + } // namespace ext +} // __SYCL_INLINE_VER_NAMESPACE(_V1) +} // namespace sycl + diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores-legacy.hpp similarity index 99% rename from sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp rename to sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores-legacy.hpp index a0796291930e1..c8fee4b8cfe51 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores-legacy.hpp @@ -786,3 +786,4 @@ inline __SYCL_ALWAYS_INLINE float round_to_tf32(float a) { } // namespace ext } // __SYCL_INLINE_VER_NAMESPACE(_V1) } // namespace sycl + diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp new file mode 100644 index 0000000000000..1310cbf92ca2d --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -0,0 +1,197 @@ +//===------- matrix-unified.hpp - SYCL matrix extension ----*- C++ -*------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ===--------------------------------------------------------------------=== // + +#pragma once +#include +#include +// #include + +namespace sycl { +__SYCL_INLINE_VER_NAMESPACE(_V1) { + namespace ext { + namespace oneapi { + namespace experimental { + namespace matrix { + + template + inline __SYCL_ALWAYS_INLINE void + joint_matrix_fill(Group sg, + joint_matrix &res, + const T2 v) { + // We kept the dynamic "sg" in joint_matrix_fill to match the other DPC++ + // functions + std::ignore = sg; +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) + res.wi_marray = v; +#elif defined(__SPIR__) +// intel joint_matrix_fill_intel_impl +#endif // defined(__NVPTX__) +#else + std::ignore = res; + std::ignore = v; + throw runtime_error("The matrix extension is only currently supported on " + "Intel and Nvidia devices", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) + } + + // TODO Two typenames (S and T) not required in CUDA backend but included for + // Intel backend requirement. TODO: check if this is still the case!! + template ::value, bool> = true> + void joint_matrix_load( + Group sg, + joint_matrix &res, + multi_ptr src, size_t stride, + sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) + sycl::ext::oneapi::detail::load_accumulator_cuda(res, src, stride, + LayoutAcc); +#elif defined(__SPIR__) +// load_accumulator_intel +#endif // defined(__NVPTX__) +#else + std::ignore = sg; + std::ignore = res; + std::ignore = src; + std::ignore = stride; + throw runtime_error("The matrix extension is only currently supported on " + "Intel and Nvidia devices", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) + } + + template < + typename Group, typename S, typename T, matrix_use Use, size_t NumRows, + size_t NumCols, matrix::layout Layout, access::address_space Space, + std::enable_if_t>::value || + (std::is_same::value && + + std::is_same, float>::value), + bool> = true> + void + joint_matrix_load(Group sg, + joint_matrix &res, + multi_ptr src, size_t stride) { +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) + sycl::ext::oneapi::detail::load_multiplicand_cuda{} + .load(res, src, stride); +#elif defined(__SPIR__) +// load_multiplicand_intel +#endif // defined(__NVPTX__) +#else + std::ignore = sg; + std::ignore = res; + std::ignore = src; + std::ignore = stride; + throw runtime_error("The matrix extension is only currently supported on " + "Intel and Nvidia devices", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) + } + + template + void joint_matrix_store( + Group sg, + joint_matrix &src, + multi_ptr dst, size_t stride, + sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) { +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) + sycl::ext::oneapi::detail::joint_matrix_store_cuda_impl{} + .store(src, dst, stride, LayoutAcc); +#elif defined(__SPIR__) +// joint_matrix_store_intel_impl +#endif // defined(__NVPTX__) +#else + std::ignore = sg; + std::ignore = src; + std::ignore = dst; + std::ignore = stride; + throw runtime_error("The matrix extension is only currently supported on " + "Intel and Nvidia devices", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) + } + + template < + typename Group, typename Ta, typename Tb, typename Tc, std::size_t M, + std::size_t K, std::size_t N, + layout LayoutA = sycl::ext::oneapi::experimental::matrix::layout::dynamic, + layout LayoutB = sycl::ext::oneapi::experimental::matrix::layout::dynamic> + joint_matrix + joint_matrix_mad( + Group sg, joint_matrix &A, + joint_matrix &B, + joint_matrix &C) { +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) + if constexpr (std::is_same::value) { + joint_matrix + D; + sycl::ext::oneapi::detail::joint_matrix_mad_cuda_impl{} + .mad(D, A, B, C); + return D; + } else { + assert(false && "Ta != Tb : In the CUDA backend joint_matrix_mad " + "requires that joint_matrix data types match"); + } +#elif defined(__SPIR__) +// joint_matrix_mad_intel_impl +#endif // defined(__NVPTX__) +#else + std::ignore = sg; + std::ignore = A; + std::ignore = B; + std::ignore = C; + throw runtime_error("The matrix extension is only currently supported on " + "Intel and Nvidia devices", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) + } + + // This function rounds the bottom 13 bits up or down, and then zeros out the + // bottom bits + float round_to_tf32(float a) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + int32_t tmp_int = __nvvm_f2tf32_rna(a); + return __nvvm_bitcast_i2f(tmp_int); +#else + uint32_t tmp_uint = reinterpret_cast(a); + tmp_uint += 0x1000u; + tmp_uint &= 0xFFFFE000u; + float ret = reinterpret_cast(tmp_uint); + return ret; +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + } + + } // namespace matrix + } // namespace experimental + } // namespace oneapi + } // namespace ext +} // __SYCL_INLINE_VER_NAMESPACE(_V1) +} // namespace sycl + diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp index ecfad58259cc2..53f2a7db132aa 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp @@ -21,11 +21,12 @@ #if (SYCL_EXT_ONEAPI_MATRIX == 1) #include #include -#endif -#if (SYCL_EXT_ONEAPI_MATRIX == 2) +#elif (SYCL_EXT_ONEAPI_MATRIX == 2) #include #include +#elif (SYCL_EXT_ONEAPI_MATRIX == 3) +#include +#elif (SYCL_EXT_ONEAPI_MATRIX == 4) +#include #endif -#if (SYCL_EXT_ONEAPI_MATRIX == 3) -#include -#endif +