From ca502ee187f7a92f3aaeedbca094b5a3bfe6b8d2 Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 29 Jun 2021 18:47:39 +0900 Subject: [PATCH] merge storage_alloc refactor branch commit 0ca192454308b8c8b7d50cd2a7233d3f7cc88e73 Author: Masahiro Masuda Date: Tue Jun 29 18:21:41 2021 +0900 remove alloc map usage in cuda codegen commit fd07b355c32be02347595c78f4ec2560bca60057 Author: Masahiro Masuda Date: Tue Jun 29 17:51:33 2021 +0900 remove storage_scope map from storage_access.cc commit 0ba5c71706b7aa161aa51aa040ca8f22fb3aee27 Author: Masahiro Masuda Date: Tue Jun 29 17:36:07 2021 +0900 remove realize_scope commit f392e472999302f3cde04cdebbd0e93aac157c18 Author: Masahiro Masuda Date: Tue Jun 29 17:31:13 2021 +0900 fix passing storage scope commit bf132cf42dc388c869fca656fa13a4d5bbd2b8e8 Author: Masahiro Masuda Date: Tue Jun 29 13:11:55 2021 +0900 make global storage scope by default commit 6630cf332beafa2c8eeb13fd52fd1e0e349c0213 Author: Masahiro Masuda Date: Tue Jun 29 13:10:37 2021 +0900 remove realize_scope from schedule_ops commit 868820510fdc8ea03ba425a43b906cce736a3be9 Author: Masahiro Masuda Date: Tue Jun 29 13:10:18 2021 +0900 remove storage_scope attr from storage_rewrite commit b382bb39c9ebc0a4afc646df7c851bba88fb1f2a Author: Masahiro Masuda Date: Tue Jun 29 12:50:31 2021 +0900 remove attr::realize_scope from storage_flatten commit e20e195d2eaa5a30b755c0ef400bb92f9f94ddba Author: Masahiro Masuda Date: Tue Jun 29 12:46:13 2021 +0900 begin removing storage_scope attr commit d39a470109c2555d30e5c107e8245615952032f4 Author: Masahiro Masuda Date: Tue Jun 29 12:23:46 2021 +0900 thread storage scope through pipeline to buffer creation commit 496a2151052694a774276bccd4df4ef8dda742dc Author: Masahiro Masuda Date: Tue Jun 29 11:46:36 2021 +0900 adding storage_scope to ProduerRealize commit c58683445000145a3b97322bfdd5d6d2487ef3c5 Author: Swift.Sun Date: Mon Jun 28 12:32:19 2021 +0800 [AutoScheduler]Simplify the code (#8351) commit 4ff5cef6123bef6bc0c297926709b5c5a35eafd8 Author: Rafael Stahl Date: Sun Jun 27 19:18:34 2021 +0200 ffi: add missing binding for FixedPointMultiplyAttrs (#8353) commit b71b837b96746b50c1ea44a893f02418693d2ee0 Author: Matthew Brookhart Date: Sun Jun 27 02:00:32 2021 -0600 Remove an extra print from the relay astext tests (#8342) --- include/tvm/te/operation.h | 14 +-- include/tvm/tir/buffer.h | 4 +- include/tvm/tir/stmt.h | 9 +- python/tvm/relay/op/op_attrs.py | 5 + python/tvm/tir/ir_builder.py | 2 +- src/auto_scheduler/search_policy/utils.cc | 101 ++++++++---------- src/runtime/thread_storage_scope.h | 4 +- src/target/llvm/codegen_amdgpu.cc | 7 +- src/target/llvm/codegen_llvm.cc | 8 +- src/target/llvm/codegen_llvm.h | 2 - src/target/llvm/codegen_nvptx.cc | 8 +- src/target/source/codegen_cuda.cc | 5 +- src/target/spirv/codegen_spirv.cc | 10 +- src/te/operation/compute_op.cc | 4 +- src/te/operation/extern_op.cc | 4 +- src/te/operation/hybrid_op.cc | 4 +- src/te/operation/placeholder_op.cc | 2 +- src/te/operation/scan_op.cc | 4 +- src/te/schedule/schedule_ops.cc | 10 +- .../schedule/schedule_postproc_to_primfunc.cc | 19 ++-- src/tir/ir/buffer.cc | 12 ++- src/tir/ir/stmt.cc | 5 +- src/tir/transforms/storage_access.cc | 20 ++-- src/tir/transforms/storage_access.h | 4 +- src/tir/transforms/storage_flatten.cc | 12 +-- src/tir/transforms/storage_rewrite.cc | 37 +++---- src/tir/transforms/thread_storage_sync.cc | 28 ++--- tests/python/relay/test_ir_text_printer.py | 1 - 28 files changed, 148 insertions(+), 197 deletions(-) diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 27e48999a7d1..8b158cb1c881 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -128,8 +128,8 @@ class TVM_DLL OperationNode : public Object { * \return A realization statement that wraps body. */ virtual Stmt BuildRealize(const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const = 0; + const std::unordered_map& realize_map, const Stmt& body, + String storage_scope = "") const = 0; /*! * \brief Build the statement that provide the output tensors. * \param stage The schedule stage of the op. @@ -168,7 +168,7 @@ class PlaceholderOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; @@ -212,7 +212,7 @@ class TVM_DLL BaseComputeOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; virtual size_t num_schedulable_dims() const = 0; static constexpr const char* _type_key = "BaseComputeOp"; @@ -370,7 +370,7 @@ class ScanOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; @@ -433,7 +433,7 @@ class ExternOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; @@ -498,7 +498,7 @@ class HybridOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index a01d69b372d2..8f46e0a826c3 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -195,7 +195,9 @@ class Buffer : public ObjectRef { * \sa Buffer for complete constructor. */ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), - String name = "buffer", Span span = Span()); + String name = "buffer", String storage_scope = "", Span span = Span()); + +TVM_DLL String GetStorageScope(Var buffer_var); /*! * \brief Base node for data producers. diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index cc10c218c8ff..6b747ac7d1d0 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -465,17 +465,21 @@ class ProducerRealizeNode : public StmtNode { /*! \brief The body of realization. */ Stmt body; + String storage_scope; + void VisitAttrs(AttrVisitor* v) { v->Visit("producer", &producer); v->Visit("bounds", &bounds); v->Visit("condition", &condition); v->Visit("body", &body); + v->Visit("storage_scope", &storage_scope); v->Visit("span", &span); } bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const { return equal(producer, other->producer) && equal(bounds, other->bounds) && - equal(condition, other->condition) && equal(body, other->body); + equal(condition, other->condition) && equal(body, other->body) && + equal(storage_scope, other->storage_scope); } void SHashReduce(SHashReducer hash_reduce) const { @@ -483,6 +487,7 @@ class ProducerRealizeNode : public StmtNode { hash_reduce(bounds); hash_reduce(condition); hash_reduce(body); + hash_reduce(storage_scope); } static constexpr const char* _type_key = "tir.ProducerRealize"; @@ -496,7 +501,7 @@ class ProducerRealizeNode : public StmtNode { class ProducerRealize : public Stmt { public: TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body, - Span span = Span()); + String storage_scope = "", Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode); }; diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 2e13d1f042a2..74c4e2f1da49 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -577,3 +577,8 @@ class UniformAttrs(Attrs): @tvm._ffi.register_object("relay.attrs.NLLLossAttrs") class NLLLossAttrs(Attrs): """Attributes for nn.nll_loss""" + + +@tvm._ffi.register_object("relay.attrs.FixedPointMultiplyAttrs") +class FixedPointMultiplyAttrs(Attrs): + """Attributes used in fixed_point_multiply operators""" diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 4934bf04727f..5aae068f4d58 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -416,7 +416,7 @@ def allocate(self, dtype, shape, name="buf", scope=None): buffer : BufferVar The buffer var representing the buffer. """ - buffer_var = _expr.Var(name, PointerType(PrimType(dtype))) + buffer_var = _expr.Var(name, PointerType(PrimType(dtype), scope)) if not isinstance(shape, (list, tuple, _container.Array)): shape = [shape] if scope: diff --git a/src/auto_scheduler/search_policy/utils.cc b/src/auto_scheduler/search_policy/utils.cc index ce8dc39922e0..ac1cf2dd82c9 100644 --- a/src/auto_scheduler/search_policy/utils.cc +++ b/src/auto_scheduler/search_policy/utils.cc @@ -153,24 +153,21 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo if (spatial_split_step_ids == nullptr) { spatial_split_step_ids = &temp_split_step_ids; } + spatial_split_step_ids->clear(); + std::vector> space_levels; std::vector> reduce_levels; std::vector space_outer, space_inner, reduce_outer, reduce_inner; - Array split_res; - for (const auto c : format) { - if (tolower(c) == 's') { - space_levels.emplace_back(); - } else if (tolower(c) == 'r') { - reduce_levels.emplace_back(); - } else { - LOG(FATAL) << "Invalid multi-level tiling format: " << format; - } + size_t n_space = + std::count(format.begin(), format.end(), 's') + std::count(format.begin(), format.end(), 'S'); + size_t n_reduce = + std::count(format.begin(), format.end(), 'r') + std::count(format.begin(), format.end(), 'R'); + if (n_space + n_reduce != format.size()) { + LOG(FATAL) << "Invalid multi-level tiling format: " << format; } - size_t n_space = space_levels.size(); - size_t n_reduce = reduce_levels.size(); - - spatial_split_step_ids->clear(); + space_levels.resize(n_space); + reduce_levels.resize(n_reduce); State tmp_s = state; const Stage& stage = state->stages[stage_id]; @@ -179,31 +176,28 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner) : std::set(); + auto sr_levels = [&](int size, const Iterator& iter, std::vector>& levels) { + ICHECK_GE(size, 1); + if (size == 1) { + levels[0].push_back(iter); + } else { + Array split_res = + tmp_s.split(stage_id, iter, Array>(size - 1, NullOpt)); + for (int i = 0; i < size; i++) { + levels[i].push_back(split_res[i]); + } + if (iter->iter_kind == IteratorKind::kSpatial) { + spatial_split_step_ids->push_back(tmp_s->transform_steps.size() - 1); + } + } + }; + for (const auto& iter : state->stages[stage_id]->iters) { if (!no_split_at_inner_name_set.count(iter->name)) { if (iter->iter_kind == IteratorKind::kSpatial) { - ICHECK_GE(n_space, 1); - - if (n_space == 1) { - space_levels[0].push_back(iter); - } else { - split_res = tmp_s.split(stage_id, iter, Array>(n_space - 1, NullOpt)); - for (size_t i = 0; i < n_space; i++) { - space_levels[i].push_back(split_res[i]); - } - spatial_split_step_ids->push_back(tmp_s->transform_steps.size() - 1); - } + sr_levels(n_space, iter, space_levels); } else if (iter->iter_kind == IteratorKind::kReduction) { - ICHECK_GE(n_reduce, 1); - - if (n_reduce == 1) { - reduce_levels[0].push_back(iter); - } else { - split_res = tmp_s.split(stage_id, iter, Array>(n_reduce - 1, NullOpt)); - for (size_t i = 0; i < n_reduce; i++) { - reduce_levels[i].push_back(split_res[i]); - } - } + sr_levels(n_reduce, iter, reduce_levels); } else { LOG(FATAL) << "Invalid iter type: " << int(iter->iter_kind); } @@ -218,40 +212,29 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo } } - if (!space_outer.empty()) { - ICHECK(!space_levels.empty()); - space_levels.front().insert(space_levels.front().begin(), - std::make_move_iterator(space_outer.begin()), - std::make_move_iterator(space_outer.end())); - } - if (!space_inner.empty()) { - ICHECK(!space_levels.empty()); - space_levels.back().insert(space_levels.back().begin(), - std::make_move_iterator(space_inner.begin()), - std::make_move_iterator(space_inner.end())); - } - - if (!reduce_outer.empty()) { - ICHECK(!reduce_levels.empty()); - reduce_levels.front().insert(reduce_levels.front().begin(), - std::make_move_iterator(reduce_outer.begin()), - std::make_move_iterator(reduce_outer.end())); + auto fill_levels = [&](std::vector& levels_iter, std::vector& fill) { + if (!fill.empty()) { + levels_iter.insert(levels_iter.begin(), std::make_move_iterator(fill.begin()), + std::make_move_iterator(fill.end())); + } + }; + if (!space_levels.empty()) { + fill_levels(space_levels.front(), space_outer); + fill_levels(space_levels.back(), space_inner); } - if (!reduce_inner.empty()) { - ICHECK(!reduce_levels.empty()); - reduce_levels.back().insert(reduce_levels.back().begin(), - std::make_move_iterator(reduce_inner.begin()), - std::make_move_iterator(reduce_inner.end())); + if (!reduce_levels.empty()) { + fill_levels(reduce_levels.front(), reduce_outer); + fill_levels(reduce_levels.back(), reduce_inner); } Array order; int space_ct = 0, reduce_ct = 0; for (const auto c : format) { - if (tolower(c) == 's') { + if (c == 's' || c == 'S') { order.insert(order.end(), std::make_move_iterator(space_levels[space_ct].begin()), std::make_move_iterator(space_levels[space_ct].end())); space_ct++; - } else if (tolower(c) == 'r') { + } else if (c == 'r' || c == 'R') { order.insert(order.end(), std::make_move_iterator(reduce_levels[reduce_ct].begin()), std::make_move_iterator(reduce_levels[reduce_ct].end())); reduce_ct++; diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index b09f594d02eb..88e3496d6b9d 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -122,7 +122,9 @@ struct StorageScope { */ static StorageScope Create(const std::string& s) { StorageScope r; - if (s.compare(0, 6, "global") == 0) { + if (s == "") { + r.rank = StorageRank::kGlobal; + } else if (s.compare(0, 6, "global") == 0) { r.rank = StorageRank::kGlobal; r.tag = s.substr(6, std::string::npos); } else if (s.compare(0, 6, "shared") == 0) { diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index ed2f47d9b243..6753196480ed 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -74,7 +74,8 @@ class CodeGenAMDGPU : public CodeGenLLVM { llvm::Value* buf = nullptr; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; - if (info.scope.rank == runtime::StorageRank::kDynShared) { + auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var)); + if (storage_scope.rank == runtime::StorageRank::kDynShared) { buf = AllocateSharedMemory(op->dtype, 0, 3, std::min(info.alignment, 16), llvm::GlobalValue::ExternalLinkage); } else { @@ -88,7 +89,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { if (info.alignment > 16) { info.alignment = 16; } - if (info.scope.rank == runtime::StorageRank::kLocal) { + if (storage_scope.rank == runtime::StorageRank::kLocal) { // const int local_address_space = 5; // TODO(tqchen): for higher version of LLVM, local address space can be set. llvm::AllocaInst* alloca = WithFunctionEntry([&]() { @@ -103,7 +104,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { } buf = alloca; } else { - ICHECK(info.scope.rank == runtime::StorageRank::kShared) + ICHECK(storage_scope.rank == runtime::StorageRank::kShared) << "Can only allocate shared or local memory inside kernel"; // Shared memory: address space == 3 buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment, diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 9dcd6c19faed..a8d5b76c34a5 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -501,7 +501,8 @@ void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExp auto it = alloc_storage_info_.find(buf_var); if (it != alloc_storage_info_.end()) { const StorageInfo& info = it->second; - *p_native_bits = NativeVectorBits(info.scope); + *p_native_bits = + NativeVectorBits(runtime::StorageScope::Create(GetStorageScope(GetRef(buf_var)))); max_align_bits = info.alignment * 8; } else { *p_native_bits = native_vector_bits_; @@ -1407,11 +1408,6 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value)); } } - } else if (op->attr_key == tir::attr::storage_scope) { - const VarNode* v = op->node.as(); - ICHECK(v); - alloc_storage_info_[v].scope = - runtime::StorageScope::Create(op->value.as()->value); } else if (op->attr_key == tir::attr::storage_alignment) { const VarNode* v = op->node.as(); ICHECK(v); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index da0dda2fbf14..52c5b98a0025 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -163,8 +163,6 @@ class CodeGenLLVM : public ExprFunctor, protected: /*! \brief The storage information */ struct StorageInfo { - /*! \brief The storage scope */ - runtime::StorageScope scope; /*! \brief The alignment of allocation */ int alignment{0}; }; diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 683e19411142..4a5f0cd0ba02 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -53,8 +53,8 @@ class CodeGenNVPTX : public CodeGenLLVM { if (info.alignment > 16) { info.alignment = 16; } - - if (info.scope.rank == runtime::StorageRank::kDynShared) { + auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var)); + if (storage_scope.rank == runtime::StorageRank::kDynShared) { buf = AllocateSharedMemory(op->dtype, 0, 3, info.alignment, llvm::GlobalValue::ExternalLinkage); } else { @@ -64,7 +64,7 @@ class CodeGenNVPTX : public CodeGenLLVM { if (constant_size % 4 == 0 && info.alignment == 0) { info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); } - if (info.scope.rank == runtime::StorageRank::kLocal) { + if (storage_scope.rank == runtime::StorageRank::kLocal) { // const int local_address_space = 5; // TODO(tqchen): for higher version of LLVM, local address space can be set. llvm::AllocaInst* alloca = WithFunctionEntry([&]() { @@ -79,7 +79,7 @@ class CodeGenNVPTX : public CodeGenLLVM { } buf = alloca; } else { - ICHECK(info.scope.rank == runtime::StorageRank::kShared) + ICHECK(storage_scope.rank == runtime::StorageRank::kShared) << "Can only allocate shared or local memory inside kernel"; buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment, llvm::GlobalValue::PrivateLinkage); diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 7db9395ca532..ae3dc9b03792 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -705,12 +705,9 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { std::string vid = AllocVarID(op->buffer_var.get()); this->PrintIndent(); + std::string scope = GetStorageScope(op->buffer_var); const VarNode* buffer = op->buffer_var.as(); - auto it = alloc_storage_scope_.find(buffer); - ICHECK(it != alloc_storage_scope_.end()) - << "Buffer " << op->buffer_var << " is missing an AttrStmt with a \"storage_scope\" key"; - std::string scope = it->second; if (scope.find("wmma.") == 0) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 2628406f6f49..45323e8901f1 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -23,6 +23,7 @@ */ #include "codegen_spirv.h" +#include #include #include #include @@ -638,13 +639,14 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; spirv::Value buf; StorageInfo& info = storage_info_[op->buffer_var.get()]; + auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var)); spirv::SType etype = builder_->GetSType(op->dtype); - if (info.scope.rank == runtime::StorageRank::kLocal) { + if (storage_scope.rank == runtime::StorageRank::kLocal) { buf = builder_->Allocate(etype, static_cast(constant_size), spv::StorageClassFunction); } else { // shared memory - ICHECK(info.scope.rank == runtime::StorageRank::kShared) + ICHECK(storage_scope.rank == runtime::StorageRank::kShared) << "Can only allocate shared or local memory inside kernel"; // Shared memory buf = @@ -667,10 +669,6 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) { var_map_[iv->var.get()] = GetThreadIndex(iv, op->value); } } - } else if (op->attr_key == tir::attr::storage_scope) { - const VarNode* v = op->node.as(); - ICHECK(v); - storage_info_[v].scope = runtime::StorageScope::Create(op->value.as()->value); } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); ICHECK(v); diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 9a4eadb35619..26c08955f5ad 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -260,7 +260,7 @@ void BaseComputeOpNode::GatherBound(const Operation& self, Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { ICHECK_EQ(stage->op.get(), this); Region bounds; for (IterVar iv : this->axis) { @@ -269,7 +269,7 @@ Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, Stmt realize = body; for (int i = this->num_outputs(); i > 0; --i) { Tensor t = stage->op.output(i - 1); - realize = tir::ProducerRealize(t, bounds, const_true(), realize); + realize = tir::ProducerRealize(t, bounds, const_true(), realize, storage_scope); // alignment requirement, only useful for compute for (size_t i = 0; i < num_schedulable_dims(); ++i) { auto it = stage->iter_var_attrs.find(this->axis[i]); diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 1c9a3cb336ae..b602efcfc28b 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -124,7 +124,7 @@ void ExternOpNode::GatherBound(const Operation& self, Stmt ExternOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { ICHECK_EQ(stage->op.get(), this); Stmt realize_body = body; for (int k = 0; k < num_outputs(); ++k) { @@ -133,7 +133,7 @@ Stmt ExternOpNode::BuildRealize(const Stage& stage, for (size_t i = 0; i < t->shape.size(); ++i) { bounds.push_back(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body); + realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body, storage_scope); } return realize_body; } diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 65b8660ca1fb..5d2412abb3d2 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -144,7 +144,7 @@ void HybridOpNode::GatherBound(const Operation& self, Stmt HybridOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { // TODO(@were): Add attribute inject here and remove it from hybrid parser. ICHECK_EQ(stage->op.get(), this); Stmt realize_body = body; @@ -154,7 +154,7 @@ Stmt HybridOpNode::BuildRealize(const Stage& stage, for (size_t i = 0; i < t->shape.size(); ++i) { bounds.push_back(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body); + realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body, storage_scope); } return realize_body; } diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index c51e53e16cd1..4f5df7ad3024 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -85,7 +85,7 @@ void PlaceholderOpNode::GatherBound(const Operation& self, Stmt PlaceholderOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { return body; } diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index a555e86097b7..39689bd9654a 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -234,7 +234,7 @@ void ScanOpNode::GatherBound(const Operation& self, } Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_map& dom_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { arith::Analyzer analyzer; ICHECK_EQ(stage->op.get(), this); Range sdom = dom_map.at(this->scan_axis); @@ -250,7 +250,7 @@ Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_mapspatial_axis_[sp_idx]; bounds.push_back(dom_map.at(sp_ax)); } - ret = tir::ProducerRealize(t, bounds, const_true(), ret); + ret = tir::ProducerRealize(t, bounds, const_true(), ret, storage_scope); } return ret; } diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 355e3c39494b..6a783006b1c5 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -51,11 +51,8 @@ Stmt MakePipeline(const Stage& s, const std::unordered_map& dom_ if (consumer.defined() && !is_no_op(consumer)) { pipeline = SeqStmt({producer, consumer}); } - pipeline = s->op->BuildRealize(s, dom_map, pipeline); - // use attribute to mark scope of the operation. - pipeline = AttrStmt(s->op, tir::attr::realize_scope, StringImm(s->scope), pipeline); - return pipeline; + return s->op->BuildRealize(s, dom_map, pipeline, s->scope); } // inject the operator's realization on the stmt. @@ -175,8 +172,7 @@ class SchedulePostProc : public StmtExprMutator { thread_extent_scope_.erase(op->node.get()); return ret; } - } else if (op->attr_key == tir::attr::realize_scope || - op->attr_key == tir::attr::double_buffer_scope) { + } else if (op->attr_key == tir::attr::double_buffer_scope) { auto it = replace_op_.find(op->node.get()); if (it != replace_op_.end()) { if (it->second.defined()) { @@ -218,7 +214,7 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_realize_.find(key); if (it != replace_realize_.end()) { if (it->second.defined()) { - Stmt ret = ProducerRealize(it->second, op->bounds, op->condition, op->body); + Stmt ret = ProducerRealize(it->second, op->bounds, op->condition, op->body, op->storage_scope); return this->VisitStmt(ret); } else { return this->VisitStmt(op->body); diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index 5c59961fe011..2063fc7cad6a 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -49,12 +49,12 @@ namespace tvm { namespace te { // create a buffer for tensor. -Buffer CreateBufferFor(const Tensor& tensor) { +Buffer CreateBufferFor(const Tensor& tensor, String storage_scope = "") { std::string name = tensor->op->name; if (tensor->op->num_outputs() != 1) { name += ".v" + std::to_string(tensor->value_index); } - Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name); + Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name, storage_scope); return buffer; } @@ -67,10 +67,7 @@ class TensorToBufferMapper : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { auto ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); - // TODO(tvm-team): remove realize_scope, turn the info into - // Buffer's scope field in this pass. - if (op->attr_key == tir::attr::realize_scope || - op->attr_key == tir::attr::double_buffer_scope) { + if (op->attr_key == tir::attr::double_buffer_scope) { Stmt body = op->body; Operation operation = Downcast(op->node); for (int i = operation->num_outputs(); i != 0; --i) { @@ -95,7 +92,7 @@ class TensorToBufferMapper : public StmtExprMutator { Stmt VisitStmt_(const ProducerRealizeNode* op) final { Tensor tensor = Downcast(op->producer); - Buffer buffer = GetOrAllocBuffer(tensor); + Buffer buffer = GetOrAllocBuffer(tensor, op->storage_scope); auto ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); @@ -122,14 +119,16 @@ class TensorToBufferMapper : public StmtExprMutator { } private: - Buffer GetOrAllocBuffer(const Tensor& tensor) { return GetBuffer(tensor, true); } + Buffer GetOrAllocBuffer(const Tensor& tensor, String storage_scope = "") { + return GetBuffer(tensor, storage_scope, true); + } - Buffer GetBuffer(const Tensor& tensor, bool allow_alloc = false) { + Buffer GetBuffer(const Tensor& tensor, String storage_scope = "", bool allow_alloc = false) { auto it = buffer_map_.find(tensor); if (it != buffer_map_.end()) return it->second; ICHECK(allow_alloc) << "Cannot find the Realization point of tensor " << tensor; - auto buffer = CreateBufferFor(tensor); + auto buffer = CreateBufferFor(tensor, storage_scope); buffer_map_[tensor] = buffer; return buffer; } diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 1667eb7d1fbd..851d440a6378 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -45,12 +45,20 @@ Array SimplifyArray(arith::Analyzer* ana, Array array) { return array; } -Buffer decl_buffer(Array shape, DataType dtype, String name, Span span) { +Buffer decl_buffer(Array shape, DataType dtype, String name, String storage_scope, + Span span) { DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); - return Buffer(Var(name, PointerType(PrimType(storage_dtype)), span), dtype, shape, + return Buffer(Var(name, PointerType(PrimType(storage_dtype), storage_scope), span), dtype, shape, Array(), PrimExpr(), name, "", 0, 0, kDefault, span); } +String GetStorageScope(Var buffer_var) { + auto type = buffer_var->type_annotation; + const auto* ptr_type = type.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; + return ptr_type->storage_scope; +} + // Split the given expression w.r.t the add operator inline std::vector ExprSplitAddition(const PrimExpr& expr) { using namespace tir; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index b2016eb74c91..a49787d22d05 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -377,7 +377,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // ProducerRealize ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, - Stmt body, Span span) { + Stmt body, String storage_scope, Span span) { for (size_t i = 0; i < bounds.size(); ++i) { ICHECK(bounds[i]->min.defined()); ICHECK(bounds[i]->extent.defined()); @@ -394,13 +394,14 @@ ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr node->condition = std::move(condition); node->body = std::move(body); node->span = std::move(span); + node->storage_scope = std::move(storage_scope); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.ProducerRealize") .set_body_typed([](DataProducer producer, Region bounds, PrimExpr condition, Stmt body, Span span) { - return ProducerRealize(producer, bounds, condition, body, span); + return ProducerRealize(producer, bounds, condition, body, "", span); }); TVM_REGISTER_NODE_TYPE(ProducerRealizeNode); diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 00002d3587db..8f5b8d75c1d4 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -35,7 +35,7 @@ namespace tir { void StorageAccessVisitor::VisitExpr_(const LoadNode* op) { const VarNode* buf = op->buffer_var.as(); - StorageScope scope = GetScope(buf); + StorageScope scope = GetScope(op->buffer_var); if (Enabled(buf, scope)) { ICHECK(allow_append_) << op << " " << scope.to_string(); AccessEntry e; @@ -56,7 +56,7 @@ void StorageAccessVisitor::VisitStmt_(const StoreNode* op) { ICHECK_EQ(curr_stmt_.access.size(), 0U); curr_stmt_.stmt = op; const VarNode* buf = op->buffer_var.as(); - StorageScope scope = GetScope(buf); + StorageScope scope = GetScope(op->buffer_var); if (Enabled(buf, scope)) { AccessEntry e; e.threads = env_threads(); @@ -90,11 +90,7 @@ void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) { } void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - storage_scope_[buf] = StorageScope::Create(op->value.as()->value); - StmtExprVisitor::VisitStmt_(op); - } else if (op->attr_key == attr::double_buffer_write) { + if (op->attr_key == attr::double_buffer_write) { ICHECK(double_buffer_write_ == nullptr); double_buffer_write_ = op->node.as(); scope_.push_back(std::vector()); @@ -208,7 +204,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { PrimExpr offset = op->args[2]; PrimExpr extent = op->args[3]; const IntImmNode* flag = op->args[4].as(); - StorageScope scope = GetScope(buffer); + StorageScope scope = GetScope(GetRef(buffer)); // The buffer scope. if (Enabled(buffer, scope)) { ICHECK(allow_append_); @@ -244,12 +240,8 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { } } -StorageScope StorageAccessVisitor::GetScope(const VarNode* buf) const { - auto it = storage_scope_.find(buf); - StorageScope s; - s.rank = StorageRank::kGlobal; - if (it == storage_scope_.end()) return s; - return it->second; +StorageScope StorageAccessVisitor::GetScope(Var buffer_var) const { + return StorageScope::Create(GetStorageScope(buffer_var)); } } // namespace tir diff --git a/src/tir/transforms/storage_access.h b/src/tir/transforms/storage_access.h index 663c570fd15c..9dc4c923b054 100644 --- a/src/tir/transforms/storage_access.h +++ b/src/tir/transforms/storage_access.h @@ -118,7 +118,7 @@ class StorageAccessVisitor : public StmtExprVisitor { * \brief Get the scope of the buffer array. * \return The scope of the final buffer array. */ - StorageScope GetScope(const VarNode* buf) const; + StorageScope GetScope(Var buffer_var) const; // access scope std::vector > scope_; @@ -135,8 +135,6 @@ class StorageAccessVisitor : public StmtExprVisitor { StmtEntry curr_stmt_; // The involving threads Array env_threads_; - // The storage scope of each buffer - std::unordered_map storage_scope_; }; } // namespace tir diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 43fc1f1ec53f..075e0fc49a34 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -78,11 +78,7 @@ class StorageFlattener : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::realize_scope) { - storage_scope_[op->node.get()] = op->value.as()->value; - return this->VisitStmt(op->body); - } else if (op->attr_key == attr::double_buffer_scope && - op->node->IsInstance()) { + if (op->attr_key == attr::double_buffer_scope && op->node->IsInstance()) { auto buffer = Downcast(op->node); Stmt body = this->VisitStmt(op->body); auto it = buf_map_.find(buffer); @@ -156,10 +152,8 @@ class StorageFlattener : public StmtExprMutator { shape.push_back(r->extent); } // deduce current storage scope. - auto it = storage_scope_.find(op->buffer.get()); - ICHECK(it != storage_scope_.end()) << "Cannot find storage scope of " << op->buffer; StorageScope skey; - const std::string& strkey = it->second; + std::string strkey = GetStorageScope(op->buffer->data); if (strkey.length() == 0) { if (curr_thread_scope_.size() != 0) { skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); @@ -491,8 +485,6 @@ class StorageFlattener : public StmtExprMutator { std::unordered_map buf_map_; // Dimension alignment std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dim_align_; - // Storage scope - std::unordered_map storage_scope_; // The current thread scope. std::vector curr_thread_scope_; // Collects shapes. diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 36eeddb17d89..a3d2c1a31cf5 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -75,8 +75,6 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { }; // The scope of each allocation struct AllocEntry { - // Scope used for allocation. - StorageScope storage_scope; // scope level size_t level{0}; // allocation stmt @@ -86,13 +84,8 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { void VisitStmt_(const AllocateNode* op) final { size_t level = scope_.size(); const VarNode* buf = op->buffer_var.get(); - auto it = alloc_info_.find(buf); - ICHECK(it != alloc_info_.end()) << "Could not find buffer `" << buf->name_hint - << "` in the list of allocated buffers. Perhaps you are " - "missing a storage_scope attr for this buffer."; - ICHECK(it->second.alloc == nullptr); - it->second.alloc = op; - it->second.level = level; + alloc_info_[buf].alloc = op; + alloc_info_[buf].level = level; StmtExprVisitor::VisitStmt_(op); } void VisitStmt_(const StoreNode* op) final { @@ -180,10 +173,6 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { VisitNewScope(op); } else if (op->attr_key == attr::virtual_thread) { VisitNewScope(op); - } else if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - alloc_info_[buf].storage_scope = StorageScope::Create(op->value.as()->value); - StmtExprVisitor::VisitStmt_(op); } else { StmtExprVisitor::VisitStmt_(op); } @@ -409,10 +398,8 @@ class StoragePlanRewriter : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - return this->VisitStmt(op->body); - } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || - attr::IsPragmaKey(op->attr_key)) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || + attr::IsPragmaKey(op->attr_key)) { // remake all the allocation at the attach scope. if (attach_map_.count(op)) { auto& svec = attach_map_[op]; @@ -716,7 +703,8 @@ class StoragePlanRewriter : public StmtExprMutator { for (const VarNode* var : it->second.gen) { ICHECK(alloc_info.count(var)); - const AllocEntry& ae = alloc_info.at(var); + const AllocateNode* alloc = alloc_info.at(var).alloc; + auto storage_scope = StorageScope::Create(GetStorageScope(GetRef(var))); StorageEntry* dst_entry = nullptr; // inplace detection if (detect_inplace) { @@ -726,13 +714,12 @@ class StoragePlanRewriter : public StmtExprMutator { if (!inplace_flag.count(src) && alloc_map_.count(src)) { InplaceOpVerifier visitor; StorageEntry* src_entry = alloc_map_.at(src); - if (src_entry->scope == ae.storage_scope && + if (src_entry->scope == storage_scope && src_entry->attach_scope_ == thread_scope_ && - src_entry->elem_type == ae.alloc->dtype.element_of() && + src_entry->elem_type == alloc->dtype.element_of() && visitor.Check(s.stmt, var, src)) { - uint64_t const_nbits = - static_cast(ae.alloc->constant_allocation_size()) * - ae.alloc->dtype.bits() * ae.alloc->dtype.lanes(); + uint64_t const_nbits = static_cast(alloc->constant_allocation_size()) * + alloc->dtype.bits() * alloc->dtype.lanes(); if (src_entry->const_nbits == const_nbits && !inplace_found) { // successfully inplace dst_entry = src_entry; @@ -744,9 +731,9 @@ class StoragePlanRewriter : public StmtExprMutator { } } if (dst_entry == nullptr) { - dst_entry = FindAlloc(ae.alloc, thread_scope_, ae.storage_scope); + dst_entry = FindAlloc(alloc, thread_scope_, storage_scope); } - dst_entry->allocs.emplace_back(ae.alloc); + dst_entry->allocs.emplace_back(alloc); alloc_map_[var] = dst_entry; } } diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index 8f757171afbd..896224c0e956 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include #include @@ -223,14 +224,14 @@ class ThreadSyncInserter : public StmtExprMutator { } PrimExpr VisitExpr_(const LoadNode* op) final { if (sync_scope_.rank == StorageRank::kGlobal && - GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) { + GetScope(op->buffer_var).rank == StorageRank::kGlobal) { ++rw_stats_[op->buffer_var].read_count; } return StmtExprMutator::VisitExpr_(op); } Stmt VisitStmt_(const StoreNode* op) final { if (sync_scope_.rank == StorageRank::kGlobal && - GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) { + GetScope(op->buffer_var).rank == StorageRank::kGlobal) { ++rw_stats_[op->buffer_var].write_count; } return StmtExprMutator::VisitStmt_(op); @@ -250,10 +251,6 @@ class ThreadSyncInserter : public StmtExprMutator { is_lead_ = PrimExpr(); } return ret; - } else if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - storage_scope_[buf] = StorageScope::Create(op->value.as()->value); - return StmtExprMutator::VisitStmt_(op); } else { return StmtExprMutator::VisitStmt_(op); } @@ -264,16 +261,15 @@ class ThreadSyncInserter : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); ICHECK_EQ(op->args.size(), 5U); - const VarNode* buffer_var = op->args[1].as(); - Var var(GetRef(buffer_var)); + Var buffer_var(GetRef(op->args[1].as())); const IntImmNode* flag = op->args[4].as(); if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal && GetScope(buffer_var).rank == StorageRank::kGlobal) { - ++rw_stats_[var].read_count; + ++rw_stats_[buffer_var].read_count; } if (flag->value & 2 && sync_scope_.rank == StorageRank::kGlobal && GetScope(buffer_var).rank == StorageRank::kGlobal) { - ++rw_stats_[var].write_count; + ++rw_stats_[buffer_var].write_count; } return expr; } else { @@ -287,14 +283,12 @@ class ThreadSyncInserter : public StmtExprMutator { int read_count{0}; int write_count{0}; }; + // Get current storage scope. - StorageScope GetScope(const VarNode* buf) const { - auto it = storage_scope_.find(buf); - StorageScope s; - s.rank = StorageRank::kGlobal; - if (it == storage_scope_.end()) return s; - return it->second; + StorageScope GetScope(Var buffer_var) const { + return StorageScope::Create(GetStorageScope(buffer_var)); } + // private functions. Stmt InitGlobalBarrier(const AttrStmtNode* op) { ICHECK(op != nullptr); @@ -337,8 +331,6 @@ class ThreadSyncInserter : public StmtExprMutator { // data structure. StorageScope sync_scope_; const std::unordered_set& syncs_; - // The storage scope of each buffer - std::unordered_map storage_scope_; // The read write statistics of storage std::unordered_map rw_stats_; // The statistics for global barrier diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index b4d02e4815fb..af3c737fccd6 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -31,7 +31,6 @@ def astext(program, unify_free_vars=False): text = program.astext() - print(text) if isinstance(program, Expr): roundtrip_program = tvm.parser.parse_expr(text) else: