Skip to content

Commit e299b41

Browse files
committed
Enhance CUDA code generation and testing for GEMM operations
- Added indentation printing in `codegen_cuda.cc` for improved assembly code formatting. - Updated `test_tilelang_tilelibrary_gemm.py` to include additional GEMM test cases and shared memory allocation with specified scope. - Introduced new `matmul_sr` and `run_gemm_sr` functions for GEMM operations with shared and fragment memory layouts. - Refactored layout inference in `mma_macro_generator.py` to improve clarity and correctness in shared memory handling. - Enhanced `gemm/__init__.py` to support new GEMM operation combinations and layout inference logic. These changes improve the clarity, functionality, and testing coverage of GEMM operations in the TileLang framework.
1 parent 52800a5 commit e299b41

File tree

4 files changed

+258
-64
lines changed

4 files changed

+258
-64
lines changed

src/target/codegen_cuda.cc

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,7 +1259,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
12591259
std::string asm_code = PrintMMAAssembly(
12601260
shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias,
12611261
b_ref, b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate);
1262-
1262+
this->PrintIndent();
12631263
this->stream << asm_code;
12641264
} else if (op->op.same_as(builtin::ptx_mma_sp())) {
12651265
// arg 0: shape: mXnXkX
@@ -1295,6 +1295,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
12951295
std::string metadata_offset = this->PrintExpr(op->args[13]);
12961296
std::string sparse_selector = this->PrintExpr(op->args[14]);
12971297
bool saturate = Downcast<Bool>(op->args[15])->value;
1298+
this->PrintIndent();
12981299
std::string asm_code = PrintMMAAssembly(
12991300
shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset,
13001301
b_ref, b_offset, c_ref, c_offset, metadata, metadata_offset,
@@ -1330,10 +1331,16 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
13301331
os << "}\n";
13311332
} else {
13321333
std::string smem_elem_offset = this->PrintExpr(op->args[6]);
1333-
need_cast_smem_ptr_to_int_ = true;
1334-
this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr,
1335-
local_elem_offset, smem_ptr,
1336-
smem_elem_offset);
1334+
// need_cast_smem_ptr_to_int_ = true;
1335+
// this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr,
1336+
// local_elem_offset, smem_ptr,
1337+
// smem_elem_offset);
1338+
std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);
1339+
if (trans == 1)
1340+
func_name += "_trans";
1341+
// this->stream << func_name << "(" << local_ptr "" << ", " << smem_ptr << ");\n";
1342+
this->PrintIndent();
1343+
this->stream << func_name << "(" << smem_ptr << " + " << smem_elem_offset<< ", " << local_ptr << " + " << local_elem_offset << ");\n";
13371344
}
13381345
} else if (op->op.same_as(builtin::mma_store())) {
13391346
int m = Downcast<Integer>(op->args[0])->value;

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py

Lines changed: 142 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from asyncio import threads
12
from tilelang import tvm as tvm
23
import tilelang.testing
34

@@ -31,8 +32,8 @@ def main(
3132
C: T.Tensor((M, N), out_dtype),
3233
):
3334
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
34-
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
35-
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
35+
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared")
36+
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared")
3637
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
3738
T.clear(C_local)
3839
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
@@ -108,8 +109,11 @@ def ref_program(A, B):
108109
def test_gemm():
109110
# More test case can be found in kernel/test_tilelang_kernel_gemm.py
110111
# GEMM tests for float16
111-
run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32,
112-
2) # f16f16f16_nn
112+
run_gemm(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 0)
113+
run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0)
114+
run_gemm(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 0)
115+
run_gemm(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 0)
116+
113117

114118

115119
def matmul_rs(
@@ -142,23 +146,26 @@ def main(
142146
C: T.Tensor((M, N), out_dtype),
143147
):
144148
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
145-
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
146-
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
149+
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared")
150+
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared")
147151
A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
148152
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
149153
T.clear(C_local)
154+
T.annotate_layout({
155+
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
156+
})
150157
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
151158
if trans_A:
152159
T.copy(A[k * block_K, by * block_M], A_shared)
153-
T.copy(A_shared, A_frag)
154160
else:
155161
T.copy(A[by * block_M, k * block_K], A_shared)
156-
T.copy(A_shared, A_frag)
157162
if trans_B:
158163
T.copy(B[bx * block_N, k * block_K], B_shared)
159164
else:
160165
T.copy(B[k * block_K, bx * block_N], B_shared)
166+
T.copy(A_shared, A_frag)
161167
T.gemm_v2(A_frag, B_shared, C_local, trans_A, trans_B)
168+
# T.gemm(A_frag, B_shared, C_local, trans_A, trans_B)
162169
T.copy(C_local, C[by * block_M, bx * block_N])
163170

164171
return main
@@ -202,6 +209,7 @@ def run_gemm_rs(
202209
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
203210
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
204211
})
212+
print(kernel.get_kernel_source())
205213
profiler = kernel.get_profiler()
206214

207215
def ref_program(A, B):
@@ -224,10 +232,134 @@ def test_gemm_rs():
224232
run_gemm_rs(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2)
225233

226234

235+
def matmul_sr(
236+
M,
237+
N,
238+
K,
239+
block_M,
240+
block_N,
241+
block_K,
242+
trans_A,
243+
trans_B,
244+
in_dtype,
245+
out_dtype,
246+
accum_dtype,
247+
num_stages,
248+
threads,
249+
):
250+
A_shape = (K, M) if trans_A else (M, K)
251+
B_shape = (N, K) if trans_B else (K, N)
252+
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
253+
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
254+
B_frag_shape = B_shared_shape
255+
256+
import tilelang.language as T
257+
258+
@T.prim_func
259+
def main(
260+
A: T.Tensor(A_shape, in_dtype),
261+
B: T.Tensor(B_shape, in_dtype),
262+
C: T.Tensor((M, N), out_dtype),
263+
):
264+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
265+
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared")
266+
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared")
267+
B_frag = T.alloc_fragment(B_frag_shape, in_dtype)
268+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
269+
T.clear(C_local)
270+
T.annotate_layout({
271+
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
272+
})
273+
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
274+
if trans_A:
275+
T.copy(A[k * block_K, by * block_M], A_shared)
276+
else:
277+
T.copy(A[by * block_M, k * block_K], A_shared)
278+
if trans_B:
279+
T.copy(B[bx * block_N, k * block_K], B_shared)
280+
else:
281+
T.copy(B[k * block_K, bx * block_N], B_shared)
282+
T.copy(B_shared, B_frag)
283+
T.gemm_v2(A_shared, B_frag, C_local, trans_A, trans_B)
284+
T.copy(C_local, C[by * block_M, bx * block_N])
285+
286+
return main
287+
288+
289+
def run_gemm_sr(
290+
M,
291+
N,
292+
K,
293+
trans_A,
294+
trans_B,
295+
in_dtype,
296+
out_dtype,
297+
dtypeAccum,
298+
block_M,
299+
block_N,
300+
block_K,
301+
num_stages=3,
302+
num_threads=128,
303+
):
304+
program = matmul_sr(
305+
M,
306+
N,
307+
K,
308+
block_M,
309+
block_N,
310+
block_K,
311+
trans_A,
312+
trans_B,
313+
in_dtype,
314+
out_dtype,
315+
dtypeAccum,
316+
num_stages,
317+
num_threads,
318+
)
319+
320+
kernel = tilelang.compile(
321+
program,
322+
out_idx=[2],
323+
pass_configs={
324+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
325+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
326+
})
327+
print(kernel.get_kernel_source())
328+
profiler = kernel.get_profiler()
329+
330+
def ref_program(A, B):
331+
import torch
332+
333+
if trans_A:
334+
A = A.T
335+
if trans_B:
336+
B = B.T
337+
C = torch.matmul(A.to(torch.float), B.to(torch.float))
338+
C = C.to(torch.__getattribute__(out_dtype))
339+
return C
340+
341+
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
342+
343+
344+
def test_gemm_sr():
345+
# GEMM tests for float16
346+
run_gemm_sr(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
347+
run_gemm_sr(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2)
348+
349+
227350
if __name__ == "__main__":
228351
# tilelang.testing.main()
229352
tilelang.disable_cache()
353+
tilelang.testing.set_random_seed(42)
230354
# run_gemm(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 0)
231355
# print("gemm fp16 nt ss done")
232-
run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0)
233-
print("gemm fp16 nn ss done")
356+
# run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0)
357+
# print("gemm fp16 nn ss done")
358+
# run_gemm(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 0)
359+
# print("gemm fp16 tn ss done")
360+
# run_gemm(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 0)
361+
# print("gemm fp16 tt ss done")
362+
# run_gemm_rs(64, 64, 32, False, True, "float16", "float16", "float16", 64, 64, 32, 0, 128)
363+
# print("gemm fp16 nt rs done")
364+
run_gemm_rs(64, 64, 32, False, True, "float16", "float16", "float16", 64, 64, 32, 0, 128)
365+
# run_gemm(64, 64, 32, False, True, "float16", "float16", "float16", 64, 64, 32, 0, 128)

tilelang/intrinsics/mma_macro_generator.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -189,19 +189,20 @@ def _warp_ldmatrix_a(
189189
stride = A_shared_buf.shape[-1]
190190
tx, _, warp_m = self.extract_thread_binding(thread_binding)
191191
trans = self.a_transposed
192-
192+
193193
for i in T.serial(warp_rows):
194+
# Assign A_shared_buf_elem
195+
wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k
196+
A_shared_buf_elem = A_shared_buf[wk, wi] if a_transposed else A_shared_buf[wi, wk]
197+
194198
T.ptx_ldmatrix(
195199
a_dtype,
196200
T.bool(trans),
197201
4,
198202
".b16",
199203
A_local_buf.data,
200204
i * local_size_a,
201-
T.address_of(A_shared_buf[
202-
warp_m * warp_row_tiles + i * micro_size_x,
203-
rk * chunk + ki * micro_size_k,
204-
]),
205+
T.address_of(A_shared_buf_elem),
205206
get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed),
206207
)
207208

@@ -232,16 +233,15 @@ def _warp_ldmatrix_b(
232233
):
233234
stride = B_shared_buf.shape[-1]
234235
tx, warp_n, _ = self.extract_thread_binding(thread_binding)
235-
trans = not self.b_transposed
236+
trans = not b_transposed
236237

237238
for j in T.serial(warp_cols):
238239
# Assign B_shared_elem
239240
wi, wk = (
240241
warp_n * warp_col_tiles + j * micro_size_y,
241242
rk * chunk + ki * micro_size_k,
242243
)
243-
B_shared_buf_elem = B_shared_buf[wi, wk] if self.b_transposed else B_shared_buf[wk,
244-
wi]
244+
B_shared_buf_elem = B_shared_buf[wi, wk] if b_transposed else B_shared_buf[wk, wi]
245245

246246
T.ptx_ldmatrix(
247247
b_dtype,
@@ -470,9 +470,6 @@ def forward_index(i: int, j: int) -> int:
470470
block_fragment = warp_fragment.repeat([warp_s, chunk // micro_size_r],
471471
repeat_on_thread=False,
472472
lower_dim_first=False)
473-
print(f"base_fragment: {base_fragment}")
474-
print(f"warp_fragment: {warp_fragment}")
475-
print(f"block_fragment: {block_fragment}")
476473
return block_fragment
477474

478475
def make_mma_store_layout(self, local_buf: Buffer) -> T.Fragment:

0 commit comments

Comments
 (0)