@@ -46,12 +46,13 @@ def main(
4646 else :
4747 T .copy (B [k * block_K , bx * block_N ], B_shared )
4848 T .gemm_v2 (A_shared , B_shared , C_local , trans_A , trans_B )
49+ # T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
4950 T .copy (C_local , C [by * block_M , bx * block_N ])
5051
5152 return main
5253
5354
54- def run_gemm (
55+ def run_gemm_ss (
5556 M ,
5657 N ,
5758 K ,
@@ -106,13 +107,13 @@ def ref_program(A, B):
106107 profiler .assert_allclose (ref_program , atol = 1e-2 , rtol = 1e-2 )
107108
108109
109- def test_gemm ():
110+ def test_gemm_ss ():
110111 # More test case can be found in kernel/test_tilelang_kernel_gemm.py
111112 # GEMM tests for float16
112- run_gemm (512 , 1024 , 768 , False , True , "float16" , "float16" , "float16" , 128 , 128 , 32 , 0 )
113- run_gemm (512 , 1024 , 768 , False , False , "float16" , "float16" , "float16" , 128 , 128 , 32 , 0 )
114- run_gemm (512 , 1024 , 768 , True , False , "float16" , "float16" , "float16" , 128 , 128 , 32 , 0 )
115- run_gemm (512 , 1024 , 768 , True , True , "float16" , "float16" , "float16" , 128 , 128 , 32 , 0 )
113+ run_gemm_ss (512 , 1024 , 768 , False , True , "float16" , "float16" , "float16" , 128 , 128 , 32 , 0 )
114+ run_gemm_ss (512 , 1024 , 768 , False , False , "float16" , "float16" , "float16" , 128 , 128 , 32 , 0 )
115+ run_gemm_ss (512 , 1024 , 768 , True , False , "float16" , "float16" , "float16" , 128 , 128 , 32 , 0 )
116+ run_gemm_ss (512 , 1024 , 768 , True , True , "float16" , "float16" , "float16" , 128 , 128 , 32 , 0 )
116117
117118
118119
@@ -165,7 +166,6 @@ def main(
165166 T .copy (B [k * block_K , bx * block_N ], B_shared )
166167 T .copy (A_shared , A_frag )
167168 T .gemm_v2 (A_frag , B_shared , C_local , trans_A , trans_B )
168- # T.gemm(A_frag, B_shared, C_local, trans_A, trans_B)
169169 T .copy (C_local , C [by * block_M , bx * block_N ])
170170
171171 return main
@@ -228,8 +228,8 @@ def ref_program(A, B):
228228
229229def test_gemm_rs ():
230230 # GEMM tests for float16
231- run_gemm_rs (512 , 1024 , 768 , False , False , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
232- run_gemm_rs (512 , 1024 , 768 , False , True , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
231+ run_gemm_rs (512 , 1024 , 768 , False , False , "float16" , "float16" , "float16" , 128 , 256 , 32 , 0 )
232+ run_gemm_rs (512 , 1024 , 768 , False , True , "float16" , "float16" , "float16" , 128 , 256 , 32 , 0 )
233233
234234
235235def matmul_sr (
@@ -280,7 +280,10 @@ def main(
280280 else :
281281 T .copy (B [k * block_K , bx * block_N ], B_shared )
282282 T .copy (B_shared , B_frag )
283- T .gemm_v2 (A_shared , B_frag , C_local , trans_A , trans_B )
283+ # for i, j in T.Parallel(block_N, block_K):
284+ # B_frag[i, j] = B_shared[j, i]
285+ # T.gemm_v2(A_shared, B_frag, C_local, trans_A, trans_B)
286+ T .gemm (A_shared , B_frag , C_local , trans_A , trans_B )
284287 T .copy (C_local , C [by * block_M , bx * block_N ])
285288
286289 return main
@@ -345,21 +348,160 @@ def test_gemm_sr():
345348 # GEMM tests for float16
346349 run_gemm_sr (512 , 1024 , 768 , False , False , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
347350 run_gemm_sr (512 , 1024 , 768 , False , True , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
351+ run_gemm_sr (512 , 1024 , 768 , True , False , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
352+ run_gemm_sr (512 , 1024 , 768 , True , True , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
353+
354+
355+ def matmul_rr (
356+ M ,
357+ N ,
358+ K ,
359+ block_M ,
360+ block_N ,
361+ block_K ,
362+ trans_A ,
363+ trans_B ,
364+ in_dtype ,
365+ out_dtype ,
366+ accum_dtype ,
367+ num_stages ,
368+ threads ,
369+ ):
370+ A_shape = (K , M ) if trans_A else (M , K )
371+ B_shape = (N , K ) if trans_B else (K , N )
372+ A_shared_shape = (block_K , block_M ) if trans_A else (block_M , block_K )
373+ B_shared_shape = (block_N , block_K ) if trans_B else (block_K , block_N )
374+ A_frag_shape = A_shared_shape
375+ B_frag_shape = B_shared_shape
376+
377+ import tilelang .language as T
378+
379+ @T .prim_func
380+ def main (
381+ A : T .Tensor (A_shape , in_dtype ),
382+ B : T .Tensor (B_shape , in_dtype ),
383+ C : T .Tensor ((M , N ), out_dtype ),
384+ ):
385+ with T .Kernel (T .ceildiv (N , block_N ), T .ceildiv (M , block_M ), threads = threads ) as (bx , by ):
386+ A_shared = T .alloc_shared (A_shared_shape , in_dtype , scope = "shared" )
387+ B_shared = T .alloc_shared (B_shared_shape , in_dtype , scope = "shared" )
388+ A_frag = T .alloc_fragment (A_frag_shape , in_dtype )
389+ B_frag = T .alloc_fragment (B_frag_shape , in_dtype )
390+ C_local = T .alloc_fragment ((block_M , block_N ), accum_dtype )
391+ T .clear (C_local )
392+ T .annotate_layout ({
393+ A_shared : tilelang .layout .make_swizzled_layout (A_shared ),
394+ B_shared : tilelang .layout .make_swizzled_layout (B_shared ),
395+ })
396+ for k in T .Pipelined (T .ceildiv (K , block_K ), num_stages = num_stages ):
397+ if trans_A :
398+ T .copy (A [k * block_K , by * block_M ], A_shared )
399+ else :
400+ T .copy (A [by * block_M , k * block_K ], A_shared )
401+ if trans_B :
402+ T .copy (B [bx * block_N , k * block_K ], B_shared )
403+ else :
404+ T .copy (B [k * block_K , bx * block_N ], B_shared )
405+ T .copy (A_shared , A_frag )
406+ T .copy (B_shared , B_frag )
407+ T .gemm_v2 (A_frag , B_frag , C_local , trans_A , trans_B )
408+ T .copy (C_local , C [by * block_M , bx * block_N ])
409+
410+ return main
411+
412+
413+ def run_gemm_rr (
414+ M ,
415+ N ,
416+ K ,
417+ trans_A ,
418+ trans_B ,
419+ in_dtype ,
420+ out_dtype ,
421+ dtypeAccum ,
422+ block_M ,
423+ block_N ,
424+ block_K ,
425+ num_stages = 3 ,
426+ num_threads = 128 ,
427+ ):
428+ program = matmul_rr (
429+ M ,
430+ N ,
431+ K ,
432+ block_M ,
433+ block_N ,
434+ block_K ,
435+ trans_A ,
436+ trans_B ,
437+ in_dtype ,
438+ out_dtype ,
439+ dtypeAccum ,
440+ num_stages ,
441+ num_threads ,
442+ )
443+
444+ kernel = tilelang .compile (
445+ program ,
446+ out_idx = [2 ],
447+ pass_configs = {
448+ tilelang .PassConfigKey .TL_DISABLE_TMA_LOWER : True ,
449+ tilelang .PassConfigKey .TL_DISABLE_WARP_SPECIALIZED : True ,
450+ })
451+ print (kernel .get_kernel_source ())
452+ profiler = kernel .get_profiler ()
453+
454+ def ref_program (A , B ):
455+ import torch
456+
457+ if trans_A :
458+ A = A .T
459+ if trans_B :
460+ B = B .T
461+ C = torch .matmul (A .to (torch .float ), B .to (torch .float ))
462+ C = C .to (torch .__getattribute__ (out_dtype ))
463+ return C
464+
465+ profiler .assert_allclose (ref_program , atol = 1e-2 , rtol = 1e-2 )
466+
467+
468+ def test_gemm_rr ():
469+ # GEMM tests for float16
470+ run_gemm_rr (512 , 1024 , 768 , False , False , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
471+ run_gemm_rr (512 , 1024 , 768 , False , True , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
472+ run_gemm_rr (512 , 1024 , 768 , True , False , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
473+ run_gemm_rr (512 , 1024 , 768 , True , True , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
474+ run_gemm_rr (512 , 1024 , 768 , False , True , "bfloat16" , "bfloat16" , "float" , 128 , 256 , 32 , 2 )
348475
349476
350477if __name__ == "__main__" :
351478 # tilelang.testing.main()
352479 tilelang .disable_cache ()
353- tilelang .testing .set_random_seed (42 )
354- # run_gemm(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 0)
480+ # test_gemm_ss()
481+ run_gemm_sr (128 , 128 , 128 , False , False , "float16" , "float16" , "float16" , 128 , 128 , 32 , 2 )
482+ # tilelang.testing.set_random_seed(42)
483+ # run_gemm_ss(128, 128, 128, False, True, "float16", "float16", "float16", 128, 128, 32, 1)
355484 # print("gemm fp16 nt ss done")
356- # run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0)
357- # print("gemm fp16 nn ss done")
358- # run_gemm(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 0)
359- # print("gemm fp16 tn ss done")
360- # run_gemm(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 0)
361- # print("gemm fp16 tt ss done")
362- # run_gemm_rs(64, 64, 32, False, True, "float16", "float16", "float16", 64, 64, 32, 0, 128)
485+ # exit()
486+
487+ # run_gemm_rs(128, 128, 32, False, True, "float16", "float16", "float16", 128, 128, 32, 0)
363488 # print("gemm fp16 nt rs done")
364- run_gemm_rs (64 , 64 , 32 , False , True , "float16" , "float16" , "float16" , 64 , 64 , 32 , 0 , 128 )
365- # run_gemm(64, 64, 32, False, True, "float16", "float16", "float16", 64, 64, 32, 0, 128)
489+ # run_gemm_rs(128, 128, 32, False, False, "float16", "float16", "float16", 128, 128, 32, 0)
490+ # print("gemm fp16 nn rs done")
491+ # run_gemm_rs(128, 128, 32, True, False, "float16", "float16", "float16", 128, 128, 32, 0)
492+ # print("gemm fp16 tn rs done")
493+ # run_gemm_rs(128, 128, 32, True, True, "float16", "float16", "float16", 128, 128, 32, 0)
494+ # print("gemm fp16 tt rs done")
495+
496+ # run_gemm_rs(16, 16, 16, True, False, "float16", "float16", "float16", 16, 16, 16, 0, 32)
497+
498+ # run_gemm_rr(128, 128, 32, False, False, "bfloat16", "bfloat16", "float", 128, 128, 32, 0)
499+ # print("gemm bf16 nn rr done")
500+ # run_gemm_rr(128, 128, 32, False, True, "bfloat16", "bfloat16", "float", 128, 128, 32, 0)
501+ # print("gemm bf16 nt rr done")
502+ # run_gemm_rr(128, 128, 32, True, False, "bfloat16", "bfloat16", "float", 128, 128, 32, 0)
503+ # print("gemm bf16 tn rr done")
504+ # run_gemm_rr(128, 128, 32, True, True, "bfloat16", "bfloat16", "float", 128, 128, 32, 0)
505+ # print("gemm bf16 tt rr done")
506+
507+
0 commit comments