1111#include " cuda_fp8.h"
1212#include " intrin.h"
1313
14- namespace cute {
14+ namespace cute ::tl_mma {
1515
1616template <typename A_type, typename B_type, typename C_type, int num_warp_m,
1717 int num_warp_n, int N>
1818struct DispatchInstruction ;
1919
2020using _X = Underscore;
2121
22- #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
22+ } // namespace cute::tl_mma
23+
24+ #define TL_DISPATCH_MMA (A_type, B_type, C_type, MMA_instr ) \
25+ namespace cute ::tl_mma { \
26+ template <int num_warp_m, int num_warp_n, int N> \
27+ struct DispatchInstruction <A_type, B_type, C_type, num_warp_m, num_warp_n, \
28+ N> { \
29+ using MMA = MMA_Atom<MMA_instr>; \
30+ using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16 , N)>, _X>; \
31+ }; \
32+ }
33+ #define TL_DISPATCH_MMA_TEMPLATE (A_type, B_type, C_type, MMA_instr ) \
34+ namespace cute ::tl_mma { \
35+ template <int num_warp_m, int num_warp_n, int N> \
36+ struct DispatchInstruction <A_type, B_type, C_type, num_warp_m, num_warp_n, \
37+ N> { \
38+ using MMA = MMA_Atom<MMA_instr<A_type, B_type, C_type>>; \
39+ using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16 , N)>, _X>; \
40+ }; \
41+ }
42+
43+ #ifdef __CUDA_ARCH_LIST__
2344#if __CUDA_ARCH_LIST__ >= 1200
24- template <int num_warp_m, int num_warp_n, int N>
25- struct DispatchInstruction <fp8_e4_t , fp8_e4_t , float , num_warp_m, num_warp_n,
26- N> {
27- using MMA = MMA_Atom<SM120_16x8x32_TN<fp8_e4_t , fp8_e4_t , float >>;
28- using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16 , N)>, _X>;
29- };
30- template <int num_warp_m, int num_warp_n, int N>
31- struct DispatchInstruction <fp8_e5_t , fp8_e5_t , float , num_warp_m, num_warp_n,
32- N> {
33- using MMA = MMA_Atom<SM120_16x8x32_TN<fp8_e5_t , fp8_e5_t , float >>;
34- using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16 , N)>, _X>;
35- };
45+ #include " cuda_fp8.h"
46+ #include < cute/arch/mma_sm120.hpp>
47+ #include < cute/arch/mma_sm80.hpp>
48+ TL_DISPATCH_MMA_TEMPLATE (fp8_e4_t , fp8_e4_t , float , SM120_16x8x32_TN)
49+ TL_DISPATCH_MMA_TEMPLATE(fp8_e5_t , fp8_e5_t , float , SM120_16x8x32_TN)
50+ TL_DISPATCH_MMA(half_t , half_t , half_t , SM80_16x8x16_F16F16F16F16_TN)
51+ TL_DISPATCH_MMA(half_t , half_t , float , SM80_16x8x16_F32F16F16F32_TN)
52+ TL_DISPATCH_MMA(bfloat16_t , bfloat16_t , float , SM80_16x8x16_F32BF16BF16F32_TN)
53+ TL_DISPATCH_MMA(tfloat32_t , tfloat32_t , float , SM80_16x8x8_F32TF32TF32F32_TN)
54+ TL_DISPATCH_MMA(int8_t , int8_t , int , SM80_16x8x32_S32S8S8S32_TN)
55+ TL_DISPATCH_MMA(double , double , double , SM80_8x8x4_F64F64F64F64_TN)
56+ #elif __CUDA_ARCH_LIST__ >= 1000
57+ #include " cuda_fp8.h"
58+ #include < cute/arch/mma_sm100.hpp>
59+ #include < cute/arch/mma_sm80.hpp>
60+ #include < cute/arch/mma_sm89.hpp>
61+ TL_DISPATCH_MMA (fp8_e4_t , fp8_e4_t , float , SM89_16x8x32_F32E4M3E4M3F32_TN)
62+ TL_DISPATCH_MMA(fp8_e5_t , fp8_e5_t , float , SM89_16x8x32_F32E5M2E5M2F32_TN)
63+ TL_DISPATCH_MMA(half_t , half_t , half_t , SM80_16x8x16_F16F16F16F16_TN)
64+ TL_DISPATCH_MMA(half_t , half_t , float , SM80_16x8x16_F32F16F16F32_TN)
65+ TL_DISPATCH_MMA(bfloat16_t , bfloat16_t , float , SM80_16x8x16_F32BF16BF16F32_TN)
66+ TL_DISPATCH_MMA(tfloat32_t , tfloat32_t , float , SM80_16x8x8_F32TF32TF32F32_TN)
67+ TL_DISPATCH_MMA(int8_t , int8_t , int , SM80_16x8x32_S32S8S8S32_TN)
68+ TL_DISPATCH_MMA(double , double , double , SM80_8x8x4_F64F64F64F64_TN)
69+ #elif __CUDA_ARCH_LIST__ >= 900
70+ #include " cuda_fp8.h"
71+ #include < cute/arch/mma_sm80.hpp>
72+ #include < cute/arch/mma_sm89.hpp>
73+ TL_DISPATCH_MMA (fp8_e4_t , fp8_e4_t , float , SM89_16x8x32_F32E4M3E4M3F32_TN)
74+ TL_DISPATCH_MMA(fp8_e5_t , fp8_e5_t , float , SM89_16x8x32_F32E5M2E5M2F32_TN)
75+ TL_DISPATCH_MMA(half_t , half_t , half_t , SM80_16x8x16_F16F16F16F16_TN)
76+ TL_DISPATCH_MMA(half_t , half_t , float , SM80_16x8x16_F32F16F16F32_TN)
77+ TL_DISPATCH_MMA(bfloat16_t , bfloat16_t , float , SM80_16x8x16_F32BF16BF16F32_TN)
78+ TL_DISPATCH_MMA(tfloat32_t , tfloat32_t , float , SM80_16x8x8_F32TF32TF32F32_TN)
79+ TL_DISPATCH_MMA(int8_t , int8_t , int , SM80_16x8x32_S32S8S8S32_TN)
80+ TL_DISPATCH_MMA(double , double , double , SM80_8x8x4_F64F64F64F64_TN)
3681#elif __CUDA_ARCH_LIST__ >= 890
37- template <int num_warp_m, int num_warp_n, int N>
38- struct DispatchInstruction <fp8_e4_t , fp8_e4_t , float , num_warp_m, num_warp_n,
39- N> {
40- using MMA = MMA_Atom<SM89_16x8x32_F32E4M3E4M3F32_TN>;
41- using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16 , N)>, _X>;
42- };
43- template <int num_warp_m, int num_warp_n, int N>
44- struct DispatchInstruction <fp8_e5_t , fp8_e5_t , float , num_warp_m, num_warp_n,
45- N> {
46- using MMA = MMA_Atom<SM89_16x8x32_F32E5M2E5M2F32_TN>;
47- using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16 , N)>, _X>;
48- };
82+ #include " cuda_fp8.h"
83+ #include < cute/arch/mma_sm80.hpp>
84+ #include < cute/arch/mma_sm89.hpp>
85+ TL_DISPATCH_MMA (fp8_e4_t , fp8_e4_t , float , SM89_16x8x32_F32E4M3E4M3F32_TN)
86+ TL_DISPATCH_MMA(fp8_e5_t , fp8_e5_t , float , SM89_16x8x32_F32E5M2E5M2F32_TN)
87+ TL_DISPATCH_MMA(half_t , half_t , half_t , SM80_16x8x16_F16F16F16F16_TN)
88+ TL_DISPATCH_MMA(half_t , half_t , float , SM80_16x8x16_F32F16F16F32_TN)
89+ TL_DISPATCH_MMA(bfloat16_t , bfloat16_t , float , SM80_16x8x16_F32BF16BF16F32_TN)
90+ TL_DISPATCH_MMA(tfloat32_t , tfloat32_t , float , SM80_16x8x8_F32TF32TF32F32_TN)
91+ TL_DISPATCH_MMA(int8_t , int8_t , int , SM80_16x8x32_S32S8S8S32_TN)
92+ TL_DISPATCH_MMA(double , double , double , SM80_8x8x4_F64F64F64F64_TN)
93+ #elif __CUDA_ARCH_LIST__ >= 800
94+ #include < cute/arch/mma_sm80.hpp>
95+ TL_DISPATCH_MMA (half_t , half_t , half_t , SM80_16x8x16_F16F16F16F16_TN)
96+ TL_DISPATCH_MMA(half_t , half_t , float , SM80_16x8x16_F32F16F16F32_TN)
97+ TL_DISPATCH_MMA(bfloat16_t , bfloat16_t , float , SM80_16x8x16_F32BF16BF16F32_TN)
98+ TL_DISPATCH_MMA(tfloat32_t , tfloat32_t , float , SM80_16x8x8_F32TF32TF32F32_TN)
99+ TL_DISPATCH_MMA(int8_t , int8_t , int , SM80_16x8x32_S32S8S8S32_TN)
100+ TL_DISPATCH_MMA(double , double , double , SM80_8x8x4_F64F64F64F64_TN)
101+ #elif __CUDA_ARCH_LIST__ >= 750
102+ TL_DISPATCH_MMA (half_t , half_t , float , SM75_16x8x8_F32F16F16F32_TN)
49103#endif
50- template <int num_warp_m, int num_warp_n, int N>
51- struct DispatchInstruction <half_t , half_t , half_t , num_warp_m, num_warp_n, N> {
52- using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
53- using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16 , N)>, _X>;
54- };
55- template <int num_warp_m, int num_warp_n, int N>
56- struct DispatchInstruction <half_t , half_t , float , num_warp_m, num_warp_n, N> {
57- using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>;
58- using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16 , N)>, _X>;
59- };
60- template <int num_warp_m, int num_warp_n, int N>
61- struct DispatchInstruction <bfloat16_t , bfloat16_t , float , num_warp_m,
62- num_warp_n, N> {
63- using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>;
64- using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16 , N)>, _X>;
65- };
66- template <int num_warp_m, int num_warp_n, int N>
67- struct DispatchInstruction <tfloat32_t , tfloat32_t , float , num_warp_m,
68- num_warp_n, N> {
69- using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>;
70- using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16 , N)>, _X>;
71- };
72- template <int num_warp_m, int num_warp_n, int N>
73- struct DispatchInstruction <int8_t , int8_t , int , num_warp_m, num_warp_n, N> {
74- using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>;
75- using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16 , N)>, _X>;
76- };
77- template <int num_warp_m, int num_warp_n, int N>
78- struct DispatchInstruction <double , double , double , num_warp_m, num_warp_n, N> {
79- using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
80- using MMA_Group = Tile<Int<num_warp_m * 16 >, Int<num_warp_n * 16 >, _X>;
81- };
82- #elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
83- template <int num_warp_m, int num_warp_n, int N>
84- struct DispatchInstruction <half_t , half_t , float , num_warp_m, num_warp_n, N> {
85- using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
86- using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16 , N)>, _16>;
87- };
88104#endif
105+ #undef TL_DISPATCH_MMA
106+ #undef TL_DISPATCH_MMA_TEMPLATE
107+
108+ namespace cute ::tl_mma {
89109
90110template <int N, int num_warp_n, bool transpose> struct SelectCopy {
91111 static constexpr int remainder = (N / num_warp_n) % 16 ;
@@ -334,13 +354,13 @@ class GemmTensorOp {
334354 make_tensor (make_rmem_ptr (reinterpret_cast <C_type *>(pC)),
335355 partition_shape_C (tiled_mma, Shape<Int<M>, Int<N>>{}));
336356
337- if constexpr (clear_accum) {
338- clear (acc);
339- }
340357 // when layout is KxN and n_warp is 1, there seem to be a bug, use this as a
341358 // workaround
342359 auto tCrA_view = make_tensor (tCrA.data (), remove_swizzle (tCrA.layout ()));
343360 auto tCrB_view = make_tensor (tCrB.data (), remove_swizzle (tCrB.layout ()));
361+ if constexpr (clear_accum) {
362+ clear (acc);
363+ }
344364 CUTE_UNROLL
345365 for (int k = 0 ; k < size<2 >(tCrA); ++k) {
346366 copy (tiled_copy_A, tCsA (_, _, k), tCrA_copy_view (_, _, k));
@@ -371,10 +391,10 @@ class GemmTensorOp {
371391 Tensor tCrA =
372392 make_tensor (make_rmem_ptr (reinterpret_cast <A_type *>(pA)),
373393 partition_shape_A (tiled_mma, Shape<Int<M>, Int<K>>{}));
394+ auto tCrB_view = make_tensor (tCrB.data (), remove_swizzle (tCrB.layout ()));
374395 if constexpr (clear_accum) {
375396 clear (acc);
376397 }
377- auto tCrB_view = make_tensor (tCrB.data (), remove_swizzle (tCrB.layout ()));
378398 copy (tiled_copy_B, tCsB (_, _, 0 ), tCrB_copy_view (_, _, 0 ));
379399 CUTE_UNROLL
380400 for (int k = 0 ; k < size<2 >(tCrA); ++k) {
@@ -407,10 +427,10 @@ class GemmTensorOp {
407427 Tensor tCrB =
408428 make_tensor (make_rmem_ptr (reinterpret_cast <B_type *>(pB)),
409429 partition_shape_B (tiled_mma, Shape<Int<N>, Int<K>>{}));
430+ auto tCrA_view = make_tensor (tCrA.data (), remove_swizzle (tCrA.layout ()));
410431 if constexpr (clear_accum) {
411432 clear (acc);
412433 }
413- auto tCrA_view = make_tensor (tCrA.data (), remove_swizzle (tCrA.layout ()));
414434 copy (tiled_copy_A, tCsA (_, _, 0 ), tCrA_copy_view (_, _, 0 ));
415435 CUTE_UNROLL
416436 for (int k = 0 ; k < size<2 >(tCrA); ++k) {
@@ -422,15 +442,16 @@ class GemmTensorOp {
422442 }
423443};
424444
425- } // namespace cute
445+ } // namespace cute::tl_mma
426446
427- namespace tl {
447+ namespace tl ::tl_mma {
428448
429449template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
430450 bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
431451 int offset_b, typename A_type, typename B_type, typename C_type>
432452CUTLASS_DEVICE void gemm_ss (A_type *pA, B_type *pB, C_type *accum) {
433- using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
453+ using MMA =
454+ cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
434455 trans_B, clear_accum, lda, ldb, offset_a,
435456 offset_b, A_type, B_type, C_type>;
436457 MMA::body (pA, pB, accum);
@@ -440,7 +461,8 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
440461 bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
441462 int offset_b, typename A_type, typename B_type, typename C_type>
442463CUTLASS_DEVICE void gemm_rs (A_type *pA, B_type *pB, C_type *accum) {
443- using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
464+ using MMA =
465+ cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
444466 trans_B, clear_accum, lda, ldb, offset_a,
445467 offset_b, A_type, B_type, C_type>;
446468 MMA::body_rs (pA, pB, accum);
@@ -450,10 +472,11 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
450472 bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
451473 int offset_b, typename A_type, typename B_type, typename C_type>
452474CUTLASS_DEVICE void gemm_sr (A_type *pA, B_type *pB, C_type *accum) {
453- using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
475+ using MMA =
476+ cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
454477 trans_B, clear_accum, lda, ldb, offset_a,
455478 offset_b, A_type, B_type, C_type>;
456479 MMA::body_sr (pA, pB, accum);
457480}
458481
459- } // namespace tl
482+ } // namespace tl::tl_mma
0 commit comments