Skip to content

Commit baf1f23

Browse files
authored
[CI][Test] Add test cases for tilelang transform InjectFenceProxy (#66)
* [Dev] Add FlashDecoding example * [CI][Test] Add test cases for tilelang kernel convolution * [CI][Test] Add test cases for tilelang kernel FlashAttention * Reduce the number of stages to ensure the shared memory allocation is valid * Temporarily remove the dim128 case * lint * update einops in requirements-dev.txt * update einops in requirements-test.txt * remove einops in requirements-dev.txt * [CI][Test] Add test cases for tilelang transform ClusterPlanning * [CI][Test] Add test cases for tilelang transform LowerHopperIntrin * [CI][Test] Add test cases for tilelang transform InjectFenceProxy
1 parent 560b1d8 commit baf1f23

File tree

3 files changed

+63
-1
lines changed

3 files changed

+63
-1
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from tilelang import tvm as tvm
4+
import tilelang as tl
5+
from tilelang.utils.target import determine_target
6+
import tilelang.language as T
7+
import tilelang.testing
8+
from tvm import tir
9+
10+
auto_target = tvm.target.Target(determine_target("auto"))
11+
12+
13+
def _check(original, transformed):
14+
func = original
15+
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
16+
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
17+
mod = tl.transform.InjectFenceProxy()(mod)
18+
mod = tir.transform.LowerOpaqueBlock()(mod)
19+
transformed = tvm.IRModule.from_expr(transformed.with_attr("global_symbol", "main"))
20+
transformed = tvm.tir.transform.BindTarget(auto_target)(transformed)
21+
transformed = tir.transform.LowerOpaqueBlock()(transformed)
22+
23+
tvm.ir.assert_structural_equal(mod["main"], transformed["main"], True)
24+
25+
26+
def test_lower_fence_proxy():
27+
28+
@T.prim_func
29+
def before():
30+
with T.Kernel(8):
31+
A_shared = T.decl_buffer((1, 8, 256), "float16", scope="shared.dyn")
32+
B_shared = T.decl_buffer((1, 4, 512), "float16", scope="shared.dyn")
33+
C_local = T.decl_buffer((32,), scope="local")
34+
for i in T.unroll(16):
35+
C_local[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0), 2)
36+
T.call_extern("handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
37+
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
38+
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
39+
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))
40+
41+
@T.prim_func
42+
def after():
43+
with T.Kernel(8):
44+
A_shared = T.decl_buffer((1, 8, 256), "float16", scope="shared.dyn")
45+
B_shared = T.decl_buffer((1, 4, 512), "float16", scope="shared.dyn")
46+
C_local = T.decl_buffer((32,), scope="local")
47+
for i in T.unroll(16):
48+
C_local[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0), 2)
49+
T.FenceProxyAsyncOp()
50+
T.call_extern("handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
51+
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
52+
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
53+
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))
54+
55+
_check(before, after)
56+
57+
58+
if __name__ == "__main__":
59+
test_lower_fence_proxy()

testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,3 @@ def after():
5656

5757
if __name__ == "__main__":
5858
tilelang.testing.main()
59-
test_lower_hopper_intrin_barrier()

tilelang/language/builtin.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,7 @@ def CreateTMADescriptorOp(*args):
1919

2020
def TMALoadOp(*args):
2121
return tir.call_intrin("handle", tir.op.Op.get("tl.TMALoadOp"), *args)
22+
23+
24+
def FenceProxyAsyncOp(*args):
25+
return tir.call_intrin("handle", tir.op.Op.get("tl.FenceProxyAsyncOp"), *args)

0 commit comments

Comments
 (0)