11
22// REQUIRES: cuda
3- // Temp xfail: test was merged early.
4- // XFAIL: cuda
53// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 %s -o %t.out
64// RUN: %t.out
75//
1412#include < sycl/sycl.hpp>
1513
1614using namespace sycl ;
17- using namespace sycl ::ext::oneapi::experimental ;
15+ using namespace sycl ::ext::oneapi;
1816using namespace sycl ::ext::oneapi::experimental::matrix;
1917constexpr float bf16_eps = 0.00390625 ;
2018
@@ -146,9 +144,11 @@ void test(queue &q) {
146144 // column id of current submatrix of BIG C matrix
147145 const auto n = item.get_group ().get_group_id ()[1 ];
148146
149- joint_matrix<T3, use::a, M, K, layout::row_major> sub_a;
150- joint_matrix<T3, use::b, K, N, layout::row_major> sub_b;
151- joint_matrix<std::remove_const_t <T2>, use::accumulator, M, N> sub_c;
147+ joint_matrix<sub_group, T3, use::a, M, K, layout::row_major> sub_a;
148+ joint_matrix<sub_group, T3, use::b, K, N, layout::row_major> sub_b;
149+ joint_matrix<sub_group, std::remove_const_t <T2>, use::accumulator,
150+ M, N>
151+ sub_c;
152152
153153 joint_matrix_load (sg, sub_c,
154154 accC.get_pointer () + (m * M) * Big_N + n * N,
@@ -165,11 +165,13 @@ void test(queue &q) {
165165
166166 // round values to correct precision if using tf32
167167 if constexpr (std::is_same<T3, precision::tf32>::value) {
168- auto wi_size = sub_a. wi_marray . size ();
169- assert (wi_size == sub_b. wi_marray . size ());
168+ auto wi_size = get_wi_data (sg, sub_a). length ();
169+ assert (wi_size == get_wi_data (sg, sub_b). length ());
170170 for (auto i = 0 ; i < wi_size; ++i) {
171- sub_a.wi_marray [i] = round_to_tf32 (sub_a.wi_marray [i]);
172- sub_b.wi_marray [i] = round_to_tf32 (sub_b.wi_marray [i]);
171+ get_wi_data (sg, sub_a)[i] =
172+ round_to_tf32 (get_wi_data (sg, sub_a)[i]);
173+ get_wi_data (sg, sub_b)[i] =
174+ round_to_tf32 (get_wi_data (sg, sub_b)[i]);
173175 }
174176 }
175177
0 commit comments