@@ -31,7 +31,8 @@ def before():
3131 C_local = T .decl_buffer ((32 ,), scope = "local" )
3232 for i in T .unroll (16 ):
3333 C_local [i * 2 :i * 2 + 2 ] = T .Broadcast (T .float32 (0 ), 2 )
34- T .call_extern ("handle" , "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>" ,
34+ T .call_intrin ("handle" , tir .op .Op .get ("tl.tl_gemm" ),
35+ "tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>" ,
3536 T .tvm_access_ptr (T .type_annotation ("float16" ), A_shared .data , 0 , 2048 , 1 ),
3637 T .tvm_access_ptr (T .type_annotation ("float16" ), B_shared .data , 0 , 2048 , 1 ),
3738 T .tvm_access_ptr (T .type_annotation ("float32" ), C_local .data , 0 , 32 , 3 ))
@@ -45,7 +46,8 @@ def after():
4546 for i in T .unroll (16 ):
4647 C_local [i * 2 :i * 2 + 2 ] = T .Broadcast (T .float32 (0 ), 2 )
4748 T .fence_proxy_async ()
48- T .call_extern ("handle" , "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>" ,
49+ T .call_intrin ("handle" , tir .op .Op .get ("tl.tl_gemm" ),
50+ "tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>" ,
4951 T .tvm_access_ptr (T .type_annotation ("float16" ), A_shared .data , 0 , 2048 , 1 ),
5052 T .tvm_access_ptr (T .type_annotation ("float16" ), B_shared .data , 0 , 2048 , 1 ),
5153 T .tvm_access_ptr (T .type_annotation ("float32" ), C_local .data , 0 , 32 , 3 ))
@@ -169,7 +171,7 @@ def before():
169171 mod = tvm .IRModule .from_expr (before .with_attr ("global_symbol" , "main" ))
170172 mod = tvm .tir .transform .BindTarget (auto_target )(mod )
171173 mod = tl .transform .InjectFenceProxy ()(mod )
172-
174+ print ( mod )
173175 order = []
174176
175177 def visit (node ):
@@ -185,43 +187,5 @@ def visit(node):
185187 assert order .index ("tl.fence_proxy_async" ) < order .index ("tl.ptx_wgmma_ss" )
186188
187189
188- def test_wgmma_after_descriptor ():
189-
190- @T .prim_func
191- def before ():
192- with T .Kernel (1 ):
193- desc_a = T .decl_buffer ((1 ,), "uint64" , scope = "local.descriptor" )
194- desc_b = T .decl_buffer ((1 ,), "uint64" , scope = "local.descriptor" )
195- C_local = T .decl_buffer ((32 ,), "float16" , scope = "local" )
196- T .initialize_descriptor (desc_a , T .uint64 (0 ), 2 , 1 , 32 )
197- T .initialize_descriptor (desc_b , T .uint64 (0 ), 2 , 1 , 32 )
198- T .warpgroup_arrive ()
199- T .ptx_wgmma_ss ("float16" , "m64n64k16" , T .bool (True ), T .bool (True ), "fp16" , "fp16" ,
200- "fp16" , desc_a .data , T .int32 (0 ), desc_b .data , T .int32 (0 ), C_local .data ,
201- T .int32 (0 ), T .bool (True ), 1 , 1 )
202-
203- mod = tvm .IRModule .from_expr (before .with_attr ("global_symbol" , "main" ))
204- mod = tvm .tir .transform .BindTarget (auto_target )(mod )
205- mod = tl .transform .InjectFenceProxy ()(mod )
206-
207- fence_count = 0
208- order = []
209-
210- def visit (node ):
211- nonlocal fence_count
212- if isinstance (node , tir .Evaluate ):
213- call = node .value
214- if isinstance (call , tir .Call ):
215- name = getattr (call .op , "name" , "" )
216- order .append (name )
217- if name == "tl.fence_proxy_async" :
218- fence_count += 1
219-
220- tir .stmt_functor .post_order_visit (mod ["main" ].body , visit )
221- assert fence_count >= 1
222- assert "tl.warpgroup_arrive" in order
223- assert order .index ("tl.fence_proxy_async" ) < order .index ("tl.warpgroup_arrive" )
224-
225-
226190if __name__ == "__main__" :
227191 tilelang .testing .main ()
0 commit comments