Skip to content

Commit

Permalink
[TensorIR] Fix parser autocompletion mode (#7737)
Browse files Browse the repository at this point in the history
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
  • Loading branch information
Hzfengsy and MasterJH5574 authored Mar 24, 2021
1 parent cfe2e28 commit 3ba5868
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 10 deletions.
26 changes: 20 additions & 6 deletions src/tir/ir/script/script_complete.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,15 @@ namespace tir {
/*! \brief Generate surrounding loops automatically */
class ScriptCompleter : public StmtMutator {
public:
explicit ScriptCompleter(Map<Var, Buffer>* buffer_var_map) : buffer_var_map_(buffer_var_map) {}
explicit ScriptCompleter(Map<Var, Buffer>* buffer_var_map, bool contain_root)
: buffer_var_map_(buffer_var_map), contain_root_(contain_root) {}
/*! \brief Whether the stmt contains at least one block. */
bool contains_block = false;

private:
Map<Var, Buffer>* buffer_var_map_;
bool contain_root_;
bool visited_root_ = false;
Stmt VisitStmt_(const BlockRealizeNode* op) override {
contains_block = true;
Stmt body = StmtMutator::VisitStmt_(op);
Expand All @@ -62,6 +65,8 @@ class ScriptCompleter : public StmtMutator {
}

Stmt VisitStmt_(const BlockNode* op) override {
bool is_root_block = contain_root_ && !visited_root_;
visited_root_ = true;
// Buffers allocated in the block can be accessed by its body.
for (const auto& alloc_buffer : op->alloc_buffers) {
buffer_var_map_->Set(alloc_buffer->data, alloc_buffer);
Expand All @@ -71,7 +76,15 @@ class ScriptCompleter : public StmtMutator {
for (const auto& alloc_buffer : op->alloc_buffers) {
buffer_var_map_->erase(alloc_buffer->data);
}
// ignore root block or blocks which already has reads/writes regions
if (block->reads.empty() || block->writes.empty()) {
if (op->iter_vars.empty()) {
// non-root opaque block is not allowed
CHECK(is_root_block)
<< "ValueError: Can not auto detect buffer access region for an opaque block. Please "
"annotate the access region manually.";
return std::move(block);
}
auto access_region = GetBlockAccessRegion(block, *buffer_var_map_);
const Array<BufferRegion>& reads = access_region[0];
const Array<BufferRegion>& writes = access_region[1];
Expand All @@ -80,8 +93,8 @@ class ScriptCompleter : public StmtMutator {
<< "ValueError: Can not auto detect buffer access region from tir.Load, tir.Store or "
"direct access by buffer data. Please annotation the access region manually";
auto n = CopyOnWrite(block.operator->());
if (!n->reads.defined()) n->reads = reads;
if (!n->writes.defined()) n->writes = writes;
if (n->reads.empty()) n->reads = reads;
if (n->writes.empty()) n->writes = writes;
return Block(n);
} else {
return std::move(block);
Expand All @@ -98,12 +111,13 @@ PrimFunc ScriptComplete(PrimFunc func, const Array<Buffer>& root_allocates) {
for (const auto& alloc : root_allocates) {
buffer_var_map.Set(alloc->data, alloc);
}
ScriptCompleter script_completer(&buffer_var_map);
bool contain_root = root_allocates.empty() && func->body->IsInstance<BlockRealizeNode>() &&
Downcast<BlockRealize>(func->body)->block->iter_vars.empty();
ScriptCompleter script_completer(&buffer_var_map, contain_root);
// generate surrounding loops automatically
Stmt res = script_completer(func->body);
// generate root block automatically
if (script_completer.contains_block &&
(!res->IsInstance<BlockRealizeNode>() || !root_allocates.empty())) {
if (script_completer.contains_block && !contain_root) {
res = Block({}, {}, {}, "root", res, NullOpt, root_allocates);
res = BlockRealize({}, Bool(true), Downcast<Block>(res));
}
Expand Down
174 changes: 174 additions & 0 deletions tests/python/unittest/test_tvmscript_complete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
from tvm import tir
from tvm.ir import Range
from tvm.script import ty, from_source
from tvm.ir.diagnostics import override_renderer


@tvm.script.tir
def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128])
B = tir.match_buffer(b, [128, 128])
C = tir.match_buffer(c, [128, 128])

with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
with tir.init():
C[vi, vj] = tir.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]


@tvm.script.tir
def matmul_original(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128])
B = tir.match_buffer(b, [128, 128])
C = tir.match_buffer(c, [128, 128])

for i, j in tir.grid(32, 32):
with tir.block([32, 32], "init") as [vi, vj]:
for ii, jj in tir.grid(4, 4):
C[vi * 4 + ii, vj * 4 + jj] = tir.float32(0)

for k in range(0, 32):
with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
for ii, jj, kk in tir.grid(4, 4, 4):
C[vi * 4 + ii, vj * 4 + jj] = (
C[vi * 4 + ii, vj * 4 + jj]
+ A[vi * 4 + ii, vk * 4 + kk] * B[vj * 4 + jj, vk * 4 + kk]
)


@tvm.script.tir
def elementwise_with_root(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128])
B = tir.match_buffer(b, [128, 128])
C = tir.match_buffer(c, [128, 128])

with tir.block([]) as []:
with tir.block([128, 128]) as [vi, vj]:
B[vi, vj] = A[vi, vj] + tir.float32(1)

with tir.block([128, 128]) as [vi, vj]:
C[vi, vj] = B[vi, vj] + tir.float32(1)


def func_with_opaque_block(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128])
B = tir.match_buffer(b, [128, 128])
C = tir.match_buffer(c, [128, 128])

with tir.block([]) as []:
with tir.block([]) as []:
B[0, 0] = A[0, 0] + tir.float32(1)

with tir.block([128, 128]) as [vi, vj]:
C[vi, vj] = B[vi, vj] + tir.float32(1)


def test_complete_matmul():
func = matmul
A, B, C = [func.buffer_map[x] for x in func.params]

block = func.body.block.body.body.body.body.block
assert isinstance(block, tvm.tir.Block)
vi, vj, vk = [x.var for x in block.iter_vars]
access_A = tir.BufferRegion(A, [Range.from_min_extent(vi, 1), Range.from_min_extent(vk, 1)])
access_B = tir.BufferRegion(B, [Range.from_min_extent(vj, 1), Range.from_min_extent(vk, 1)])
access_C = tir.BufferRegion(C, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])
tvm.ir.assert_structural_equal(block.reads, [access_C, access_A, access_B])
tvm.ir.assert_structural_equal(block.writes, [access_C])


def test_complete_matmul_original():
func = matmul_original
A, B, C = [func.buffer_map[x] for x in func.params]

block1 = func.body.block.body.body.body[0].block
assert isinstance(block1, tvm.tir.Block)
vi, vj = [x.var for x in block1.iter_vars]
access_C = tir.BufferRegion(
C, [Range.from_min_extent(vi * 4, 4), Range.from_min_extent(vj * 4, 4)]
)
tvm.ir.assert_structural_equal(block1.reads, [])
tvm.ir.assert_structural_equal(block1.writes, [access_C])

block2 = func.body.block.body.body.body[1].body.block
assert isinstance(block2, tvm.tir.Block)
vi, vj, vk = [x.var for x in block2.iter_vars]
access_A = tir.BufferRegion(
A, [Range.from_min_extent(vi * 4, 4), Range.from_min_extent(vk * 4, 4)]
)
access_B = tir.BufferRegion(
B, [Range.from_min_extent(vj * 4, 4), Range.from_min_extent(vk * 4, 4)]
)
access_C = tir.BufferRegion(
C, [Range.from_min_extent(vi * 4, 4), Range.from_min_extent(vj * 4, 4)]
)
tvm.ir.assert_structural_equal(block2.reads, [access_C, access_A, access_B])
tvm.ir.assert_structural_equal(block2.writes, [access_C])


def test_complete_with_root():
func = elementwise_with_root
A, B, C = [func.buffer_map[x] for x in func.params]

block1 = func.body.block.body[0].body.body.block
assert isinstance(block1, tvm.tir.Block)
vi, vj = [x.var for x in block1.iter_vars]

tvm.ir.assert_structural_equal(
block1.reads,
[tir.BufferRegion(A, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])],
)
tvm.ir.assert_structural_equal(
block1.writes,
[tir.BufferRegion(B, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])],
)

block2 = func.body.block.body[1].body.body.block
assert isinstance(block2, tvm.tir.Block)
vi, vj = [x.var for x in block2.iter_vars]
tvm.ir.assert_structural_equal(
block2.reads,
[tir.BufferRegion(B, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])],
)
tvm.ir.assert_structural_equal(
block2.writes,
[tir.BufferRegion(C, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])],
)


def test_complete_opaque_block_error():
def render(e):
pass

override_renderer(render)

try:
from_source(func_with_opaque_block)
except tvm.error.DiagnosticError:
return
assert False


if __name__ == "__main__":
test_complete_matmul()
test_complete_matmul_original()
test_complete_with_root()
test_complete_opaque_block_error()
11 changes: 7 additions & 4 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2819,10 +2819,13 @@ def test_block_elements():
tvm.ir.assert_structural_equal(func, rt_func)

assert isinstance(rt_func.body.block, tir.stmt.Block)
assert isinstance(rt_func.body.block.body, tir.stmt.BufferStore)
assert isinstance(rt_func.body.block.init, tir.stmt.BufferStore)
assert len(rt_func.body.block.annotations) == 1
assert rt_func.body.block.annotations["attr_key"] == "attr_value"
assert isinstance(rt_func.body.block.body, tir.stmt.BlockRealize)
assert isinstance(rt_func.body.block.body.block, tir.stmt.Block)
block = rt_func.body.block.body.block
assert isinstance(block.body, tir.stmt.BufferStore)
assert isinstance(block.init, tir.stmt.BufferStore)
assert len(block.annotations) == 1
assert block.annotations["attr_key"] == "attr_value"


if __name__ == "__main__":
Expand Down

0 comments on commit 3ba5868

Please sign in to comment.