Skip to content

Commit 5a0e7fc

Browse files
authored
[CI][Test] Add test cases for tilelang transform PipelinePlanning (#44)
* [Doc] update installation.md and readme * solve conflicts * change readme * fix installation.rst * fix readme * fix installation * [fix] fix installation.rst * [Doc] fix installation.rst * [Doc] fix installation * [CI][Test] Add test cases for tilelang transform PipelinePlanning
1 parent cc5118a commit 5a0e7fc

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from tilelang import tvm as tvm
4+
import tilelang as tl
5+
from tilelang.utils.target import determine_target
6+
import tilelang.language as T
7+
import tilelang.testing
8+
9+
auto_target = tvm.target.Target(determine_target("auto"))
10+
11+
12+
def _check(original, transformed):
13+
func = original
14+
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
15+
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
16+
mod = tl.transform.PipelinePlanning()(mod)
17+
mod = tl.transform.Simplify()(mod)
18+
transformed = tvm.IRModule.from_expr(transformed.with_attr("global_symbol", "main"))
19+
transformed = tvm.tir.transform.BindTarget(auto_target)(transformed)
20+
tvm.ir.assert_structural_equal(mod["main"], transformed["main"], True)
21+
22+
23+
def test_simple_pipeline():
24+
25+
@T.prim_func
26+
def before(A: T.Buffer((1024, 32), "float32"), B: T.Buffer((32, 1024), "float32"), C: T.Buffer(
27+
(1024, 1024), "float32")):
28+
with T.Kernel(8, 8, threads=128) as (bx, by):
29+
A_shared = T.alloc_shared((128, 32), "float32")
30+
B_shared = T.alloc_shared((32, 128), "float32")
31+
C_local = T.alloc_fragment((128, 128), "float32")
32+
33+
T.clear(C_local)
34+
35+
for ko in T.Pipelined(32, num_stages=3):
36+
T.copy(A[by * 128, ko * 32], A_shared)
37+
T.copy(B[ko * 32, bx * 128], B_shared)
38+
39+
T.gemm(A_shared, B_shared, C_local)
40+
41+
T.copy(C_local, C[by * 128, bx * 128])
42+
43+
@T.prim_func
44+
def after(A: T.Buffer((1024, 32), "float32"), B: T.Buffer((32, 1024), "float32"), C: T.Buffer(
45+
(1024, 1024), "float32")):
46+
with T.Kernel(8, 8, threads=128) as (bx, by):
47+
A_shared = T.alloc_shared((128, 32), "float32")
48+
B_shared = T.alloc_shared((32, 128), "float32")
49+
C_local = T.alloc_fragment((128, 128), "float32")
50+
51+
T.clear(C_local)
52+
53+
for ko in T.serial(
54+
32,
55+
annotations={
56+
"software_pipeline_async_stages": [0],
57+
"software_pipeline_order": [0, 1, 2],
58+
"software_pipeline_stage": [3, 3, 3]
59+
}):
60+
T.copy(A[by * 128, ko * 32], A_shared)
61+
T.copy(B[ko * 32, bx * 128], B_shared)
62+
T.gemm(A_shared, B_shared, C_local)
63+
64+
T.copy(C_local, C[by * 128, bx * 128])
65+
66+
_check(before, after)
67+
68+
69+
if __name__ == "__main__":
70+
tilelang.testing.main()

0 commit comments

Comments
 (0)