From 6aefc262b3244190a676919344368f514200e62d Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Sat, 17 Apr 2021 04:52:40 +0800 Subject: [PATCH] [TensorIR][M1c] LCA detector (#7848) Co-authored-by: Ruihang Lai Co-authored-by: Junru Shao --- include/tvm/tir/analysis.h | 9 + python/tvm/tir/analysis/analysis.py | 22 ++- .../analysis/buffer_access_lca_detector.cc | 173 ++++++++++++++++++ ...t_tir_analysis_detect_buffer_access_lca.py | 107 +++++++++++ 4 files changed, 310 insertions(+), 1 deletion(-) create mode 100644 src/tir/analysis/buffer_access_lca_detector.cc create mode 100644 tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 250a84e782a2..9282d6412d49 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -178,6 +178,15 @@ Array> GetBlockAccessRegion(const Block& block, */ TVM_DLL size_t CalculateExprComplexity(const PrimExpr& expr); +/*! + * \brief Detect the lowest common ancestor(LCA) of buffer access, including both high-level + * access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access). + * The LCA may be a For loop or a Block. + * \param func The PrimFunc to be detected. + * \return The Map from buffer to the LCA of all access to it. + */ +TVM_DLL Map DetectBufferAccessLCA(const PrimFunc& func); + // Pass variants of verification analysis // directly throws RuntimeError when verification fails. namespace transform { diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 829eb8bbdedb..a462853f9d55 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -16,8 +16,10 @@ # under the License. """Wrapping existing analysis utils.""" # pylint: disable=invalid-name - +from typing import Dict from . import _ffi_api +from ..function import PrimFunc +from .. import Buffer, Stmt def expr_deep_equal(lhs, rhs): @@ -129,3 +131,21 @@ def get_block_access_region(block, buffer_var_map): - third: opaque regions """ return _ffi_api.get_block_access_region(block, buffer_var_map) + + +def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]: + """Detect the lowest common ancestor(LCA) of buffer access, including both high-level + access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access). + The LCA may be a For loop or a Block. + + Parameters + ---------- + func: tvm.tir.PrimFunc + The function to be detected. + + Returns + ------- + result : Dict[Buffer, Stmt] + Map from buffer to the LCA of all access to it. + """ + return _ffi_api.detect_buffer_access_lca(func) # pylint: disable=no-member diff --git a/src/tir/analysis/buffer_access_lca_detector.cc b/src/tir/analysis/buffer_access_lca_detector.cc new file mode 100644 index 000000000000..23e60e16fc62 --- /dev/null +++ b/src/tir/analysis/buffer_access_lca_detector.cc @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/buffer_access_lca_detector.cc + * \brief Detect the lowest common ancestor(LCA) of buffer access + */ + +#include +#include + +#include "../../support/arena.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Detect the lowest common ancestor(LCA) position of Buffer access. + * \note Only consider BlockNode and ForNode to be the LCA nodes. + */ +class LCADetector : public StmtExprVisitor { + public: + static Map Detect(const PrimFunc& func) { + LCADetector detector; + for (const auto& kv : func->buffer_map) { + const Buffer& buffer = kv.second; + detector.buffer_var_map_.emplace(buffer->data.get(), buffer.get()); + } + detector(func->body); + // Prepare the return + Map buffer_lca; + for (const auto& kv : detector.buffer_lca_) { + buffer_lca.Set(GetRef(kv.first), GetRef(kv.second->stmt)); + } + return buffer_lca; + } + + private: + /*! + * \brief The AST node information for querying LCA. + * \note Only BlockNode and ForNode are considered, since they are the only statements whose + * body can be a SeqStmt (the LCA of buffer access) in TensorIR. + */ + struct ScopeInfo { + // The parent scope info + const ScopeInfo* parent_scope_info; + // The parent scope stmt node + const StmtNode* stmt; + // The scope depth in the AST + int depth; + ScopeInfo(const ScopeInfo* parent_info, const StmtNode* stmt, int depth) + : parent_scope_info(parent_info), stmt(stmt), depth(depth) {} + }; + + void VisitStmt_(const ForNode* op) final { + int n = ancestor_scopes_.size(); + const ScopeInfo* parent_scope = ancestor_scopes_.back(); + auto* current_scope = arena_.make(parent_scope, op, n); + ancestor_scopes_.push_back(current_scope); + StmtExprVisitor::VisitStmt_(op); + ancestor_scopes_.pop_back(); + } + + void VisitStmt_(const BlockNode* op) final { + int n = ancestor_scopes_.size(); + for (const Buffer& buf : op->alloc_buffers) { + buffer_var_map_.emplace(buf->data.get(), buf.get()); + } + const ScopeInfo* parent_scope = ancestor_scopes_.back(); + auto* current_scope = arena_.make(parent_scope, op, n); + ancestor_scopes_.push_back(current_scope); + StmtExprVisitor::VisitStmt_(op); + ancestor_scopes_.pop_back(); + } + + void VisitExpr_(const BufferLoadNode* op) final { + UpdateBufferLCA(op->buffer.get()); + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const BufferStoreNode* op) final { + UpdateBufferLCA(op->buffer.get()); + StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const BufferRealizeNode* op) final { + buffer_var_map_.emplace(op->buffer->data.get(), op->buffer.get()); + StmtExprVisitor::VisitStmt_(op); + } + + // Works for Load/Store and opaque access. + void VisitExpr_(const VarNode* op) final { VisitBufferVar(op); } + + // Explict to visit buffer data in Load and Store node. + void VisitExpr_(const LoadNode* op) final { + ExprVisitor::VisitExpr_(op); + VisitBufferVar(op->buffer_var.get()); + } + + void VisitStmt_(const StoreNode* op) final { + StmtVisitor::VisitStmt_(op); + VisitBufferVar(op->buffer_var.get()); + } + + void VisitBufferVar(const VarNode* op) { + auto it = buffer_var_map_.find(op); + if (it != buffer_var_map_.end()) { + UpdateBufferLCA(it->second); + } + } + + void UpdateBufferLCA(const BufferNode* buffer) { + const ScopeInfo*& lca = buffer_lca_[buffer]; + lca = LowestCommonAncestor(lca, ancestor_scopes_.back()); + } + + static const ScopeInfo* LowestCommonAncestor(const ScopeInfo* lhs, const ScopeInfo* rhs) { + ICHECK(lhs || rhs); + if (lhs == nullptr) return rhs; + if (rhs == nullptr) return lhs; + while (lhs->parent_scope_info != nullptr && // + rhs->parent_scope_info != nullptr && // + lhs != rhs) { + if (lhs->depth == rhs->depth) { + lhs = lhs->parent_scope_info; + rhs = rhs->parent_scope_info; + } else if (lhs->depth < rhs->depth) { + rhs = rhs->parent_scope_info; + } else { + lhs = lhs->parent_scope_info; + } + } + if (lhs->parent_scope_info == nullptr) { + return lhs; + } + if (rhs->parent_scope_info == nullptr) { + return rhs; + } + ICHECK(lhs == rhs); + return lhs; + } + + /*! \brief The ancestor scope stacks info (Block and For), initialized with Null. */ + std::vector ancestor_scopes_ = {nullptr}; + /*! \brief The map from Buffer to its LCA ForNode/BlockNode. */ + std::unordered_map buffer_lca_ = {}; + /*! \brief The map from Buffer data to the Buffer. */ + std::unordered_map buffer_var_map_ = {}; + /*! \brief Internal arena. */ + support::Arena arena_; +}; + +Map DetectBufferAccessLCA(const PrimFunc& func) { return LCADetector::Detect(func); } + +TVM_REGISTER_GLOBAL("tir.analysis.detect_buffer_access_lca").set_body_typed(DetectBufferAccessLCA); +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py new file mode 100644 index 000000000000..7ac61a705fbd --- /dev/null +++ b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import tir +from tvm.script import ty + + +@tvm.script.tir +def buffer_load_store_func(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + B = tir.match_buffer(b, (128, 128), "float32") + C = tir.alloc_buffer((128, 128), "float32") + D = tir.alloc_buffer((128, 128), "float32") + with tir.block([128, 128]) as [i, j]: + A[i, j] = tir.float32(0) + with tir.block([32, 32, tir.reduce_axis(0, 32)]) as [i, j, k]: + with tir.init(): + for ii, jj in tir.grid(4, 4): + B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] + for ii, jj in tir.grid(4, 4): + for kk in range(0, 4): + B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk] + for kk in range(0, 4): + B[i * 4 + ii, j * 4 + jj] += D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk] + + +@tvm.script.tir +def buffer_opaque_access(b: ty.handle, c: ty.handle) -> None: + B = tir.match_buffer(b, [16, 16], "float32") + C = tir.match_buffer(c, [16, 16], "float32") + + with tir.block([]): + tir.reads([]) + tir.writes(B[0:16, 0:16]) + A = tir.allocate([256], "float32", "global") + for i, j in tir.grid(16, 16): + tir.store(A, i * 16 + j, 1) + for i in range(0, 16): + for j in range(0, 16): + tir.evaluate(tir.load("float32", A, i * 16 + j)) + for j in range(0, 16): + tir.evaluate( + tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, tir.float32(0), dtype="handle") + ) + + for i, j in tir.grid(16, 16): + with tir.block([16, 16]) as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + C[vi, vj] = B[vi, vj] + + +def test_buffer_load_store(): + func = buffer_load_store_func + A, B = [func.buffer_map[x] for x in func.params] + C, D = func.body.block.alloc_buffers + lca = tir.analysis.detect_buffer_access_lca(func) + + # LCA of Buffer A is root + root_block = func.body.block + assert lca[A] == func.body.block + + # LCA of Buffer B is reduction block + reduce_block = root_block.body[1].body.body.body.block + assert lca[B] == reduce_block + + # LCA of Buffer C is the second loop kk + loop_jj = reduce_block.body.body + assert lca[C] == loop_jj + + # LCA of Buffer D is loop jj + loop_kk = loop_jj.body[1] + assert lca[D] == loop_kk + + +def test_opaque_access(): + func = buffer_opaque_access + B, C = [func.buffer_map[x] for x in func.params] + lca = tir.analysis.detect_buffer_access_lca(func) + + # Cannot detect buffer A since it is define by low-level Allocate + + # LCA of Buffer B is root + root_block = func.body.block + assert lca[B] == func.body.block + + # LCA of Buffer C is the correspond block + assert lca[C] == root_block.body[1].body.body.block + + +if __name__ == "__main__": + test_buffer_load_store() + test_opaque_access()