diff --git a/SYCL/Matrix/joint_matrix_tensorcore.cpp b/SYCL/Matrix/joint_matrix_tensorcore.cpp index 8c827d1c11..a489aeb0ca 100644 --- a/SYCL/Matrix/joint_matrix_tensorcore.cpp +++ b/SYCL/Matrix/joint_matrix_tensorcore.cpp @@ -74,7 +74,7 @@ T2 matrix_ref_mn(const int &m, const int &n, T1 *A, T1 *B, T2 *C) { } template + size_t Sub_Tiles_N, size_t M, size_t K, size_t N, typename T3 = T1> void test() { constexpr auto Big_M = @@ -131,19 +131,19 @@ void test() { range<2> GlobalRange = {Sub_Tiles_M, Sub_Tiles_N * N_THREADS_PER_MATRIX_OP}; cgh.parallel_for>( - nd_range<2>(GlobalRange, LocalRange), [= - ](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + nd_range<2>(GlobalRange, LocalRange), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); const auto m = - item.get_group() - .get_id()[0]; // row id of current submatrix of BIG C matrix + item.get_group().get_group_id()[0]; // row id of current submatrix + // of BIG C matrix const auto n = - item.get_group().get_id()[1]; // column id of current - // submatrix of BIG C matrix + item.get_group().get_group_id()[1]; // column id of current + // submatrix of BIG C matrix - joint_matrix sub_a; + joint_matrix sub_a; - joint_matrix sub_b; + joint_matrix sub_b; joint_matrix @@ -163,6 +163,14 @@ void test() { accB.get_pointer() + (k * K * Big_N) + (n * N), Big_N); + // Convert values if using tf32 + if constexpr (std::is_same::value) { + for (auto i = 0; i < 4; ++i) { + sub_a.data[i] = round_to_tf32(sub_a.data[i]); + sub_b.data[i] = round_to_tf32(sub_b.data[i]); + } + } + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( @@ -182,7 +190,6 @@ void test() { }; int main() { - // A/B half, Accumulator float test(); test(); @@ -208,5 +215,9 @@ int main() { test(); test(); + // A/B tf32 + test(); + return 0; };