Skip to content

Commit b62a0b4

Browse files
authored
[Refactor] Use new namespace and enhance dispatch macros for mma (#801)
* Refactor CUDA GEMM operations to use new namespace and enhance dispatch macros - Moved GEMM-related dispatch instructions to the `cute::tl_mma` namespace for better organization. - Introduced `TL_DISPATCH_MMA` and `TL_DISPATCH_MMA_TEMPLATE` macros to streamline the definition of dispatch instructions for various data types and architectures. - Updated the handling of CUDA architecture checks to include additional support for newer architectures. - Improved clarity and maintainability of the code by restructuring the layout and organization of dispatch instructions. - Ensured consistent usage of tensor views and memory clearing operations across different GEMM implementations. * Remove deprecated `DispatchInstruction` templates and `tl_mma` namespace from CUDA GEMM implementation. This cleanup enhances code clarity and maintainability by eliminating unused structures and streamlining the overall organization of the GEMM operations.
1 parent 5529363 commit b62a0b4

File tree

2 files changed

+98
-513
lines changed

2 files changed

+98
-513
lines changed

src/tl_templates/cuda/gemm_mma.h

Lines changed: 98 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -11,81 +11,101 @@
1111
#include "cuda_fp8.h"
1212
#include "intrin.h"
1313

14-
namespace cute {
14+
namespace cute::tl_mma {
1515

1616
template <typename A_type, typename B_type, typename C_type, int num_warp_m,
1717
int num_warp_n, int N>
1818
struct DispatchInstruction;
1919

2020
using _X = Underscore;
2121

22-
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
22+
} // namespace cute::tl_mma
23+
24+
#define TL_DISPATCH_MMA(A_type, B_type, C_type, MMA_instr) \
25+
namespace cute::tl_mma { \
26+
template <int num_warp_m, int num_warp_n, int N> \
27+
struct DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, \
28+
N> { \
29+
using MMA = MMA_Atom<MMA_instr>; \
30+
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; \
31+
}; \
32+
}
33+
#define TL_DISPATCH_MMA_TEMPLATE(A_type, B_type, C_type, MMA_instr) \
34+
namespace cute::tl_mma { \
35+
template <int num_warp_m, int num_warp_n, int N> \
36+
struct DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, \
37+
N> { \
38+
using MMA = MMA_Atom<MMA_instr<A_type, B_type, C_type>>; \
39+
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; \
40+
}; \
41+
}
42+
43+
#ifdef __CUDA_ARCH_LIST__
2344
#if __CUDA_ARCH_LIST__ >= 1200
24-
template <int num_warp_m, int num_warp_n, int N>
25-
struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, num_warp_m, num_warp_n,
26-
N> {
27-
using MMA = MMA_Atom<SM120_16x8x32_TN<fp8_e4_t, fp8_e4_t, float>>;
28-
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
29-
};
30-
template <int num_warp_m, int num_warp_n, int N>
31-
struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, num_warp_m, num_warp_n,
32-
N> {
33-
using MMA = MMA_Atom<SM120_16x8x32_TN<fp8_e5_t, fp8_e5_t, float>>;
34-
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
35-
};
45+
#include "cuda_fp8.h"
46+
#include <cute/arch/mma_sm120.hpp>
47+
#include <cute/arch/mma_sm80.hpp>
48+
TL_DISPATCH_MMA_TEMPLATE(fp8_e4_t, fp8_e4_t, float, SM120_16x8x32_TN)
49+
TL_DISPATCH_MMA_TEMPLATE(fp8_e5_t, fp8_e5_t, float, SM120_16x8x32_TN)
50+
TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
51+
TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
52+
TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
53+
TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
54+
TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
55+
TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
56+
#elif __CUDA_ARCH_LIST__ >= 1000
57+
#include "cuda_fp8.h"
58+
#include <cute/arch/mma_sm100.hpp>
59+
#include <cute/arch/mma_sm80.hpp>
60+
#include <cute/arch/mma_sm89.hpp>
61+
TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN)
62+
TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN)
63+
TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
64+
TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
65+
TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
66+
TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
67+
TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
68+
TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
69+
#elif __CUDA_ARCH_LIST__ >= 900
70+
#include "cuda_fp8.h"
71+
#include <cute/arch/mma_sm80.hpp>
72+
#include <cute/arch/mma_sm89.hpp>
73+
TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN)
74+
TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN)
75+
TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
76+
TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
77+
TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
78+
TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
79+
TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
80+
TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
3681
#elif __CUDA_ARCH_LIST__ >= 890
37-
template <int num_warp_m, int num_warp_n, int N>
38-
struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, num_warp_m, num_warp_n,
39-
N> {
40-
using MMA = MMA_Atom<SM89_16x8x32_F32E4M3E4M3F32_TN>;
41-
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
42-
};
43-
template <int num_warp_m, int num_warp_n, int N>
44-
struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, num_warp_m, num_warp_n,
45-
N> {
46-
using MMA = MMA_Atom<SM89_16x8x32_F32E5M2E5M2F32_TN>;
47-
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
48-
};
82+
#include "cuda_fp8.h"
83+
#include <cute/arch/mma_sm80.hpp>
84+
#include <cute/arch/mma_sm89.hpp>
85+
TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN)
86+
TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN)
87+
TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
88+
TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
89+
TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
90+
TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
91+
TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
92+
TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
93+
#elif __CUDA_ARCH_LIST__ >= 800
94+
#include <cute/arch/mma_sm80.hpp>
95+
TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
96+
TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
97+
TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
98+
TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
99+
TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
100+
TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
101+
#elif __CUDA_ARCH_LIST__ >= 750
102+
TL_DISPATCH_MMA(half_t, half_t, float, SM75_16x8x8_F32F16F16F32_TN)
49103
#endif
50-
template <int num_warp_m, int num_warp_n, int N>
51-
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n, N> {
52-
using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
53-
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
54-
};
55-
template <int num_warp_m, int num_warp_n, int N>
56-
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
57-
using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>;
58-
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
59-
};
60-
template <int num_warp_m, int num_warp_n, int N>
61-
struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m,
62-
num_warp_n, N> {
63-
using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>;
64-
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
65-
};
66-
template <int num_warp_m, int num_warp_n, int N>
67-
struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m,
68-
num_warp_n, N> {
69-
using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>;
70-
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
71-
};
72-
template <int num_warp_m, int num_warp_n, int N>
73-
struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n, N> {
74-
using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>;
75-
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
76-
};
77-
template <int num_warp_m, int num_warp_n, int N>
78-
struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n, N> {
79-
using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
80-
using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>;
81-
};
82-
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
83-
template <int num_warp_m, int num_warp_n, int N>
84-
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
85-
using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
86-
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _16>;
87-
};
88104
#endif
105+
#undef TL_DISPATCH_MMA
106+
#undef TL_DISPATCH_MMA_TEMPLATE
107+
108+
namespace cute::tl_mma {
89109

90110
template <int N, int num_warp_n, bool transpose> struct SelectCopy {
91111
static constexpr int remainder = (N / num_warp_n) % 16;
@@ -334,13 +354,13 @@ class GemmTensorOp {
334354
make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
335355
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
336356

337-
if constexpr (clear_accum) {
338-
clear(acc);
339-
}
340357
// when layout is KxN and n_warp is 1, there seem to be a bug, use this as a
341358
// workaround
342359
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
343360
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
361+
if constexpr (clear_accum) {
362+
clear(acc);
363+
}
344364
CUTE_UNROLL
345365
for (int k = 0; k < size<2>(tCrA); ++k) {
346366
copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k));
@@ -371,10 +391,10 @@ class GemmTensorOp {
371391
Tensor tCrA =
372392
make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
373393
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
394+
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
374395
if constexpr (clear_accum) {
375396
clear(acc);
376397
}
377-
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
378398
copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0));
379399
CUTE_UNROLL
380400
for (int k = 0; k < size<2>(tCrA); ++k) {
@@ -407,10 +427,10 @@ class GemmTensorOp {
407427
Tensor tCrB =
408428
make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)),
409429
partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
430+
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
410431
if constexpr (clear_accum) {
411432
clear(acc);
412433
}
413-
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
414434
copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0));
415435
CUTE_UNROLL
416436
for (int k = 0; k < size<2>(tCrA); ++k) {
@@ -422,15 +442,16 @@ class GemmTensorOp {
422442
}
423443
};
424444

425-
} // namespace cute
445+
} // namespace cute::tl_mma
426446

427-
namespace tl {
447+
namespace tl::tl_mma {
428448

429449
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
430450
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
431451
int offset_b, typename A_type, typename B_type, typename C_type>
432452
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
433-
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
453+
using MMA =
454+
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
434455
trans_B, clear_accum, lda, ldb, offset_a,
435456
offset_b, A_type, B_type, C_type>;
436457
MMA::body(pA, pB, accum);
@@ -440,7 +461,8 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
440461
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
441462
int offset_b, typename A_type, typename B_type, typename C_type>
442463
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
443-
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
464+
using MMA =
465+
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
444466
trans_B, clear_accum, lda, ldb, offset_a,
445467
offset_b, A_type, B_type, C_type>;
446468
MMA::body_rs(pA, pB, accum);
@@ -450,10 +472,11 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
450472
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
451473
int offset_b, typename A_type, typename B_type, typename C_type>
452474
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
453-
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
475+
using MMA =
476+
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
454477
trans_B, clear_accum, lda, ldb, offset_a,
455478
offset_b, A_type, B_type, C_type>;
456479
MMA::body_sr(pA, pB, accum);
457480
}
458481

459-
} // namespace tl
482+
} // namespace tl::tl_mma

0 commit comments

Comments
 (0)