|
13 | 13 |
|
14 | 14 | #include "../target/utils.h" |
15 | 15 | #include "tvm/ffi/string.h" |
16 | | -#include <vector> |
| 16 | +#include "tcgen5_meta.h" |
17 | 17 |
|
18 | 18 | namespace tvm { |
19 | 19 | namespace tl { |
20 | 20 |
|
21 | 21 | using namespace tir; |
22 | 22 |
|
23 | | -struct TCGEN5MMAMeta { |
24 | | - int atom_m, atom_n, atom_k; |
25 | | -}; |
26 | | - |
27 | | -// Return {is_success, meta} |
28 | | -static inline std::pair<bool, TCGEN5MMAMeta> |
29 | | -GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { |
30 | | -// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA. |
31 | | -#define FAIL \ |
32 | | - return { \ |
33 | | - false, TCGEN5MMAMeta { 0, 0, 0 } \ |
34 | | - } |
35 | | -#define SUCCESS(atom_m, atom_n, atom_k) \ |
36 | | - return { \ |
37 | | - true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \ |
38 | | - } |
39 | | - std::vector<int> ws_valid_atom_ns = {256, 128, 64}; |
40 | | - if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) && |
41 | | - (c_dtype.is_float() && c_dtype.bits() == 32)) { |
42 | | - if (K % 16 != 0) |
43 | | - FAIL; |
44 | | - if (M % 128 == 0) { |
45 | | - for (int atom_n = 256; atom_n >= 16; atom_n -= 16) |
46 | | - if (N % atom_n == 0) |
47 | | - SUCCESS(128, atom_n, 16); |
48 | | - FAIL; |
49 | | - } else if (M % 64 == 0) { |
50 | | - for (int atom_n : ws_valid_atom_ns) |
51 | | - if (N % atom_n == 0) |
52 | | - SUCCESS(64, atom_n, 16); |
53 | | - FAIL; |
54 | | - } else if (M % 32 == 0) { |
55 | | - for (int atom_n : ws_valid_atom_ns) |
56 | | - if (N % atom_n == 0) |
57 | | - SUCCESS(32, atom_n, 16); |
58 | | - FAIL; |
59 | | - } else { |
60 | | - FAIL; |
61 | | - } |
62 | | - } else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) && |
63 | | - (c_dtype.is_float() && c_dtype.bits() == 32)) { |
64 | | - if (K % 32 != 0) |
65 | | - FAIL; |
66 | | - if (M % 128 == 0) { |
67 | | - for (int atom_n = 256; atom_n >= 16; atom_n -= 16) |
68 | | - if (N % atom_n == 0) |
69 | | - SUCCESS(128, atom_n, 32); |
70 | | - FAIL; |
71 | | - } else if (M % 64 == 0) { |
72 | | - for (int atom_n : ws_valid_atom_ns) |
73 | | - if (N % atom_n == 0) |
74 | | - SUCCESS(64, atom_n, 32); |
75 | | - FAIL; |
76 | | - } else if (M % 32 == 0) { |
77 | | - for (int atom_n : ws_valid_atom_ns) |
78 | | - if (N % atom_n == 0) |
79 | | - SUCCESS(32, atom_n, 32); |
80 | | - FAIL; |
81 | | - } else { |
82 | | - FAIL; |
83 | | - } |
84 | | - } |
85 | | - FAIL; |
86 | | -#undef FAIL |
87 | | -#undef SUCCESS |
88 | | -} |
89 | | - |
90 | 23 | /** |
91 | 24 | * @brief Construct a Gemm operator from serialized TL arguments and a buffer |
92 | 25 | * map. |
@@ -144,6 +77,20 @@ GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) { |
144 | 77 | if (args.size() > 15) { |
145 | 78 | node->wg_wait = args[15].as<IntImm>().value()->value; |
146 | 79 | } |
| 80 | + if (args.size() > 16) { |
| 81 | + node->mbarptr = args[16]; |
| 82 | + } else { |
| 83 | + node->mbarptr = IntImm(DataType::UInt(32), 0); |
| 84 | + } |
| 85 | + if (args.size() > 18) { |
| 86 | + node->C_coords = Array<PrimExpr>({args[17], args[18]}); |
| 87 | + } else if (args.size() > 17) { |
| 88 | + node->C_coords = |
| 89 | + Array<PrimExpr>({args[17], IntImm(DataType::Int(32), 0)}); |
| 90 | + } else { |
| 91 | + node->C_coords = Array<PrimExpr>( |
| 92 | + {IntImm(DataType::Int(32), 0), IntImm(DataType::Int(32), 0)}); |
| 93 | + } |
147 | 94 | data_ = std::move(node); |
148 | 95 | } |
149 | 96 |
|
@@ -378,5 +325,31 @@ TVM_FFI_STATIC_INIT_BLOCK({ |
378 | 325 | }); |
379 | 326 | }); |
380 | 327 |
|
| 328 | +TVM_FFI_STATIC_INIT_BLOCK({ |
| 329 | + namespace refl = tvm::ffi::reflection; |
| 330 | + refl::GlobalDef().def( |
| 331 | + "tl.get_tcgen5_mma_meta", |
| 332 | + [](int M, int N, int K, DataType ab_dtype, DataType c_dtype) { |
| 333 | + auto [success, meta] = GetTCGEN5MMAMeta(M, N, K, ab_dtype, c_dtype); |
| 334 | + Array<Integer> result; |
| 335 | + if (success) { |
| 336 | + result.push_back(Integer(meta.atom_m)); |
| 337 | + result.push_back(Integer(meta.atom_n)); |
| 338 | + result.push_back(Integer(meta.atom_k)); |
| 339 | + } |
| 340 | + return result; |
| 341 | + }); |
| 342 | + refl::GlobalDef().def( |
| 343 | + "tl.get_tcgen5_instr_desc", |
| 344 | + [](int atom_m, int atom_n, int atom_k, DataType ab_dtype, |
| 345 | + DataType c_dtype, bool a_is_k_major, bool b_is_k_major, |
| 346 | + int scale_in_a, int scale_in_b) { |
| 347 | + uint32_t desc = GetTCGEN5InstrDesc(atom_m, atom_n, atom_k, ab_dtype, |
| 348 | + c_dtype, a_is_k_major, b_is_k_major, |
| 349 | + scale_in_a, scale_in_b); |
| 350 | + return Integer(static_cast<int64_t>(desc)); |
| 351 | + }); |
| 352 | +}); |
| 353 | + |
381 | 354 | } // namespace tl |
382 | 355 | } // namespace tvm |
0 commit comments