Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[TIR] Added unit test for dynamic parameter in layout transform (apac…
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg authored and xinetzone committed Nov 25, 2022
1 parent abc2435 commit c66c68f
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions tests/python/unittest/test_tir_schedule_transform_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,5 +836,53 @@ def before(A: T.Buffer[14, "int32"]):
expected = tvm.tir.schedule.schedule.ScheduleError


class TestTransformLayoutWithVar(tvm.testing.CompareBeforeAfter):
"""Layout transform with dynamic parameter in transform"""

@pytest.fixture
def transform(self):
def transform(mod):
sch = tir.Schedule(mod)

n = sch.mod["main"].params[1]

sch.transform_layout(
"block",
"B",
lambda i: [i // n, i % n],
pad_value=0,
)
return sch.mod

return transform

def before(A: T.Buffer[16, "int32"], n: T.int32):
B = T.alloc_buffer(16, "int32")
for i in T.serial(16):
with T.block("block"):
vi = T.axis.remap("S", [i])
B[vi] = A[vi]

def expected(A: T.Buffer[16, "int32"], n: T.int32):
B = T.alloc_buffer([(-16 % n + 16) // n, n], dtype="int32")
for i, j in T.grid((-16 % n + 16) // n, n):
with T.block("block"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = T.if_then_else(
# Checks if the transform introduced padding
-16 % n != 0
and (
# If so, is vi in the last group (which may
# include padding).
(vj + vi * n) // n == 16 // n
# And is vj within the padding
and 16 % n <= (vj + vi * n) % n
),
0,
A[vj + vi * n],
dtype="int32",
)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit c66c68f

Please sign in to comment.