Skip to content

Commit

Permalink
[TIR][LowerMatchBuffer] Fix lowering strides when source region has h…
Browse files Browse the repository at this point in the history
…igher dimension than the buffer
  • Loading branch information
vinx13 committed Sep 28, 2021
1 parent 4905a8c commit a15b990
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/tir/transforms/lower_match_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,25 +188,25 @@ class MatchBufferLower : public StmtExprMutator {
Load load = Downcast<Load>(source_buffer.vload(indices, source_buffer->dtype));
Bind(buffer->elem_offset, load->index, buffer->name + ".elem_offset");
CHECK(analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0))
<< "The source elem_offset " << buffer->elem_offset
<< " does not satisfy the offset_factor " << buffer->offset_factor << ".";
<< "The source elem_offset " << load->index << " does not satisfy the offset_factor "
<< buffer->offset_factor << ".";
}

// Step 2.3. Check and update strides
// Check if target buffer strides are defined
ICHECK(source->region.size() >= buffer->shape.size());
size_t offset = source->region.size() - buffer->shape.size();
if (!buffer->strides.empty()) {
ICHECK_EQ(buffer->strides.size(), buffer->shape.size());
PrimExpr stride = make_const(DataType::Int(32), 1);
for (size_t i = buffer->shape.size(); i > 0; --i) {
const PrimExpr& shape = source_buffer->shape[i - 1];
const PrimExpr& shape = source_buffer->shape[i - 1 + offset];
Bind(buffer->strides[i - 1], stride, buffer->name + ".strides_" + std::to_string(i - 1));
stride *= shape;
}
}

// Step 2.4. Check and update shape
ICHECK(source->region.size() >= buffer->shape.size());
size_t offset = source->region.size() - buffer->shape.size();
for (size_t i = 0; i < buffer->shape.size(); ++i) {
const Range& range = source->region[i + offset];
Bind(buffer->shape[i], range->extent, buffer->name + ".shape_" + std::to_string(i));
Expand Down
53 changes: 53 additions & 0 deletions tests/python/unittest/test_tir_lower_match_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,54 @@ def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None:
)


@tvm.script.tir
def high_dim_opaque_access(a: ty.handle) -> None:
A = tir.match_buffer(a, (16, 32, 64))
for i, j, k in tir.grid(16, 2, 4):
with tir.block([]):
As_0 = tir.var("int32")
As_1 = tir.var("int32")
tir.reads([])
tir.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16])
sub_A = tir.match_buffer(
A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16],
(16, 16),
strides=[As_0, As_1],
offset_factor=1,
)
tir.evaluate(
tir.intrin_test(
sub_A.data,
sub_A.elem_offset,
sub_A.strides[0],
sub_A.strides[1],
sub_A.shape[0],
sub_A.shape[1],
dtype="handle",
)
)


@tvm.script.tir
def transformed_high_dim_opaque_access(a: ty.handle) -> None:
A = tir.match_buffer(a, (16, 32, 64))
for i, j, k in tir.grid(16, 2, 4):
with tir.block([]):
tir.reads([])
tir.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16])
tir.evaluate(
tir.intrin_test(
A.data,
i * 2048 + j * 1024 + k * 16,
64,
1,
16,
16,
dtype="handle",
)
)


@tvm.script.tir
def recursive_match(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (64, 64, 64))
Expand Down Expand Up @@ -419,6 +467,10 @@ def test_opaque_access():
_check(opaque_access, transformed_opaque_access)


def test_high_dim_opaque_access():
_check(high_dim_opaque_access, transformed_high_dim_opaque_access)


def test_recursive_match():
_check(recursive_match, transformed_recursive_match)

Expand Down Expand Up @@ -447,6 +499,7 @@ def test_fail_match_func_param():
if __name__ == "__main__":
test_buffer_load_store()
test_opaque_access()
test_high_dim_opaque_access()
test_recursive_match()
test_symbolic_match()
test_rank0_buffer()
Expand Down

0 comments on commit a15b990

Please sign in to comment.