Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 61 additions & 30 deletions src/tir/analysis/buffer_access_lca_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,13 @@ class LCADetector : public StmtExprVisitor {

ancestor_scopes_.push_back(current_scope);

// For each accessed buffer of the block, update the buffer's lca to
// For each accessed buffer of the block
// If it accesses the opaque block iter vars, update the buffer's lca to
// the lowest inclusive stmt position, which should dominate all loops
// related to the accessed opaque block iter vars in buffer indices.
UpdateDominateScopeOfOpaqueIter(op);
// related to the accessed opaque block iter vars.
// If it is the reduction block write buffer, update the buffer's lca to
// dominate all reduction iter var related loops.
UpdateDominateScopeOfNonDataParIter(op);

// Update match_buffers
for (const MatchBufferRegion& match_buffer : block->match_buffers) {
Expand All @@ -132,43 +135,70 @@ class LCADetector : public StmtExprVisitor {
ancestor_scopes_.pop_back();
}

void UpdateDominateScopeOfOpaqueIter(const BlockRealizeNode* block_realize) {
// map opaque iter var to the scope which dominate all loop carried dependencies.
std::unordered_map<const VarNode*, const ScopeInfo*> itervar_to_dom_scope;
void UpdateDominateScopeOfNonDataParIter(const BlockRealizeNode* block_realize) {
// map iter var to the scope which dominate all loop carried dependencies.
std::unordered_map<const VarNode*, const ScopeInfo*> opaque_var_scope;
// maintain highest scope which dominate all reduce loop iters. null denotes non-reduce block.
const ScopeInfo* highest_reduce_scope = nullptr;

// function to collect `itervar_to_dom_scope`, the result scope for each block
// iter var should be above all loop scopes the opaque iter var binding relates to.
auto do_collect_itervar_scope = [this, &itervar_to_dom_scope](const IterVar& itervar,
const PrimExpr& binding) {
PostOrderVisit(binding, [this, &itervar_to_dom_scope, &itervar](const ObjectRef& obj) {
auto do_collect_itervar_scope = [this](const IterVar& itervar,
const PrimExpr& binding) -> const ScopeInfo* {
const ScopeInfo* highest_scope = nullptr;
PostOrderVisit(binding, [this, &itervar, &highest_scope](const ObjectRef& obj) {
if (const VarNode* loop_var = obj.as<VarNode>()) {
auto it = loop_scope_map_.find(loop_var);
if (it == loop_scope_map_.end()) {
return;
}
const ScopeInfo* scope = it->second->parent_scope_info;
// find the highest loop scope the iter var binding has related to.
auto dom_scope_it = itervar_to_dom_scope.find(itervar->var.get());
if (dom_scope_it == itervar_to_dom_scope.end()) {
itervar_to_dom_scope.insert(dom_scope_it, {itervar->var.get(), scope});
} else if (scope->depth < dom_scope_it->second->depth) {
dom_scope_it->second = scope;
if (highest_scope == nullptr) {
highest_scope = scope;
} else if (scope->depth < highest_scope->depth) {
highest_scope = scope;
}
}
});
return highest_scope;
};

// collect non-data-parallel block iteration's dominate scope.
// for reduction iter type, we maintain the highest dominate scope for all reduce iters.
// for other iter type, we maintain the dict for each individual iter.
const Block& block = block_realize->block;
bool is_reduce_block = false;
for (size_t i = 0; i < block_realize->iter_values.size(); ++i) {
const IterVar& iter_var = block->iter_vars[i];
if (iter_var->iter_type != IterVarType::kDataPar) {
const auto* scope = do_collect_itervar_scope(iter_var, block_realize->iter_values[i]);
if (scope == nullptr) continue;
if (iter_var->iter_type == IterVarType::kCommReduce) {
is_reduce_block = true;
if (highest_reduce_scope == nullptr || scope->depth < highest_reduce_scope->depth) {
highest_reduce_scope = scope;
}
} else {
opaque_var_scope[iter_var->var.get()] = scope;
for (const auto& write : block->writes) {
UpdateBufferLCA(write->buffer.get(), scope);
}
}
}
}

// function to update lca scope of the buffer with loop carried dependent buffer accesses.
// the result scope should be above all loop scopes the accessed opaque block iter vars
// relate to, which is record in `itervar_to_dom_scope`.
auto do_update = [this, &itervar_to_dom_scope](const BufferRegion& region) {
auto do_update = [this, &opaque_var_scope, highest_reduce_scope](const BufferRegion& region,
bool is_reduce_write = false) {
const Buffer& buffer = region->buffer;
const ScopeInfo* scope = ancestor_scopes_.back();

auto handle_itervar = [&itervar_to_dom_scope, &scope](const ObjectRef& obj) {
auto handle_itervar = [&opaque_var_scope, &scope](const ObjectRef& obj) {
if (const VarNode* iter_var = obj.as<VarNode>()) {
auto dom_scope_it = itervar_to_dom_scope.find(iter_var);
if (dom_scope_it == itervar_to_dom_scope.end()) {
auto dom_scope_it = opaque_var_scope.find(iter_var);
if (dom_scope_it == opaque_var_scope.end()) {
return;
}
// find the highest loop scope the accessed buffer index has
Expand All @@ -184,24 +214,25 @@ class LCADetector : public StmtExprVisitor {
PostOrderVisit(range->min, handle_itervar);
PostOrderVisit(range->min + range->extent - 1, handle_itervar);
}

// the scope should be above `highest_reduce_scope` for reduce output buffer.
if (is_reduce_write && highest_reduce_scope != nullptr &&
scope->depth > highest_reduce_scope->depth) {
scope = highest_reduce_scope;
}
UpdateBufferLCA(buffer.get(), scope);
};

// do collect and update
const Block& block = block_realize->block;
for (size_t i = 0; i < block_realize->iter_values.size(); ++i) {
const IterVar& iter_var = block->iter_vars[i];
if (iter_var->iter_type != IterVarType::kDataPar &&
iter_var->iter_type != IterVarType::kCommReduce) {
do_collect_itervar_scope(iter_var, block_realize->iter_values[i]);
}
}
if (!itervar_to_dom_scope.empty()) {
if (!opaque_var_scope.empty()) {
for (const auto& read : block->reads) {
do_update(read);
}
for (const auto& write : block->writes) {
do_update(write);
do_update(write, /*is_reduce_write=*/is_reduce_block);
}
} else if (is_reduce_block && highest_reduce_scope != nullptr) {
for (const auto& write : block->writes) {
do_update(write, /*is_reduce_write=*/true);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,10 @@ def test_buffer_load_store():
root_block = func.body.block
assert lca[A] == func.body.block

# LCA of Buffer B is reduction block
reduce_block = root_block.body[1].body.body.body.block
assert lca[B] == reduce_block
# LCA of Buffer B is the loop dominate all reduction loop
reduce_dom_loop = root_block.body[1].body
reduce_block = reduce_dom_loop.body.body.block
assert lca[B] == reduce_dom_loop

# LCA of Buffer C is the second loop kk
loop_jj = reduce_block.body.body
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,5 +402,55 @@ def before(dlpack_handle: T.handle, axis: T.int64) -> T.int64:
_check(before, after)


def test_reduce_buffer_dominate_reduce_loops():
"""Reduction write buffer allocation should dominate all reduce loops"""

@T.prim_func
def before(x: T.Buffer((256, 256, 256), "float32"), x_red: T.Buffer((256, 256), "float32")):
x_red_ = T.alloc_buffer((256, 256))
for ax0_0, k1_0, ax1_0 in T.grid(4, 4, 4):
for ax0_1, k1_1, ax1_1 in T.grid(64, 64, 64):
with T.block("x_red"):
v_ax0 = T.axis.spatial(256, ax0_0 * 64 + ax0_1)
v_ax1 = T.axis.spatial(256, ax1_0 * 64 + ax1_1)
v_k1 = T.axis.reduce(256, k1_0 * 64 + k1_1)
if v_k1 == 0:
x_red_[v_ax0, v_ax1] = T.float32(0.0)
x_red_[v_ax0, v_ax1] = x_red_[v_ax0, v_ax1] + x[v_ax0, v_k1, v_ax1]
for ax0, ax1 in T.grid(64, 64):
with T.block("x_red_"):
v0 = T.axis.spatial(256, ax0_0 * 64 + ax0)
v1 = T.axis.spatial(256, ax1_0 * 64 + ax1)
x_red[v0, v1] = x_red_[v0, v1]

@T.prim_func
def after(x: T.Buffer((256, 256, 256), "float32"), x_red: T.Buffer((256, 256), "float32")):
for ax0_0 in range(4):
with T.block(""):
T.reads(x[ax0_0 * 64 : ax0_0 * 64 + 64, 0:256, 0:256])
T.writes(x_red[ax0_0 * 64 : ax0_0 * 64 + 64, 0:256])
x_red_ = T.alloc_buffer((256, 256))
for k1_0, ax1_0 in T.grid(4, 4):
for ax0_1, k1_1, ax1_1 in T.grid(64, 64, 64):
with T.block("x_red"):
v_ax0 = T.axis.spatial(256, ax0_0 * 64 + ax0_1)
v_ax1 = T.axis.spatial(256, ax1_0 * 64 + ax1_1)
v_k1 = T.axis.reduce(256, k1_0 * 64 + k1_1)
T.reads(x_red_[v_ax0, v_ax1], x[v_ax0, v_k1, v_ax1])
T.writes(x_red_[v_ax0, v_ax1])
if v_k1 == 0:
x_red_[v_ax0, v_ax1] = T.float32(0.0)
x_red_[v_ax0, v_ax1] = x_red_[v_ax0, v_ax1] + x[v_ax0, v_k1, v_ax1]
for ax0, ax1 in T.grid(64, 64):
with T.block("x_red_"):
v0 = T.axis.spatial(256, ax0_0 * 64 + ax0)
v1 = T.axis.spatial(256, ax1_0 * 64 + ax1)
T.reads(x_red_[v0, v1])
T.writes(x_red[v0, v1])
x_red[v0, v1] = x_red_[v0, v1]

_check(before, after)


if __name__ == "__main__":
tvm.testing.main()
Loading