-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TensorIR][M1c] LCA detector (#7848)
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com> Co-authored-by: Junru Shao <junrushao1994@gmail.com>
- Loading branch information
1 parent
cc79e8f
commit 6aefc26
Showing
4 changed files
with
310 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file tir/analysis/buffer_access_lca_detector.cc | ||
* \brief Detect the lowest common ancestor(LCA) of buffer access | ||
*/ | ||
|
||
#include <tvm/tir/analysis.h> | ||
#include <tvm/tir/stmt_functor.h> | ||
|
||
#include "../../support/arena.h" | ||
|
||
namespace tvm { | ||
namespace tir { | ||
|
||
/*! | ||
* \brief Detect the lowest common ancestor(LCA) position of Buffer access. | ||
* \note Only consider BlockNode and ForNode to be the LCA nodes. | ||
*/ | ||
class LCADetector : public StmtExprVisitor { | ||
public: | ||
static Map<Buffer, Stmt> Detect(const PrimFunc& func) { | ||
LCADetector detector; | ||
for (const auto& kv : func->buffer_map) { | ||
const Buffer& buffer = kv.second; | ||
detector.buffer_var_map_.emplace(buffer->data.get(), buffer.get()); | ||
} | ||
detector(func->body); | ||
// Prepare the return | ||
Map<Buffer, Stmt> buffer_lca; | ||
for (const auto& kv : detector.buffer_lca_) { | ||
buffer_lca.Set(GetRef<Buffer>(kv.first), GetRef<Stmt>(kv.second->stmt)); | ||
} | ||
return buffer_lca; | ||
} | ||
|
||
private: | ||
/*! | ||
* \brief The AST node information for querying LCA. | ||
* \note Only BlockNode and ForNode are considered, since they are the only statements whose | ||
* body can be a SeqStmt (the LCA of buffer access) in TensorIR. | ||
*/ | ||
struct ScopeInfo { | ||
// The parent scope info | ||
const ScopeInfo* parent_scope_info; | ||
// The parent scope stmt node | ||
const StmtNode* stmt; | ||
// The scope depth in the AST | ||
int depth; | ||
ScopeInfo(const ScopeInfo* parent_info, const StmtNode* stmt, int depth) | ||
: parent_scope_info(parent_info), stmt(stmt), depth(depth) {} | ||
}; | ||
|
||
void VisitStmt_(const ForNode* op) final { | ||
int n = ancestor_scopes_.size(); | ||
const ScopeInfo* parent_scope = ancestor_scopes_.back(); | ||
auto* current_scope = arena_.make<ScopeInfo>(parent_scope, op, n); | ||
ancestor_scopes_.push_back(current_scope); | ||
StmtExprVisitor::VisitStmt_(op); | ||
ancestor_scopes_.pop_back(); | ||
} | ||
|
||
void VisitStmt_(const BlockNode* op) final { | ||
int n = ancestor_scopes_.size(); | ||
for (const Buffer& buf : op->alloc_buffers) { | ||
buffer_var_map_.emplace(buf->data.get(), buf.get()); | ||
} | ||
const ScopeInfo* parent_scope = ancestor_scopes_.back(); | ||
auto* current_scope = arena_.make<ScopeInfo>(parent_scope, op, n); | ||
ancestor_scopes_.push_back(current_scope); | ||
StmtExprVisitor::VisitStmt_(op); | ||
ancestor_scopes_.pop_back(); | ||
} | ||
|
||
void VisitExpr_(const BufferLoadNode* op) final { | ||
UpdateBufferLCA(op->buffer.get()); | ||
StmtExprVisitor::VisitExpr_(op); | ||
} | ||
|
||
void VisitStmt_(const BufferStoreNode* op) final { | ||
UpdateBufferLCA(op->buffer.get()); | ||
StmtExprVisitor::VisitStmt_(op); | ||
} | ||
|
||
void VisitStmt_(const BufferRealizeNode* op) final { | ||
buffer_var_map_.emplace(op->buffer->data.get(), op->buffer.get()); | ||
StmtExprVisitor::VisitStmt_(op); | ||
} | ||
|
||
// Works for Load/Store and opaque access. | ||
void VisitExpr_(const VarNode* op) final { VisitBufferVar(op); } | ||
|
||
// Explict to visit buffer data in Load and Store node. | ||
void VisitExpr_(const LoadNode* op) final { | ||
ExprVisitor::VisitExpr_(op); | ||
VisitBufferVar(op->buffer_var.get()); | ||
} | ||
|
||
void VisitStmt_(const StoreNode* op) final { | ||
StmtVisitor::VisitStmt_(op); | ||
VisitBufferVar(op->buffer_var.get()); | ||
} | ||
|
||
void VisitBufferVar(const VarNode* op) { | ||
auto it = buffer_var_map_.find(op); | ||
if (it != buffer_var_map_.end()) { | ||
UpdateBufferLCA(it->second); | ||
} | ||
} | ||
|
||
void UpdateBufferLCA(const BufferNode* buffer) { | ||
const ScopeInfo*& lca = buffer_lca_[buffer]; | ||
lca = LowestCommonAncestor(lca, ancestor_scopes_.back()); | ||
} | ||
|
||
static const ScopeInfo* LowestCommonAncestor(const ScopeInfo* lhs, const ScopeInfo* rhs) { | ||
ICHECK(lhs || rhs); | ||
if (lhs == nullptr) return rhs; | ||
if (rhs == nullptr) return lhs; | ||
while (lhs->parent_scope_info != nullptr && // | ||
rhs->parent_scope_info != nullptr && // | ||
lhs != rhs) { | ||
if (lhs->depth == rhs->depth) { | ||
lhs = lhs->parent_scope_info; | ||
rhs = rhs->parent_scope_info; | ||
} else if (lhs->depth < rhs->depth) { | ||
rhs = rhs->parent_scope_info; | ||
} else { | ||
lhs = lhs->parent_scope_info; | ||
} | ||
} | ||
if (lhs->parent_scope_info == nullptr) { | ||
return lhs; | ||
} | ||
if (rhs->parent_scope_info == nullptr) { | ||
return rhs; | ||
} | ||
ICHECK(lhs == rhs); | ||
return lhs; | ||
} | ||
|
||
/*! \brief The ancestor scope stacks info (Block and For), initialized with Null. */ | ||
std::vector<const ScopeInfo*> ancestor_scopes_ = {nullptr}; | ||
/*! \brief The map from Buffer to its LCA ForNode/BlockNode. */ | ||
std::unordered_map<const BufferNode*, const ScopeInfo*> buffer_lca_ = {}; | ||
/*! \brief The map from Buffer data to the Buffer. */ | ||
std::unordered_map<const VarNode*, const BufferNode*> buffer_var_map_ = {}; | ||
/*! \brief Internal arena. */ | ||
support::Arena arena_; | ||
}; | ||
|
||
Map<Buffer, Stmt> DetectBufferAccessLCA(const PrimFunc& func) { return LCADetector::Detect(func); } | ||
|
||
TVM_REGISTER_GLOBAL("tir.analysis.detect_buffer_access_lca").set_body_typed(DetectBufferAccessLCA); | ||
} // namespace tir | ||
} // namespace tvm |
107 changes: 107 additions & 0 deletions
107
tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
import tvm | ||
from tvm import tir | ||
from tvm.script import ty | ||
|
||
|
||
@tvm.script.tir | ||
def buffer_load_store_func(a: ty.handle, b: ty.handle) -> None: | ||
A = tir.match_buffer(a, (128, 128), "float32") | ||
B = tir.match_buffer(b, (128, 128), "float32") | ||
C = tir.alloc_buffer((128, 128), "float32") | ||
D = tir.alloc_buffer((128, 128), "float32") | ||
with tir.block([128, 128]) as [i, j]: | ||
A[i, j] = tir.float32(0) | ||
with tir.block([32, 32, tir.reduce_axis(0, 32)]) as [i, j, k]: | ||
with tir.init(): | ||
for ii, jj in tir.grid(4, 4): | ||
B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] | ||
for ii, jj in tir.grid(4, 4): | ||
for kk in range(0, 4): | ||
B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk] | ||
for kk in range(0, 4): | ||
B[i * 4 + ii, j * 4 + jj] += D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk] | ||
|
||
|
||
@tvm.script.tir | ||
def buffer_opaque_access(b: ty.handle, c: ty.handle) -> None: | ||
B = tir.match_buffer(b, [16, 16], "float32") | ||
C = tir.match_buffer(c, [16, 16], "float32") | ||
|
||
with tir.block([]): | ||
tir.reads([]) | ||
tir.writes(B[0:16, 0:16]) | ||
A = tir.allocate([256], "float32", "global") | ||
for i, j in tir.grid(16, 16): | ||
tir.store(A, i * 16 + j, 1) | ||
for i in range(0, 16): | ||
for j in range(0, 16): | ||
tir.evaluate(tir.load("float32", A, i * 16 + j)) | ||
for j in range(0, 16): | ||
tir.evaluate( | ||
tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, tir.float32(0), dtype="handle") | ||
) | ||
|
||
for i, j in tir.grid(16, 16): | ||
with tir.block([16, 16]) as [vi, vj]: | ||
tir.bind(vi, i) | ||
tir.bind(vj, j) | ||
C[vi, vj] = B[vi, vj] | ||
|
||
|
||
def test_buffer_load_store(): | ||
func = buffer_load_store_func | ||
A, B = [func.buffer_map[x] for x in func.params] | ||
C, D = func.body.block.alloc_buffers | ||
lca = tir.analysis.detect_buffer_access_lca(func) | ||
|
||
# LCA of Buffer A is root | ||
root_block = func.body.block | ||
assert lca[A] == func.body.block | ||
|
||
# LCA of Buffer B is reduction block | ||
reduce_block = root_block.body[1].body.body.body.block | ||
assert lca[B] == reduce_block | ||
|
||
# LCA of Buffer C is the second loop kk | ||
loop_jj = reduce_block.body.body | ||
assert lca[C] == loop_jj | ||
|
||
# LCA of Buffer D is loop jj | ||
loop_kk = loop_jj.body[1] | ||
assert lca[D] == loop_kk | ||
|
||
|
||
def test_opaque_access(): | ||
func = buffer_opaque_access | ||
B, C = [func.buffer_map[x] for x in func.params] | ||
lca = tir.analysis.detect_buffer_access_lca(func) | ||
|
||
# Cannot detect buffer A since it is define by low-level Allocate | ||
|
||
# LCA of Buffer B is root | ||
root_block = func.body.block | ||
assert lca[B] == func.body.block | ||
|
||
# LCA of Buffer C is the correspond block | ||
assert lca[C] == root_block.body[1].body.body.block | ||
|
||
|
||
if __name__ == "__main__": | ||
test_buffer_load_store() | ||
test_opaque_access() |