Skip to content

Commit 6f69f02

Browse files
committed
minor fix
1 parent eff7916 commit 6f69f02

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

testing/python/transform/test_tilelang_transform_layout_inference.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
1616
K = tvm.te.var("k")
1717

1818
def before():
19+
1920
@T.prim_func
2021
def main(B: T.Tensor((K, N), dtype),):
2122
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
@@ -35,9 +36,11 @@ def main(B: T.Tensor((K, N), dtype),):
3536
t // (block_N // vec_load_b), bx * block_N + t %
3637
(block_N // vec_load_b) * (block_N // vec_load_b) + vec],
3738
T.float16(0))
39+
3840
return tvm.IRModule({'main': main})
3941

4042
def after():
43+
4144
@T.prim_func
4245
def main(B: T.Tensor((K, N), dtype),):
4346
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
@@ -73,6 +76,7 @@ def main(B: T.Tensor((K, N), dtype),):
7376
t // (block_N // vec_load_b),
7477
bx * block_N + t % (block_N // vec_load_b) *
7578
(block_N // vec_load_b) + vec], T.float16(0))
79+
7680
return tvm.IRModule({'main': main})
7781

7882
with tvm.target.Target(auto_target):

testing/python/transform/test_tilelang_transform_lower_tile_op.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,18 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
1616
K = tvm.te.var("k")
1717

1818
def before():
19+
1920
@T.prim_func
2021
def main(B: T.Tensor((K, N), dtype),):
2122
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
2223
B_shared = T.alloc_shared((block_K, block_N), dtype)
2324
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
2425
T.copy(B[k * block_K, bx * block_N], B_shared)
26+
2527
return tvm.IRModule({'main': main})
2628

2729
def after():
30+
2831
@T.prim_func
2932
def main(B: T.Tensor((K, N), dtype),):
3033
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
@@ -60,6 +63,7 @@ def main(B: T.Tensor((K, N), dtype),):
6063
t // (block_N // vec_load_b),
6164
bx * block_N + t % (block_N // vec_load_b) *
6265
(block_N // vec_load_b) + vec], T.float16(0))
66+
6367
return tvm.IRModule({'main': main})
6468

6569
with tvm.transform.PassContext():

0 commit comments

Comments
 (0)