Skip to content

Commit 0a9d50f

Browse files
committed
Enhance test for InjectSoftwarePipeline by adding LowerOpaqueBlock transformation and updating expected function signature to use match_buffer for better clarity.
1 parent c5b1a10 commit 0a9d50f

File tree

1 file changed

+13
-27
lines changed

1 file changed

+13
-27
lines changed

testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ def _check(original, transformed):
99
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
1010
mod = tl.transform.InjectSoftwarePipeline()(mod)
1111
mod = tl.transform.Simplify()(mod)
12+
mod = tl.transform.LowerOpaqueBlock()(mod)
1213
tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"),
1314
True)
1415

@@ -39,35 +40,20 @@ def before(A: T.Tensor((16, 1), "float32"), C: T.Tensor((16, 1), "float32")):
3940
C[tx, i] = B[tx, 0] + T.float32(1)
4041

4142
@T.prim_func
42-
def expected(A: T.Buffer((16, 1), "float32"), C: T.Buffer((16, 1), "float32")):
43-
for tx in T.thread_binding(16, thread="threadIdx.x"):
44-
with T.block():
45-
T.reads(A[tx, 0])
46-
T.writes(C[tx, 0])
47-
B = T.alloc_buffer((2, 16, 1), scope="shared")
48-
with T.block():
49-
T.reads(A[tx, 0])
50-
T.writes(B[0, tx, 0])
51-
B[0, tx, 0] = A[tx, 0] * T.float32(2.0)
52-
with T.block():
53-
T.reads(A[tx, 1:1], B[0:2, tx, 0])
54-
T.writes(B[1:1, tx, 0], C[tx, 0:0])
55-
for i in range(0):
56-
with T.block():
57-
T.reads(A[tx, i + 1])
58-
T.writes(B[i + 1, tx, 0])
59-
B[i + 1, tx, 0] = A[tx, i + 1] * T.float32(2.0)
60-
with T.block():
61-
T.reads(B[i, tx, 0])
62-
T.writes(C[tx, i])
63-
C[tx, i] = B[i, tx, 0] + T.float32(1.0)
64-
with T.block():
65-
T.reads(B[0, tx, 0])
66-
T.writes(C[tx, 0])
67-
C[tx, 0] = B[0, tx, 0] + T.float32(1.0)
43+
def expected(A_handle: T.handle, C_handle: T.handle):
44+
A = T.match_buffer(A_handle, (16, 1), strides=(1, 1))
45+
C = T.match_buffer(C_handle, (16, 1), strides=(1, 1))
46+
tx = T.launch_thread("threadIdx.x", 16)
47+
B = T.decl_buffer((2, 16, 1), scope="shared")
48+
B[0, tx, 0] = A[tx, 0] * T.float32(2.0)
49+
for i in range(0):
50+
B[i + 1, tx, 0] = A[tx, i + 1] * T.float32(2.0)
51+
C[tx, i] = B[i, tx, 0] + T.float32(1.0)
52+
C[tx, 0] = B[0, tx, 0] + T.float32(1.0)
6853

6954
_check(before, expected)
7055

7156

7257
if __name__ == "__main__":
73-
tilelang.testing.main()
58+
# tilelang.testing.main()
59+
test_trival_pipeline()

0 commit comments

Comments
 (0)