diff --git a/sycl/include/CL/sycl.hpp b/sycl/include/CL/sycl.hpp index 383725dcc4b88..bd4a62cbdc947 100644 --- a/sycl/include/CL/sycl.hpp +++ b/sycl/include/CL/sycl.hpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include diff --git a/sycl/include/CL/sycl/ONEAPI/intel_matrix/matrix-amx.hpp b/sycl/include/CL/sycl/ONEAPI/intel_matrix/matrix-amx.hpp new file mode 100644 index 0000000000000..73a99876e9564 --- /dev/null +++ b/sycl/include/CL/sycl/ONEAPI/intel_matrix/matrix-amx.hpp @@ -0,0 +1,446 @@ +//===-------------- matrix-amx.hpp - SYCL matrix --------------*- C++ -*---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ===--------------------------------------------------------------------=== // +/// +/// We provide new interfaces for matrix muliply in this patch: +/// 1. A new class called joint_matrix is introduced, and the user needs to +/// specify the type of the elements, sizes, and the memory layout. +/// +/// 2. joint_matrix_load is used for loading data from main memory to tiles of +/// AMX or kernel's local memory. +/// +/// 3. joint_matrix_store is used for storing data tiles of AMX or kernel's +/// local memory to main memory. +/// +/// 4. joint_matrix_mad is used for the matrix multiply and add function. +/// It performs the multiply operation on the matrices A and B, accumulates the +/// result with C and returns the result. +/// +/// The following operation can be realized with the interfaces: +/// C = A*B+C +/// 1. All cases where A(int8, any-size, row_major), B(int8, any-size, +/// packed_b), C(int32, any-size, row_major) +/// 2. All cases where A(bf16, any-size, row_major), B(bf16, any-size, +/// packed_b), C(float, any-size, row_major) +/// +/// +// ===--------------------------------------------------------------------=== // + +#pragma once + +#include +#include + +__SYCL_INLINE_NAMESPACE(cl) { +namespace sycl { +namespace ext { +namespace intel { +namespace detail { +template class submatrix { +public: + _tile1024i tile; + short rows, cols; +}; + +constexpr size_t dynamic_extent = std::numeric_limits::max(); + +template struct elems_per_dword { + static constexpr size_t value = 1; +}; + +#define ELEMS_PER_DWORD(TYPE, NUM) \ + template <> struct elems_per_dword { \ + static constexpr size_t value = NUM; \ + }; + +ELEMS_PER_DWORD(int8_t, 4) +ELEMS_PER_DWORD(unsigned short, 2) + +} // namespace detail + +namespace matrix { +using namespace cl::sycl; +using namespace cl::sycl::ONEAPI; + +#ifdef __SYCL_DEVICE_ONLY__ +SYCL_EXTERNAL extern "C" _tile1024i +_tileloadd64_internal(short row, short col, char *buf, size_t stride); +SYCL_EXTERNAL extern "C" _tile1024i +_tdpbssd_internal(unsigned short m, unsigned short n, unsigned short k, + _tile1024i dst, _tile1024i src1, _tile1024i src2); +SYCL_EXTERNAL extern "C" _tile1024i +_tdpbf16ps_internal(unsigned short m, unsigned short n, unsigned short k, + _tile1024i dst, _tile1024i src1, _tile1024i src2); +SYCL_EXTERNAL extern "C" void _tilestored64_internal(short row, short col, + char *buf, size_t stride, + _tile1024i tile); +static _tile1024i tileloadd64_internal(short row, short col, char *buf, + size_t stride) { + return _tileloadd64_internal(row, col, buf, stride); +} +static _tile1024i tdpbssd_internal(unsigned short m, unsigned short n, + unsigned short k, _tile1024i dst, + _tile1024i src1, _tile1024i src2) { + return _tdpbssd_internal(m, n, k, dst, src1, src2); +} +static _tile1024i tdpbf16ps_internal(unsigned short m, unsigned short n, + unsigned short k, _tile1024i dst, + _tile1024i src1, _tile1024i src2) { + return _tdpbf16ps_internal(m, n, k, dst, src1, src2); +} +static void tilestored64_internal(short row, short col, char *buf, + size_t stride, _tile1024i tile) { + return _tilestored64_internal(row, col, buf, stride, tile); +} +#else +static _tile1024i tileloadd64_internal(short row, short col, char *buf, + size_t stride) { + return __builtin_ia32_tileloadd64_internal(row, col, buf, stride); +} +static _tile1024i tdpbssd_internal(unsigned short m, unsigned short n, + unsigned short k, _tile1024i dst, + _tile1024i src1, _tile1024i src2) { + return __builtin_ia32_tdpbssd_internal(m, n, k, dst, src1, src2); +} +static _tile1024i tdpbf16ps_internal(unsigned short m, unsigned short n, + unsigned short k, _tile1024i dst, + _tile1024i src1, _tile1024i src2) { + return __builtin_ia32_tdpbf16ps_internal(m, n, k, dst, src1, src2); +} +static void tilestored64_internal(short row, short col, char *buf, + size_t stride, _tile1024i tile) { + __builtin_ia32_tilestored64_internal(row, col, buf, stride, tile); +} +#endif + +enum class matrix_layout { row_major, col_major, packed_a, packed_b }; + +inline constexpr size_t tile_size = 16; + +template +struct joint_matrix { + joint_matrix(Group sg) {} + joint_matrix(Group sg, size_t Size) { + static_assert((NumRows != detail::dynamic_extent && + NumCols != detail::dynamic_extent), + "AMX implementation does not support dynamic allocation"); + } + joint_matrix(Group sg, size_t Rows, size_t Cols) { + static_assert((NumRows != detail::dynamic_extent && + NumCols != detail::dynamic_extent), + "AMX implementation does not support dynamic allocation"); + } +}; + +// This template specialization handles cases where matrix can't be accommodated +// by a tile. In this case, we create raw_storage for the matrix and the size +// is the multiply of (TILE*TILE*4). +template +struct joint_matrix< + Group, T, NumRows, NumCols, Layout, + typename std::enable_if::type> { +public: + // trows: Num of tiles in row. + // If T=int8, NumRows==33, trows should be 3=(33+15)/16 + static constexpr size_t trows = (NumRows + tile_size - 1) / tile_size; + // tcols: Num of tiles in column. + static constexpr size_t tcols = + (NumCols * sizeof(T) / 4 + tile_size - 1) / tile_size; + // if T=int8, NumRows==33, NumCols==33*4, tile_size==16, then size of + // raw_storage should be 48*48*4. + // FIXME: Greedy Regalloc for tile seems has some limitation and currently we + // do tileload for (16,16*4) instead of varying shapes, so raw_storage's size + // is multiple of (16*16*4) + static constexpr size_t size = trows * tcols * tile_size * tile_size * 4; + // stride is aligned to T instead of int8 + static constexpr size_t stride = tcols * tile_size * 4 / sizeof(T); + int8_t raw_storage[size]; + static constexpr bool isSmall = false; + +public: + matrix_layout layout; + // We do zero-padding for matrix whose size is not fitted into tiles in ctor. + joint_matrix(Group sg) { memset(raw_storage, 0x00, size); } +}; + +// This template specialization handles cases where matrix can be put into a +// tile and users specify layout is packed_a or packed_b +template +struct joint_matrix< + Group, T, NumRows, NumCols, Layout, + typename std::enable_if<(NumRows <= tile_size) && + (NumCols * sizeof(T) / 4 <= tile_size)>::type> { +public: + static constexpr size_t trows = (NumRows + tile_size - 1) / tile_size; + // tcols: Num of tiles in column. + static constexpr size_t tcols = + (NumCols * sizeof(T) / 4 + tile_size - 1) / tile_size; + static constexpr size_t size = trows * tcols * tile_size * tile_size * 4; + // stride is aligned to T instead of int8 + static constexpr size_t stride = tcols * tile_size * 4 / sizeof(T); + _tile1024i tile; + static constexpr bool isSmall = true; + matrix_layout layout; + // We do zero-padding for matrix whose size is not fitted into tiles in ctor. + joint_matrix(Group sg) {} +}; + +} // namespace matrix + +namespace detail { + +template +inline __SYCL_ALWAYS_INLINE static + typename std::enable_if<(NumRows > matrix::tile_size) || + (NumCols * sizeof(T) / 4 > matrix::tile_size), + void>::type + submatrix_load(detail::submatrix &sub_m, + matrix::joint_matrix jm, + uint32_t row, uint32_t col, size_t stride, + matrix::matrix_layout layout, bool shouldreload) { + uint32_t offset = (row * stride + col); + T *ptr = reinterpret_cast(jm.raw_storage); + ptr += offset; + stride *= sizeof(T); + sub_m.rows = matrix::tile_size; + sub_m.cols = matrix::tile_size * 4; + sub_m.tile = matrix::tileloadd64_internal( + sub_m.rows, sub_m.cols, reinterpret_cast(ptr), stride); +} + +template +inline __SYCL_ALWAYS_INLINE static + typename std::enable_if<(NumRows <= matrix::tile_size) && + (NumCols * sizeof(T) / 4 <= matrix::tile_size), + void>::type + submatrix_load(detail::submatrix &sub_m, + matrix::joint_matrix &jm, + uint32_t row, uint32_t col, size_t stride, + matrix::matrix_layout layout, bool shouldreload) { + if (shouldreload) { + // Force sub_m.tile's shape to be matrix::tile_size * matrix::tile_size * 4 + int8_t NewjmC[matrix::tile_size * matrix::tile_size * 4]; + matrix::tilestored64_internal(NumRows, NumCols * sizeof(T), + reinterpret_cast(NewjmC), + matrix::tile_size * 4, jm.tile); + sub_m.rows = matrix::tile_size; + sub_m.cols = matrix::tile_size * 4; + sub_m.tile = matrix::tileloadd64_internal(sub_m.rows, sub_m.cols, + reinterpret_cast(NewjmC), + matrix::tile_size * 4); + return; + } + sub_m.rows = NumRows; + sub_m.cols = NumCols * sizeof(T); + sub_m.tile = jm.tile; +} + +// This handles cases where T1 is int8, T2 is int32. +inline __SYCL_ALWAYS_INLINE static void +submatrix_mad(detail::submatrix &sub_ma, + detail::submatrix &sub_mb, + detail::submatrix &sub_mc) { + sub_mc.tile = matrix::tdpbssd_internal(sub_mc.rows, sub_mc.cols, sub_ma.cols, + sub_mc.tile, sub_ma.tile, sub_mb.tile); +} + +// This handles cases where T1 is int16(bfloat16), T2 is float. +inline __SYCL_ALWAYS_INLINE static void +submatrix_mad(detail::submatrix &sub_ma, + detail::submatrix &sub_mb, + detail::submatrix &sub_mc) { + sub_mc.tile = + matrix::tdpbf16ps_internal(sub_mc.rows, sub_mc.cols, sub_ma.cols, + sub_mc.tile, sub_ma.tile, sub_mb.tile); +} + +template +inline __SYCL_ALWAYS_INLINE static + typename std::enable_if<(NumRows > matrix::tile_size) || + (NumCols * sizeof(T) / 4 > matrix::tile_size), + void>::type + submatrix_store(detail::submatrix &sub_m, + matrix::joint_matrix &jm, + uint32_t row, uint32_t col, size_t stride, + matrix::matrix_layout layout, bool shouldreload) { + uint32_t offset = (row * stride + col); + T *ptr = reinterpret_cast(jm.raw_storage); + ptr += offset; + stride *= sizeof(T); + matrix::tilestored64_internal(sub_m.rows, sub_m.cols, + reinterpret_cast(ptr), stride, + sub_m.tile); +} + +template +inline __SYCL_ALWAYS_INLINE static + typename std::enable_if<(NumRows <= matrix::tile_size) && + (NumCols * sizeof(T) / 4 <= matrix::tile_size), + void>::type + submatrix_store(detail::submatrix &sub_m, + matrix::joint_matrix &jm, + uint32_t row, uint32_t col, size_t stride, + matrix::matrix_layout layout, bool shouldreload) { + if (shouldreload) { + int8_t NewjmC[matrix::tile_size * matrix::tile_size * 4]; + matrix::tilestored64_internal(matrix::tile_size, matrix::tile_size * 4, + reinterpret_cast(NewjmC), + matrix::tile_size * 4, sub_m.tile); + jm.tile = matrix::tileloadd64_internal(NumRows, NumCols * sizeof(T), + reinterpret_cast(NewjmC), + matrix::tile_size * 4); + return; + } + jm.tile = sub_m.tile; +} + +} // namespace detail + +namespace matrix { + +// This handles cases where matrix can't be accommodated by a tile +template +inline __SYCL_ALWAYS_INLINE typename std::enable_if< + (NumRows > tile_size) || (NumCols * sizeof(T) / 4 > tile_size), void>::type +joint_matrix_load(Group sg, + joint_matrix &jm, + multi_ptr src, size_t stride, + matrix_layout layout) { + T *mem = src.get(); + // memcpy from mem to jm.raw_storage + for (int i = 0; i < NumRows; ++i) { + char *srcptr = reinterpret_cast(mem) + i * stride * sizeof(T); + char *dstptr = + reinterpret_cast(jm.raw_storage) + i * jm.stride * sizeof(T); + // TODO: we may reformat layout. + memcpy(dstptr, srcptr, NumCols * sizeof(T)); + } + jm.layout = layout; +} + +// This handles cases where matrix can be put into a tile +template +inline __SYCL_ALWAYS_INLINE + typename std::enable_if<(NumRows <= tile_size) && + (NumCols * sizeof(T) / 4 <= tile_size), + void>::type + joint_matrix_load(Group sg, + joint_matrix &jm, + multi_ptr src, size_t stride, + matrix_layout layout) { + T *mem = src.get(); + // tileload happens! + jm.tile = + tileloadd64_internal(NumRows, NumCols * sizeof(T), + reinterpret_cast(mem), stride * sizeof(T)); + jm.layout = layout; +} + +// This handles cases where matrix can't be accommodated by a tile +template +inline __SYCL_ALWAYS_INLINE typename std::enable_if< + (NumRows > tile_size) || (NumCols * sizeof(T) / 4 > tile_size), void>::type +joint_matrix_store(Group sg, + joint_matrix &jm, + multi_ptr dst, size_t stride, + matrix_layout layout) { + T *mem = dst.get(); + for (int i = 0; i < NumRows; ++i) { + char *dstptr = reinterpret_cast(mem) + i * stride * sizeof(T); + char *srcptr = + reinterpret_cast(jm.raw_storage) + i * jm.stride * sizeof(T); + // TODO: we may reformat layout. + memcpy(dstptr, srcptr, NumCols * sizeof(T)); + } + return; +} + +// This handles cases where matrix can be put into a tile +template +inline __SYCL_ALWAYS_INLINE + typename std::enable_if<(NumRows <= tile_size) && + (NumCols * sizeof(T) / 4 <= tile_size), + void>::type + joint_matrix_store(Group sg, + joint_matrix &jm, + multi_ptr dst, size_t stride, + matrix_layout layout) { + T *mem = dst.get(); + // tilestore happens! + tilestored64_internal(NumRows, NumCols * sizeof(T), + reinterpret_cast(mem), stride * sizeof(T), + jm.tile); + return; +} + +template +inline __SYCL_ALWAYS_INLINE typename std::enable_if< + ((std::is_same::value && std::is_same::value) || + (std::is_same::value && + std::is_same::value)) && + (LayoutA == matrix_layout::row_major) && + (LayoutB == matrix_layout::packed_b) && + (LayoutC == matrix_layout::row_major), + void>::type +joint_matrix_mad(Group sg, + joint_matrix &jmA, + joint_matrix &jmB, + joint_matrix &jmC) { + constexpr size_t epd = detail::elems_per_dword::value; + // If A is large and C is small, in joint_matrix_load, we do memcpy for A, and + // we do tileload for C whose shape is not tile_size*tile_size*4. In + // joint_matrix_mad, we do tileload for A and shape is tile_size*tile_size*4. + // So we need to reshape C before we do dpbssd. + bool Cshouldreload = jmC.isSmall && !jmA.isSmall && !jmB.isSmall; + bool Ashouldreload = jmA.isSmall && !jmB.isSmall; + bool Bshouldreload = jmB.isSmall && !jmA.isSmall; + + for (int m = 0; m < jmC.trows; ++m) { + for (int n = 0; n < jmC.tcols; ++n) { + detail::submatrix sub_c; + + // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 + submatrix_load(sub_c, jmC, m * tile_size, n * tile_size, jmC.stride, + matrix_layout::row_major, Cshouldreload); + for (int k = 0; k < jmA.tcols; ++k) { // K->int8_t + detail::submatrix sub_a; + detail::submatrix sub_b; + submatrix_load(sub_a, jmA, m * tile_size, k * tile_size * epd, + jmA.stride, matrix_layout::packed_a, Ashouldreload); + // Assume we alreay in vnni format. + submatrix_load(sub_b, jmB, k * tile_size, n * tile_size * epd, + jmB.stride, matrix_layout::packed_b, Bshouldreload); + submatrix_mad(sub_a, sub_b, sub_c); + } + submatrix_store(sub_c, jmC, m * tile_size, n * tile_size, jmC.stride, + matrix_layout::row_major, Cshouldreload); + } + } + return; +} + +} // namespace matrix +} // namespace intel +} // namespace ext +} // namespace sycl +} // __SYCL_INLINE_NAMESPACE(cl) diff --git a/sycl/include/CL/sycl/ONEAPI/intel_matrix/matrix.hpp b/sycl/include/CL/sycl/ONEAPI/intel_matrix/matrix.hpp new file mode 100644 index 0000000000000..eff1b805cdb9c --- /dev/null +++ b/sycl/include/CL/sycl/ONEAPI/intel_matrix/matrix.hpp @@ -0,0 +1,19 @@ +//==------------------ matrix.hpp - SYCL matrix ----------------*- C++ -*---==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ===--------------------------------------------------------------------=== // +/// Currently, this is the compilation command line needed to invoke AMX unit of +/// Sapphire Rapids CPU: clang++ -fsycl -march=sapphirerapids +/// fsycl-targets="spir64_x86_64-uknown-linux-sycldevice" -O2 main.cpp +/// +/// +// ===--------------------------------------------------------------------=== // + +#pragma once + +#if defined(__AMXTILE__) && defined(__AMXINT8__) && defined(__AMXBF16__) +#include +#endif diff --git a/sycl/test/on-device/extensions/matrix-amx-bf16-test.cpp b/sycl/test/on-device/extensions/matrix-amx-bf16-test.cpp new file mode 100644 index 0000000000000..e3f94ab1273c1 --- /dev/null +++ b/sycl/test/on-device/extensions/matrix-amx-bf16-test.cpp @@ -0,0 +1,184 @@ +// RUN: %clangxx -march=sapphirerapids -fsycl -O2 %s -o %t.out +#include +#include + +using namespace cl::sycl; +using namespace cl::sycl::intel; +using namespace cl::sycl::ext::intel::matrix; + +#define TILE_SZ 16 +#define TM (3 * TILE_SZ-1) +#define TN (3 * TILE_SZ-1) +#define TK (9 * TILE_SZ+2) + +template struct big_matrix{ +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) { + } +}; + +template +void matrix_multiply(big_matrix &C, big_matrix &A, big_matrix &B) { + size_t M = NUM_ROWS_C; + size_t N = NUM_COLS_C; + size_t K = NUM_COLS_A; + // B => K/4 x N*4, A => M x K, C => M, N + // stride should be X's cols, e.g., B's stirde = N*4 + assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 2); + size_t NDRangeM = M / TM; + size_t NDRangeN = N / TN; + buffer bufA(A.get_data(), range<2>(M, K)); + buffer bufB(B.get_data(), range<2>(K, N)); + buffer bufC((float*)C.get_data(), range<2>(M, N)); + + queue q; + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({NDRangeM, NDRangeN}, {1, 1}), + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(1)]] + + { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx; + const auto sg_starty = global_idy; + + ONEAPI::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed layout, + // users need to specify the updated VNNI sizes along with the packed_b layout. + // By default, the layout is row_major and size is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); + + // Only the leader perform AMX computation. + if (spmd_item.get_local_id(1) % TILE_SZ) + return; + // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 + // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 + joint_matrix_load(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { // K->int8_t + joint_matrix_load(sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * K + + k * TK, + K, matrix_layout::row_major); + // Assume we alreay in vnni format. + joint_matrix_load(sg, sub_b, + accB.get_pointer() + + (k * TK / 2) * (N * 2) + sg_starty * TN * 2, + N * 2, matrix_layout::packed_b); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); +} + +static constexpr size_t MATRIX_M = TM * 2; +static constexpr size_t MATRIX_N = TN * 2; +static constexpr size_t MATRIX_K = TK * 2; +unsigned short A[MATRIX_M][MATRIX_K]; +unsigned short B[MATRIX_K / 2][MATRIX_N * 2]; +float C[MATRIX_M][MATRIX_N]; +float D[MATRIX_M][MATRIX_N]; + +float make_fp32(short x) +{ + unsigned int y = x; + y = y << 16; + float *res = reinterpret_cast(&y); + return *res; +} + +unsigned short make_bf16(float x) +{ + int *res = reinterpret_cast(&x); + *res = *res >> 16; + return (unsigned short)*res; +} + +void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N, int K) { + // tiling + for (int m = 0; m < M; m++) + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + short *va = (short *)(A_mem + m*K + k); + short *vb = (short *)(B_mem + k*N + n); + float acc = *((float*)(C_mem + m*N + n)); + // FIXME: Should we do reduce-add in another version? + for (int i = 0; i < 2; i++) { + acc += (make_fp32(va[i]) * make_fp32(vb[i])); + } + *((float*)(C_mem + m*N + n))= acc; + } + } +} + +int main() { + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_K; j++) { + A[i][j] = make_bf16(1.0f * (i+j)); + } + } + for (int i = 0; i < MATRIX_K / 2; i++) { + for (int j = 0; j < MATRIX_N * 2; j++) { + B[i][j] = make_bf16(2.0f*i + 3.0f*j); + } + } + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + C[i][j] = 1.0; + D[i][j] = 1.0; + } + } + + big_matrix MC((float *)&C); + big_matrix MD((float *)&D); + big_matrix MA((unsigned short *)&A); + big_matrix MB((unsigned short *)&B); + matrix_multiply(MC, MA, MB); + matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M, + MATRIX_N, MATRIX_K / 2); + + bool res = true; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + if (C[i][j] != D[i][j]) + res = false; + } + } + if (res) + std::cout << "passed\n"; + else + std::cout << "failed\n"; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) + std::cout << C[i][j] << ", "; + std::cout << "\n"; + } + std::cout << std::endl; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) + std::cout << D[i][j] << ", "; + std::cout << "\n"; + } +} diff --git a/sycl/test/on-device/extensions/matrix-amx-int8-test.cpp b/sycl/test/on-device/extensions/matrix-amx-int8-test.cpp new file mode 100644 index 0000000000000..62e6a90eb8ed8 --- /dev/null +++ b/sycl/test/on-device/extensions/matrix-amx-int8-test.cpp @@ -0,0 +1,169 @@ +// RUN: %clangxx -march=sapphirerapids -fsycl -O2 %s -o %t.out +#include +#include + +using namespace cl::sycl; +using namespace cl::sycl::intel; +using namespace cl::sycl::ext::intel::matrix; + +#define TILE_SZ 16 +#define TM (4 * TILE_SZ-4) +#define TN (4 * TILE_SZ-4) +#define TK (4 * TILE_SZ-16) + +template struct big_matrix{ +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) { + } +}; + +template +void matrix_multiply(big_matrix &C, big_matrix &A, big_matrix &B) { + size_t M = NUM_ROWS_C; + size_t N = NUM_COLS_C; + size_t K = NUM_COLS_A; + // B => K/4 x N*4, A => M x K, C => M, N + // stride should be X's cols, e.g., B's stirde = N*4 + assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4); + size_t NDRangeM = M / TM; + size_t NDRangeN = N / TN; + buffer bufA(A.get_data(), range<2>(M, K)); + buffer bufB(B.get_data(), range<2>(K, N)); + buffer bufC(C.get_data(), range<2>(M, N)); + + queue q; + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({NDRangeM, NDRangeN}, {1, 1}), + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(1)]] + + { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx; + const auto sg_starty = global_idy; + + ONEAPI::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed layout, + // users need to specify the updated VNNI sizes along with the packed_b layout. + // By default, the layout is row_major and size is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); + + // Only the leader perform AMX computation. + if (spmd_item.get_local_id(1) % TILE_SZ) + return; + // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 + // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 + joint_matrix_load(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { // K->int8_t + joint_matrix_load(sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * K + + k * TK, + K, matrix_layout::packed_a); + // Assume we alreay in vnni format. + joint_matrix_load(sg, sub_b, + accB.get_pointer() + + (k * TK / 4) * (N * 4) + sg_starty * TN * 4, + N * 4, matrix_layout::packed_b); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); +} + +static constexpr size_t MATRIX_M = TM * 2; +static constexpr size_t MATRIX_N = TN * 2; +static constexpr size_t MATRIX_K = TK * 2; +int8_t A[MATRIX_M][MATRIX_K]; +int8_t B[MATRIX_K / 4][MATRIX_N * 4]; +int32_t C[MATRIX_M][MATRIX_N]; +int32_t D[MATRIX_M][MATRIX_N]; + +void matrix_multiply_ref(int32_t *A_mem, int32_t *B_mem, int32_t *C_mem, int M, + int N, int K) { + // tiling + for (int m = 0; m < M; m++) + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + char *va = (char *)(A_mem + m * K + k); + char *vb = (char *)(B_mem + k * N + n); + int acc = *(C_mem + m * N + n); + for (int i = 0; i < 4; i++) { + acc += (va[i] * vb[i]); + } + *(C_mem + m * N + n) = acc; + } + } +} + +int main() { + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_K; j++) { + A[i][j] = i+2*j; + } + } + for (int i = 0; i < MATRIX_K / 4; i++) { + for (int j = 0; j < MATRIX_N * 4; j++) { + B[i][j] = i+j; + } + } + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + C[i][j] = 1; + D[i][j] = 1; + } + } + + big_matrix MC((int32_t *)&C); + big_matrix MD((int32_t *)&D); + big_matrix MA((int8_t *)&A); + big_matrix MB((int8_t *)&B); + matrix_multiply(MC, MA, MB); + matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M, + MATRIX_N, MATRIX_K / 4); + + bool res = true; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + if (C[i][j] != D[i][j]) + res = false; + } + } + if (res) + std::cout << "passed\n"; + else + std::cout << "failed\n"; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) + std::cout << C[i][j] << ", "; + std::cout << "\n"; + } + std::cout << std::endl; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) + std::cout << D[i][j] << ", "; + std::cout << "\n"; + } +}