Skip to content
Closed
8 changes: 8 additions & 0 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ template <typename T, std::size_t R, std::size_t C,
extern SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL(
JOINT_MATRIX_INTEL(T, R, C, L, S, U) *);

template <typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __ocl_vec_t<uint32_t, 2>
__spirv_JointMatrixGetElementCoordINTEL(JOINT_MATRIX_INTEL(T, R, C, L, S, U) *,
size_t i);

template <typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
Expand Down
44 changes: 44 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <CL/__spirv/spirv_ops.hpp>
#include <sycl/detail/defines_elementary.hpp>
#include <sycl/feature_test.hpp>
#include <tuple>

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
Expand Down Expand Up @@ -256,6 +257,21 @@ class wi_element {
wi_element(joint_matrix<T, NumRows, NumCols, Use, Layout, Group> &Mat,
std::size_t i)
: M(Mat), idx(i) {}

std::tuple<uint32_t, uint32_t> get_coord() {
#ifdef __SYCL_DEVICE_ONLY__
__ocl_vec_t<uint32_t, 2> coord =
__spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
const uint32_t row = coord[0];
const uint32_t col = coord[1];
return std::make_tuple(row, col);
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}

// Various Operations
operator T() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx);
Expand Down Expand Up @@ -339,6 +355,20 @@ class wi_element<uint16_t, NumRows, NumCols, Use, Layout, Group> {
wi_element(joint_matrix<uint16_t, NumRows, NumCols, Use, Layout, Group> &Mat,
std::size_t i)
: M(Mat), idx(i) {}

std::tuple<uint32_t, uint32_t> get_coord() {
#ifdef __SYCL_DEVICE_ONLY__
__ocl_vec_t<uint32_t, 2> coord =
__spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
const uint32_t row = coord[0];
const uint32_t col = coord[1];
return std::make_tuple(row, col);
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}

operator uint16_t() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx);
Expand Down Expand Up @@ -489,6 +519,20 @@ class wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
NumCols, Use, Layout, Group> &Mat,
std::size_t i)
: M(Mat), idx(i) {}

std::tuple<uint32_t, uint32_t> get_coord() {
#ifdef __SYCL_DEVICE_ONLY__
__ocl_vec_t<uint32_t, 2> coord =
__spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
const uint32_t row = coord[0];
const uint32_t col = coord[1];
return std::make_tuple(row, col);
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}

operator sycl::ext::oneapi::experimental::bfloat16() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx);
Expand Down
117 changes: 117 additions & 0 deletions sycl/test/matrix/matrix-bfloat16-test-coord-basic.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// RUN: %clangxx -fsycl -O2 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=2 %s -o %t.out
// RUN: %t.out
// XFAIL: *

// this code calculates the sum of rows into a global array of number of rows
// elements. First, partial reduction is computed inside each SG, then atomic
// add is used to reduce between SG leaders. The get_coord() API is used for
// retrieving the row

#include <iostream>
#include <sycl/sycl.hpp>

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;

#define SG_SZ 16

#define TN SG_SZ
#define TK 32

template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
public:
T *mat;

public:
T *get_data() { return mat; }
void set_data(T *data) { mat = data; }
big_matrix(T *data) : mat(data) {}
};

template <typename T, size_t M, size_t N>
void sum_rows_ref(
accessor<T, 2, access::mode::read, access::target::host_buffer> B,
accessor<int, 1, access::mode::read, access::target::host_buffer>
sum_rows) {
int sum_rows_ref[M] = {0};
for (size_t i = 0; i < M; i++) {
for (size_t j = 0; j < N; j++) {
sum_rows_ref[i] += B[i][j];
}
auto diff = sum_rows[i] - sum_rows_ref[i];
assert(std::fabs(static_cast<int>(diff)) <=
std::numeric_limits<int>::epsilon());
}
}

template <typename T, size_t M, size_t N>
void matrix_sum_rows(queue q, big_matrix<T, M, N> &B, nd_range<2> &r) {
buffer<int8_t, 2> bufB(B.get_data(), range<2>(M, N));
// size of vector is known because SG size of set by the user in this case
int sum_rows[M] = {0};
buffer<int> sum_rows_v(sum_rows, M); // there are total of tK/4 * 2, 16 rows
q.submit([&](handler &cgh) {
auto accB = bufB.get_access<access::mode::read_write>(cgh);

auto v = sum_rows_v.get_access<access::mode::atomic>(cgh);

cgh.parallel_for<class add_matrix>(
r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

ext::oneapi::sub_group sg = spmd_item.get_sub_group();

joint_matrix<T, TK, TN, use::b> sub_b(sg);

joint_matrix_load(sg, sub_b,
accB.get_pointer() + (global_idx * (TK / 4) * N) +
sg_starty / SG_SZ * TN * 4,
N, layout::packed_b);

int32_t sum_local_rows[M] = {0};
auto tBData = sub_b.get_wi_data();

// each WI calculates local sum of rows
for (int i = 0; i < tBData.length(); ++i) {
// row and col holds global co_ordinates of the matrix
auto [row, col] = tBData[i].get_coord();
sum_local_rows[row] += tBData[i];

sum_local_rows[row] =
reduce_over_group(sg, sum_local_rows[row], sycl::plus<>());
// only Groups leader perform the global reduction
if (global_idy % SG_SZ == 0) {
atomic_fetch_add(v[row], sum_local_rows[row]);
}
}
}); // parallel for
}).wait();
sum_rows_ref<T, M, N>(bufB.get_access<access::mode::read>(),
sum_rows_v.get_access<access::mode::read>());
}

static constexpr size_t MATRIX_K = TK / 4 * 2;
static constexpr size_t MATRIX_N = TN * 4 * 2;
int8_t B[MATRIX_K][MATRIX_N];

int main() {
big_matrix<int8_t, MATRIX_K, MATRIX_N> MB((int8_t *)&B);

size_t NDRangeK = MATRIX_K / (TK / 4);
size_t NDRangeN = (MATRIX_N / 4) / TN;
queue q;
nd_range<2> r({NDRangeK, NDRangeN * SG_SZ}, {1, 1 * SG_SZ});

for (int i = 0; i < MATRIX_K; i++) {
for (int j = 0; j < MATRIX_N; j++) {
B[i][j] = i;
}
}

matrix_sum_rows<int8_t, MATRIX_K, MATRIX_N>(q, MB, r);

return 0;
}
Loading