Skip to content

Commit 1ab46ef

Browse files
committed
Refactor GEMM layout and testing for improved clarity and functionality
- Updated `gemm_layouts.cc` to enhance the layout generation logic for transposed and non-transposed GEMM operations. - Renamed and modified functions in `test_tilelang_tilelibrary_gemm.py` to reflect changes in GEMM function signatures and improve test coverage. - Introduced new GEMM operation combinations in `gemm/__init__.py` to support additional layouts and configurations. - Enhanced layout inference in `mma_layout.py` and `mma_macro_generator.py` for better handling of shared memory layouts. These changes improve the clarity, functionality, and testing coverage of GEMM operations in the TileLang framework.
1 parent e299b41 commit 1ab46ef

File tree

9 files changed

+326
-66
lines changed

9 files changed

+326
-66
lines changed

3rdparty/tvm

Submodule tvm updated from 1fc7578 to eddefbd

src/layout/gemm_layouts.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,16 +205,14 @@ Fragment makeGemmFragmentB(const int block_m, const int block_n,
205205
ICHECK(block_k % 16 == 0);
206206
if (transposed) {
207207
auto base_layout = makeGemmFragment8x8()->Repeat({1, 2}, false, false);
208-
auto warp_layout = base_layout->Replicate(block_m / warp_m)
209-
->Repeat({block_n / warp_n, 1}, true, false);
208+
auto warp_layout = base_layout->Repeat({block_n / warp_n, 1}, true, false)->Replicate(block_m / warp_m);
210209
auto block_layout =
211210
warp_layout->Repeat({warp_n / 8, block_k / 16}, false, false);
212211
return block_layout;
213212
} else {
214213
auto base_layout =
215214
makeGemmFragment8x8Transposed()->Repeat({2, 1}, false, false);
216-
auto warp_layout = base_layout->Replicate(block_m / warp_m)
217-
->Repeat({1, block_n / warp_n}, true);
215+
auto warp_layout = base_layout->Repeat({1, block_n / warp_n}, true)->Replicate(block_m / warp_m);
218216
auto block_layout =
219217
warp_layout->Repeat({block_k / 16, warp_n / 8}, false, true);
220218
return block_layout;

src/op/gemm_py.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <tvm/tir/transform.h>
1313

1414
#include "../target/utils.h"
15+
#include "tvm/ffi/string.h"
1516

1617
namespace tvm {
1718
namespace tl {
@@ -224,9 +225,18 @@ Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
224225
M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA);
225226

226227
if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) {
227-
auto stmt = Downcast<Stmt>(
228+
auto prim_func = Downcast<PrimFunc>(
228229
(*f)(GetRef<GemmPy>(this), T.target, T.thread_bounds, T.thread_var));
229-
return stmt;
230+
BlockRealize block_realize = Downcast<BlockRealize>(prim_func->body);
231+
ICHECK(prim_func->attrs.defined());
232+
auto global_symbol = prim_func->attrs.GetAttr<String>("global_symbol");
233+
ICHECK(global_symbol.defined());
234+
auto block = block_realize->block;
235+
{
236+
BlockNode* n = block.CopyOnWrite();
237+
n->name_hint = global_symbol.value();
238+
}
239+
return BlockRealize(block_realize->iter_values, block_realize->predicate, block);
230240
} else {
231241
LOG(FATAL) << "No lower function found for gemm_py";
232242
}

src/transform/inject_pipeline.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,12 @@ class PipelineRewriter : public StmtExprMutator {
675675
}
676676
new_block = Downcast<Block>(Substitute(
677677
new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
678+
679+
Array<Array<BufferRegion>> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
680+
BlockNode* n = new_block.CopyOnWrite();
681+
n->reads = access[0];
682+
n->writes = access[1];
683+
678684
if (pipeline_info_[block].async) {
679685
auto &local_state = async_states_local[stage];
680686
local_state.producer_head = normalized_access_index;

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py

Lines changed: 163 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,13 @@ def main(
4646
else:
4747
T.copy(B[k * block_K, bx * block_N], B_shared)
4848
T.gemm_v2(A_shared, B_shared, C_local, trans_A, trans_B)
49+
# T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
4950
T.copy(C_local, C[by * block_M, bx * block_N])
5051

5152
return main
5253

5354

54-
def run_gemm(
55+
def run_gemm_ss(
5556
M,
5657
N,
5758
K,
@@ -106,13 +107,13 @@ def ref_program(A, B):
106107
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
107108

108109

109-
def test_gemm():
110+
def test_gemm_ss():
110111
# More test case can be found in kernel/test_tilelang_kernel_gemm.py
111112
# GEMM tests for float16
112-
run_gemm(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 0)
113-
run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0)
114-
run_gemm(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 0)
115-
run_gemm(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 0)
113+
run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 0)
114+
run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0)
115+
run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 0)
116+
run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 0)
116117

117118

118119

@@ -165,7 +166,6 @@ def main(
165166
T.copy(B[k * block_K, bx * block_N], B_shared)
166167
T.copy(A_shared, A_frag)
167168
T.gemm_v2(A_frag, B_shared, C_local, trans_A, trans_B)
168-
# T.gemm(A_frag, B_shared, C_local, trans_A, trans_B)
169169
T.copy(C_local, C[by * block_M, bx * block_N])
170170

171171
return main
@@ -228,8 +228,8 @@ def ref_program(A, B):
228228

229229
def test_gemm_rs():
230230
# GEMM tests for float16
231-
run_gemm_rs(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
232-
run_gemm_rs(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2)
231+
run_gemm_rs(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 0)
232+
run_gemm_rs(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 0)
233233

234234

235235
def matmul_sr(
@@ -280,7 +280,10 @@ def main(
280280
else:
281281
T.copy(B[k * block_K, bx * block_N], B_shared)
282282
T.copy(B_shared, B_frag)
283-
T.gemm_v2(A_shared, B_frag, C_local, trans_A, trans_B)
283+
# for i, j in T.Parallel(block_N, block_K):
284+
# B_frag[i, j] = B_shared[j, i]
285+
# T.gemm_v2(A_shared, B_frag, C_local, trans_A, trans_B)
286+
T.gemm(A_shared, B_frag, C_local, trans_A, trans_B)
284287
T.copy(C_local, C[by * block_M, bx * block_N])
285288

286289
return main
@@ -345,21 +348,160 @@ def test_gemm_sr():
345348
# GEMM tests for float16
346349
run_gemm_sr(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
347350
run_gemm_sr(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2)
351+
run_gemm_sr(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2)
352+
run_gemm_sr(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2)
353+
354+
355+
def matmul_rr(
356+
M,
357+
N,
358+
K,
359+
block_M,
360+
block_N,
361+
block_K,
362+
trans_A,
363+
trans_B,
364+
in_dtype,
365+
out_dtype,
366+
accum_dtype,
367+
num_stages,
368+
threads,
369+
):
370+
A_shape = (K, M) if trans_A else (M, K)
371+
B_shape = (N, K) if trans_B else (K, N)
372+
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
373+
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
374+
A_frag_shape = A_shared_shape
375+
B_frag_shape = B_shared_shape
376+
377+
import tilelang.language as T
378+
379+
@T.prim_func
380+
def main(
381+
A: T.Tensor(A_shape, in_dtype),
382+
B: T.Tensor(B_shape, in_dtype),
383+
C: T.Tensor((M, N), out_dtype),
384+
):
385+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
386+
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared")
387+
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared")
388+
A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
389+
B_frag = T.alloc_fragment(B_frag_shape, in_dtype)
390+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
391+
T.clear(C_local)
392+
T.annotate_layout({
393+
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
394+
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
395+
})
396+
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
397+
if trans_A:
398+
T.copy(A[k * block_K, by * block_M], A_shared)
399+
else:
400+
T.copy(A[by * block_M, k * block_K], A_shared)
401+
if trans_B:
402+
T.copy(B[bx * block_N, k * block_K], B_shared)
403+
else:
404+
T.copy(B[k * block_K, bx * block_N], B_shared)
405+
T.copy(A_shared, A_frag)
406+
T.copy(B_shared, B_frag)
407+
T.gemm_v2(A_frag, B_frag, C_local, trans_A, trans_B)
408+
T.copy(C_local, C[by * block_M, bx * block_N])
409+
410+
return main
411+
412+
413+
def run_gemm_rr(
414+
M,
415+
N,
416+
K,
417+
trans_A,
418+
trans_B,
419+
in_dtype,
420+
out_dtype,
421+
dtypeAccum,
422+
block_M,
423+
block_N,
424+
block_K,
425+
num_stages=3,
426+
num_threads=128,
427+
):
428+
program = matmul_rr(
429+
M,
430+
N,
431+
K,
432+
block_M,
433+
block_N,
434+
block_K,
435+
trans_A,
436+
trans_B,
437+
in_dtype,
438+
out_dtype,
439+
dtypeAccum,
440+
num_stages,
441+
num_threads,
442+
)
443+
444+
kernel = tilelang.compile(
445+
program,
446+
out_idx=[2],
447+
pass_configs={
448+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
449+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
450+
})
451+
print(kernel.get_kernel_source())
452+
profiler = kernel.get_profiler()
453+
454+
def ref_program(A, B):
455+
import torch
456+
457+
if trans_A:
458+
A = A.T
459+
if trans_B:
460+
B = B.T
461+
C = torch.matmul(A.to(torch.float), B.to(torch.float))
462+
C = C.to(torch.__getattribute__(out_dtype))
463+
return C
464+
465+
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
466+
467+
468+
def test_gemm_rr():
469+
# GEMM tests for float16
470+
run_gemm_rr(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
471+
run_gemm_rr(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2)
472+
run_gemm_rr(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2)
473+
run_gemm_rr(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2)
474+
run_gemm_rr(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2)
348475

349476

350477
if __name__ == "__main__":
351478
# tilelang.testing.main()
352479
tilelang.disable_cache()
353-
tilelang.testing.set_random_seed(42)
354-
# run_gemm(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 0)
480+
# test_gemm_ss()
481+
run_gemm_sr(128, 128, 128, False, False, "float16", "float16", "float16", 128, 128, 32, 2)
482+
# tilelang.testing.set_random_seed(42)
483+
# run_gemm_ss(128, 128, 128, False, True, "float16", "float16", "float16", 128, 128, 32, 1)
355484
# print("gemm fp16 nt ss done")
356-
# run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0)
357-
# print("gemm fp16 nn ss done")
358-
# run_gemm(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 0)
359-
# print("gemm fp16 tn ss done")
360-
# run_gemm(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 0)
361-
# print("gemm fp16 tt ss done")
362-
# run_gemm_rs(64, 64, 32, False, True, "float16", "float16", "float16", 64, 64, 32, 0, 128)
485+
# exit()
486+
487+
# run_gemm_rs(128, 128, 32, False, True, "float16", "float16", "float16", 128, 128, 32, 0)
363488
# print("gemm fp16 nt rs done")
364-
run_gemm_rs(64, 64, 32, False, True, "float16", "float16", "float16", 64, 64, 32, 0, 128)
365-
# run_gemm(64, 64, 32, False, True, "float16", "float16", "float16", 64, 64, 32, 0, 128)
489+
# run_gemm_rs(128, 128, 32, False, False, "float16", "float16", "float16", 128, 128, 32, 0)
490+
# print("gemm fp16 nn rs done")
491+
# run_gemm_rs(128, 128, 32, True, False, "float16", "float16", "float16", 128, 128, 32, 0)
492+
# print("gemm fp16 tn rs done")
493+
# run_gemm_rs(128, 128, 32, True, True, "float16", "float16", "float16", 128, 128, 32, 0)
494+
# print("gemm fp16 tt rs done")
495+
496+
# run_gemm_rs(16, 16, 16, True, False, "float16", "float16", "float16", 16, 16, 16, 0, 32)
497+
498+
# run_gemm_rr(128, 128, 32, False, False, "bfloat16", "bfloat16", "float", 128, 128, 32, 0)
499+
# print("gemm bf16 nn rr done")
500+
# run_gemm_rr(128, 128, 32, False, True, "bfloat16", "bfloat16", "float", 128, 128, 32, 0)
501+
# print("gemm bf16 nt rr done")
502+
# run_gemm_rr(128, 128, 32, True, False, "bfloat16", "bfloat16", "float", 128, 128, 32, 0)
503+
# print("gemm bf16 tn rr done")
504+
# run_gemm_rr(128, 128, 32, True, True, "bfloat16", "bfloat16", "float", 128, 128, 32, 0)
505+
# print("gemm bf16 tt rr done")
506+
507+

tilelang/engine/phase.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,11 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
140140
mod = tilelang.transform.IfStmtBinding()(mod)
141141
mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
142142
mod = tilelang.transform.PipelinePlanning()(mod)
143+
print("after pipeline planning")
144+
print(mod)
143145
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
146+
print("after inject software pipeline")
147+
print(mod)
144148
mod = tilelang.transform.MergeIfStmt()(mod)
145149
if allow_fence_proxy(target=target):
146150
# in hopper device, wgmma is an async proxy

tilelang/intrinsics/mma_layout.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,26 @@ def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id):
4747

4848
# sr represents spatial + reduction layout
4949
# the first axis is spatial while the second axis is reduction
50-
def shared_16x16_to_mma_32x8_layout_sr(i, j):
50+
# mma.sync matrix A layout, if wanna trans, please apply map_indices
51+
def shared_16x16_to_mma_a_32x8_layout(i, j):
5152
thread_id = 4 * (i % 8) + (j % 8) // 2
5253
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2)
5354

55+
def shared_16x16_to_mma_a_32x8_layout_trans(i, j):
56+
return shared_16x16_to_mma_a_32x8_layout(j, i)
5457

55-
def shared_16x16_to_mma_32x8_layout_rs(i, j):
56-
thread_id = 4 * (j % 8) + (i % 8) // 2
57-
return thread_id, 4 * (i // 8) + (j // 8) * 2 + (i % 2)
58+
# mma.sync matrix B layout, if wanna trans, please apply map_indices
59+
def shared_16x16_to_mma_b_32x8_layout(i, j):
60+
thread_id = 4 * (i % 8) + (j % 8) // 2
61+
return thread_id, 4 * (i // 8) + (j // 8) * 2 + (j % 2)
5862

63+
def shared_16x16_to_mma_b_32x8_layout_trans(i, j):
64+
return shared_16x16_to_mma_b_32x8_layout(j, i)
5965

60-
shared_16x16_to_mma_32x8_layout = shared_16x16_to_mma_32x8_layout_sr
61-
shared_16x16_to_mma_32x8_layout_trans = shared_16x16_to_mma_32x8_layout_rs
66+
shared_16x16_to_mma_32x8_layout_sr_a = shared_16x16_to_mma_a_32x8_layout
67+
shared_16x16_to_mma_32x8_layout_sr_b = shared_16x16_to_mma_b_32x8_layout
68+
shared_16x16_to_mma_32x8_layout_rs_a = shared_16x16_to_mma_a_32x8_layout_trans
69+
shared_16x16_to_mma_32x8_layout_rs_b = shared_16x16_to_mma_b_32x8_layout_trans
6270

6371

6472
def shared_16x32_to_mma_32x16_layout(i, j):

0 commit comments

Comments
 (0)