Skip to content

Commit 50e3dde

Browse files
committed
Add test for T.assume
1 parent e2b10c5 commit 50e3dde

File tree

1 file changed

+93
-0
lines changed

1 file changed

+93
-0
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import tilelang
2+
import tilelang.language as T
3+
import tilelang.testing
4+
5+
6+
def test_assume_remove_boundary_check():
7+
@tilelang.jit
8+
def kernel_with_assume():
9+
N = T.dynamic('N')
10+
11+
@T.prim_func
12+
def main(
13+
A: T.Tensor((N, ), "float32"),
14+
l : T.int32,
15+
r : T.int32
16+
):
17+
with T.Kernel(1, threads=32) as _:
18+
for i in T.serial(r - l + 1):
19+
T.assume(l + i >= 0 and l + i < N)
20+
A[l + i] = 0
21+
22+
return main
23+
24+
jit_kernel = kernel_with_assume()
25+
source = jit_kernel.get_kernel_source()
26+
print(source)
27+
28+
assert not ("if (" in source)
29+
30+
31+
def test_assume_enable_vectorization():
32+
@tilelang.jit
33+
def kernel_vectorize(M):
34+
N = T.dynamic('N')
35+
vectorize_size = 4
36+
37+
@T.prim_func
38+
def main(
39+
A: T.Tensor((M, N), "float32"),
40+
B: T.Tensor((M, N), "float32"),
41+
):
42+
with T.Kernel(1, threads=32) as _:
43+
tid = T.get_thread_binding()
44+
45+
base_idx = tid * 4
46+
T.assume(N % vectorize_size == 0)
47+
48+
for i in T.vectorized(vectorize_size):
49+
T.assume(base_idx + i < N)
50+
B[tid, base_idx + i] = A[tid, base_idx + i]
51+
52+
return main
53+
54+
jit_kernel = kernel_vectorize(128)
55+
source = jit_kernel.get_kernel_source()
56+
print(source)
57+
58+
assert ("float4" in source) and not ("if (" in source)
59+
60+
61+
def test_assume_complex_indexing():
62+
@tilelang.jit
63+
def kernel_complex():
64+
M = T.dynamic('M')
65+
N = T.dynamic('N')
66+
67+
@T.prim_func
68+
def main(
69+
A: T.Tensor((M, N), "float32"),
70+
B: T.Tensor((M, N), "float32"),
71+
):
72+
with T.Kernel(1, threads=32) as _:
73+
tid = T.get_thread_binding()
74+
for j in T.serial(N):
75+
i_src = T.min(j + 233, tid + 2)
76+
j_src = j * T.ceildiv(j, i_src) * j - 1
77+
78+
T.assume(i_src >= 0 and i_src < M)
79+
T.assume(j_src >= 0 and j_src < N)
80+
81+
B[tid, j] = A[i_src, j_src]
82+
83+
return main
84+
85+
jit_kernel = kernel_complex()
86+
source = jit_kernel.get_kernel_source()
87+
print(source)
88+
89+
assert not ("if (" in source)
90+
91+
92+
if __name__ == '__main__':
93+
tilelang.testing.main()

0 commit comments

Comments
 (0)