@@ -67,10 +67,26 @@ struct joint_matrix;
6767
6868} // namespace matrix
6969} // namespace experimental
70+
71+ namespace detail {
72+ // Differentiating between the "element type" and the "storage element type"
73+ template <typename T> struct jm_type_interpretation_helper_trait {
74+ using element_type = T;
75+ using storage_element_type = T;
76+ };
77+
78+ template <>
79+ struct jm_type_interpretation_helper_trait <
80+ sycl::ext::oneapi::experimental::matrix::precision::tf32> {
81+ using element_type = sycl::ext::oneapi::experimental::matrix::precision::tf32;
82+ using storage_element_type = float ;
83+ };
84+ } // namespace detail
7085} // namespace oneapi
7186
7287namespace intel ::experimental::matrix {
7388
89+ using namespace sycl ::ext::oneapi::experimental::matrix;
7490// Begin wi_element definition
7591
7692template <typename T, size_t NumRows, size_t NumCols,
@@ -84,6 +100,9 @@ class wi_element {
84100 std::size_t idx;
85101
86102public:
103+ using storage_element_type =
104+ typename oneapi::detail::jm_type_interpretation_helper_trait<
105+ T>::storage_element_type;
87106 wi_element (sycl::ext::oneapi::experimental::matrix::joint_matrix<
88107 Group, T, Use, NumRows, NumCols, Layout> &Mat,
89108 std::size_t i)
@@ -102,9 +121,15 @@ class wi_element {
102121#endif // __SYCL_DEVICE_ONLY__
103122 }
104123
105- operator T () {
124+ operator storage_element_type () {
106125#ifdef __SYCL_DEVICE_ONLY__
107- return __spirv_VectorExtractDynamic (M.spvm , idx);
126+ storage_element_type elem =
127+ __spirv_VectorExtractDynamic<storage_element_type, T, NumRows, NumCols,
128+ spv_matrix_use_traits<Use>::value,
129+ spv_matrix_layout_traits<Layout>::value,
130+ spv_scope_traits<Group>::value>(M.spvm ,
131+ idx);
132+ return elem;
108133#else
109134 throw runtime_error (" joint matrix is not supported on host device." ,
110135 PI_ERROR_INVALID_DEVICE);
@@ -113,7 +138,12 @@ class wi_element {
113138
114139 explicit operator bool () {
115140#ifdef __SYCL_DEVICE_ONLY__
116- return __spirv_VectorExtractDynamic (M.spvm , idx) != static_cast <T>(0 );
141+ return __spirv_VectorExtractDynamic<storage_element_type, T, NumRows,
142+ NumCols,
143+ spv_matrix_use_traits<Use>::value,
144+ spv_matrix_layout_traits<Layout>::value,
145+ spv_scope_traits<Group>::value>(
146+ M.spvm , idx) != static_cast <storage_element_type>(0 );
117147#else
118148 throw runtime_error (" joint matrix is not supported on host device." ,
119149 PI_ERROR_INVALID_DEVICE);
@@ -122,7 +152,8 @@ class wi_element {
122152
123153 template <typename T2> wi_element &operator =(const T2 &rhs) {
124154#ifdef __SYCL_DEVICE_ONLY__
125- M.spvm = __spirv_VectorInsertDynamic (M.spvm , static_cast <T>(rhs), idx);
155+ M.spvm = __spirv_VectorInsertDynamic (
156+ M.spvm , static_cast <storage_element_type>(rhs), idx);
126157 return *this ;
127158#else
128159 (void )rhs;
@@ -135,7 +166,13 @@ class wi_element {
135166 operator =(const wi_element<T, NumRows, NumCols, Use, Layout, Group> &rhs) {
136167#ifdef __SYCL_DEVICE_ONLY__
137168 M.spvm = __spirv_VectorInsertDynamic (
138- M.spvm , __spirv_VectorExtractDynamic (rhs.M .spvm , rhs.idx ), idx);
169+ M.spvm ,
170+ __spirv_VectorExtractDynamic<storage_element_type, T, NumRows, NumCols,
171+ spv_matrix_use_traits<Use>::value,
172+ spv_matrix_layout_traits<Layout>::value,
173+ spv_scope_traits<Group>::value>(rhs.M .spvm ,
174+ rhs.idx ),
175+ idx);
139176 return *this ;
140177#else
141178 (void )rhs;
@@ -149,8 +186,13 @@ class wi_element {
149186 template <typename T2> wi_element &operator op##=(const T2 &rhs) { \
150187 M.spvm = __spirv_VectorInsertDynamic ( \
151188 M.spvm , \
152- static_cast <T>(__spirv_VectorExtractDynamic (M.spvm , idx) \
153- op static_cast <T>(rhs)), \
189+ static_cast <storage_element_type>( \
190+ __spirv_VectorExtractDynamic< \
191+ storage_element_type, T, NumRows, NumCols, \
192+ spv_matrix_use_traits<Use>::value, \
193+ spv_matrix_layout_traits<Layout>::value, \
194+ spv_scope_traits<Group>::value>(M.spvm , idx) \
195+ op static_cast <storage_element_type>(rhs)), \
154196 idx); \
155197 return *this ; \
156198 }
@@ -201,7 +243,11 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
201243
202244 operator sycl::ext::oneapi::bfloat16 () {
203245#ifdef __SYCL_DEVICE_ONLY__
204- return __spirv_VectorExtractDynamic (M.spvm , idx);
246+ return __spirv_VectorExtractDynamic<
247+ sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows,
248+ NumCols, spv_matrix_use_traits<Use>::value,
249+ spv_matrix_layout_traits<Layout>::value,
250+ spv_scope_traits<Group>::value>(M.spvm , idx);
205251#else
206252 throw runtime_error (" joint matrix is not supported on host device." ,
207253 PI_ERROR_INVALID_DEVICE);
@@ -210,8 +256,13 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
210256
211257 explicit operator bool () {
212258#ifdef __SYCL_DEVICE_ONLY__
213- return std::fabs (static_cast <float >(__spirv_VectorExtractDynamic (
214- M.spvm , idx))) >= std::numeric_limits<float >::epsilon ();
259+ return std::fabs (static_cast <float >(
260+ __spirv_VectorExtractDynamic<
261+ sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16,
262+ NumRows, NumCols, spv_matrix_use_traits<Use>::value,
263+ spv_matrix_layout_traits<Layout>::value,
264+ spv_scope_traits<Group>::value>(M.spvm , idx))) >=
265+ std::numeric_limits<float >::epsilon ();
215266#else
216267 throw runtime_error (" joint matrix is not supported on host device." ,
217268 PI_ERROR_INVALID_DEVICE);
@@ -233,7 +284,14 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
233284 NumCols, Use, Layout, Group> &rhs) {
234285#ifdef __SYCL_DEVICE_ONLY__
235286 M.spvm = __spirv_VectorInsertDynamic (
236- M.spvm , __spirv_VectorExtractDynamic (rhs.M .spvm , rhs.idx ), idx);
287+ M.spvm ,
288+ __spirv_VectorExtractDynamic<sycl::ext::oneapi::bfloat16,
289+ sycl::ext::oneapi::bfloat16, NumRows,
290+ NumCols, spv_matrix_use_traits<Use>::value,
291+ spv_matrix_layout_traits<Layout>::value,
292+ spv_scope_traits<Group>::value>(rhs.M .spvm ,
293+ rhs.idx ),
294+ idx);
237295 return *this ;
238296#else
239297 (void )rhs;
@@ -246,7 +304,13 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
246304#define OP (opassign, op ) \
247305 wi_element &operator opassign (const sycl::ext::oneapi::bfloat16 &rhs) { \
248306 M.spvm = __spirv_VectorInsertDynamic ( \
249- M.spvm , __spirv_VectorExtractDynamic (M.spvm , idx) op rhs, idx); \
307+ M.spvm , \
308+ __spirv_VectorExtractDynamic< \
309+ sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
310+ NumCols, spv_matrix_use_traits<Use>::value, \
311+ spv_matrix_layout_traits<Layout>::value, \
312+ spv_scope_traits<Group>::value>(M.spvm , idx) op rhs, \
313+ idx); \
250314 return *this ; \
251315 }
252316#else // __SYCL_DEVICE_ONLY__
@@ -269,13 +333,21 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
269333 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
270334 Layout, Group> &lhs, \
271335 const sycl::ext::oneapi::bfloat16 &rhs) { \
272- return __spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx ) op rhs; \
336+ return __spirv_VectorExtractDynamic< \
337+ sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
338+ NumCols, spv_matrix_use_traits<Use>::value, \
339+ spv_matrix_layout_traits<Layout>::value, \
340+ spv_scope_traits<Group>::value>(lhs.M .spvm , lhs.idx ) op rhs; \
273341 } \
274342 friend type operator op ( \
275343 const sycl::ext::oneapi::bfloat16 &lhs, \
276344 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
277345 Layout, Group> &rhs) { \
278- return __spirv_VectorExtractDynamic (rhs.M .spvm , rhs.idx ) op lhs; \
346+ return __spirv_VectorExtractDynamic< \
347+ sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
348+ NumCols, spv_matrix_use_traits<Use>::value, \
349+ spv_matrix_layout_traits<Layout>::value, \
350+ spv_scope_traits<Group>::value>(rhs.M .spvm , rhs.idx ) op lhs; \
279351 }
280352 OP (sycl::ext::oneapi::bfloat16, +)
281353 OP(sycl::ext::oneapi::bfloat16, -)
@@ -287,15 +359,25 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
287359 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
288360 Layout, Group> &lhs, \
289361 const sycl::ext::oneapi::bfloat16 &rhs) { \
290- return type{static_cast <float >(__spirv_VectorExtractDynamic ( \
291- lhs.M .spvm , lhs.idx )) op static_cast <float >(rhs)}; \
362+ return type{static_cast <float >( \
363+ __spirv_VectorExtractDynamic< \
364+ sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
365+ NumCols, spv_matrix_use_traits<Use>::value, \
366+ spv_matrix_layout_traits<Layout>::value, \
367+ spv_scope_traits<Group>::value>(lhs.M .spvm , lhs.idx )) \
368+ op static_cast <float >(rhs)}; \
292369 } \
293370 friend type operator op ( \
294371 const sycl::ext::oneapi::bfloat16 &lhs, \
295372 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
296373 Layout, Group> &rhs) { \
297- return type{static_cast <float >(__spirv_VectorExtractDynamic ( \
298- rhs.M .spvm , rhs.idx )) op static_cast <float >(lhs)}; \
374+ return type{static_cast <float >( \
375+ __spirv_VectorExtractDynamic< \
376+ sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
377+ NumCols, spv_matrix_use_traits<Use>::value, \
378+ spv_matrix_layout_traits<Layout>::value, \
379+ spv_scope_traits<Group>::value>(rhs.M .spvm , rhs.idx )) \
380+ op static_cast <float >(lhs)}; \
299381 }
300382 OP (bool , ==)
301383 OP(bool , !=)
@@ -386,7 +468,7 @@ get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix<
386468// End wi_data definition
387469
388470template <
389- typename Group, typename T,
471+ typename Group, typename T, typename Tp,
390472 sycl::ext::oneapi::experimental::matrix::use Use, size_t NumRows,
391473 size_t NumCols, sycl::ext::oneapi::experimental::matrix::layout Layout,
392474 access::address_space Space, access::decorated IsDecorated,
@@ -396,7 +478,7 @@ template <
396478inline __SYCL_ALWAYS_INLINE void
397479joint_matrix_store (Group sg,
398480 sycl::ext::oneapi::experimental::matrix::joint_matrix<
399- Group, T , Use, NumRows, NumCols, Layout> &src,
481+ Group, Tp , Use, NumRows, NumCols, Layout> &src,
400482 multi_ptr<T, Space, IsDecorated> dst, size_t stride) {
401483#if defined(__SYCL_DEVICE_ONLY__)
402484#if defined(__NVPTX__)
@@ -411,7 +493,7 @@ joint_matrix_store(Group sg,
411493#else
412494 // intel's impl
413495 T *Ptr = dst.get ();
414- __spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
496+ __spirv_JointMatrixStoreINTEL<T, Tp, NumRows, NumCols,
415497 sycl::ext::oneapi::experimental::matrix::
416498 spv_matrix_use_traits<Use>::value,
417499 sycl::ext::oneapi::experimental::matrix::
0 commit comments