Skip to content

Commit 0d101c1

Browse files
authored
[WIP] support more dtypes for tcgen05 (#1229)
support ld with pack for fp32 dtype add dump add tempalte expand remove unused dtype and change to rebased apis
1 parent bf90a5f commit 0d101c1

File tree

11 files changed

+976
-88
lines changed

11 files changed

+976
-88
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import torch
2+
import tilelang
3+
import tilelang.language as T
4+
from tilelang.utils.tensor import map_torch_type
5+
6+
7+
def matmul(
8+
M,
9+
N,
10+
K,
11+
block_M,
12+
block_N,
13+
block_K,
14+
trans_A,
15+
trans_B,
16+
in_dtype,
17+
out_dtype,
18+
accum_dtype,
19+
num_stages,
20+
threads,
21+
):
22+
A_shape = (K, M) if trans_A else (M, K)
23+
B_shape = (N, K) if trans_B else (K, N)
24+
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
25+
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
26+
27+
@T.prim_func
28+
def main(
29+
A: T.Tensor(A_shape, in_dtype),
30+
B: T.Tensor(B_shape, in_dtype),
31+
C: T.Tensor((M, N), out_dtype),
32+
):
33+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
34+
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
35+
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
36+
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
37+
mbar = T.alloc_barrier(1)
38+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
39+
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
40+
41+
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
42+
T.copy(A[by * block_M, k * block_K], A_shared)
43+
T.copy(B[bx * block_N, k * block_K], B_shared)
44+
T.gemm_v2(
45+
A_shared,
46+
B_shared,
47+
C_tmem,
48+
trans_A,
49+
trans_B,
50+
mbar=mbar,
51+
wg_wait=-1,
52+
clear_accum=(k == 0),
53+
)
54+
T.mbarrier_wait_parity(mbar, k % 2)
55+
56+
T.copy(C_tmem, C_local)
57+
T.copy(C_local, C_shared)
58+
59+
T.copy(C_shared, C[by * block_M, bx * block_N])
60+
61+
return main
62+
63+
64+
def calc_diff(x, y):
65+
x, y = x.double(), y.double()
66+
denominator = (x * x + y * y).sum()
67+
sim = 2 * (x * y).sum() / denominator
68+
return 1 - sim
69+
70+
71+
M, N, K = 4096, 4096, 8192
72+
block_M, block_N, block_K = 64, 256, 32
73+
trans_A, trans_B = False, True
74+
num_stages = 2
75+
threads = 256
76+
for tvm_fp8_dtype in ["float8_e4m3", "float8_e5m2"]:
77+
for tvm_acc_dtype in ["float16", "float32"]: # , torch.float16]:
78+
torch_fp8_dtype = map_torch_type(tvm_fp8_dtype)
79+
torch_acc_dtype = map_torch_type(tvm_acc_dtype)
80+
print(f"running {tvm_fp8_dtype} -> {tvm_acc_dtype}")
81+
in_dtype, out_dtype, accum_dtype = tvm_fp8_dtype, tvm_acc_dtype, tvm_acc_dtype
82+
83+
func = matmul(
84+
M,
85+
N,
86+
K,
87+
block_M,
88+
block_N,
89+
block_K,
90+
trans_A,
91+
trans_B,
92+
in_dtype,
93+
out_dtype,
94+
accum_dtype,
95+
num_stages,
96+
threads,
97+
)
98+
jit_kernel = tilelang.compile(
99+
func,
100+
out_idx=[2],
101+
target="cuda",
102+
pass_configs={
103+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
104+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
105+
tilelang.PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT: True,
106+
},
107+
)
108+
# jit_kernel.export_ptx("./dump.ptx")
109+
# jit_kernel.export_sources("./dump.cu")
110+
111+
a = torch.randn(M, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype)
112+
b = torch.randn(N, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype)
113+
114+
c = jit_kernel(a, b)
115+
ref_c = (a.to(torch.half) @ b.T.to(torch.half)).float()
116+
c = c.float()
117+
diff = calc_diff(c, ref_c)
118+
# assert diff < 1e-3, f"{diff}"
119+
print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] diff = {diff}")
120+
121+
profiler = jit_kernel.get_profiler()
122+
latency = profiler.do_bench()
123+
print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Latency: {latency} ms")
124+
print(
125+
f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS"
126+
)

src/op/copy.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,16 +1117,20 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T,
11171117
bool is_ld = false; // tcgen05.ld (tensor memory -> register)
11181118
bool is_st = false; // tcgen05.st (register -> tensor memory)
11191119
bool is_cp = false; // tcgen05.cp (shared memory -> tensor memory)
1120+
bool src_needs_pack =
1121+
16 == src->dtype.bits(); // if needs .pack::16b when is_ld
1122+
bool dst_needs_unpack =
1123+
16 == dst->dtype.bits(); // if needs .unpack::16b when is_st
1124+
11201125
if (src.scope() == "shared.tmem" && dst.scope() == "local.fragment") {
11211126
is_ld = true;
11221127
} else if (src.scope() == "local.fragment" && dst.scope() == "shared.tmem") {
11231128
is_st = true;
11241129
} else if (src.scope() == "shared.dyn" && dst.scope() == "shared.tmem") {
11251130
is_cp = true;
11261131
} else {
1127-
ICHECK(0) << "Unsupported tensor memory copy: "
1128-
<< "src scope = " << src.scope()
1129-
<< ", dst scope = " << dst.scope();
1132+
ICHECK(0) << "Unsupported tensor memory copy: " << "src scope = "
1133+
<< src.scope() << ", dst scope = " << dst.scope();
11301134
}
11311135
// Currently tcgen05.cp is not supported
11321136
// TODO (mzw) Support tcgen05.cp
@@ -1246,8 +1250,10 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T,
12461250
: relative_wg_idx * (num_chunks_each_wg * meta.width);
12471251
have_succeeded = true;
12481252
Array<PrimExpr> args;
1253+
const char *bool_str = src_needs_pack ? "true" : "false";
12491254
args.push_back(StringImm(meta.intrinsics_name + "<" +
1250-
std::to_string(num_chunks_each_wg) + ">"));
1255+
std::to_string(num_chunks_each_wg) + ", " +
1256+
bool_str + ">"));
12511257
args.push_back(
12521258
BufferLoad(src, {(int)logical_row_min,
12531259
(int)logical_col_min})); // Will be translated later

src/op/gemm_py.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
428428
result.push_back(Integer(meta.atom_m));
429429
result.push_back(Integer(meta.atom_n));
430430
result.push_back(Integer(meta.atom_k));
431+
result.push_back(Integer(meta.enable_ws));
432+
result.push_back(Integer(meta.enable_2cta));
431433
}
432434
return result;
433435
});

src/op/tcgen5_meta.h

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,19 @@ using runtime::DataType;
1515

1616
struct TCGEN5MMAMeta {
1717
int atom_m, atom_n, atom_k;
18+
bool enable_ws, enable_2cta;
1819
};
1920

2021
inline std::pair<bool, TCGEN5MMAMeta>
2122
GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
2223
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
2324
#define FAIL \
24-
return { false, TCGEN5MMAMeta{0, 0, 0} }
25-
#define SUCCESS(atom_m, atom_n, atom_k) \
2625
return { \
27-
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \
26+
false, TCGEN5MMAMeta { 0, 0, 0, false, false } \
27+
}
28+
#define SUCCESS(atom_m, atom_n, atom_k, use_ws, use_2cta) \
29+
return { \
30+
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k, use_ws, use_2cta } \
2831
}
2932
std::vector<int> ws_valid_atom_ns = {256, 128, 64};
3033
if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) &&
@@ -34,39 +37,52 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
3437
if (M % 128 == 0) {
3538
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
3639
if (N % atom_n == 0)
37-
SUCCESS(128, atom_n, 16);
40+
SUCCESS(128, atom_n, 16, false, false);
3841
FAIL;
3942
} else if (M % 64 == 0) {
4043
for (int atom_n : ws_valid_atom_ns)
4144
if (N % atom_n == 0)
42-
SUCCESS(64, atom_n, 16);
45+
SUCCESS(64, atom_n, 16, false, false);
4346
FAIL;
4447
} else if (M % 32 == 0) {
4548
for (int atom_n : ws_valid_atom_ns)
4649
if (N % atom_n == 0)
47-
SUCCESS(32, atom_n, 16);
50+
SUCCESS(32, atom_n, 16, false, false);
4851
FAIL;
4952
} else {
5053
FAIL;
5154
}
52-
} else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) &&
53-
(c_dtype.is_float() && c_dtype.bits() == 32)) {
55+
} else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e4m3() ||
56+
ab_dtype.is_float8_e5m2() || ab_dtype.is_float8_e5m2fnuz() ||
57+
ab_dtype.is_float6_e2m3fn() || ab_dtype.is_float6_e3m2fn() ||
58+
ab_dtype.is_float4_e2m1fn()) &&
59+
((c_dtype.is_float() && c_dtype.bits() == 32) ||
60+
(c_dtype.is_float16() && c_dtype.bits() == 16))) {
5461
if (K % 32 != 0)
5562
FAIL;
5663
if (M % 128 == 0) {
64+
for (int atom_n : ws_valid_atom_ns)
65+
if (N % atom_n == 0)
66+
SUCCESS(128, atom_n, 32, true, false);
5767
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
5868
if (N % atom_n == 0)
59-
SUCCESS(128, atom_n, 32);
69+
SUCCESS(128, atom_n, 32, false, true);
70+
for (int atom_n = 256; atom_n >= 8; atom_n -= 8)
71+
if (N % atom_n == 0)
72+
SUCCESS(128, atom_n, 32, false, false);
6073
FAIL;
6174
} else if (M % 64 == 0) {
6275
for (int atom_n : ws_valid_atom_ns)
6376
if (N % atom_n == 0)
64-
SUCCESS(64, atom_n, 32);
77+
SUCCESS(64, atom_n, 32, true, false);
78+
for (int atom_n = 256; atom_n >= 8; atom_n -= 8)
79+
if (N % atom_n == 0)
80+
SUCCESS(128, atom_n, 32, false, false);
6581
FAIL;
6682
} else if (M % 32 == 0) {
6783
for (int atom_n : ws_valid_atom_ns)
6884
if (N % atom_n == 0)
69-
SUCCESS(32, atom_n, 32);
85+
SUCCESS(32, atom_n, 32, true, false);
7086
FAIL;
7187
} else {
7288
FAIL;

src/tl_templates/cuda/copy_sm100.h

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,21 @@ __device__ __forceinline__ void st_global_256(fp8_e4_32_t *ptr,
5151
:
5252
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
5353
}
54+
__device__ __forceinline__ ulonglong4 ld_global_256(const fp8_e5_32_t *ptr) {
55+
ulonglong4 ret;
56+
asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];"
57+
: "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
58+
: "l"(ptr));
59+
return ret;
60+
}
61+
62+
__device__ __forceinline__ void st_global_256(fp8_e5_32_t *ptr,
63+
fp8_e5_32_t &val8) {
64+
ulonglong4 &val = *((ulonglong4 *)&val8);
65+
asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};"
66+
:
67+
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
68+
}
5469

5570
__device__ __forceinline__ unsigned long long
5671
pack_bfloat16x4(const bfloat16_t x, const bfloat16_t y, const bfloat16_t z,
@@ -95,38 +110,38 @@ __device__ __forceinline__ void tcgen05_ld_core(uint32_t const &tmem_start_col,
95110
}
96111
}
97112

98-
template <int N, typename dst_t>
113+
template <int N, bool pack16, typename dst_t>
99114
__device__ __forceinline__ void
100115
tcgen05_ld_32dp32bNx(uint32_t const &tmem_start_col,
101116
uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
102-
tcgen05_ld_core<tl::tmem_ld_32dp32bNx, 7, N>(tmem_start_col + tmem_col_offset,
103-
dst_ptr);
117+
tcgen05_ld_core<tl::tmem_ld_32dp32bNx<pack16>, 7, N>(
118+
tmem_start_col + tmem_col_offset, dst_ptr);
104119
tl::fence_view_async_tmem_load();
105120
}
106121

107-
template <int N, typename dst_t>
122+
template <int N, bool pack16, typename dst_t>
108123
__device__ __forceinline__ void
109124
tcgen05_ld_32dp64bNx(uint32_t const &tmem_start_col,
110125
uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
111-
tcgen05_ld_core<tl::tmem_ld_32dp64bNx, 7, N>(tmem_start_col + tmem_col_offset,
112-
dst_ptr);
126+
tcgen05_ld_core<tl::tmem_ld_32dp64bNx<pack16>, 7, N>(
127+
tmem_start_col + tmem_col_offset, dst_ptr);
113128
tl::fence_view_async_tmem_load();
114129
}
115130

116-
template <int N, typename dst_t>
131+
template <int N, bool pack16, typename dst_t>
117132
__device__ __forceinline__ void
118133
tcgen05_ld_32dp128bNx(uint32_t const &tmem_start_col,
119134
uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
120-
tcgen05_ld_core<tl::tmem_ld_32dp128bNx, 6, N>(
135+
tcgen05_ld_core<tl::tmem_ld_32dp128bNx<pack16>, 6, N>(
121136
tmem_start_col + tmem_col_offset, dst_ptr);
122137
tl::fence_view_async_tmem_load();
123138
}
124139

125-
template <int N, typename dst_t>
140+
template <int N, bool pack16, typename dst_t>
126141
__device__ __forceinline__ void
127142
tcgen05_ld_32dp256bNx(uint32_t const &tmem_start_col,
128143
uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
129-
tcgen05_ld_core<tl::tmem_ld_32dp256bNx, 5, N>(
144+
tcgen05_ld_core<tl::tmem_ld_32dp256bNx<pack16>, 5, N>(
130145
tmem_start_col + tmem_col_offset, dst_ptr);
131146
tl::fence_view_async_tmem_load();
132147
}

0 commit comments

Comments
 (0)