@@ -18,6 +18,10 @@ enum class matrix_use { a, b, accumulator };
18
18
19
19
enum class matrix_layout { row_major, col_major, packed_a, packed_b };
20
20
21
+ namespace precision {
22
+ class tf32 {};
23
+ } // namespace precision
24
+
21
25
template <typename T, matrix_use Use, size_t Rows = sycl::dynamic_extent,
22
26
size_t Cols = sycl::dynamic_extent,
23
27
matrix_layout Layout = matrix_layout::row_major,
@@ -81,18 +85,23 @@ __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 16, 16, int32_t, 2)
81
85
__SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , b, 16 , 16 , int32_t , 2 )
82
86
__SYCL_JOINT_MATRIX_OVERLOAD (int32_t , accumulator, 16 , 16 , int32_t , 8 )
83
87
88
+ // m16n16k8 tf32
89
+ __SYCL_JOINT_MATRIX_OVERLOAD (precision::tf32, a, 16 , 8 , float , 4 )
90
+ __SYCL_JOINT_MATRIX_OVERLOAD (precision::tf32, b, 8 , 16 , float , 4 )
91
+
84
92
#undef __SYCL_JOINT_MATRIX_OVERLOAD
85
93
} // namespace experimental::matrix
86
94
87
95
namespace detail {
88
96
89
- template <typename T, sycl::ext::oneapi::experimental::matrix::matrix_use Use,
97
+ template <typename S, typename T,
98
+ sycl::ext::oneapi::experimental::matrix::matrix_use Use,
90
99
size_t NumRows, size_t NumCols,
91
100
sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
92
101
access::address_space Space, typename Cond = void >
93
102
struct joint_matrix_load_impl {
94
103
void load (sycl::ext::oneapi::experimental::matrix::joint_matrix<
95
- T , Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
104
+ S , Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
96
105
multi_ptr<T, Space> src, size_t stride);
97
106
};
98
107
@@ -111,18 +120,19 @@ constexpr int get_layout_id<
111
120
return 1 ;
112
121
}
113
122
114
- template <typename T, sycl::ext::oneapi::experimental::matrix::matrix_use Use,
123
+ template <typename S, typename T,
124
+ sycl::ext::oneapi::experimental::matrix::matrix_use Use,
115
125
size_t NumRows, size_t NumCols,
116
126
sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
117
127
access::address_space Space>
118
128
struct joint_matrix_load_impl <
119
- T, Use, NumRows, NumCols, Layout, Space,
129
+ S, T, Use, NumRows, NumCols, Layout, Space,
120
130
typename std::enable_if_t <Layout == sycl::ext::oneapi::experimental::
121
131
matrix::matrix_layout::row_major ||
122
132
Layout == sycl::ext::oneapi::experimental::
123
133
matrix::matrix_layout::col_major>> {
124
134
void load (sycl::ext::oneapi::experimental::matrix::joint_matrix<
125
- T , Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
135
+ S , Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
126
136
multi_ptr<T, Space> src, size_t stride) {
127
137
if constexpr (std::is_same<T, uint16_t >::value) {
128
138
int32_t *tileptr = reinterpret_cast <int32_t *>(src.get ());
@@ -247,15 +257,27 @@ struct joint_matrix_load_impl<
247
257
get_layout_id<Layout>());
248
258
}
249
259
} else if constexpr (std::is_same<T, float >::value) {
250
- if constexpr (NumRows == 16 && NumCols == 16 ) {
251
- __hmma_m16n16k16_ld_c_f32 (res.data , src.get (), stride,
252
- get_layout_id<Layout>());
253
- } else if constexpr (NumRows == 8 && NumCols == 32 ) {
254
- __hmma_m8n32k16_ld_c_f32 (res.data , src.get (), stride,
255
- get_layout_id<Layout>());
256
- } else if constexpr (NumRows == 32 && NumCols == 8 ) {
257
- __hmma_m32n8k16_ld_c_f32 (res.data , src.get (), stride,
258
- get_layout_id<Layout>());
260
+ if (std::is_same<S, float >::value) {
261
+ if constexpr (NumRows == 16 && NumCols == 16 ) {
262
+ __hmma_m16n16k16_ld_c_f32 (res.data , src.get (), stride,
263
+ get_layout_id<Layout>());
264
+ } else if constexpr (NumRows == 8 && NumCols == 32 ) {
265
+ __hmma_m8n32k16_ld_c_f32 (res.data , src.get (), stride,
266
+ get_layout_id<Layout>());
267
+ } else if constexpr (NumRows == 32 && NumCols == 8 ) {
268
+ __hmma_m32n8k16_ld_c_f32 (res.data , src.get (), stride,
269
+ get_layout_id<Layout>());
270
+ }
271
+ } else if (std::is_same<S, sycl::ext::oneapi::experimental::matrix::
272
+ precision::tf32>::value) {
273
+ int32_t *tileptr = reinterpret_cast <int32_t *>(src.get ());
274
+ if constexpr (NumRows == 16 && NumCols == 8 ) {
275
+ __mma_tf32_m16n16k8_ld_a (reinterpret_cast <int32_t *>(res.data ),
276
+ tileptr, stride, get_layout_id<Layout>());
277
+ } else if constexpr (NumRows == 8 && NumCols == 16 ) {
278
+ __mma_tf32_m16n16k8_ld_b (reinterpret_cast <int32_t *>(res.data ),
279
+ tileptr, stride, get_layout_id<Layout>());
280
+ }
259
281
}
260
282
} else if constexpr (std::is_same<T, double >::value) {
261
283
if constexpr (Use ==
@@ -495,6 +517,10 @@ struct joint_matrix_mad_impl<
495
517
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
496
518
}
497
519
}
520
+ } else if constexpr (M == 16 && N == 16 && K == 8 ) {
521
+ __mma_tf32_m16n16k8_mma_f32 (D.data , reinterpret_cast <int32_t *>(A.data ),
522
+ reinterpret_cast <int32_t *>(B.data ), C.data ,
523
+ get_layout_pair_id<LayoutA, LayoutB>(), 0 );
498
524
} else if constexpr (std::is_same<T1, double >::value) {
499
525
__dmma_m8n8k4_mma_f64 (D.data , A.data , B.data , C.data ,
500
526
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
@@ -507,13 +533,18 @@ struct joint_matrix_mad_impl<
507
533
508
534
namespace experimental ::matrix {
509
535
510
- template <typename Group, typename T, matrix_use Use, size_t NumRows,
511
- size_t NumCols, matrix_layout Layout, access::address_space Space>
536
+ template <typename Group, typename S, typename T, matrix_use Use,
537
+ size_t NumRows, size_t NumCols, matrix_layout Layout,
538
+ access::address_space Space,
539
+ std::enable_if_t <std::is_same<S, T>::value ||
540
+ (std::is_same<S, precision::tf32>::value &&
541
+ std::is_same<T, float >::value),
542
+ bool > = true >
512
543
void joint_matrix_load (
513
- Group sg, joint_matrix<T , Use, NumRows, NumCols, Layout, Group> &res,
544
+ Group sg, joint_matrix<S , Use, NumRows, NumCols, Layout, Group> &res,
514
545
multi_ptr<T, Space> src, size_t stride) {
515
546
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
516
- sycl::ext::oneapi::detail::joint_matrix_load_impl<T, Use, NumRows, NumCols,
547
+ sycl::ext::oneapi::detail::joint_matrix_load_impl<S, T, Use, NumRows, NumCols,
517
548
Layout, Space>{}
518
549
.load (res, src, stride);
519
550
#else
@@ -573,6 +604,21 @@ joint_matrix_mad(
573
604
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
574
605
}
575
606
607
+ // This function rounds the bottom 13 bits up or down, and then zeros out the
608
+ // bottom bits
609
+ float round_to_tf32 (float a) {
610
+ #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
611
+ int32_t tmp_int = __nvvm_f2tf32_rna (a);
612
+ return __nvvm_bitcast_i2f (tmp_int);
613
+ #else
614
+ uint32_t tmp_uint = reinterpret_cast <uint32_t &>(a);
615
+ tmp_uint += 0x1000u ;
616
+ tmp_uint &= 0xFFFFE000u ;
617
+ float ret = reinterpret_cast <float &>(tmp_uint);
618
+ return ret;
619
+ #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
620
+ }
621
+
576
622
} // namespace experimental::matrix
577
623
} // namespace oneapi
578
624
} // namespace ext
0 commit comments