Skip to content

Commit 2340b33

Browse files
authored
[SYCL][ext][CUDA] Use float as storage type for tf32 joint matrix (#5870)
Changing joint_matrix impl to use float as storage type instead of uint32_t for tf32. Test: intel/llvm-test-suite#963
1 parent 8c4d9a5 commit 2340b33

File tree

2 files changed

+205
-18
lines changed

2 files changed

+205
-18
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp

Lines changed: 64 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ enum class matrix_use { a, b, accumulator };
1818

1919
enum class matrix_layout { row_major, col_major, packed_a, packed_b };
2020

21+
namespace precision {
22+
class tf32 {};
23+
} // namespace precision
24+
2125
template <typename T, matrix_use Use, size_t Rows = sycl::dynamic_extent,
2226
size_t Cols = sycl::dynamic_extent,
2327
matrix_layout Layout = matrix_layout::row_major,
@@ -81,18 +85,23 @@ __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 16, 16, int32_t, 2)
8185
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 16, int32_t, 2)
8286
__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 16, 16, int32_t, 8)
8387

88+
// m16n16k8 tf32
89+
__SYCL_JOINT_MATRIX_OVERLOAD(precision::tf32, a, 16, 8, float, 4)
90+
__SYCL_JOINT_MATRIX_OVERLOAD(precision::tf32, b, 8, 16, float, 4)
91+
8492
#undef __SYCL_JOINT_MATRIX_OVERLOAD
8593
} // namespace experimental::matrix
8694

8795
namespace detail {
8896

89-
template <typename T, sycl::ext::oneapi::experimental::matrix::matrix_use Use,
97+
template <typename S, typename T,
98+
sycl::ext::oneapi::experimental::matrix::matrix_use Use,
9099
size_t NumRows, size_t NumCols,
91100
sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
92101
access::address_space Space, typename Cond = void>
93102
struct joint_matrix_load_impl {
94103
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix<
95-
T, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
104+
S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
96105
multi_ptr<T, Space> src, size_t stride);
97106
};
98107

@@ -111,18 +120,19 @@ constexpr int get_layout_id<
111120
return 1;
112121
}
113122

114-
template <typename T, sycl::ext::oneapi::experimental::matrix::matrix_use Use,
123+
template <typename S, typename T,
124+
sycl::ext::oneapi::experimental::matrix::matrix_use Use,
115125
size_t NumRows, size_t NumCols,
116126
sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
117127
access::address_space Space>
118128
struct joint_matrix_load_impl<
119-
T, Use, NumRows, NumCols, Layout, Space,
129+
S, T, Use, NumRows, NumCols, Layout, Space,
120130
typename std::enable_if_t<Layout == sycl::ext::oneapi::experimental::
121131
matrix::matrix_layout::row_major ||
122132
Layout == sycl::ext::oneapi::experimental::
123133
matrix::matrix_layout::col_major>> {
124134
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix<
125-
T, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
135+
S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
126136
multi_ptr<T, Space> src, size_t stride) {
127137
if constexpr (std::is_same<T, uint16_t>::value) {
128138
int32_t *tileptr = reinterpret_cast<int32_t *>(src.get());
@@ -247,15 +257,27 @@ struct joint_matrix_load_impl<
247257
get_layout_id<Layout>());
248258
}
249259
} else if constexpr (std::is_same<T, float>::value) {
250-
if constexpr (NumRows == 16 && NumCols == 16) {
251-
__hmma_m16n16k16_ld_c_f32(res.data, src.get(), stride,
252-
get_layout_id<Layout>());
253-
} else if constexpr (NumRows == 8 && NumCols == 32) {
254-
__hmma_m8n32k16_ld_c_f32(res.data, src.get(), stride,
255-
get_layout_id<Layout>());
256-
} else if constexpr (NumRows == 32 && NumCols == 8) {
257-
__hmma_m32n8k16_ld_c_f32(res.data, src.get(), stride,
258-
get_layout_id<Layout>());
260+
if (std::is_same<S, float>::value) {
261+
if constexpr (NumRows == 16 && NumCols == 16) {
262+
__hmma_m16n16k16_ld_c_f32(res.data, src.get(), stride,
263+
get_layout_id<Layout>());
264+
} else if constexpr (NumRows == 8 && NumCols == 32) {
265+
__hmma_m8n32k16_ld_c_f32(res.data, src.get(), stride,
266+
get_layout_id<Layout>());
267+
} else if constexpr (NumRows == 32 && NumCols == 8) {
268+
__hmma_m32n8k16_ld_c_f32(res.data, src.get(), stride,
269+
get_layout_id<Layout>());
270+
}
271+
} else if (std::is_same<S, sycl::ext::oneapi::experimental::matrix::
272+
precision::tf32>::value) {
273+
int32_t *tileptr = reinterpret_cast<int32_t *>(src.get());
274+
if constexpr (NumRows == 16 && NumCols == 8) {
275+
__mma_tf32_m16n16k8_ld_a(reinterpret_cast<int32_t *>(res.data),
276+
tileptr, stride, get_layout_id<Layout>());
277+
} else if constexpr (NumRows == 8 && NumCols == 16) {
278+
__mma_tf32_m16n16k8_ld_b(reinterpret_cast<int32_t *>(res.data),
279+
tileptr, stride, get_layout_id<Layout>());
280+
}
259281
}
260282
} else if constexpr (std::is_same<T, double>::value) {
261283
if constexpr (Use ==
@@ -495,6 +517,10 @@ struct joint_matrix_mad_impl<
495517
get_layout_pair_id<LayoutA, LayoutB>(), 0);
496518
}
497519
}
520+
} else if constexpr (M == 16 && N == 16 && K == 8) {
521+
__mma_tf32_m16n16k8_mma_f32(D.data, reinterpret_cast<int32_t *>(A.data),
522+
reinterpret_cast<int32_t *>(B.data), C.data,
523+
get_layout_pair_id<LayoutA, LayoutB>(), 0);
498524
} else if constexpr (std::is_same<T1, double>::value) {
499525
__dmma_m8n8k4_mma_f64(D.data, A.data, B.data, C.data,
500526
get_layout_pair_id<LayoutA, LayoutB>(), 0);
@@ -507,13 +533,18 @@ struct joint_matrix_mad_impl<
507533

508534
namespace experimental::matrix {
509535

510-
template <typename Group, typename T, matrix_use Use, size_t NumRows,
511-
size_t NumCols, matrix_layout Layout, access::address_space Space>
536+
template <typename Group, typename S, typename T, matrix_use Use,
537+
size_t NumRows, size_t NumCols, matrix_layout Layout,
538+
access::address_space Space,
539+
std::enable_if_t<std::is_same<S, T>::value ||
540+
(std::is_same<S, precision::tf32>::value &&
541+
std::is_same<T, float>::value),
542+
bool> = true>
512543
void joint_matrix_load(
513-
Group sg, joint_matrix<T, Use, NumRows, NumCols, Layout, Group> &res,
544+
Group sg, joint_matrix<S, Use, NumRows, NumCols, Layout, Group> &res,
514545
multi_ptr<T, Space> src, size_t stride) {
515546
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
516-
sycl::ext::oneapi::detail::joint_matrix_load_impl<T, Use, NumRows, NumCols,
547+
sycl::ext::oneapi::detail::joint_matrix_load_impl<S, T, Use, NumRows, NumCols,
517548
Layout, Space>{}
518549
.load(res, src, stride);
519550
#else
@@ -573,6 +604,21 @@ joint_matrix_mad(
573604
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
574605
}
575606

607+
// This function rounds the bottom 13 bits up or down, and then zeros out the
608+
// bottom bits
609+
float round_to_tf32(float a) {
610+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
611+
int32_t tmp_int = __nvvm_f2tf32_rna(a);
612+
return __nvvm_bitcast_i2f(tmp_int);
613+
#else
614+
uint32_t tmp_uint = reinterpret_cast<uint32_t &>(a);
615+
tmp_uint += 0x1000u;
616+
tmp_uint &= 0xFFFFE000u;
617+
float ret = reinterpret_cast<float &>(tmp_uint);
618+
return ret;
619+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
620+
}
621+
576622
} // namespace experimental::matrix
577623
} // namespace oneapi
578624
} // namespace ext
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
// REQUIRES: cuda
2+
3+
// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s
4+
5+
// IMPORTANT: before updating sm version support beyond sm_86 read the following
6+
// NOTE!
7+
8+
// NOTE: Technically the 'wrong' ptx instruction is called by
9+
// joint_matrix_load/joint_matrix_store in this case: notice that the load and
10+
// store instructions use shape m16n16k16, rather than the correct shape
11+
// m16n16k8. The 'wrong' ptx instruction is used because it returns the correct
12+
// SASS instructions for all existing supported sm versions: sm_80 and sm_86.
13+
// The reason for this ptx instruction redundancy is due to the ptx naming
14+
// convention for the mnk shape triple; however we cannot in principle a priori
15+
// know that future sm versions will behave in the same way and that this
16+
// redundancy will continue as future architecture is released. This should be
17+
// validated before supporting any sm versions beyond sm_86. The reason that we
18+
// choose to use the m16n16k16 instruction is that it allows the significant
19+
// advantage of being able to use a portable interface across Intel and Nvidia
20+
// backends.
21+
22+
#include <CL/sycl.hpp>
23+
24+
using namespace sycl;
25+
using namespace sycl::ext::oneapi::experimental::matrix;
26+
27+
// M, N, K define the sizes of dimensions of the three matrix types (a, b,
28+
// accumulator) used per subgroup operation.
29+
constexpr int M = 16; // number of rows of accumulator,
30+
// number of cols of b.
31+
constexpr int N = 16; // number of cols of accumulator,
32+
// number of rows of a.
33+
constexpr int K = 8; // number of cols of a/number of rows of b.
34+
35+
// float is used in this test as the storage type for tf32
36+
float A[M * K];
37+
float B[K * N];
38+
float C[M * N];
39+
float D[M * N];
40+
41+
int main() {
42+
43+
buffer<float, 1> bufA(A, range<1>(M * K)); // will be used as tf32
44+
buffer<float, 1> bufB(B, range<1>(K * N)); // will be used as tf32
45+
buffer<float, 1> bufC(C, range<1>(M * N));
46+
buffer<float, 1> bufD(D, range<1>(M * N));
47+
48+
queue q;
49+
50+
q.submit([&](handler &cgh) {
51+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
52+
auto accB = bufB.get_access<access::mode::read_write>(cgh);
53+
auto accC = bufC.get_access<access::mode::read_write>(cgh);
54+
auto accD = bufD.get_access<access::mode::read_write>(cgh);
55+
56+
cgh.parallel_for<class row_row>(
57+
nd_range<2>({1, 32}, {1, 32}),
58+
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
59+
sycl::sub_group sg = item.get_sub_group();
60+
61+
joint_matrix<precision::tf32, matrix_use::a, M, K,
62+
matrix_layout::row_major>
63+
sub_a;
64+
65+
joint_matrix<precision::tf32, matrix_use::b, K, N,
66+
matrix_layout::row_major>
67+
sub_b;
68+
69+
joint_matrix<float, matrix_use::accumulator, M, N,
70+
matrix_layout::row_major>
71+
sub_c;
72+
73+
//CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.row.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 8) #{{.*}}
74+
joint_matrix_load(sg, sub_a, accA.get_pointer(), K);
75+
//CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.b.row.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}}
76+
joint_matrix_load(sg, sub_b, accB.get_pointer(), N);
77+
//CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) #{{.*}}
78+
joint_matrix_load(sg, sub_c, accC.get_pointer(), N);
79+
80+
// CHECK: tail call i32 @llvm.nvvm.f2tf32.rna(float {{.*}}
81+
// Round a, b to tf32
82+
for (auto i = 0; i < 4; ++i)
83+
sub_a.data[i] = round_to_tf32(sub_a.data[i]);
84+
85+
for (auto i = 0; i < 4; ++i)
86+
sub_b.data[i] = round_to_tf32(sub_b.data[i]);
87+
88+
//CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.row.row.tf32(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 %{{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) #{{.*}}
89+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
90+
//CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1f32(float addrspace(1)* {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 {{.*}}
91+
joint_matrix_store(sg, sub_c, accD.get_pointer(), N);
92+
});
93+
});
94+
95+
q.submit([&](handler &cgh) {
96+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
97+
auto accB = bufB.get_access<access::mode::read_write>(cgh);
98+
auto accC = bufC.get_access<access::mode::read_write>(cgh);
99+
auto accD = bufD.get_access<access::mode::read_write>(cgh);
100+
101+
cgh.parallel_for<class col_col>(
102+
nd_range<2>({1, 32}, {1, 32}),
103+
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
104+
sycl::sub_group sg = item.get_sub_group();
105+
106+
joint_matrix<precision::tf32, matrix_use::a, M, K,
107+
matrix_layout::col_major>
108+
sub_a;
109+
110+
joint_matrix<precision::tf32, matrix_use::b, K, N,
111+
matrix_layout::col_major>
112+
sub_b;
113+
114+
joint_matrix<float, matrix_use::accumulator, M, N,
115+
matrix_layout::col_major>
116+
sub_c;
117+
118+
//CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.col.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 8) #{{.*}}
119+
joint_matrix_load(sg, sub_a, accA.get_pointer(), K);
120+
//CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.b.col.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}}
121+
joint_matrix_load(sg, sub_b, accB.get_pointer(), N);
122+
//CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1f32(float addrspace(1)* {{.*}}, i32 {{.*}}) #{{.*}}
123+
joint_matrix_load(sg, sub_c, accC.get_pointer(), N);
124+
125+
// CHECK: tail call i32 @llvm.nvvm.f2tf32.rna(float {{.*}}
126+
// Round a, b to tf32
127+
for (auto i = 0; i < 4; ++i)
128+
sub_a.data[i] = round_to_tf32(sub_a.data[i]);
129+
130+
for (auto i = 0; i < 4; ++i)
131+
sub_b.data[i] = round_to_tf32(sub_b.data[i]);
132+
133+
//CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.col.col.tf32(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) #{{.*}}
134+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
135+
//CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1f32(float addrspace(1)* {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) #{{.*}}
136+
joint_matrix_store(sg, sub_c, accD.get_pointer(), N);
137+
});
138+
});
139+
140+
return 0;
141+
};

0 commit comments

Comments
 (0)