Skip to content

Commit 21cc02c

Browse files
committed
Updating ptr cast
1 parent 438a9f2 commit 21cc02c

File tree

1 file changed

+53
-50
lines changed

1 file changed

+53
-50
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp

Lines changed: 53 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,9 @@ struct joint_matrix_load_impl<
151151
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix<
152152
S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
153153
multi_ptr<T, Space> src, size_t stride) {
154-
if constexpr (sycl::detail::is_same_v<T, uint16_t> ||
155-
sycl::detail::is_same_v<
156-
T, sycl::ext::oneapi::experimental::bfloat16>) {
154+
if constexpr (std::is_same<T, uint16_t>::value ||
155+
std::is_same<
156+
T, sycl::ext::oneapi::experimental::bfloat16>::value) {
157157
auto tileptr = reinterpret_cast<int32_t const *>(src.get());
158158
auto destptr = reinterpret_cast<int32_t *>(&res.data);
159159
if constexpr (NumRows == 16 && NumCols == 16) {
@@ -179,7 +179,7 @@ struct joint_matrix_load_impl<
179179
__mma_bf16_m32n8k16_ld_b(destptr, tileptr, stride,
180180
get_layout_id<Layout>());
181181
}
182-
} else if constexpr (sycl::detail::is_same_v<T, uint8_t>) {
182+
} else if constexpr (std::is_same<T, uint8_t>::value) {
183183
auto tileptr = reinterpret_cast<int32_t const *>(src.get());
184184
auto destptr = reinterpret_cast<int32_t *>(&res.data);
185185
if constexpr (NumRows == 16 && NumCols == 16) {
@@ -205,7 +205,7 @@ struct joint_matrix_load_impl<
205205
__imma_m32n8k16_ld_b_u8(destptr, tileptr, stride,
206206
get_layout_id<Layout>());
207207
}
208-
} else if constexpr (sycl::detail::is_same_v<T, int8_t>) {
208+
} else if constexpr (std::is_same<T, int8_t>::value) {
209209
auto tileptr = reinterpret_cast<int32_t const *>(src.get());
210210
auto destptr = reinterpret_cast<int32_t *>(&res.data);
211211
if constexpr (NumRows == 16 && NumCols == 16) {
@@ -231,7 +231,7 @@ struct joint_matrix_load_impl<
231231
__imma_m32n8k16_ld_b_s8(destptr, tileptr, stride,
232232
get_layout_id<Layout>());
233233
}
234-
} else if constexpr (sycl::detail::is_same_v<T, half>) {
234+
} else if constexpr (std::is_same<T, half>::value) {
235235
auto tileptr = reinterpret_cast<int32_t const *>(src.get());
236236
auto dstptr = reinterpret_cast<int32_t *>(&res.data);
237237
if constexpr (NumRows == 16 && NumCols == 16) {
@@ -264,7 +264,7 @@ struct joint_matrix_load_impl<
264264
get_layout_id<Layout>());
265265
}
266266

267-
} else if constexpr (sycl::detail::is_same_v<T, int32_t>) {
267+
} else if constexpr (std::is_same<T, int32_t>::value) {
268268
auto destptr = reinterpret_cast<int32_t *>(&res.data);
269269
if constexpr (NumRows == 16 && NumCols == 16) {
270270
__imma_m16n16k16_ld_c(destptr, src.get(), stride,
@@ -278,14 +278,15 @@ struct joint_matrix_load_impl<
278278
}
279279
} else if constexpr (std::is_same<T, float>::value) {
280280
if (std::is_same<S, float>::value) {
281+
auto destptr = reinterpret_cast<float *>(&res.data);
281282
if constexpr (NumRows == 16 && NumCols == 16) {
282-
__hmma_m16n16k16_ld_c_f32(res.data, src.get(), stride,
283+
__hmma_m16n16k16_ld_c_f32(destptr, src.get(), stride,
283284
get_layout_id<Layout>());
284285
} else if constexpr (NumRows == 8 && NumCols == 32) {
285-
__hmma_m8n32k16_ld_c_f32(res.data, src.get(), stride,
286+
__hmma_m8n32k16_ld_c_f32(destptr, src.get(), stride,
286287
get_layout_id<Layout>());
287288
} else if constexpr (NumRows == 32 && NumCols == 8) {
288-
__hmma_m32n8k16_ld_c_f32(res.data, src.get(), stride,
289+
__hmma_m32n8k16_ld_c_f32(destptr, src.get(), stride,
289290
get_layout_id<Layout>());
290291
}
291292
} else if (std::is_same<S, sycl::ext::oneapi::experimental::matrix::
@@ -299,7 +300,7 @@ struct joint_matrix_load_impl<
299300
tileptr, stride, get_layout_id<Layout>());
300301
}
301302
}
302-
} else if constexpr (sycl::detail::is_same_v<T, double>) {
303+
} else if constexpr (std::is_same<T, double>::value) {
303304
auto dstptr = reinterpret_cast<double *>(&res.data);
304305
if constexpr (Use ==
305306
sycl::ext::oneapi::experimental::matrix::matrix_use::a) {
@@ -341,51 +342,51 @@ struct joint_matrix_store_impl<
341342
NumRows, NumCols, Layout, sycl::sub_group> &src,
342343
multi_ptr<T, Space> dst, size_t stride) {
343344
if (NumRows == 16 && NumCols == 16) {
344-
if constexpr (sycl::detail::is_same_v<T, float>) {
345+
if constexpr (std::is_same<T, float>::value) {
345346
auto srcptr = reinterpret_cast<float *>(&src.data);
346347
__hmma_m16n16k16_st_c_f32(dst.get(), srcptr, stride,
347348
get_layout_id<Layout>());
348-
} else if constexpr (sycl::detail::is_same_v<T, int32_t>) {
349+
} else if constexpr (std::is_same<T, int32_t>::value) {
349350
auto srcptr = reinterpret_cast<int32_t *>(&src.data);
350351
__imma_m16n16k16_st_c_i32(dst.get(), srcptr, stride,
351352
get_layout_id<Layout>());
352-
} else if constexpr (sycl::detail::is_same_v<T, half>) {
353+
} else if constexpr (std::is_same<T, half>::value) {
353354
auto tileptr = reinterpret_cast<int32_t *>(dst.get());
354355
auto srcptr = reinterpret_cast<int32_t *>(&src.data);
355356
__hmma_m16n16k16_st_c_f16(tileptr, srcptr, stride,
356357
get_layout_id<Layout>());
357358
}
358359
} else if (NumRows == 8 && NumCols == 32) {
359-
if constexpr (sycl::detail::is_same_v<T, float>) {
360+
if constexpr (std::is_same<T, float>::value) {
360361
auto srcptr = reinterpret_cast<float *>(&src.data);
361362
__hmma_m8n32k16_st_c_f32(dst.get(), srcptr, stride,
362363
get_layout_id<Layout>());
363-
} else if constexpr (sycl::detail::is_same_v<T, int32_t>) {
364+
} else if constexpr (std::is_same<T, int32_t>::value) {
364365
auto srcptr = reinterpret_cast<int32_t *>(&src.data);
365366
__imma_m8n32k16_st_c_i32(dst.get(), srcptr, stride,
366367
get_layout_id<Layout>());
367-
} else if constexpr (sycl::detail::is_same_v<T, half>) {
368+
} else if constexpr (std::is_same<T, half>::value) {
368369
auto tileptr = reinterpret_cast<int32_t *>(dst.get());
369370
auto srcptr = reinterpret_cast<int32_t *>(&src.data);
370371
__hmma_m8n32k16_st_c_f16(tileptr, srcptr, stride,
371372
get_layout_id<Layout>());
372373
}
373374
} else if (NumRows == 32 && NumCols == 8) {
374-
if constexpr (sycl::detail::is_same_v<T, float>) {
375+
if constexpr (std::is_same<T, float>::value) {
375376
auto srcptr = reinterpret_cast<float *>(&src.data);
376377
__hmma_m32n8k16_st_c_f32(dst.get(), srcptr, stride,
377378
get_layout_id<Layout>());
378-
} else if constexpr (sycl::detail::is_same_v<T, int32_t>) {
379+
} else if constexpr (std::is_same<T, int32_t>::value) {
379380
auto srcptr = reinterpret_cast<int32_t *>(&src.data);
380381
__imma_m32n8k16_st_c_i32(dst.get(), srcptr, stride,
381382
get_layout_id<Layout>());
382-
} else if constexpr (sycl::detail::is_same_v<T, half>) {
383+
} else if constexpr (std::is_same<T, half>::value) {
383384
auto tileptr = reinterpret_cast<int32_t *>(dst.get());
384385
auto srcptr = reinterpret_cast<int32_t *>(&src.data);
385386
__hmma_m32n8k16_st_c_f16(tileptr, srcptr, stride,
386387
get_layout_id<Layout>());
387388
}
388-
} else if constexpr (sycl::detail::is_same_v<T, double>) {
389+
} else if constexpr (std::is_same<T, double>::value) {
389390
auto srcptr = reinterpret_cast<double *>(&src.data);
390391
__dmma_m8n8k4_st_c_f64(dst.get(), srcptr, stride,
391392
get_layout_id<Layout>());
@@ -487,19 +488,19 @@ struct joint_matrix_mad_impl<
487488
N, LayoutC, sycl::sub_group>
488489
D;
489490
if constexpr (M == 16 && N == 16 && K == 16) {
490-
if constexpr (sycl::detail::is_same_v<T2, int32_t>) {
491+
if constexpr (std::is_same<T2, int32_t>::value) {
491492
auto ptrA = reinterpret_cast<int32_t const *>(&A.data);
492493
auto ptrB = reinterpret_cast<int32_t const *>(&B.data);
493494
auto ptrC = reinterpret_cast<int32_t const *>(&C.data);
494495
auto ptrD = reinterpret_cast<int32_t *>(&D.data);
495-
if constexpr (sycl::detail::is_same_v<T1, int8_t>) {
496+
if constexpr (std::is_same<T1, int8_t>::value) {
496497
__imma_m16n16k16_mma_s8(ptrD, ptrA, ptrB, ptrC,
497498
get_layout_pair_id<LayoutA, LayoutB>(), 0);
498-
} else if constexpr (sycl::detail::is_same_v<T1, uint8_t>) {
499+
} else if constexpr (std::is_same<T1, uint8_t>::value) {
499500
__imma_m16n16k16_mma_u8(ptrD, ptrA, ptrB, ptrC,
500501
get_layout_pair_id<LayoutA, LayoutB>(), 0);
501502
}
502-
} else if constexpr (sycl::detail::is_same_v<T1, half>) {
503+
} else if constexpr (std::is_same<T1, half>::value) {
503504
auto ptrA = reinterpret_cast<int32_t const *>(&A.data);
504505
auto ptrB = reinterpret_cast<int32_t const *>(&B.data);
505506
if constexpr (std::is_same<T2, float>::value) {
@@ -508,16 +509,16 @@ struct joint_matrix_mad_impl<
508509
__hmma_m16n16k16_mma_f32f32(ptrD, ptrA, ptrB, ptrC,
509510
get_layout_pair_id<LayoutA, LayoutB>(),
510511
0);
511-
} else if constexpr (sycl::detail::is_same_v<T2, half>) {
512+
} else if constexpr (std::is_same<T2, half>::value) {
512513
auto ptrC = reinterpret_cast<int32_t const *>(&C.data);
513514
auto ptrD = reinterpret_cast<int32_t *>(&D.data);
514515
__hmma_m16n16k16_mma_f16f16(ptrD, ptrA, ptrB, ptrC,
515516
get_layout_pair_id<LayoutA, LayoutB>(),
516517
0);
517518
}
518-
} else if constexpr (sycl::detail::is_same_v<T1, uint16_t> ||
519-
sycl::detail::is_same_v<
520-
T1, sycl::ext::oneapi::experimental::bfloat16>) {
519+
} else if constexpr (std::is_same<T1, uint16_t>::value ||
520+
std::is_same<T1, sycl::ext::oneapi::experimental::
521+
bfloat16>::value) {
521522
auto ptrA = reinterpret_cast<int32_t const *>(&A.data);
522523
auto ptrB = reinterpret_cast<int32_t const *>(&B.data);
523524
auto ptrC = reinterpret_cast<float const *>(&C.data);
@@ -526,35 +527,35 @@ struct joint_matrix_mad_impl<
526527
get_layout_pair_id<LayoutA, LayoutB>(), 0);
527528
}
528529
} else if constexpr (M == 8 && N == 32 && K == 16) {
529-
if constexpr (sycl::detail::is_same_v<T2, int32_t>) {
530+
if constexpr (std::is_same<T2, int32_t>::value) {
530531
auto ptrA = reinterpret_cast<int32_t const *>(&A.data);
531532
auto ptrB = reinterpret_cast<int32_t const *>(&B.data);
532533
auto ptrC = reinterpret_cast<int32_t const *>(&C.data);
533534
auto ptrD = reinterpret_cast<int32_t *>(&D.data);
534-
if constexpr (sycl::detail::is_same_v<T1, int8_t>) {
535+
if constexpr (std::is_same<T1, int8_t>::value) {
535536
__imma_m8n32k16_mma_s8(ptrD, ptrA, ptrB, ptrC,
536537
get_layout_pair_id<LayoutA, LayoutB>(), 0);
537-
} else if constexpr (sycl::detail::is_same_v<T1, uint8_t>) {
538+
} else if constexpr (std::is_same<T1, uint8_t>::value) {
538539
__imma_m8n32k16_mma_u8(ptrD, ptrA, ptrB, ptrC,
539540
get_layout_pair_id<LayoutA, LayoutB>(), 0);
540541
}
541-
} else if constexpr (sycl::detail::is_same_v<T1, half>) {
542+
} else if constexpr (std::is_same<T1, half>::value) {
542543
auto ptrA = reinterpret_cast<int32_t const *>(&A.data);
543544
auto ptrB = reinterpret_cast<int32_t const *>(&B.data);
544-
if constexpr (sycl::detail::is_same_v<T2, float>) {
545+
if constexpr (std::is_same<T2, float>::value) {
545546
auto ptrC = reinterpret_cast<float const *>(&C.data);
546547
auto ptrD = reinterpret_cast<float *>(&D.data);
547548
__hmma_m8n32k16_mma_f32f32(ptrD, ptrA, ptrB, ptrC,
548549
get_layout_pair_id<LayoutA, LayoutB>(), 0);
549-
} else if constexpr (sycl::detail::is_same_v<T2, half>) {
550+
} else if constexpr (std::is_same<T2, half>::value) {
550551
auto ptrC = reinterpret_cast<int32_t const *>(&C.data);
551552
auto ptrD = reinterpret_cast<int32_t *>(&D.data);
552553
__hmma_m8n32k16_mma_f16f16(ptrD, ptrA, ptrB, ptrC,
553554
get_layout_pair_id<LayoutA, LayoutB>(), 0);
554555
}
555-
} else if constexpr (sycl::detail::is_same_v<T1, uint16_t> ||
556-
sycl::detail::is_same_v<
557-
T1, sycl::ext::oneapi::experimental::bfloat16>) {
556+
} else if constexpr (std::is_same<T1, uint16_t>::value ||
557+
std::is_same<T1, sycl::ext::oneapi::experimental::
558+
bfloat16>::value) {
558559
auto ptrA = reinterpret_cast<int32_t const *>(&A.data);
559560
auto ptrB = reinterpret_cast<int32_t const *>(&B.data);
560561
auto ptrC = reinterpret_cast<float const *>(&C.data);
@@ -563,47 +564,49 @@ struct joint_matrix_mad_impl<
563564
get_layout_pair_id<LayoutA, LayoutB>(), 0);
564565
}
565566
} else if constexpr (M == 32 && N == 8 && K == 16) {
566-
if constexpr (sycl::detail::is_same_v<T2, int32_t>) {
567+
if constexpr (std::is_same<T2, int32_t>::value) {
567568
auto ptrA = reinterpret_cast<int32_t const *>(&A.data);
568569
auto ptrB = reinterpret_cast<int32_t const *>(&B.data);
569570
auto ptrC = reinterpret_cast<int32_t const *>(&C.data);
570571
auto ptrD = reinterpret_cast<int32_t *>(&D.data);
571-
if constexpr (sycl::detail::is_same_v<T1, int8_t>) {
572+
if constexpr (std::is_same<T1, int8_t>::value) {
572573
__imma_m32n8k16_mma_s8(ptrD, ptrA, ptrB, ptrC,
573574
get_layout_pair_id<LayoutA, LayoutB>(), 0);
574-
} else if constexpr (sycl::detail::is_same_v<T1, uint8_t>) {
575+
} else if constexpr (std::is_same<T1, uint8_t>::value) {
575576
__imma_m32n8k16_mma_u8(ptrD, ptrA, ptrB, ptrC,
576577
get_layout_pair_id<LayoutA, LayoutB>(), 0);
577578
}
578-
} else if constexpr (sycl::detail::is_same_v<T1, uint16_t> ||
579-
sycl::detail::is_same_v<
580-
T1, sycl::ext::oneapi::experimental::bfloat16>) {
579+
} else if constexpr (std::is_same<T1, uint16_t>::value ||
580+
std::is_same<T1, sycl::ext::oneapi::experimental::
581+
bfloat16>::value) {
581582
auto ptrA = reinterpret_cast<int32_t const *>(&A.data);
582583
auto ptrB = reinterpret_cast<int32_t const *>(&B.data);
583584
auto ptrC = reinterpret_cast<float const *>(&C.data);
584585
auto ptrD = reinterpret_cast<float *>(&D.data);
585586
__mma_bf16_m32n8k16_mma_f32(ptrD, ptrA, ptrB, ptrC,
586587
get_layout_pair_id<LayoutA, LayoutB>(), 0);
587-
} else if constexpr (sycl::detail::is_same_v<T1, half>) {
588+
} else if constexpr (std::is_same<T1, half>::value) {
588589
auto ptrA = reinterpret_cast<int32_t const *>(&A.data);
589590
auto ptrB = reinterpret_cast<int32_t const *>(&B.data);
590-
if constexpr (sycl::detail::is_same_v<T2, float>) {
591+
if constexpr (std::is_same<T2, float>::value) {
591592
auto ptrC = reinterpret_cast<float const *>(&C.data);
592593
auto ptrD = reinterpret_cast<float *>(&D.data);
593594
__hmma_m32n8k16_mma_f32f32(ptrD, ptrA, ptrB, ptrC,
594595
get_layout_pair_id<LayoutA, LayoutB>(), 0);
595-
} else if constexpr (sycl::detail::is_same_v<T2, half>) {
596+
} else if constexpr (std::is_same<T2, half>::value) {
596597
auto ptrC = reinterpret_cast<int32_t const *>(&C.data);
597598
auto ptrD = reinterpret_cast<int32_t *>(&D.data);
598599
__hmma_m32n8k16_mma_f16f16(ptrD, ptrA, ptrB, ptrC,
599600
get_layout_pair_id<LayoutA, LayoutB>(), 0);
600601
}
601602
}
602603
} else if constexpr (M == 16 && N == 16 && K == 8) {
603-
__mma_tf32_m16n16k8_mma_f32(D.data, reinterpret_cast<int32_t *>(A.data),
604-
reinterpret_cast<int32_t *>(B.data), C.data,
604+
__mma_tf32_m16n16k8_mma_f32(reinterpret_cast<float *>(&D.data),
605+
reinterpret_cast<int32_t *>(A.data),
606+
reinterpret_cast<int32_t *>(B.data),
607+
reinterpret_cast<float *>(&C.data),
605608
get_layout_pair_id<LayoutA, LayoutB>(), 0);
606-
} else if constexpr (sycl::detail::is_same_v<T1, double>) {
609+
} else if constexpr (std::is_same<T1, double>::value) {
607610
auto ptrA = reinterpret_cast<double const *>(&A.data);
608611
auto ptrB = reinterpret_cast<double const *>(&B.data);
609612
auto ptrC = reinterpret_cast<double const *>(&C.data);

0 commit comments

Comments
 (0)