Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[BugFix][TIR] Fix Buffer LCA Detector (apache#12819)
Browse files Browse the repository at this point in the history
Prior to this PR, the LCA detector of buffers in TIR didn't take buffer memory scopes and GPU hierarchy into consideration. An consequent issue is that, when an intermediate buffer is in global memory, TIR's lowering passes don't necessarily allocated the intermediate buffer outside all `blockIdx`. As a result, the global intermediate buffer is allocated under a GPU thread block, which is illegal.

This PR fixes this issue by fixing the LCA detector, making it be aware of the buffer memory scopes and GPU hierarchy. With this fix, the global intermediate buffers are all allocated outside `blockIdx`.
  • Loading branch information
MasterJH5574 authored and xinetzone committed Nov 25, 2022
1 parent 5f3f89b commit 62cee24
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1 deletion.
45 changes: 44 additions & 1 deletion src/tir/analysis/buffer_access_lca_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,19 @@
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>

#include "../../runtime/thread_storage_scope.h"
#include "../../support/arena.h"

namespace tvm {
namespace tir {

/*!
* \brief Detect the lowest common ancestor(LCA) position of Buffer access.
* \note Only consider BlockNode and ForNode to be the LCA nodes.
* \note
* - Only consider BlockNode and ForNode to be the LCA nodes.
* - In the LCA locator, we are aware of the buffer scope and CUDA hierarchy so that any buffer in
* global memory will have its buffer access LCA outside all launch sites of `blockIdx`, in order to
* prevent conflicts between buffer memory scopes and CUDA hierarchy.
*/
class LCADetector : public StmtExprVisitor {
public:
Expand All @@ -51,6 +56,8 @@ class LCADetector : public StmtExprVisitor {
detector.ancestor_scopes_.push_back(&root);

detector(func->body);
detector.UpdateWithBlockidx();

// Prepare the return
Map<Buffer, Optional<Stmt>> buffer_lca;
for (const auto& kv : detector.buffer_lca_) {
Expand Down Expand Up @@ -82,6 +89,15 @@ class LCADetector : public StmtExprVisitor {
int n = ancestor_scopes_.size();
const ScopeInfo* parent_scope = ancestor_scopes_.back();
auto* current_scope = arena_.make<ScopeInfo>(parent_scope, op, n);

if (op->thread_binding.defined()) {
const runtime::ThreadScope& scope =
runtime::ThreadScope::Create(op->thread_binding.value()->thread_tag);
if (scope.rank == 0) {
blockidx_scopes_.push_back(current_scope);
}
}

ancestor_scopes_.push_back(current_scope);
StmtExprVisitor::VisitStmt_(op);
ancestor_scopes_.pop_back();
Expand All @@ -107,6 +123,18 @@ class LCADetector : public StmtExprVisitor {
ancestor_scopes_.pop_back();
}

void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
const auto* iter = op->node.as<IterVarNode>();
ICHECK_NOTNULL(iter);
const runtime::ThreadScope& scope = runtime::ThreadScope::Create(iter->thread_tag);
if (scope.rank == 0) {
blockidx_scopes_.push_back(ancestor_scopes_.back());
}
}
StmtExprVisitor::VisitStmt_(op);
}

void VisitExpr_(const BufferLoadNode* op) final {
UpdateBufferLCA(op->buffer.get());
StmtExprVisitor::VisitExpr_(op);
Expand Down Expand Up @@ -150,6 +178,19 @@ class LCADetector : public StmtExprVisitor {
}
}

void UpdateWithBlockidx() {
for (const auto& it : buffer_lca_) {
const runtime::StorageScope& scope =
runtime::StorageScope::Create(GetRef<Buffer>(it.first).scope());
if (scope.rank == runtime::StorageRank::kGlobal) {
const ScopeInfo*& lca = buffer_lca_[it.first];
for (const ScopeInfo* blockidx_scope : blockidx_scopes_) {
lca = LowestCommonAncestor(lca, blockidx_scope);
}
}
}
}

static const ScopeInfo* LowestCommonAncestor(const ScopeInfo* lhs, const ScopeInfo* rhs) {
if (lhs == nullptr) return rhs;
if (rhs == nullptr) return lhs;
Expand Down Expand Up @@ -186,6 +227,8 @@ class LCADetector : public StmtExprVisitor {
std::unordered_map<const VarNode*, const BufferNode*> buffer_var_map_ = {};
/*! \brief The match buffers inside blocks. */
std::unordered_set<const BufferNode*> match_buffers_ = {};
/*! \brief The ForNodes/BlockNodes which contain immediate `blockIdx` launch. */
std::vector<const ScopeInfo*> blockidx_scopes_ = {};
/*! \brief Internal arena. */
support::Arena arena_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,19 @@ def match_buffer_func(a: T.handle, b: T.handle) -> None:
T.evaluate(B1.data)


@T.prim_func
def global_buffer_with_blockidx(
a: T.Buffer[(1, 32), "int32"], b: T.Buffer[(1, 32), "int32"]
) -> None:
for i0 in T.thread_binding(0, 1, thread="blockIdx.x"):
for i1 in T.thread_binding(0, 32, thread="threadIdx.x"):
with T.block("copy"):
i, j = T.axis.remap("SS", [i0, i1])
T.reads(a[i, j])
T.writes(b[i, j])
b[i, j] = a[i, j]


def test_buffer_load_store():
func = buffer_load_store_func
A, B = [func.buffer_map[x] for x in func.params]
Expand Down Expand Up @@ -154,8 +167,21 @@ def test_match_buffer():
assert lca[B] == block


def test_global_buffer_with_blockidx():
func = global_buffer_with_blockidx
A, B = [func.buffer_map[x] for x in func.params]
lca = tir.analysis.detect_buffer_access_lca(func)

root_block = func.body.block
blockidx_loop = root_block.body
# LCA of both A and B should be the loop bound to `blockIdx`
assert lca[A] == blockidx_loop
assert lca[B] == blockidx_loop


if __name__ == "__main__":
test_buffer_load_store()
test_opaque_access()
test_lca_func_root()
test_match_buffer()
test_global_buffer_with_blockidx()

0 comments on commit 62cee24

Please sign in to comment.