1+ from asyncio import threads
12from tilelang import tvm as tvm
23import tilelang .testing
34
@@ -31,8 +32,8 @@ def main(
3132 C : T .Tensor ((M , N ), out_dtype ),
3233 ):
3334 with T .Kernel (T .ceildiv (N , block_N ), T .ceildiv (M , block_M ), threads = threads ) as (bx , by ):
34- A_shared = T .alloc_shared (A_shared_shape , in_dtype )
35- B_shared = T .alloc_shared (B_shared_shape , in_dtype )
35+ A_shared = T .alloc_shared (A_shared_shape , in_dtype , scope = "shared" )
36+ B_shared = T .alloc_shared (B_shared_shape , in_dtype , scope = "shared" )
3637 C_local = T .alloc_fragment ((block_M , block_N ), accum_dtype )
3738 T .clear (C_local )
3839 for k in T .Pipelined (T .ceildiv (K , block_K ), num_stages = num_stages ):
@@ -108,8 +109,11 @@ def ref_program(A, B):
108109def test_gemm ():
109110 # More test case can be found in kernel/test_tilelang_kernel_gemm.py
110111 # GEMM tests for float16
111- run_gemm (512 , 1024 , 768 , False , False , "float16" , "float16" , "float16" , 128 , 256 , 32 ,
112- 2 ) # f16f16f16_nn
112+ run_gemm (512 , 1024 , 768 , False , True , "float16" , "float16" , "float16" , 128 , 128 , 32 , 0 )
113+ run_gemm (512 , 1024 , 768 , False , False , "float16" , "float16" , "float16" , 128 , 128 , 32 , 0 )
114+ run_gemm (512 , 1024 , 768 , True , False , "float16" , "float16" , "float16" , 128 , 128 , 32 , 0 )
115+ run_gemm (512 , 1024 , 768 , True , True , "float16" , "float16" , "float16" , 128 , 128 , 32 , 0 )
116+
113117
114118
115119def matmul_rs (
@@ -142,23 +146,26 @@ def main(
142146 C : T .Tensor ((M , N ), out_dtype ),
143147 ):
144148 with T .Kernel (T .ceildiv (N , block_N ), T .ceildiv (M , block_M ), threads = threads ) as (bx , by ):
145- A_shared = T .alloc_shared (A_shared_shape , in_dtype )
146- B_shared = T .alloc_shared (B_shared_shape , in_dtype )
149+ A_shared = T .alloc_shared (A_shared_shape , in_dtype , scope = "shared" )
150+ B_shared = T .alloc_shared (B_shared_shape , in_dtype , scope = "shared" )
147151 A_frag = T .alloc_fragment (A_frag_shape , in_dtype )
148152 C_local = T .alloc_fragment ((block_M , block_N ), accum_dtype )
149153 T .clear (C_local )
154+ T .annotate_layout ({
155+ A_shared : tilelang .layout .make_swizzled_layout (A_shared ),
156+ })
150157 for k in T .Pipelined (T .ceildiv (K , block_K ), num_stages = num_stages ):
151158 if trans_A :
152159 T .copy (A [k * block_K , by * block_M ], A_shared )
153- T .copy (A_shared , A_frag )
154160 else :
155161 T .copy (A [by * block_M , k * block_K ], A_shared )
156- T .copy (A_shared , A_frag )
157162 if trans_B :
158163 T .copy (B [bx * block_N , k * block_K ], B_shared )
159164 else :
160165 T .copy (B [k * block_K , bx * block_N ], B_shared )
166+ T .copy (A_shared , A_frag )
161167 T .gemm_v2 (A_frag , B_shared , C_local , trans_A , trans_B )
168+ # T.gemm(A_frag, B_shared, C_local, trans_A, trans_B)
162169 T .copy (C_local , C [by * block_M , bx * block_N ])
163170
164171 return main
@@ -202,6 +209,7 @@ def run_gemm_rs(
202209 tilelang .PassConfigKey .TL_DISABLE_TMA_LOWER : True ,
203210 tilelang .PassConfigKey .TL_DISABLE_WARP_SPECIALIZED : True ,
204211 })
212+ print (kernel .get_kernel_source ())
205213 profiler = kernel .get_profiler ()
206214
207215 def ref_program (A , B ):
@@ -224,10 +232,134 @@ def test_gemm_rs():
224232 run_gemm_rs (512 , 1024 , 768 , False , True , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
225233
226234
235+ def matmul_sr (
236+ M ,
237+ N ,
238+ K ,
239+ block_M ,
240+ block_N ,
241+ block_K ,
242+ trans_A ,
243+ trans_B ,
244+ in_dtype ,
245+ out_dtype ,
246+ accum_dtype ,
247+ num_stages ,
248+ threads ,
249+ ):
250+ A_shape = (K , M ) if trans_A else (M , K )
251+ B_shape = (N , K ) if trans_B else (K , N )
252+ A_shared_shape = (block_K , block_M ) if trans_A else (block_M , block_K )
253+ B_shared_shape = (block_N , block_K ) if trans_B else (block_K , block_N )
254+ B_frag_shape = B_shared_shape
255+
256+ import tilelang .language as T
257+
258+ @T .prim_func
259+ def main (
260+ A : T .Tensor (A_shape , in_dtype ),
261+ B : T .Tensor (B_shape , in_dtype ),
262+ C : T .Tensor ((M , N ), out_dtype ),
263+ ):
264+ with T .Kernel (T .ceildiv (N , block_N ), T .ceildiv (M , block_M ), threads = threads ) as (bx , by ):
265+ A_shared = T .alloc_shared (A_shared_shape , in_dtype , scope = "shared" )
266+ B_shared = T .alloc_shared (B_shared_shape , in_dtype , scope = "shared" )
267+ B_frag = T .alloc_fragment (B_frag_shape , in_dtype )
268+ C_local = T .alloc_fragment ((block_M , block_N ), accum_dtype )
269+ T .clear (C_local )
270+ T .annotate_layout ({
271+ B_shared : tilelang .layout .make_swizzled_layout (B_shared ),
272+ })
273+ for k in T .Pipelined (T .ceildiv (K , block_K ), num_stages = num_stages ):
274+ if trans_A :
275+ T .copy (A [k * block_K , by * block_M ], A_shared )
276+ else :
277+ T .copy (A [by * block_M , k * block_K ], A_shared )
278+ if trans_B :
279+ T .copy (B [bx * block_N , k * block_K ], B_shared )
280+ else :
281+ T .copy (B [k * block_K , bx * block_N ], B_shared )
282+ T .copy (B_shared , B_frag )
283+ T .gemm_v2 (A_shared , B_frag , C_local , trans_A , trans_B )
284+ T .copy (C_local , C [by * block_M , bx * block_N ])
285+
286+ return main
287+
288+
289+ def run_gemm_sr (
290+ M ,
291+ N ,
292+ K ,
293+ trans_A ,
294+ trans_B ,
295+ in_dtype ,
296+ out_dtype ,
297+ dtypeAccum ,
298+ block_M ,
299+ block_N ,
300+ block_K ,
301+ num_stages = 3 ,
302+ num_threads = 128 ,
303+ ):
304+ program = matmul_sr (
305+ M ,
306+ N ,
307+ K ,
308+ block_M ,
309+ block_N ,
310+ block_K ,
311+ trans_A ,
312+ trans_B ,
313+ in_dtype ,
314+ out_dtype ,
315+ dtypeAccum ,
316+ num_stages ,
317+ num_threads ,
318+ )
319+
320+ kernel = tilelang .compile (
321+ program ,
322+ out_idx = [2 ],
323+ pass_configs = {
324+ tilelang .PassConfigKey .TL_DISABLE_TMA_LOWER : True ,
325+ tilelang .PassConfigKey .TL_DISABLE_WARP_SPECIALIZED : True ,
326+ })
327+ print (kernel .get_kernel_source ())
328+ profiler = kernel .get_profiler ()
329+
330+ def ref_program (A , B ):
331+ import torch
332+
333+ if trans_A :
334+ A = A .T
335+ if trans_B :
336+ B = B .T
337+ C = torch .matmul (A .to (torch .float ), B .to (torch .float ))
338+ C = C .to (torch .__getattribute__ (out_dtype ))
339+ return C
340+
341+ profiler .assert_allclose (ref_program , atol = 1e-2 , rtol = 1e-2 )
342+
343+
344+ def test_gemm_sr ():
345+ # GEMM tests for float16
346+ run_gemm_sr (512 , 1024 , 768 , False , False , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
347+ run_gemm_sr (512 , 1024 , 768 , False , True , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
348+
349+
227350if __name__ == "__main__" :
228351 # tilelang.testing.main()
229352 tilelang .disable_cache ()
353+ tilelang .testing .set_random_seed (42 )
230354 # run_gemm(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 0)
231355 # print("gemm fp16 nt ss done")
232- run_gemm (512 , 1024 , 768 , False , False , "float16" , "float16" , "float16" , 128 , 128 , 32 , 0 )
233- print ("gemm fp16 nn ss done" )
356+ # run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0)
357+ # print("gemm fp16 nn ss done")
358+ # run_gemm(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 0)
359+ # print("gemm fp16 tn ss done")
360+ # run_gemm(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 0)
361+ # print("gemm fp16 tt ss done")
362+ # run_gemm_rs(64, 64, 32, False, True, "float16", "float16", "float16", 64, 64, 32, 0, 128)
363+ # print("gemm fp16 nt rs done")
364+ run_gemm_rs (64 , 64 , 32 , False , True , "float16" , "float16" , "float16" , 64 , 64 , 32 , 0 , 128 )
365+ # run_gemm(64, 64, 32, False, True, "float16", "float16", "float16", 64, 64, 32, 0, 128)
0 commit comments