Skip to content

Commit c3858df

Browse files
committed
test fix
1 parent aad2051 commit c3858df

File tree

1 file changed

+5
-41
lines changed

1 file changed

+5
-41
lines changed

testing/python/transform/test_tilelang_transform_inject_fence_proxy.py

Lines changed: 5 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def before():
3131
C_local = T.decl_buffer((32,), scope="local")
3232
for i in T.unroll(16):
3333
C_local[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0), 2)
34-
T.call_extern("handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
34+
T.call_intrin("handle", tir.op.Op.get("tl.tl_gemm"),
35+
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>",
3536
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
3637
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
3738
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))
@@ -45,7 +46,8 @@ def after():
4546
for i in T.unroll(16):
4647
C_local[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0), 2)
4748
T.fence_proxy_async()
48-
T.call_extern("handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
49+
T.call_intrin("handle", tir.op.Op.get("tl.tl_gemm"),
50+
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>",
4951
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
5052
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
5153
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))
@@ -169,7 +171,7 @@ def before():
169171
mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
170172
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
171173
mod = tl.transform.InjectFenceProxy()(mod)
172-
174+
print(mod)
173175
order = []
174176

175177
def visit(node):
@@ -185,43 +187,5 @@ def visit(node):
185187
assert order.index("tl.fence_proxy_async") < order.index("tl.ptx_wgmma_ss")
186188

187189

188-
def test_wgmma_after_descriptor():
189-
190-
@T.prim_func
191-
def before():
192-
with T.Kernel(1):
193-
desc_a = T.decl_buffer((1,), "uint64", scope="local.descriptor")
194-
desc_b = T.decl_buffer((1,), "uint64", scope="local.descriptor")
195-
C_local = T.decl_buffer((32,), "float16", scope="local")
196-
T.initialize_descriptor(desc_a, T.uint64(0), 2, 1, 32)
197-
T.initialize_descriptor(desc_b, T.uint64(0), 2, 1, 32)
198-
T.warpgroup_arrive()
199-
T.ptx_wgmma_ss("float16", "m64n64k16", T.bool(True), T.bool(True), "fp16", "fp16",
200-
"fp16", desc_a.data, T.int32(0), desc_b.data, T.int32(0), C_local.data,
201-
T.int32(0), T.bool(True), 1, 1)
202-
203-
mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
204-
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
205-
mod = tl.transform.InjectFenceProxy()(mod)
206-
207-
fence_count = 0
208-
order = []
209-
210-
def visit(node):
211-
nonlocal fence_count
212-
if isinstance(node, tir.Evaluate):
213-
call = node.value
214-
if isinstance(call, tir.Call):
215-
name = getattr(call.op, "name", "")
216-
order.append(name)
217-
if name == "tl.fence_proxy_async":
218-
fence_count += 1
219-
220-
tir.stmt_functor.post_order_visit(mod["main"].body, visit)
221-
assert fence_count >= 1
222-
assert "tl.warpgroup_arrive" in order
223-
assert order.index("tl.fence_proxy_async") < order.index("tl.warpgroup_arrive")
224-
225-
226190
if __name__ == "__main__":
227191
tilelang.testing.main()

0 commit comments

Comments
 (0)