Skip to content

Commit

Permalink
Use local complete block and local reduction block to identify compac…
Browse files Browse the repository at this point in the history
…t dataflow (#10705)

* inint

* upd

* upd

* remove redundant print

* upd

* change the reads/writes region for argmin/val

* fix wrong push
  • Loading branch information
yzh119 authored Mar 23, 2022
1 parent d37c1d2 commit 0ddaaa6
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 38 deletions.
9 changes: 2 additions & 7 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,11 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl
/*!
* \brief Check the subtree compact dataflow property. The scope root may have one or more subtrees
* rooted at its direct children, and this property requires all the blocks of the subtree
* that the specified sref is in to be complete block or reduction block.
* that the specified sref is in to be local complete block or local reduction block.
* \param self The schedule state
* \param subtree_root The sref of the subtree root to be checked
* \param scope_root_sref The scope root of the block
* \throw ScheduleError If the subtree that the sref is in doesn't satisfy the compact
* dataflow condition, i.e. a block in the subtree is neither complete block nor
* reduction block
*/
void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root,
const StmtSRef& scope_root_sref);
void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root);
/*!
* \brief Check if the block is an output block, i.e. the block writes to at least a buffer that is
* not allocated under the current scope
Expand Down
109 changes: 81 additions & 28 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,25 +143,53 @@ ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block) {
return std::move(visitor.result);
}

/*!
* \brief Check whether the given sref_a is higher than or equal to sref_b.
*/
void CheckSRefHigherOrEqual(const StmtSRef& sref_a, const StmtSRef& sref_b) {
const StmtSRefNode* p = sref_b.get();
for (; p != nullptr; p = p->parent) {
if (p == sref_a.get()) {
return;
}
}
CHECK(false) << "Expect StmtSRef " << sref_a << "to be higher than or equal to " << sref_b;
}

/*!
* \brief Check the dominant property of a block:
* the block is the only writer of its output, dominating the reader of its output buffers
* \param scope The block-scope of the block to be checked
* \param block_sref The block whose dominant property is to be checked
* \return A boolean indicating if the block is a dominant block
* the block is the only writer of its output, dominating the reader of its output buffers under the
* given root scope.
* \param self The schedule state.
* \param scope_root_sref The StmtSRef corresponding to the root scope.
* \param block_sref The block whose dominant property is to be checked.
* \return A boolean indicating if the block is a dominant block.
*/
bool IsDominantBlock(const BlockScope& scope, const StmtSRef& block_sref) {
bool IsDominantBlock(const ScheduleState& self, const StmtSRef& scope_root_sref,
const StmtSRef& block_sref) {
std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;
CheckSRefHigherOrEqual(scope_root_sref, block_sref);
const BlockNode* maybe_root_block = scope_root_sref->StmtAs<BlockNode>();
if (maybe_root_block) {
BlockScope scope = self->GetBlockScope(scope_root_sref);
buffer_writers = scope->buffer_writers;
} else {
// Collect all child blocks of root sub-tree, and merge their buffer writers.
Array<StmtSRef> child_block_srefs = GetChildBlockSRefOnSRefTree(self, scope_root_sref);
for (const StmtSRef& child_block_sref : child_block_srefs) {
BlockScope child_scope = self->GetBlockScope(child_block_sref);
for (const auto& it : child_scope->buffer_writers) {
buffer_writers.insert(it);
}
}
}
// Check whether the input block is the only writer of its outputs
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual>& buffer_writers =
scope->buffer_writers;
for (const BufferRegion& write_region : block->writes) {
ICHECK(buffer_writers.count(write_region->buffer))
<< "InternalError: buffer \"" << write_region->buffer->name
<< "\" does not exist in the current scope, when querying block:\n"
<< GetRef<Block>(block);
if (buffer_writers.at(write_region->buffer).size() != 1) {
return false;
if (buffer_writers.count(write_region->buffer)) {
if (buffer_writers.at(write_region->buffer).size() != 1) {
return false;
}
}
}
return true;
Expand All @@ -178,7 +206,6 @@ bool IsDominantBlock(const BlockScope& scope, const StmtSRef& block_sref) {
*/
int CheckCompleteBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_root_sref) {
BlockScope scope = self->GetBlockScope(scope_root_sref);
// Cond 1. All block vars are data parallel
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
for (const IterVar& iter_var : block->iter_vars) {
Expand All @@ -188,7 +215,7 @@ int CheckCompleteBlockErrorCode(const ScheduleState& self, const StmtSRef& block
}
// Cond 2. Dominant: the block is the only writer of its output,
// dominating the reader of its output buffers
if (!IsDominantBlock(scope, block_sref)) {
if (!IsDominantBlock(self, scope_root_sref, block_sref)) {
return 2;
}
// Cond 3. No overlap between the buffers the block reads and writes
Expand Down Expand Up @@ -217,6 +244,18 @@ static const char* kReductionBlockDefinition = R"(Definition of a reduction bloc
4) Dominant: the block is the only writer of its output, dominating the reader of its output buffers
5) The reduction block vars are not used to index the output buffers)";

static const char* kLocalCompleteBlockDefinition = R"(Definition of a local complete block:
1) All block vars are data parallel
2) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree
3) No overlap between the buffers the block reads and writes)";

static const char* kLocalReductionBlockDefinition = R"(Definition of a reduction block:
1) The block has the `init` statement
2) All the block bindings are quasi-affine expressions
3) All block vars are either data parallel block vars or reduction block vars
4) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree
5) The reduction block vars are not used to index the output buffers)";

bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_root_sref) {
return CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref) == 0;
Expand Down Expand Up @@ -260,7 +299,6 @@ void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref,
*/
int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_root_sref) {
BlockScope scope = self->GetBlockScope(scope_root_sref);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
// Cond 1. The block has the `init` statement.
if (!block->init.defined()) {
Expand All @@ -277,7 +315,7 @@ int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& bloc
}
// Cond 4. Dominant: the block is the only writer of its output, dominating the reader of its
// output buffers.
if (!IsDominantBlock(scope, block_sref)) {
if (!IsDominantBlock(self, scope_root_sref, block_sref)) {
return 4;
}
// Cond 5. The reduction block vars are not used to index the output buffers.
Expand Down Expand Up @@ -363,40 +401,55 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl
reduction_block_error_code);
}

void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root,
const StmtSRef& scope_root_sref) {
void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root) {
class NotCompactDataFlowError : public ScheduleError {
public:
explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, Block violate_block)
explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, Block violate_block,
int local_complete_block_code, int local_reduction_block_code)
: mod_(std::move(mod)),
subtree_root_(std::move(subtree_root)),
violate_block_(std::move(violate_block)) {
violate_block_(std::move(violate_block)),
local_complete_block_code_(local_complete_block_code),
local_reduction_block_code_(local_reduction_block_code) {
ICHECK(subtree_root_->IsInstance<BlockNode>() || subtree_root_->IsInstance<ForNode>());
}
String FastErrorString() const final {
return "ScheduleError: The queried subtree root in SRef tree does not have compact dataflow, "
"because some of its child block on SRef tree is neither a complete block nor a "
"reduction block";
"because some of its child block on SRef tree is neither a local complete block nor a "
"local reduction block.";
}
String DetailRenderTemplate() const final {
return "The queried subtree root {0} in SRef tree does not have compact dataflow, because "
"its child block {1} on SRef tree is neither a complete block nor a reduction block";
std::ostringstream os;
os << "The queried subtree root {0} in SRef tree does not have compact dataflow, because "
"its child block {1} on SRef tree is neither a local complete block nor a local "
"reduction block.\n";
os << "It violates condition #" << local_complete_block_code_
<< " as a local complete block.\n";
os << kLocalCompleteBlockDefinition << "\n";
os << "It violates condition #" << local_reduction_block_code_
<< " as a local reduction block.\n";
os << kLocalReductionBlockDefinition << "\n";
return os.str();
}
IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {subtree_root_, violate_block_}; }

IRModule mod_;
Stmt subtree_root_;
Block violate_block_;
int local_complete_block_code_;
int local_reduction_block_code_;
};

Array<StmtSRef> child_block_srefs = GetChildBlockSRefOnSRefTree(self, subtree_root);
for (const StmtSRef& block_sref : child_block_srefs) {
if (!IsCompleteBlock(self, block_sref, scope_root_sref) &&
!IsReductionBlock(self, block_sref, scope_root_sref)) {
int local_complete_block_code = CheckCompleteBlockErrorCode(self, block_sref, subtree_root),
local_reduction_block_code = CheckReductionBlockErrorCode(self, block_sref, subtree_root);
if (local_complete_block_code != 0 && local_reduction_block_code != 0) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
throw NotCompactDataFlowError(self->mod, GetRef<Stmt>(subtree_root->stmt),
GetRef<Block>(block));
GetRef<Block>(block), local_complete_block_code,
local_reduction_block_code);
}
}
}
Expand Down
4 changes: 1 addition & 3 deletions src/tir/schedule/primitive/for_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,7 @@ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref
* parallelized/vectorized/bound.
*/
// Step 1. Check whether the subtree rooted from the `loop` in sref tree has compact data flow.
StmtSRef scope_root_sref = GetScopeRoot(self, loop_sref,
/*require_stage_pipeline=*/true);
CheckSubtreeCompactDataflow(self, loop_sref, scope_root_sref);
CheckSubtreeCompactDataflow(self, loop_sref);

// Step 2. Check whether the loop can be parallelized/vectorized/bound with regard to each
// underlying block.
Expand Down
89 changes: 89 additions & 0 deletions tests/python/unittest/test_tir_schedule_for_kind.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,72 @@ def decomposed_gemm_after_vectorize(
C[vi, vj] = local[vi, vj]


@T.prim_func
def decomposed_gemm_parallelize_init(
A: T.Buffer[(16, 16), "float32"],
B: T.Buffer[(16, 16), "float32"],
C: T.Buffer[(16, 16), "float32"],
) -> None:
local = T.alloc_buffer([16, 16], dtype="float32")
for i, j in T.grid(4, 4):
for ii in T.serial(4):
for jj in T.vectorized(4):
with T.block("init"):
vi = T.axis.spatial(16, i * 4 + ii)
vj = T.axis.spatial(16, j * 4 + jj)
T.reads()
T.writes(local[vi, vj])
local[vi, vj] = 0
for k, ii, jj in T.grid(16, 4, 4):
with T.block("update"):
vi = T.axis.spatial(16, i * 4 + ii)
vj = T.axis.spatial(16, j * 4 + jj)
vk = T.axis.reduce(16, k)
T.reads(local[vi, vj], A[vi, vk], B[vj, vk])
T.writes(local[vi, vj])
local[vi, vj] = local[vi, vj] + A[vi, vk] * B[vj, vk]
for ii, jj in T.grid(4, 4):
with T.block("C"):
vi = T.axis.spatial(16, i * 4 + ii)
vj = T.axis.spatial(16, j * 4 + jj)
T.reads(local[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = local[vi, vj]


@T.prim_func
def scatter_compute(A: T.Buffer[(16,), "float32"], B: T.Buffer[(16,), "float32"]):
for i in T.grid(8):
with T.block("first_half"):
vi = T.axis.spatial(16, 8 + i)
B[vi] = A[vi - 8]

for i in T.grid(8):
with T.block("last_half"):
vi = T.axis.spatial(16, i)
B[vi] = A[vi + 8]


@T.prim_func
def scatter_compute_parallelize(
A: T.Buffer[(16,), "float32"], B: T.Buffer[(16,), "float32"]
) -> None:
# body
# with T.block("root")
for i in T.parallel(8):
with T.block("first_half"):
vi = T.axis.spatial(16, 8 + i)
T.reads(A[vi - 8])
T.writes(B[vi])
B[vi] = A[vi - 8]
for i in T.parallel(8):
with T.block("last_half"):
vi = T.axis.spatial(16, i)
T.reads(A[vi + 8])
T.writes(B[vi])
B[vi] = A[vi + 8]


# pylint: enable=no-member,invalid-name,unused-variable


Expand Down Expand Up @@ -468,5 +534,28 @@ def test_vectorize_after_decompose():
verify_trace_roundtrip(s, mod=decomposed_gemm)


def test_vectorize_init():
s = tir.Schedule(decomposed_gemm, debug_mask="all")
init_blk = s.get_block("init")
upd_blk = s.get_block("update")
_, _, ii_0, jj_0 = s.get_loops(init_blk)
_, _, k_1, ii_1, jj_1 = s.get_loops(upd_blk)
s.vectorize(jj_0)
tvm.ir.assert_structural_equal(s.mod["main"], decomposed_gemm_parallelize_init)
verify_trace_roundtrip(s, mod=decomposed_gemm)


def test_scatter_parallelize():
s = tir.Schedule(scatter_compute, debug_mask="all")
first = s.get_block("first_half")
last = s.get_block("last_half")
(i_0,) = s.get_loops(first)
(i_1,) = s.get_loops(last)
s.parallel(i_0)
s.parallel(i_1)
tvm.ir.assert_structural_equal(s.mod["main"], scatter_compute_parallelize)
verify_trace_roundtrip(s, mod=scatter_compute)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 0ddaaa6

Please sign in to comment.