Skip to content

Commit 61d9016

Browse files
authored
[CI][Test] Add test cases for tilelang transform LayoutInference and LowerTileOp on loop tail split functionality (#29)
* [CI][Test] Add test cases for tilelang transform `LayoutInference` and `LowerTileOp` on loop tail split functionality * format * rename test script
1 parent 5a0e7fc commit 61d9016

File tree

2 files changed

+176
-0
lines changed

2 files changed

+176
-0
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from tilelang import tvm as tvm
4+
from tilelang.utils.target import determine_target
5+
import tilelang as tl
6+
import tilelang.language as T
7+
import tilelang.testing
8+
import pytest
9+
10+
auto_target = tvm.target.Target(determine_target("auto"))
11+
12+
13+
@pytest.mark.parametrize("block_M, block_N, block_K, threads, vec_load_b, dtype", [
14+
(64, 64, 32, 128, 8, "float16"),
15+
])
16+
def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
17+
N = tvm.te.var("n")
18+
K = tvm.te.var("k")
19+
20+
@tvm.script.ir.ir_module
21+
class Before:
22+
23+
@T.prim_func
24+
def main(B: T.Buffer((K, N), dtype),):
25+
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
26+
B_shared = T.alloc_shared((block_K, block_N), dtype)
27+
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
28+
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
29+
t = thread_bindings
30+
for i in T.unroll(0, block_N * block_K // (threads * vec_load_b)):
31+
for vec in T.Parallel(vec_load_b):
32+
B_shared[i * (threads * vec_load_b // block_N) + t //
33+
(block_N // vec_load_b), t % (block_N // vec_load_b) *
34+
(block_N // vec_load_b) + vec] = T.if_then_else(
35+
k * block_K + i * (threads * vec_load_b // block_N) + t //
36+
(block_N // vec_load_b) < K and bx * block_N + t %
37+
(block_N // vec_load_b) * (block_N // vec_load_b) < N,
38+
B[k * block_K + i * (threads * vec_load_b // block_N) +
39+
t // (block_N // vec_load_b), bx * block_N + t %
40+
(block_N // vec_load_b) * (block_N // vec_load_b) + vec],
41+
T.float16(0))
42+
43+
@tvm.script.ir.ir_module
44+
class After:
45+
46+
@T.prim_func
47+
def main(B: T.Buffer((K, N), dtype),):
48+
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
49+
B_shared = T.alloc_shared((block_K, block_N), dtype)
50+
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
51+
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
52+
t = thread_bindings
53+
for i in T.unroll(0, block_N * block_K // (threads * vec_load_b)):
54+
if (k * block_K + i * (threads * vec_load_b // block_N) + t //
55+
(block_N // vec_load_b)) * N % vec_load_b == 0:
56+
for vec in T.vectorized(vec_load_b):
57+
B_shared[i * (threads * vec_load_b // block_N) + t //
58+
(block_N // vec_load_b), t % (block_N // vec_load_b) *
59+
(block_N // vec_load_b) + vec] = T.if_then_else(
60+
k * block_K + i *
61+
(threads * vec_load_b // block_N) + t //
62+
(block_N // vec_load_b) < K and bx * block_N + t %
63+
(block_N // vec_load_b) * (block_N // vec_load_b) < N,
64+
B[k * block_K + i * (threads * vec_load_b // block_N) +
65+
t // (block_N // vec_load_b),
66+
bx * block_N + t % (block_N // vec_load_b) *
67+
(block_N // vec_load_b) + vec], T.float16(0))
68+
else:
69+
for vec in T.serial(vec_load_b):
70+
B_shared[i * (threads * vec_load_b // block_N) + t //
71+
(block_N // vec_load_b), t % (block_N // vec_load_b) *
72+
(block_N // vec_load_b) + vec] = T.if_then_else(
73+
k * block_K + i *
74+
(threads * vec_load_b // block_N) + t //
75+
(block_N // vec_load_b) < K and bx * block_N + t %
76+
(block_N // vec_load_b) * (block_N // vec_load_b) < N,
77+
B[k * block_K + i * (threads * vec_load_b // block_N) +
78+
t // (block_N // vec_load_b),
79+
bx * block_N + t % (block_N // vec_load_b) *
80+
(block_N // vec_load_b) + vec], T.float16(0))
81+
82+
mod = tvm.tir.transform.BindTarget(auto_target)(Before)
83+
mod = tl.transform.LayoutInference()(mod)
84+
mod = tvm.tir.transform.Simplify()(mod)
85+
ref_mod = tvm.tir.transform.BindTarget(auto_target)(After)
86+
ref_mod = tvm.tir.transform.Simplify()(ref_mod)
87+
# Note(tzj): The structures are equal except one more "for" loop after the LayoutInference pass
88+
# This loop is "for vec in T.parallel(1)",
89+
# Since the loop var "vec" is never used in the loop body, it does not affect the correctness
90+
tvm.ir.structural_equal(mod, ref_mod)
91+
# tvm.ir.assert_structural_equal(mod, ref_mod)
92+
93+
94+
if __name__ == "__main__":
95+
tilelang.testing.main()
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from tilelang import tvm as tvm
4+
from tilelang.utils.target import determine_target
5+
import tilelang as tl
6+
import tilelang.language as T
7+
import tilelang.testing
8+
import pytest
9+
10+
auto_target = tvm.target.Target(determine_target("auto"))
11+
12+
13+
@pytest.mark.parametrize("block_M, block_N, block_K, threads, vec_load_b, dtype", [
14+
(64, 64, 32, 128, 8, "float16"),
15+
])
16+
def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
17+
N = tvm.te.var("n")
18+
K = tvm.te.var("k")
19+
20+
@tvm.script.ir.ir_module
21+
class Before:
22+
23+
@T.prim_func
24+
def main(B: T.Buffer((K, N), dtype),):
25+
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
26+
B_shared = T.alloc_shared((block_K, block_N), dtype)
27+
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
28+
T.copy(B[k * block_K, bx * block_N], B_shared)
29+
30+
@tvm.script.ir.ir_module
31+
class After:
32+
33+
@T.prim_func
34+
def main(B: T.Buffer((K, N), dtype),):
35+
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
36+
B_shared = T.alloc_shared((block_K, block_N), dtype)
37+
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
38+
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
39+
t = thread_bindings
40+
for i in T.unroll(0, block_N * block_K // (threads * vec_load_b)):
41+
if (k * block_K + i * (threads * vec_load_b // block_N) + t //
42+
(block_N // vec_load_b)) * N % vec_load_b == 0:
43+
for vec in T.vectorized(vec_load_b):
44+
B_shared[i * (threads * vec_load_b // block_N) + t //
45+
(block_N // vec_load_b), t % (block_N // vec_load_b) *
46+
(block_N // vec_load_b) + vec] = T.if_then_else(
47+
k * block_K + i *
48+
(threads * vec_load_b // block_N) + t //
49+
(block_N // vec_load_b) < K and bx * block_N + t %
50+
(block_N // vec_load_b) * (block_N // vec_load_b) < N,
51+
B[k * block_K + i * (threads * vec_load_b // block_N) +
52+
t // (block_N // vec_load_b),
53+
bx * block_N + t % (block_N // vec_load_b) *
54+
(block_N // vec_load_b) + vec], T.float16(0))
55+
else:
56+
for vec in T.serial(vec_load_b):
57+
B_shared[i * (threads * vec_load_b // block_N) + t //
58+
(block_N // vec_load_b), t % (block_N // vec_load_b) *
59+
(block_N // vec_load_b) + vec] = T.if_then_else(
60+
k * block_K + i *
61+
(threads * vec_load_b // block_N) + t //
62+
(block_N // vec_load_b) < K and bx * block_N + t %
63+
(block_N // vec_load_b) * (block_N // vec_load_b) < N,
64+
B[k * block_K + i * (threads * vec_load_b // block_N) +
65+
t // (block_N // vec_load_b),
66+
bx * block_N + t % (block_N // vec_load_b) *
67+
(block_N // vec_load_b) + vec], T.float16(0))
68+
69+
mod = tvm.tir.transform.BindTarget(auto_target)(Before)
70+
mod = tl.transform.LowerTileOp()(mod)
71+
mod = tvm.tir.transform.Simplify()(mod)
72+
ref_mod = tvm.tir.transform.BindTarget(auto_target)(After)
73+
ref_mod = tvm.tir.transform.Simplify()(ref_mod)
74+
# Note(tzj): The structures are equal except the argument in "T.reads" function.
75+
# The difference is just between the first index and the indices range, which is totally equivalent
76+
tvm.ir.structural_equal(mod, ref_mod)
77+
# tvm.ir.assert_structural_equal(mod, ref_mod)
78+
79+
80+
if __name__ == "__main__":
81+
tilelang.testing.main()

0 commit comments

Comments
 (0)