Skip to content

Commit

Permalink
[TIR][MetaSchedule] Add regression test for layout_rewrite extent=1 (#…
Browse files Browse the repository at this point in the history
…12916)

* [TIR][MetaSchedule] Add regression test for layout_rewrite extent=1

Adds a regression test for using the `layout_rewrite` post-proc on a
buffer with an extent of one in at least one dimension, issue
#12852.  This bug was resolved as
part of the refactor in #12904, but
didn't have a regression test at that point.

* Identified segfault and added test case
  • Loading branch information
Lunderberg authored Sep 28, 2022
1 parent d1c9feb commit 17e4644
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 48 deletions.
2 changes: 2 additions & 0 deletions src/meta_schedule/postproc/rewrite_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class BufferReadPosCollector : public StmtExprVisitor {
}

void VisitExpr_(const BufferLoadNode* op) final {
CHECK(cur_realize_.defined()) << "BufferLoad occurred outside of any block";

const Buffer& buffer = op->buffer;
if (buffers_.count(buffer.get())) {
Map<Var, PrimExpr> subst_map;
Expand Down
156 changes: 108 additions & 48 deletions tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,54 +38,114 @@ def _create_context(mod, target) -> TuneContext:
)


@T.prim_func
def tir_matmul(
A: T.Buffer[(16, 16), "float32"],
B: T.Buffer[(16, 16), "float32"],
C: T.Buffer[(16, 16), "float32"],
) -> None:
T.func_attr({"layout_free_buffers": [1]})
for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4):
with T.block("matmul"):
vi = T.axis.S(16, i0 * 4 + i1)
vj = T.axis.S(16, j)
vk = T.axis.R(16, k0 * 4 + k1)
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]


@T.prim_func
def rewritten_tir_matmul(
A: T.Buffer[(16, 16), "float32"],
B: T.Buffer[(16, 16), "float32"],
C: T.Buffer[(16, 16), "float32"],
) -> None:
T.func_attr({"layout_free_buffers": [1]})
B_reindex = T.alloc_buffer([16, 4, 4], dtype="float32")
for ax0, ax1 in T.grid(16, 16):
with T.block("layout_rewrite"):
i0, i1 = T.axis.remap("SS", [ax0, ax1])
T.block_attr({"meta_schedule.layout_rewrite_preproc": True})
B_reindex[i1, i0 // 4, i0 % 4] = B[i0, i1]
for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4):
with T.block("matmul"):
vi = T.axis.spatial(16, i0 * 4 + i1)
vj = T.axis.spatial(16, j)
vk = T.axis.reduce(16, k0 * 4 + k1)
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B_reindex[vj, vk // 4, vk % 4]


def test_layout_rewrite():
target = _target()
ctx = _create_context(tir_matmul, target)
sch = tvm.tir.Schedule(tir_matmul, debug_mask="all")
sch.enter_postproc()
assert ctx.postprocs[0].apply(sch)
tvm.ir.assert_structural_equal(sch.mod["main"], rewritten_tir_matmul)
class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
def transform(self):
def inner(mod):
target = Target("cuda", host="llvm")
ctx = TuneContext(
mod=mod,
target=target,
postprocs=[
RewriteLayout(),
],
task_name="test",
)
sch = tvm.tir.Schedule(mod, debug_mask="all")
sch.enter_postproc()
assert ctx.postprocs[0].apply(sch)
return sch.mod

return inner


class TestTIRMatmul(BaseBeforeAfter):
"""Main functionality test
A new block should be inserted to transform the layout, with the
compute block operating on the temporary transformed buffer.
"""

def before(
A: T.Buffer[(16, 16), "float32"],
B: T.Buffer[(16, 16), "float32"],
C: T.Buffer[(16, 16), "float32"],
) -> None:
T.func_attr({"layout_free_buffers": [1]})
for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4):
with T.block("matmul"):
vi = T.axis.S(16, i0 * 4 + i1)
vj = T.axis.S(16, j)
vk = T.axis.R(16, k0 * 4 + k1)
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

def expected(
A: T.Buffer[(16, 16), "float32"],
B: T.Buffer[(16, 16), "float32"],
C: T.Buffer[(16, 16), "float32"],
) -> None:
T.func_attr({"layout_free_buffers": [1]})
B_reindex = T.alloc_buffer([16, 4, 4], dtype="float32")
for ax0, ax1 in T.grid(16, 16):
with T.block("layout_rewrite"):
i0, i1 = T.axis.remap("SS", [ax0, ax1])
T.block_attr({"meta_schedule.layout_rewrite_preproc": True})
B_reindex[i1, i0 // 4, i0 % 4] = B[i0, i1]
for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4):
with T.block("matmul"):
vi = T.axis.spatial(16, i0 * 4 + i1)
vj = T.axis.spatial(16, j)
vk = T.axis.reduce(16, k0 * 4 + k1)
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B_reindex[vj, vk // 4, vk % 4]


class TestRewrittenBuffersMustOccurWithinBlock(BaseBeforeAfter):
"""Buffers must occur within a Block"""

def before(
A: T.Buffer[(16, 16), "float32"],
) -> None:
T.func_attr({"layout_free_buffers": [0]})
for i, j in T.grid(16, 16):
T.evaluate(A[i, j])

expected = tvm.TVMError


class TestExtentOne(BaseBeforeAfter):
"""Buffers with dimensions of extent 1 can be transformed
Regression test for a previous bug, in which the removal of
trivial variables resulted in an error in `IndexMap::Inverse`.
"""

def before(
A: T.Buffer[(16, 1), "float32"],
) -> None:
T.func_attr({"layout_free_buffers": [0]})
for i, j in T.grid(16, 1):
with T.block("block"):
vi, vj = T.axis.remap("SS", [i, j])
T.evaluate(A[vi, vj])

def expected(A: T.Buffer[(16, 1), "float32"]):
T.func_attr({"layout_free_buffers": [0]})

A_global = T.alloc_buffer([16], dtype="float32")
for ax0, ax1 in T.grid(16, 1):
with T.block("A_global"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.block_attr({"meta_schedule.layout_rewrite_preproc": True})
A_global[v0] = A[v0, v1]

for i, j in T.grid(16, 1):
with T.block("block"):
vi, vj = T.axis.remap("SS", [i, j])
T.evaluate(A_global[vi])


if __name__ == "__main__":
test_layout_rewrite()
tvm.testing.main()

0 comments on commit 17e4644

Please sign in to comment.