Skip to content

Commit

Permalink
[TIR] Add support for 0-dim buffer (#9224)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuanjing Shi authored Oct 14, 2021
1 parent 95a2031 commit 08018ea
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 7 deletions.
6 changes: 0 additions & 6 deletions src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,6 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::
f_push_block_vars(compute_op->axis);
f_push_block_vars(compute_op->reduce_axis);

// If we have a rank 0 tensor then we manifest it as a rank 1 buffer with a single element.
if (compute_op->axis.size() == 0) {
iter_vars.push_back(IterVar(Range::FromMinExtent(0, 1), Var(), IterVarType::kDataPar));
bindings.push_back(Var());
}

// Step 2. Declare buffer and update op2buffers
Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global");
info->tensor2buffers[tensor] = buffer;
Expand Down
2 changes: 1 addition & 1 deletion src/tir/ir/script/script_complete.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ PrimFunc ScriptComplete(PrimFunc func, const Array<Buffer>& root_allocates) {
// generate surrounding loops automatically
Stmt res = script_completer(func->body);
// generate root block automatically
if (script_completer.contains_block && !contain_root) {
if ((script_completer.contains_block || root_allocates.size()) && !contain_root) {
res = Block({}, {}, {}, "root", res, NullOpt, root_allocates);
res = BlockRealize({}, Bool(true), Downcast<Block>(res));
}
Expand Down
29 changes: 29 additions & 0 deletions tests/python/unittest/test_tvmscript_complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,39 @@ def test_complete_match_buffer():
tvm.ir.assert_structural_equal(match_buffer_func, expected_match_buffer_func)


@T.prim_func
def alloc_buffer_func(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [2, 2], dtype="float32")
B = T.match_buffer(b, [2, 2], dtype="float32")
C = T.alloc_buffer([2, 2], dtype="float32")
A[(0, 0)] = T.float32(2)
C[(0, 0)] = A[(0, 0)] + B[(0, 0)]
B[(0, 0)] = C[(0, 0)]


@T.prim_func
def expect_alloc_buffer_func(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [2, 2], dtype="float32", elem_offset=0, align=128, offset_factor=1)
B = T.match_buffer(b, [2, 2], dtype="float32", elem_offset=0, align=128, offset_factor=1)
with T.block([], "root"):
T.reads([])
T.writes([])
C = T.alloc_buffer([2, 2], dtype="float32", elem_offset=0, align=128, offset_factor=1)
A[(0, 0)] = T.float32(2)
C[(0, 0)] = A[(0, 0)] + B[(0, 0)]
B[(0, 0)] = C[(0, 0)]


def test_complete_alloc_buffer():
rt_func = tvm.script.from_source(alloc_buffer_func.script(show_meta=True))
tvm.ir.assert_structural_equal(alloc_buffer_func, expect_alloc_buffer_func)


if __name__ == "__main__":
test_complete_matmul()
test_complete_matmul_original()
test_complete_with_root()
test_complete_part_region()
test_complete_buffer_indices()
test_complete_match_buffer()
test_complete_alloc_buffer()
59 changes: 59 additions & 0 deletions tests/python/unittest/test_tvmscript_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,64 @@ def test_get_valid_counts_script_func():
_check_get_valid_counts_with_numpy(f, (1, 2500, 6), 0.0, 0, 1)


@T.prim_func
def alloc_zero_dim_buffer(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [], dtype="float32")
B = T.match_buffer(b, [], dtype="float32")
# body
# tir.with block("root")
C = T.alloc_buffer([], dtype="float32")
A[()] = T.float32(2)
C[()] = A[()] + B[()]
B[()] = C[()]


@T.prim_func
def alloc_zero_dim_buffer_block(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (), "float32")
B = T.match_buffer(b, (), "float32")
with T.block([], "root"):
T.reads([])
T.writes([])
C = T.alloc_buffer((), "float32")
A[()] = T.float32(2)
C[()] = A[()] + B[()]
B[()] = C[()]


def _check_alloc_zero_dim_buffer(f):
dtype = "float32"
ctx = tvm.cpu()

np_data = np.zeros(shape=()).astype(dtype)
np_out = np.zeros(shape=()).astype(dtype)
tvm_data = tvm.nd.array(np_data, ctx)
tvm_out = tvm.nd.array(np_out, ctx)

# np func exection
np_inter = np.array(1)
np_data[()] = 2.0
np_inter[()] = np_data[()] + np_out[()]
np_out[()] = np_inter[()]

# tvm func execution
f(tvm_data, tvm_out)
tvm.testing.assert_allclose(tvm_out.numpy(), np_out, rtol=1e-5)


def test_alloc_zero_dim_buffer_round_trip():
func = alloc_zero_dim_buffer
func_with_block = alloc_zero_dim_buffer_block
rt_func = tvm.script.from_source(func.script(show_meta=True))
rt_func_with_block = tvm.script.from_source(func_with_block.script(show_meta=True))
rt_mod = tvm.build(rt_func, "llvm")
rt_mod_with_block = tvm.build(rt_func_with_block, "llvm")
tvm.ir.assert_structural_equal(func, func_with_block)
tvm.ir.assert_structural_equal(rt_func, rt_func_with_block)
_check_alloc_zero_dim_buffer(rt_mod)
_check_alloc_zero_dim_buffer(rt_mod_with_block)


if __name__ == "__main__":
test_get_valid_counts_script_func()
test_alloc_zero_dim_buffer_round_trip()

0 comments on commit 08018ea

Please sign in to comment.