Skip to content

Commit

Permalink
Fix non-zero extent of access_ptr out of range (apache#1937) (apache#…
Browse files Browse the repository at this point in the history
  • Loading branch information
gaoxiong-1 authored and Wei Chen committed Feb 19, 2019
1 parent 93db916 commit 3ec1191
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/lang/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -357,9 +357,9 @@ Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr
} else if (self->strides.size() == self->shape.size()) {
int highest_dim = 0;
extent = arith::ComputeExpr<ir::Mul>(
self->strides[highest_dim], self->shape[highest_dim]);
self->strides[highest_dim], self->shape[highest_dim]) - offset;
} else {
extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr());
extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr()) - offset;
}
Expr elem_offset = self->elem_offset + offset;
if (content_lanes > 1) {
Expand Down
13 changes: 13 additions & 0 deletions tests/python/unittest/test_lang_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ def test_buffer_access_ptr_offset():
assert tvm.ir_pass.Equal(offset, tvm.call_extern('int32', "test_call", 200 + v))
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE

def test_buffer_access_ptr_extent():
m = tvm.var('m')
n = tvm.var('n')
Ab = tvm.decl_buffer((m, n), tvm.float32)
aptr = Ab.access_ptr("rw")
assert tvm.ir_pass.Equal(aptr.args[3], m * n)
aptr = Ab.access_ptr("rw", offset=100)
assert tvm.ir_pass.Equal(aptr.args[3], m * n - 100)
Ab = tvm.decl_buffer((m, n), tvm.float32, strides=[n + 1 , 1])
aptr = Ab.access_ptr("rw", offset=100)
assert tvm.ir_pass.Equal(aptr.args[3], Ab.strides[0] * m - 100)

def test_buffer_vload():
m = tvm.var('m')
n = tvm.var('n')
Expand Down Expand Up @@ -84,5 +96,6 @@ def assert_simplified_equal(index_simplified, index_direct):
test_buffer()
test_buffer_access_ptr()
test_buffer_access_ptr_offset()
test_buffer_access_ptr_extent()
test_buffer_vload()
test_buffer_index_merge_mult_mod()

0 comments on commit 3ec1191

Please sign in to comment.