@@ -151,9 +151,9 @@ struct joint_matrix_load_impl<
151
151
void load (sycl::ext::oneapi::experimental::matrix::joint_matrix<
152
152
S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
153
153
multi_ptr<T, Space> src, size_t stride) {
154
- if constexpr (sycl::detail::is_same_v <T, uint16_t > ||
155
- sycl::detail::is_same_v <
156
- T, sycl::ext::oneapi::experimental::bfloat16>) {
154
+ if constexpr (std::is_same <T, uint16_t >::value ||
155
+ std::is_same <
156
+ T, sycl::ext::oneapi::experimental::bfloat16>::value ) {
157
157
auto tileptr = reinterpret_cast <int32_t const *>(src.get ());
158
158
auto destptr = reinterpret_cast <int32_t *>(&res.data );
159
159
if constexpr (NumRows == 16 && NumCols == 16 ) {
@@ -179,7 +179,7 @@ struct joint_matrix_load_impl<
179
179
__mma_bf16_m32n8k16_ld_b (destptr, tileptr, stride,
180
180
get_layout_id<Layout>());
181
181
}
182
- } else if constexpr (sycl::detail::is_same_v <T, uint8_t >) {
182
+ } else if constexpr (std::is_same <T, uint8_t >::value ) {
183
183
auto tileptr = reinterpret_cast <int32_t const *>(src.get ());
184
184
auto destptr = reinterpret_cast <int32_t *>(&res.data );
185
185
if constexpr (NumRows == 16 && NumCols == 16 ) {
@@ -205,7 +205,7 @@ struct joint_matrix_load_impl<
205
205
__imma_m32n8k16_ld_b_u8 (destptr, tileptr, stride,
206
206
get_layout_id<Layout>());
207
207
}
208
- } else if constexpr (sycl::detail::is_same_v <T, int8_t >) {
208
+ } else if constexpr (std::is_same <T, int8_t >::value ) {
209
209
auto tileptr = reinterpret_cast <int32_t const *>(src.get ());
210
210
auto destptr = reinterpret_cast <int32_t *>(&res.data );
211
211
if constexpr (NumRows == 16 && NumCols == 16 ) {
@@ -231,7 +231,7 @@ struct joint_matrix_load_impl<
231
231
__imma_m32n8k16_ld_b_s8 (destptr, tileptr, stride,
232
232
get_layout_id<Layout>());
233
233
}
234
- } else if constexpr (sycl::detail::is_same_v <T, half>) {
234
+ } else if constexpr (std::is_same <T, half>::value ) {
235
235
auto tileptr = reinterpret_cast <int32_t const *>(src.get ());
236
236
auto dstptr = reinterpret_cast <int32_t *>(&res.data );
237
237
if constexpr (NumRows == 16 && NumCols == 16 ) {
@@ -264,7 +264,7 @@ struct joint_matrix_load_impl<
264
264
get_layout_id<Layout>());
265
265
}
266
266
267
- } else if constexpr (sycl::detail::is_same_v <T, int32_t >) {
267
+ } else if constexpr (std::is_same <T, int32_t >::value ) {
268
268
auto destptr = reinterpret_cast <int32_t *>(&res.data );
269
269
if constexpr (NumRows == 16 && NumCols == 16 ) {
270
270
__imma_m16n16k16_ld_c (destptr, src.get (), stride,
@@ -278,14 +278,15 @@ struct joint_matrix_load_impl<
278
278
}
279
279
} else if constexpr (std::is_same<T, float >::value) {
280
280
if (std::is_same<S, float >::value) {
281
+ auto destptr = reinterpret_cast <float *>(&res.data );
281
282
if constexpr (NumRows == 16 && NumCols == 16 ) {
282
- __hmma_m16n16k16_ld_c_f32 (res. data , src.get (), stride,
283
+ __hmma_m16n16k16_ld_c_f32 (destptr , src.get (), stride,
283
284
get_layout_id<Layout>());
284
285
} else if constexpr (NumRows == 8 && NumCols == 32 ) {
285
- __hmma_m8n32k16_ld_c_f32 (res. data , src.get (), stride,
286
+ __hmma_m8n32k16_ld_c_f32 (destptr , src.get (), stride,
286
287
get_layout_id<Layout>());
287
288
} else if constexpr (NumRows == 32 && NumCols == 8 ) {
288
- __hmma_m32n8k16_ld_c_f32 (res. data , src.get (), stride,
289
+ __hmma_m32n8k16_ld_c_f32 (destptr , src.get (), stride,
289
290
get_layout_id<Layout>());
290
291
}
291
292
} else if (std::is_same<S, sycl::ext::oneapi::experimental::matrix::
@@ -299,7 +300,7 @@ struct joint_matrix_load_impl<
299
300
tileptr, stride, get_layout_id<Layout>());
300
301
}
301
302
}
302
- } else if constexpr (sycl::detail::is_same_v <T, double >) {
303
+ } else if constexpr (std::is_same <T, double >::value ) {
303
304
auto dstptr = reinterpret_cast <double *>(&res.data );
304
305
if constexpr (Use ==
305
306
sycl::ext::oneapi::experimental::matrix::matrix_use::a) {
@@ -341,51 +342,51 @@ struct joint_matrix_store_impl<
341
342
NumRows, NumCols, Layout, sycl::sub_group> &src,
342
343
multi_ptr<T, Space> dst, size_t stride) {
343
344
if (NumRows == 16 && NumCols == 16 ) {
344
- if constexpr (sycl::detail::is_same_v <T, float >) {
345
+ if constexpr (std::is_same <T, float >::value ) {
345
346
auto srcptr = reinterpret_cast <float *>(&src.data );
346
347
__hmma_m16n16k16_st_c_f32 (dst.get (), srcptr, stride,
347
348
get_layout_id<Layout>());
348
- } else if constexpr (sycl::detail::is_same_v <T, int32_t >) {
349
+ } else if constexpr (std::is_same <T, int32_t >::value ) {
349
350
auto srcptr = reinterpret_cast <int32_t *>(&src.data );
350
351
__imma_m16n16k16_st_c_i32 (dst.get (), srcptr, stride,
351
352
get_layout_id<Layout>());
352
- } else if constexpr (sycl::detail::is_same_v <T, half>) {
353
+ } else if constexpr (std::is_same <T, half>::value ) {
353
354
auto tileptr = reinterpret_cast <int32_t *>(dst.get ());
354
355
auto srcptr = reinterpret_cast <int32_t *>(&src.data );
355
356
__hmma_m16n16k16_st_c_f16 (tileptr, srcptr, stride,
356
357
get_layout_id<Layout>());
357
358
}
358
359
} else if (NumRows == 8 && NumCols == 32 ) {
359
- if constexpr (sycl::detail::is_same_v <T, float >) {
360
+ if constexpr (std::is_same <T, float >::value ) {
360
361
auto srcptr = reinterpret_cast <float *>(&src.data );
361
362
__hmma_m8n32k16_st_c_f32 (dst.get (), srcptr, stride,
362
363
get_layout_id<Layout>());
363
- } else if constexpr (sycl::detail::is_same_v <T, int32_t >) {
364
+ } else if constexpr (std::is_same <T, int32_t >::value ) {
364
365
auto srcptr = reinterpret_cast <int32_t *>(&src.data );
365
366
__imma_m8n32k16_st_c_i32 (dst.get (), srcptr, stride,
366
367
get_layout_id<Layout>());
367
- } else if constexpr (sycl::detail::is_same_v <T, half>) {
368
+ } else if constexpr (std::is_same <T, half>::value ) {
368
369
auto tileptr = reinterpret_cast <int32_t *>(dst.get ());
369
370
auto srcptr = reinterpret_cast <int32_t *>(&src.data );
370
371
__hmma_m8n32k16_st_c_f16 (tileptr, srcptr, stride,
371
372
get_layout_id<Layout>());
372
373
}
373
374
} else if (NumRows == 32 && NumCols == 8 ) {
374
- if constexpr (sycl::detail::is_same_v <T, float >) {
375
+ if constexpr (std::is_same <T, float >::value ) {
375
376
auto srcptr = reinterpret_cast <float *>(&src.data );
376
377
__hmma_m32n8k16_st_c_f32 (dst.get (), srcptr, stride,
377
378
get_layout_id<Layout>());
378
- } else if constexpr (sycl::detail::is_same_v <T, int32_t >) {
379
+ } else if constexpr (std::is_same <T, int32_t >::value ) {
379
380
auto srcptr = reinterpret_cast <int32_t *>(&src.data );
380
381
__imma_m32n8k16_st_c_i32 (dst.get (), srcptr, stride,
381
382
get_layout_id<Layout>());
382
- } else if constexpr (sycl::detail::is_same_v <T, half>) {
383
+ } else if constexpr (std::is_same <T, half>::value ) {
383
384
auto tileptr = reinterpret_cast <int32_t *>(dst.get ());
384
385
auto srcptr = reinterpret_cast <int32_t *>(&src.data );
385
386
__hmma_m32n8k16_st_c_f16 (tileptr, srcptr, stride,
386
387
get_layout_id<Layout>());
387
388
}
388
- } else if constexpr (sycl::detail::is_same_v <T, double >) {
389
+ } else if constexpr (std::is_same <T, double >::value ) {
389
390
auto srcptr = reinterpret_cast <double *>(&src.data );
390
391
__dmma_m8n8k4_st_c_f64 (dst.get (), srcptr, stride,
391
392
get_layout_id<Layout>());
@@ -487,19 +488,19 @@ struct joint_matrix_mad_impl<
487
488
N, LayoutC, sycl::sub_group>
488
489
D;
489
490
if constexpr (M == 16 && N == 16 && K == 16 ) {
490
- if constexpr (sycl::detail::is_same_v <T2, int32_t >) {
491
+ if constexpr (std::is_same <T2, int32_t >::value ) {
491
492
auto ptrA = reinterpret_cast <int32_t const *>(&A.data );
492
493
auto ptrB = reinterpret_cast <int32_t const *>(&B.data );
493
494
auto ptrC = reinterpret_cast <int32_t const *>(&C.data );
494
495
auto ptrD = reinterpret_cast <int32_t *>(&D.data );
495
- if constexpr (sycl::detail::is_same_v <T1, int8_t >) {
496
+ if constexpr (std::is_same <T1, int8_t >::value ) {
496
497
__imma_m16n16k16_mma_s8 (ptrD, ptrA, ptrB, ptrC,
497
498
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
498
- } else if constexpr (sycl::detail::is_same_v <T1, uint8_t >) {
499
+ } else if constexpr (std::is_same <T1, uint8_t >::value ) {
499
500
__imma_m16n16k16_mma_u8 (ptrD, ptrA, ptrB, ptrC,
500
501
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
501
502
}
502
- } else if constexpr (sycl::detail::is_same_v <T1, half>) {
503
+ } else if constexpr (std::is_same <T1, half>::value ) {
503
504
auto ptrA = reinterpret_cast <int32_t const *>(&A.data );
504
505
auto ptrB = reinterpret_cast <int32_t const *>(&B.data );
505
506
if constexpr (std::is_same<T2, float >::value) {
@@ -508,16 +509,16 @@ struct joint_matrix_mad_impl<
508
509
__hmma_m16n16k16_mma_f32f32 (ptrD, ptrA, ptrB, ptrC,
509
510
get_layout_pair_id<LayoutA, LayoutB>(),
510
511
0 );
511
- } else if constexpr (sycl::detail::is_same_v <T2, half>) {
512
+ } else if constexpr (std::is_same <T2, half>::value ) {
512
513
auto ptrC = reinterpret_cast <int32_t const *>(&C.data );
513
514
auto ptrD = reinterpret_cast <int32_t *>(&D.data );
514
515
__hmma_m16n16k16_mma_f16f16 (ptrD, ptrA, ptrB, ptrC,
515
516
get_layout_pair_id<LayoutA, LayoutB>(),
516
517
0 );
517
518
}
518
- } else if constexpr (sycl::detail::is_same_v <T1, uint16_t > ||
519
- sycl::detail::is_same_v<
520
- T1, sycl::ext::oneapi::experimental:: bfloat16>) {
519
+ } else if constexpr (std::is_same <T1, uint16_t >::value ||
520
+ std::is_same<T1, sycl::ext::oneapi::experimental::
521
+ bfloat16>::value ) {
521
522
auto ptrA = reinterpret_cast <int32_t const *>(&A.data );
522
523
auto ptrB = reinterpret_cast <int32_t const *>(&B.data );
523
524
auto ptrC = reinterpret_cast <float const *>(&C.data );
@@ -526,35 +527,35 @@ struct joint_matrix_mad_impl<
526
527
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
527
528
}
528
529
} else if constexpr (M == 8 && N == 32 && K == 16 ) {
529
- if constexpr (sycl::detail::is_same_v <T2, int32_t >) {
530
+ if constexpr (std::is_same <T2, int32_t >::value ) {
530
531
auto ptrA = reinterpret_cast <int32_t const *>(&A.data );
531
532
auto ptrB = reinterpret_cast <int32_t const *>(&B.data );
532
533
auto ptrC = reinterpret_cast <int32_t const *>(&C.data );
533
534
auto ptrD = reinterpret_cast <int32_t *>(&D.data );
534
- if constexpr (sycl::detail::is_same_v <T1, int8_t >) {
535
+ if constexpr (std::is_same <T1, int8_t >::value ) {
535
536
__imma_m8n32k16_mma_s8 (ptrD, ptrA, ptrB, ptrC,
536
537
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
537
- } else if constexpr (sycl::detail::is_same_v <T1, uint8_t >) {
538
+ } else if constexpr (std::is_same <T1, uint8_t >::value ) {
538
539
__imma_m8n32k16_mma_u8 (ptrD, ptrA, ptrB, ptrC,
539
540
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
540
541
}
541
- } else if constexpr (sycl::detail::is_same_v <T1, half>) {
542
+ } else if constexpr (std::is_same <T1, half>::value ) {
542
543
auto ptrA = reinterpret_cast <int32_t const *>(&A.data );
543
544
auto ptrB = reinterpret_cast <int32_t const *>(&B.data );
544
- if constexpr (sycl::detail::is_same_v <T2, float >) {
545
+ if constexpr (std::is_same <T2, float >::value ) {
545
546
auto ptrC = reinterpret_cast <float const *>(&C.data );
546
547
auto ptrD = reinterpret_cast <float *>(&D.data );
547
548
__hmma_m8n32k16_mma_f32f32 (ptrD, ptrA, ptrB, ptrC,
548
549
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
549
- } else if constexpr (sycl::detail::is_same_v <T2, half>) {
550
+ } else if constexpr (std::is_same <T2, half>::value ) {
550
551
auto ptrC = reinterpret_cast <int32_t const *>(&C.data );
551
552
auto ptrD = reinterpret_cast <int32_t *>(&D.data );
552
553
__hmma_m8n32k16_mma_f16f16 (ptrD, ptrA, ptrB, ptrC,
553
554
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
554
555
}
555
- } else if constexpr (sycl::detail::is_same_v <T1, uint16_t > ||
556
- sycl::detail::is_same_v<
557
- T1, sycl::ext::oneapi::experimental:: bfloat16>) {
556
+ } else if constexpr (std::is_same <T1, uint16_t >::value ||
557
+ std::is_same<T1, sycl::ext::oneapi::experimental::
558
+ bfloat16>::value ) {
558
559
auto ptrA = reinterpret_cast <int32_t const *>(&A.data );
559
560
auto ptrB = reinterpret_cast <int32_t const *>(&B.data );
560
561
auto ptrC = reinterpret_cast <float const *>(&C.data );
@@ -563,47 +564,49 @@ struct joint_matrix_mad_impl<
563
564
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
564
565
}
565
566
} else if constexpr (M == 32 && N == 8 && K == 16 ) {
566
- if constexpr (sycl::detail::is_same_v <T2, int32_t >) {
567
+ if constexpr (std::is_same <T2, int32_t >::value ) {
567
568
auto ptrA = reinterpret_cast <int32_t const *>(&A.data );
568
569
auto ptrB = reinterpret_cast <int32_t const *>(&B.data );
569
570
auto ptrC = reinterpret_cast <int32_t const *>(&C.data );
570
571
auto ptrD = reinterpret_cast <int32_t *>(&D.data );
571
- if constexpr (sycl::detail::is_same_v <T1, int8_t >) {
572
+ if constexpr (std::is_same <T1, int8_t >::value ) {
572
573
__imma_m32n8k16_mma_s8 (ptrD, ptrA, ptrB, ptrC,
573
574
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
574
- } else if constexpr (sycl::detail::is_same_v <T1, uint8_t >) {
575
+ } else if constexpr (std::is_same <T1, uint8_t >::value ) {
575
576
__imma_m32n8k16_mma_u8 (ptrD, ptrA, ptrB, ptrC,
576
577
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
577
578
}
578
- } else if constexpr (sycl::detail::is_same_v <T1, uint16_t > ||
579
- sycl::detail::is_same_v<
580
- T1, sycl::ext::oneapi::experimental:: bfloat16>) {
579
+ } else if constexpr (std::is_same <T1, uint16_t >::value ||
580
+ std::is_same<T1, sycl::ext::oneapi::experimental::
581
+ bfloat16>::value ) {
581
582
auto ptrA = reinterpret_cast <int32_t const *>(&A.data );
582
583
auto ptrB = reinterpret_cast <int32_t const *>(&B.data );
583
584
auto ptrC = reinterpret_cast <float const *>(&C.data );
584
585
auto ptrD = reinterpret_cast <float *>(&D.data );
585
586
__mma_bf16_m32n8k16_mma_f32 (ptrD, ptrA, ptrB, ptrC,
586
587
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
587
- } else if constexpr (sycl::detail::is_same_v <T1, half>) {
588
+ } else if constexpr (std::is_same <T1, half>::value ) {
588
589
auto ptrA = reinterpret_cast <int32_t const *>(&A.data );
589
590
auto ptrB = reinterpret_cast <int32_t const *>(&B.data );
590
- if constexpr (sycl::detail::is_same_v <T2, float >) {
591
+ if constexpr (std::is_same <T2, float >::value ) {
591
592
auto ptrC = reinterpret_cast <float const *>(&C.data );
592
593
auto ptrD = reinterpret_cast <float *>(&D.data );
593
594
__hmma_m32n8k16_mma_f32f32 (ptrD, ptrA, ptrB, ptrC,
594
595
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
595
- } else if constexpr (sycl::detail::is_same_v <T2, half>) {
596
+ } else if constexpr (std::is_same <T2, half>::value ) {
596
597
auto ptrC = reinterpret_cast <int32_t const *>(&C.data );
597
598
auto ptrD = reinterpret_cast <int32_t *>(&D.data );
598
599
__hmma_m32n8k16_mma_f16f16 (ptrD, ptrA, ptrB, ptrC,
599
600
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
600
601
}
601
602
}
602
603
} else if constexpr (M == 16 && N == 16 && K == 8 ) {
603
- __mma_tf32_m16n16k8_mma_f32 (D.data , reinterpret_cast <int32_t *>(A.data ),
604
- reinterpret_cast <int32_t *>(B.data ), C.data ,
604
+ __mma_tf32_m16n16k8_mma_f32 (reinterpret_cast <float *>(&D.data ),
605
+ reinterpret_cast <int32_t *>(A.data ),
606
+ reinterpret_cast <int32_t *>(B.data ),
607
+ reinterpret_cast <float *>(&C.data ),
605
608
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
606
- } else if constexpr (sycl::detail::is_same_v <T1, double >) {
609
+ } else if constexpr (std::is_same <T1, double >::value ) {
607
610
auto ptrA = reinterpret_cast <double const *>(&A.data );
608
611
auto ptrB = reinterpret_cast <double const *>(&B.data );
609
612
auto ptrC = reinterpret_cast <double const *>(&C.data );
0 commit comments