From 8edd4f13dd4cbb735d48aad43797a7d4e2a0937c Mon Sep 17 00:00:00 2001 From: masa Date: Thu, 8 Jul 2021 07:43:57 +0900 Subject: [PATCH 1/7] Remove all attr::storage_scope usage --- include/tvm/tir/stmt.h | 2 - python/tvm/contrib/hexagon.py | 16 ++++--- python/tvm/script/scope_handler.py | 2 +- python/tvm/tir/ir_builder.py | 2 - src/printer/tvmscript_printer.cc | 48 ++++++++----------- src/relay/backend/aot_executor_codegen.cc | 2 - src/target/source/codegen_c.cc | 15 ++---- src/te/operation/cross_thread_reduction.cc | 3 -- src/tir/analysis/verify_gpu_code.cc | 21 +++----- src/tir/ir/stmt.cc | 10 ---- src/tir/transforms/flatten_buffer.cc | 1 - src/tir/transforms/inject_copy_intrin.cc | 15 ++---- src/tir/transforms/inject_double_buffer.cc | 17 ++----- src/tir/transforms/ir_utils.cc | 10 ---- .../lower_device_storage_access_info.cc | 46 ++++-------------- src/tir/transforms/lower_thread_allreduce.cc | 12 +---- src/tir/transforms/lower_warp_memory.cc | 17 +------ src/tir/transforms/storage_flatten.cc | 1 - src/tir/transforms/storage_rewrite.cc | 8 +--- .../transforms/tensorcore_infer_fragment.cc | 11 +---- .../update_pointer_storage_scope.cc | 11 ----- .../transforms/update_pointer_storage_scope.h | 1 - tests/python/unittest/test_tir_ir_builder.py | 2 - .../test_tir_transform_coproc_sync.py | 4 +- ...test_tir_transform_inject_double_buffer.py | 4 +- ...est_tir_transform_inject_virtual_thread.py | 12 ++--- .../test_tir_transform_lift_attr_scope.py | 4 +- .../test_tir_transform_loop_partition.py | 4 +- .../test_tir_transform_lower_warp_memory.py | 5 +- .../test_tir_transform_storage_flatten.py | 6 +-- .../test_tir_transform_storage_rewrite.py | 24 ++++------ .../test_tir_transform_thread_sync.py | 2 +- 32 files changed, 97 insertions(+), 241 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 9997a4d95694..c41cac2a3a25 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1240,8 +1240,6 @@ constexpr const char* extern_scope = "extern_scope"; * This can hint some code generator to create a new function for compute. */ constexpr const char* compute_scope = "compute_scope"; -/*! \brief Mark storage scope of buffers */ -constexpr const char* storage_scope = "storage_scope"; /*! \brief Mark storage alignement requirement of buffers */ constexpr const char* storage_alignment = "storage_alignment"; /*! \brief Mark storage scope of realization */ diff --git a/python/tvm/contrib/hexagon.py b/python/tvm/contrib/hexagon.py index 34b37537776f..c2197af22d2a 100644 --- a/python/tvm/contrib/hexagon.py +++ b/python/tvm/contrib/hexagon.py @@ -176,23 +176,27 @@ def buf_align(var): def visit(stmt): """Collect information about VTCM buffers and their alignments.""" if isinstance(stmt, tvm.tir.AttrStmt): - if stmt.attr_key == "storage_scope" and stmt.value == "local.vtcm": - vtcm_buffers.append(stmt.node) - elif stmt.attr_key == "storage_alignment": + if stmt.attr_key == "storage_alignment": if not stmt.node in alignments: alignments[stmt.node] = [] alignments[stmt.node].append(stmt.value) + elif isinstance(stmt, tvm.tir.Allocate): + scope = stmt.buffer_var.type_annotation.storage_scope + if scope == "local.vtcm": + vtcm_buffers.append(stmt.node) + def mutate(stmt): """Insert calls to VTCM allocation and deallocation routines.""" if isinstance(stmt, tvm.tir.AttrStmt): - if stmt.attr_key == "storage_scope" and stmt.value == "local.vtcm": - vtcm_buffers.pop() - elif stmt.attr_key == "storage_alignment": + if stmt.attr_key == "storage_alignment": alignments[stmt.node].pop() return stmt if isinstance(stmt, tvm.tir.Allocate): var = stmt.buffer_var + scope = var.type_annotation.storage_scope + if scope == "local.vtcm": + vtcm_buffers.pop() if var in vtcm_buffers: is_null = tvm.tir.call_intrin("bool", tvm.ir.Op.get("tir.isnullptr"), var) throw_error = tvm.tir.call_intrin( diff --git a/python/tvm/script/scope_handler.py b/python/tvm/script/scope_handler.py index d07209485bd4..971580343763 100644 --- a/python/tvm/script/scope_handler.py +++ b/python/tvm/script/scope_handler.py @@ -113,7 +113,7 @@ def allocate(extents, dtype, scope, condition=True, span=None): body = tvm.tir.Allocate( self.buffer_var, dtype, extents, condition, self.body, span=span ) - return tvm.tir.AttrStmt(self.buffer_var, "storage_scope", scope, body, span=span) + return body super().__init__(allocate, concise_scope=True, def_symbol=True) self.buffer_var = None diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 35932540fe68..978c630b17ad 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -419,8 +419,6 @@ def allocate(self, dtype, shape, name="buf", scope=""): buffer_var = _expr.Var(name, PointerType(PrimType(dtype), scope)) if not isinstance(shape, (list, tuple, _container.Array)): shape = [shape] - if scope: - self.scope_attr(buffer_var, "storage_scope", scope) self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x)) return BufferVar(self, buffer_var, shape, dtype) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 01f79bd0c750..cc8aa48f3cd1 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -37,6 +37,7 @@ #include #include +#include "../tir/transforms/ir_utils.h" #include "doc.h" #include "meta_data.h" #include "text_printer.h" @@ -579,31 +580,6 @@ Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) { Doc TVMScriptPrinter::VisitStmt_(const AttrStmtNode* op) { Doc doc; - // merge attr with allocate when possible - if (op->node->IsInstance() && op->attr_key == "storage_scope" && - op->body->IsInstance()) { - const auto* alloc = Downcast(op->body).get(); - if (alloc->buffer_var.same_as(op->node)) { - var_not_in_headers.insert(alloc->buffer_var.get()); - if (current_num_ != num_child_ - 1) { - doc << "with tir.allocate(" << Print(alloc->extents) << ", " << PrintDType(alloc->dtype) - << ", " << Print(op->value); - if (!is_one(alloc->condition)) { - doc << ", " << Print(alloc->condition); - } - doc << ") as " << Print(op->node) << ":"; - doc << Doc::Indent(4, Doc::NewLine() << PrintBody(alloc->body)); - } else { - doc << Print(op->node) << " = tir.allocate(" << Print(alloc->extents) << ", " - << PrintDType(alloc->dtype) << ", " << Print(op->value); - if (!is_one(alloc->condition)) { - doc << ", " << Print(alloc->condition); - } - doc << ")" << Doc::NewLine() << PrintBody(alloc->body); - } - return doc; - } - } // merge attr with realize when possible if (op->node->IsInstance() && op->attr_key == "realize_scope" && op->body->IsInstance()) { @@ -681,8 +657,26 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) { } Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { - LOG(FATAL) << "TVM Script Printer Internal Error: All the Allocate should be folded with Attr"; - return Doc(); + var_not_in_headers.insert(op->buffer_var.get()); + Doc doc; + auto storage_scope = GetPtrStorageScope(op->buffer_var); + if (current_num_ != num_child_ - 1) { + doc << "with tir.allocate(" << Print(op->extents) << ", " << PrintDType(op->dtype) << ", " + << Print(storage_scope); + if (!is_one(op->condition)) { + doc << ", " << Print(op->condition); + } + doc << ") as " << Print(op->buffer_var) << ":"; + doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + } else { + doc << Print(op->buffer_var) << " = tir.allocate(" << Print(op->extents) << ", " + << PrintDType(op->dtype) << ", " << Print(storage_scope); + if (!is_one(op->condition)) { + doc << ", " << Print(op->condition); + } + doc << ")" << Doc::NewLine() << PrintBody(op->body); + } + return doc; } Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) { diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 4df38b9449ae..fd6ee27eb6be 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -625,8 +625,6 @@ class AOTExecutorCodegen : public ExprVisitor { // so we don't pay the price of allocation for every inference if (!allocated[sid]) { body = tir::Allocate(sids_table_[sid], DataType::Int(8), {size}, tir::const_true(), body); - body = tir::AttrStmt(sids_table_[sid], tir::attr::storage_scope, tir::StringImm("global"), - body); } allocated[sid] = true; } diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 99c9452975d4..8397044e8b93 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -861,12 +861,11 @@ void CodeGenC::VisitStmt_(const AllocateNode* op) { this->PrintIndent(); int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; - const VarNode* buffer = op->buffer_var.as(); - auto it = alloc_storage_scope_.find(buffer); - if (it != alloc_storage_scope_.end()) { - std::string scope = alloc_storage_scope_.at(buffer); - PrintStorageScope(scope, stream); - } + + auto scope = GetPtrStorageScope(op->buffer_var); + alloc_storage_scope_[op->buffer_var.get()] = scope; + PrintStorageScope(scope, stream); + PrintType(op->dtype, stream); stream << ' ' << vid << '[' << constant_size << "];\n"; @@ -882,10 +881,6 @@ void CodeGenC::VisitStmt_(const AttrStmtNode* op) { BindThreadIndex(iv); } } - } else if (op->attr_key == tir::attr::storage_scope) { - const VarNode* v = op->node.as(); - ICHECK(v); - alloc_storage_scope_[v] = op->value.as()->value; } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); ICHECK(v); diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index f844090ca6f5..2ed5fd4029a2 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -225,12 +225,9 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, Stmt body = SeqStmt::Flatten(reduce_body, assign_body); for (size_t idx = size; idx != 0; --idx) { body = Allocate(res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); - body = AttrStmt(res_handles[idx - 1], tir::attr::storage_scope, StringImm("local"), body); if (!normal_red.empty()) { body = Allocate(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); - body = - AttrStmt(normal_res_handles[idx - 1], tir::attr::storage_scope, StringImm("local"), body); } } body = Substitute(body, value_map); diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index afd3c7add605..10d857bdc953 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -30,6 +30,8 @@ #include #include +#include "../transforms/ir_utils.h" + namespace tvm { namespace tir { @@ -58,11 +60,12 @@ class GPUCodeVerifier : public StmtExprVisitor { void VisitStmt_(const AllocateNode* op) final { StmtVisitor::VisitStmt_(op); + auto scope = GetPtrStorageScope(op->buffer_var); // visit an allocation of a buffer in shared memory, record its size - if (visited_local_buffers_.count(op->buffer_var.get()) != 0) { + if (scope == "local") { size_t size = static_cast(op->constant_allocation_size()); local_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); - } else if (visited_shared_buffers_.count(op->buffer_var.get()) != 0) { + } else if (scope == "shared") { size_t size = static_cast(op->constant_allocation_size()); shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); } @@ -78,15 +81,7 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - std::string op_value = op->value.as()->value; - if (op_value == "local") { - visited_local_buffers_.insert(op->node.as()); - } else if (op_value == "shared") { - visited_shared_buffers_.insert(op->node.as()); - } - StmtVisitor::VisitStmt_(op); - } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { if (nest_level_ == 0) { // enter a new kernel, reset statistics Reset_(); @@ -211,8 +206,6 @@ class GPUCodeVerifier : public StmtExprVisitor { private: int nest_level_{0}; - std::unordered_set visited_local_buffers_; - std::unordered_set visited_shared_buffers_; std::unordered_set visited_threads_; size_t thread_x_extent_, thread_y_extent_, thread_z_extent_; @@ -230,8 +223,6 @@ class GPUCodeVerifier : public StmtExprVisitor { std::vector errors_; void Reset_() { - visited_local_buffers_.clear(); - visited_shared_buffers_.clear(); local_memory_per_block_ = 0; shared_memory_per_block_ = 0; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 42ef60bb86d7..6fdeb30ec100 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -61,16 +61,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // AttrStmt AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) { - if (attr_key == attr::storage_scope) { - const VarNode* buf = node.as(); - ICHECK(buf); - const auto* ptr_type = buf->type_annotation.as(); - ICHECK(ptr_type) << "The provided variable is not of pointer type"; - auto attr_scope = value.as()->value; - ICHECK_EQ(attr_scope, ptr_type->storage_scope) - << "Storage scopes attached to AttrStmt and buffer var are different. " << attr_scope - << ", " << ptr_type->storage_scope; - } auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 88c254a8cb5e..f1f914fa2f5c 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -130,7 +130,6 @@ class BufferFlattener : public StmtExprMutator { String storage_scope = buffer.scope(); PrimExpr area = BufferArea(buffer); body = Allocate(buffer->data, buffer->dtype, {area}, const_true(), std::move(body)); - body = AttrStmt(buffer->data, attr::storage_scope, StringImm(storage_scope), std::move(body)); return body; } diff --git a/src/tir/transforms/inject_copy_intrin.cc b/src/tir/transforms/inject_copy_intrin.cc index 40f0e368d93d..74538bcb6806 100644 --- a/src/tir/transforms/inject_copy_intrin.cc +++ b/src/tir/transforms/inject_copy_intrin.cc @@ -29,6 +29,7 @@ #include #include "../../arith/pattern_match.h" +#include "ir_utils.h" namespace tvm { namespace tir { @@ -42,10 +43,7 @@ class CopyIntrinInjector : public StmtMutator { flower_copy_fromto_(flower_copy_fromto) {} Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - storage_scope_[buf] = op->value.as()->value; - } else if (op->attr_key == pragma_key_) { + if (op->attr_key == pragma_key_) { Stmt ret; ICHECK(MatchCopyPattern(op->body, &ret)) << "Cannot match copy pattern of " << op->body; return ret; @@ -157,19 +155,12 @@ class CopyIntrinInjector : public StmtMutator { } // Get storage scope std::string GetStorageScope(const VarNode* var) const { - auto it = storage_scope_.find(var); - if (it != storage_scope_.end()) { - return it->second; - } else { - return ""; - } + return GetPtrStorageScope(GetRef(var)); } // pragma key std::string pragma_key_; // function to lower copy intrinsics. const PackedFunc& flower_copy_fromto_; - // Storage scope - std::unordered_map storage_scope_; // arith analyzer arith::Analyzer analyzer_; }; diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index 7a16c06d8058..0b45bde28dfe 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -95,16 +95,7 @@ class DoubleBufferInjector : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - auto it = dbuffer_info_.find(buf); - if (it != dbuffer_info_.end()) { - it->second.scope = op->value.as()->value; - return this->VisitStmt(op->body); - } else { - return StmtExprMutator::VisitStmt_(op); - } - } else if (op->attr_key == attr::double_buffer_scope) { + if (op->attr_key == attr::double_buffer_scope) { return MakeProducer(op); } else { return StmtExprMutator::VisitStmt_(op); @@ -112,8 +103,10 @@ class DoubleBufferInjector : public StmtExprMutator { } Stmt VisitStmt_(const AllocateNode* op) final { - auto it = dbuffer_info_.find(op->buffer_var.get()); + const VarNode* buf = op->buffer_var.as(); + auto it = dbuffer_info_.find(buf); if (it != dbuffer_info_.end()) { + it->second.scope = GetPtrStorageScope(op->buffer_var); it->second.stride = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, make_const(DataType::Int(32), 1), op->extents) * op->dtype.lanes(); @@ -125,8 +118,6 @@ class DoubleBufferInjector : public StmtExprMutator { } ICHECK(it->second.loop != nullptr); auto& alloc_nest = loop_allocs_[it->second.loop]; - alloc_nest.emplace_back( - AttrStmt(op->buffer_var, attr::storage_scope, StringImm(it->second.scope), Evaluate(0))); alloc_nest.emplace_back( Allocate(op->buffer_var, op->dtype, new_extents, op->condition, Evaluate(0))); return op->body; diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index f7ece25d3fcd..b7348fe09fe2 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -172,16 +172,6 @@ class IRConvertSSA final : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { if (const VarNode* v = op->node.as()) { - if (op->attr_key == attr::storage_scope) { - const AllocateNode* alloc = op->body.as(); - if (alloc && op->node.same_as(alloc->buffer_var)) { - Stmt new_alloc = this->VisitStmt(op->body); - if (new_alloc.same_as(op->body)) return GetRef(op); - alloc = new_alloc.as(); - ICHECK(alloc); - return AttrStmt(alloc->buffer_var, op->attr_key, op->value, new_alloc); - } - } Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (scope_.count(v) && scope_[v].size() != 0) { diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index eafed837cee3..0893f02d7443 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -44,13 +44,13 @@ class StorageAccessInfoLower : public StmtExprMutator { // Lower allocate to device allocate when needed. Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - // For special memory, remove allocate, or use head expr - auto it = storage_info_.find(op->buffer_var.get()); - if (it != storage_info_.end() && it->second.info.defined()) { - const MemoryInfo& info = it->second.info; - ++it->second.alloc_count; - ICHECK_LE(it->second.alloc_count, 1) - << "Double allocation of " << it->second.scope.to_string(); + auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var)); + ICHECK(storage_info_.find(op->buffer_var.get()) == storage_info_.end()) + << "Double allocation of " << scope.to_string(); + if (scope.tag.length() != 0) { + auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var)); + ICHECK(info.defined()) << "Cannot find memory info of " << scope.to_string(); + storage_info_[op->buffer_var.get()] = info; if (info->head_address.defined()) { return LetStmt(op->buffer_var, info->head_address, op->body); @@ -61,23 +61,6 @@ class StorageAccessInfoLower : public StmtExprMutator { return stmt; } } - Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - StorageScope scope = StorageScope::Create(op->value.as()->value); - StorageEntry e; - e.scope = scope; - if (scope.tag.length() != 0 && scope.tag != ".dyn") { - e.info = GetMemoryInfo(op->value.as()->value); - ICHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string(); - } - storage_info_[buf] = e; - return StmtExprMutator::VisitStmt_(op); - - } else { - return StmtExprMutator::VisitStmt_(op); - } - } PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::tvm_access_ptr())) { @@ -99,8 +82,8 @@ class StorageAccessInfoLower : public StmtExprMutator { Var buffer_var = Downcast(op->args[1]); PrimExpr offset = op->args[2]; auto it = storage_info_.find(buffer); - if (it != storage_info_.end() && it->second.info.defined()) { - return MakeTaggedAccessPtr(op->dtype, buffer_var, dtype, offset, it->second.info); + if (it != storage_info_.end() && it->second.defined()) { + return MakeTaggedAccessPtr(op->dtype, buffer_var, dtype, offset, it->second); } ICHECK(op->dtype.is_handle()); // Change to address_of @@ -118,17 +101,8 @@ class StorageAccessInfoLower : public StmtExprMutator { return cast(ptr_type, analyzer_.Simplify( offset / make_const(offset.dtype(), info->unit_bits / dtype_bits))); } - // The storage entry. - struct StorageEntry { - // Whether it is tagged memory. - StorageScope scope; - // The memory info if any. - MemoryInfo info; - // Allocation counter - int alloc_count{0}; - }; // The storage scope of each buffer - std::unordered_map storage_info_; + std::unordered_map storage_info_; // analyzer arith::Analyzer analyzer_; }; diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 25a2f4e060dd..481b1bfd4b19 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -53,8 +53,7 @@ class UpdatePointerStorageScopeAllReduce final : public UpdatePointerStorageScop // use volatile access to shared buffer. body = AttrStmt(remapped, attr::volatile_scope, 1, body); } - body = Allocate(remapped, op->dtype, op->extents, op->condition, body); - return AttrStmt(remapped, attr::storage_scope, StringImm(new_scope), body); + return Allocate(remapped, op->dtype, op->extents, op->condition, body); } return StmtExprMutator::VisitStmt_(op); } @@ -71,15 +70,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Stmt ret = StmtExprMutator::VisitStmt_(op); thread_extents_.pop_back(); return ret; - } else if (op->attr_key == attr::storage_scope) { - Stmt ret = StmtExprMutator::VisitStmt_(op); - op = ret.as(); - const VarNode* v = op->node.as(); - if (alloc_remap_.count(v)) { - return op->body; - } else { - return ret; - } } else if (op->attr_key == attr::reduce_scope) { const CommReducerNode* combiner = op->node.as(); ICHECK(combiner); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 060b02c3d137..8cc6d3f2541f 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -364,28 +364,15 @@ class WarpMemoryRewriter : private StmtMutator { Stmt VisitStmt_(const AllocateNode* op) { auto ret = StmtMutator::VisitStmt_(op); op = ret.as(); - if (warp_buffer_.count(op->buffer_var.get())) { + if (GetPtrStorageScope(op->buffer_var) == "warp") { + new_storage_scopes_[op->buffer_var.get()] = "local"; WarpAccessRewriter rewriter(warp_size_, &analyzer_); ret = rewriter.Rewrite(op); } return ret; } - Stmt VisitStmt_(const AttrStmtNode* op) { - using runtime::StorageScope; - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - StorageScope scope = StorageScope::Create(op->value.as()->value); - if (scope.rank == runtime::StorageRank::kWarp) { - warp_buffer_.insert(buf); - new_storage_scopes_[buf] = "local"; - } - } - return StmtMutator::VisitStmt_(op); - } - int warp_size_{0}; - std::unordered_set warp_buffer_; arith::Analyzer analyzer_; // variable domain std::unordered_map var_dom_; diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 5de22fe8665d..38b3a77b1a0c 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -224,7 +224,6 @@ class StorageFlattener : public StmtExprMutator { ret = Allocate(e.buffer->data, storage_type, shape, make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); } - ret = AttrStmt(e.buffer->data, attr::storage_scope, StringImm(skey.to_string()), ret); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { ret = AttrStmt(e.buffer->data, tir::attr::buffer_bound, diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index b216b8b848db..3a2990c928c7 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -398,10 +398,8 @@ class StoragePlanRewriter : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - return this->VisitStmt(op->body); - } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || - attr::IsPragmaKey(op->attr_key)) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || + attr::IsPragmaKey(op->attr_key)) { // remake all the allocation at the attach scope. if (attach_map_.count(op)) { auto& svec = attach_map_[op]; @@ -485,8 +483,6 @@ class StoragePlanRewriter : public StmtExprMutator { std::vector nest; for (StorageEntry* e : svec) { if (e->new_alloc.defined()) { - nest.emplace_back(AttrStmt(e->alloc_var, attr::storage_scope, - StringImm(e->scope.to_string()), Evaluate(0))); nest.push_back(e->new_alloc); } } diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc index d0f58074ada0..1836b8ecec0d 100644 --- a/src/tir/transforms/tensorcore_infer_fragment.cc +++ b/src/tir/transforms/tensorcore_infer_fragment.cc @@ -69,7 +69,7 @@ class FragmentGetter : public StmtExprVisitor { ICHECK(k); ICHECK(layout); - std::string scope = scopes[buffer_var]; + std::string scope = GetPtrStorageScope(GetRef(buffer_var)); if (fragments.count(buffer_var)) { // check if the fragment has met before FragmentInfo info = fragments[buffer_var]; @@ -102,7 +102,7 @@ class FragmentGetter : public StmtExprVisitor { ICHECK(n); ICHECK(k); - std::string scope = scopes[buffer_var]; + std::string scope = GetPtrStorageScope(GetRef(buffer_var)); // Only wmma.accumulator can use tvm_fill_fragment ICHECK_EQ(scope, "wmma.accumulator"); if (fragments.count(buffer_var)) { @@ -119,16 +119,9 @@ class FragmentGetter : public StmtExprVisitor { // Get memory scope void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - const VarNode* buffer = op->node.as(); - ICHECK(buffer); - scopes[buffer] = op->value.as()->value; - } StmtExprVisitor::VisitStmt_(op); } - // Memory scope for allocations - std::unordered_map scopes; // Fragment metadata for all fragments std::unordered_map fragments; }; diff --git a/src/tir/transforms/update_pointer_storage_scope.cc b/src/tir/transforms/update_pointer_storage_scope.cc index 0ae02fec9f95..4143577a0b17 100644 --- a/src/tir/transforms/update_pointer_storage_scope.cc +++ b/src/tir/transforms/update_pointer_storage_scope.cc @@ -64,17 +64,6 @@ PrimExpr UpdatePointerStorageScope::VisitExpr_(const LoadNode* op) { StmtExprMutator::VisitExpr(op->predicate)); } -Stmt UpdatePointerStorageScope::VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - auto remapped = Downcast(StmtExprMutator::VisitExpr(GetRef(buf))); - auto new_scope = GetPtrStorageScope(remapped); - return AttrStmt(remapped, attr::storage_scope, StringImm(new_scope), - StmtMutator::VisitStmt(op->body)); - } - return StmtMutator::VisitStmt_(op); -} - Stmt UpdatePointerStorageScope::VisitStmt_(const AllocateNode* op) { auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); return Allocate(remapped, op->dtype, op->extents, StmtExprMutator::VisitExpr(op->condition), diff --git a/src/tir/transforms/update_pointer_storage_scope.h b/src/tir/transforms/update_pointer_storage_scope.h index 481536a45b27..f310194a4a51 100644 --- a/src/tir/transforms/update_pointer_storage_scope.h +++ b/src/tir/transforms/update_pointer_storage_scope.h @@ -40,7 +40,6 @@ class UpdatePointerStorageScope : public StmtExprMutator { virtual PrimExpr VisitExpr_(const VarNode*); virtual PrimExpr VisitExpr_(const LoadNode*); - virtual Stmt VisitStmt_(const AttrStmtNode*); virtual Stmt VisitStmt_(const AllocateNode*); virtual Stmt VisitStmt_(const StoreNode*); diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 0329134bb3fa..5b123e883849 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -31,8 +31,6 @@ def test_for(): A[j] = A[j] + 2 body = ib.get() - assert isinstance(body, tvm.tir.AttrStmt) - body = body.body assert isinstance(body, tvm.tir.Allocate) body = body.body assert isinstance(body, tvm.tir.For) diff --git a/tests/python/unittest/test_tir_transform_coproc_sync.py b/tests/python/unittest/test_tir_transform_coproc_sync.py index 2d45118f39f2..7dacd8e046cc 100644 --- a/tests/python/unittest/test_tir_transform_coproc_sync.py +++ b/tests/python/unittest/test_tir_transform_coproc_sync.py @@ -51,7 +51,7 @@ def meminfo_cache(): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body - body = stmt.body.body.body + body = stmt.body.body blist = tvm.tir.stmt_list(body) assert blist[1].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_read_barrier")) @@ -112,7 +112,7 @@ def __check_list(tvm_array, py_list): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body - slist = tvm.tir.stmt_list(stmt[0].body.body) + slist = tvm.tir.stmt_list(stmt[0].body) push_st = slist[2] slist = tvm.tir.stmt_list(slist[-1]) pop_st = slist[0].body[0] diff --git a/tests/python/unittest/test_tir_transform_inject_double_buffer.py b/tests/python/unittest/test_tir_transform_inject_double_buffer.py index ceb32c484c6d..9b37bcaaacbc 100644 --- a/tests/python/unittest/test_tir_transform_inject_double_buffer.py +++ b/tests/python/unittest/test_tir_transform_inject_double_buffer.py @@ -47,8 +47,8 @@ def test_double_buffer(): mod = opt(mod) stmt = mod["db"].body - assert isinstance(stmt.body.body, tvm.tir.Allocate) - assert stmt.body.body.extents[0].value == 2 + assert isinstance(stmt.body, tvm.tir.Allocate) + assert stmt.body.extents[0].value == 2 f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] count = [0] diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py index 3e7a5a0cb300..673267a9b1fa 100644 --- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -49,13 +49,13 @@ def get_vthread(name): stmt = tvm.tir.transform.InjectVirtualThread()( tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("vthread"))) - )["main"].body + )["main"] assert stmt.body.body.extents[0].value == 2 stmt = tvm.tir.transform.InjectVirtualThread()( tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("cthread"))) - )["main"].body + )["main"] assert len(stmt.body.body.extents) == 3 @@ -94,11 +94,11 @@ def get_vthread(name): stmt = tvm.tir.transform.InjectVirtualThread()( tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("cthread"))) - )["main"].body + )["main"] assert stmt.body.body.extents[0].value == 2 - assert stmt.body.body.body.body.body.body.extents[0].value == 2 - assert len(stmt.body.body.body.body.body.body.extents) == 3 + assert stmt.body.body.body.body.extents[0].value == 2 + assert len(stmt.body.body.body.body.extents) == 3 def test_vthread_if_then_else(): @@ -119,7 +119,7 @@ def test_vthread_if_then_else(): stmt = tvm.tir.transform.InjectVirtualThread()( tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) - )["main"].body + )["main"] assert stmt.body.body.body[0].else_case != None assert stmt.body.body.body[1].else_case == None diff --git a/tests/python/unittest/test_tir_transform_lift_attr_scope.py b/tests/python/unittest/test_tir_transform_lift_attr_scope.py index 12ad16dfe092..65e317dfbcb8 100644 --- a/tests/python/unittest/test_tir_transform_lift_attr_scope.py +++ b/tests/python/unittest/test_tir_transform_lift_attr_scope.py @@ -38,7 +38,7 @@ def test_coproc_lift(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) - body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"].body + body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"] assert body.body.body.node == cp @@ -58,7 +58,7 @@ def test_coproc_lift(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) - body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"].body + body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"] assert body.body.body.body[1].node == cp assert len(body.body.body.body) == 2 diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index 6194024748e0..c632f744bb81 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -178,7 +178,7 @@ def test_vectorize(): s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) s[C].vectorize(x) - stmt = tvm.lower(s, [A, B], name="main")["main"].body + stmt = tvm.lower(s, [A, B], name="main")["main"] body = stmt.body.body.body.body assert x.var.name not in str(body.condition) assert any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp))) @@ -229,7 +229,7 @@ def test_thread_axis2(): _, x = s[C].split(x, factor=m) s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) - stmt = tvm.lower(s, [A, B], name="main")["main"].body + stmt = tvm.lower(s, [A, B], name="main")["main"] for_body = stmt.body.body.body.body[0] assert "threadIdx" not in str(for_body.extent) diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index f3baff120cf6..84bf0c4d52fd 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -47,8 +47,9 @@ def test_lower_warp_memory_local_scope(): fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"] mod = tvm.IRModule.from_expr(fdevice) fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"] - assert fdevice.body.body.value.value == "local" - assert fdevice.body.body.body.extents[0].value == 2 + allocate = fdevice.body.body + assert allocate.buffer_var.type_annotation.storage_scope == "local" + assert fdevice.body.body.extents[0].value == 2 @tvm.testing.requires_cuda diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index 2d1fea01aa32..0e9ab862a9c8 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -79,7 +79,7 @@ def test_flatten_storage_align(): )(mod) stmt = mod["main"].body - assert stmt.body.extents[0].value == 17 * 8 + assert stmt.extents[0].value == 17 * 8 def test_flatten_double_buffer(): @@ -114,8 +114,8 @@ def test_flatten_double_buffer(): )(mod) stmt = mod["main"].body - assert isinstance(stmt.body.body, tvm.tir.Allocate) - assert stmt.body.body.extents[0].value == 2 + assert isinstance(stmt.body, tvm.tir.Allocate) + assert stmt.body.extents[0].value == 2 mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C], stmt).with_attr("global_symbol", "db")) f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index 70e77ff69fea..9e738b136b17 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -298,9 +298,9 @@ def test_storage_share_gpu(): alloc_stats = {"global": 0, "shared": 0} def verify(n): - if isinstance(n, tvm.tir.AttrStmt): - if n.attr_key == "storage_scope": - alloc_stats[n.value.value] += 1 + if isinstance(n, tvm.tir.Allocate): + scope = n.buffer_var.type_annotation.storage_scope + alloc_stats[scope] += 1 tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert alloc_stats["global"] == 2 @@ -317,7 +317,7 @@ def test_parallel_alloc(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) - body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + body = tvm.tir.transform.StorageRewrite()(mod)["main"] assert isinstance(body.body.body, tvm.tir.Allocate) @@ -334,7 +334,7 @@ def test_parallel_alloc(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) - body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + body = tvm.tir.transform.StorageRewrite()(mod)["main"] assert isinstance(body.body.body.body.body, tvm.tir.Allocate) @@ -356,7 +356,6 @@ def get_mod(kind="serial"): mod = get_mod(kind="parallel") # parallel (i, 0, n) { - # // attr [j] storage_scope = "global" # allocate j[int32 * 1] # j[0] = 0 # while((j[0] < 10)){ @@ -366,11 +365,9 @@ def get_mod(kind="serial"): # j[0] = (j[0] + (j[0] + 1)) # } # } - body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + body = tvm.tir.transform.StorageRewrite()(mod)["main"] # parallel (i, 0, n) { - # // attr [j] storage_scope = "global" # allocate j[int32 * 1] - # // attr [A] storage_scope = "global" # allocate A[float32 * n] # j[0] = 0 # while((j[0] < 10)){ @@ -379,11 +376,10 @@ def get_mod(kind="serial"): # } # } assert isinstance(body.body.body, tvm.tir.Allocate) # j - assert isinstance(body.body.body.body.body, tvm.tir.Allocate) # A + assert isinstance(body.body.body.body, tvm.tir.Allocate) # A mod = get_mod(kind="serial") # for (i, 0, n) { - # // attr [j] storage_scope = "global" # allocate j[int32 * 1] # j[0] = 0 # while((j[0] < 10)){ @@ -393,10 +389,8 @@ def get_mod(kind="serial"): # j[0] = (j[0] + (j[0] + 1)) # } # } - body = tvm.tir.transform.StorageRewrite()(mod)["main"].body - # // attr [j] storage_scope = "global" + body = tvm.tir.transform.StorageRewrite()(mod)["main"] # allocate j[int32 * 1] - # // attr [A] storage_scope = "global" # allocate A[float32 * n] # for (i, 0, n) { # j[0] = 0 @@ -406,7 +400,7 @@ def get_mod(kind="serial"): # } # } assert isinstance(body.body, tvm.tir.Allocate) # j - assert isinstance(body.body.body.body, tvm.tir.Allocate) # A + assert isinstance(body.body.body, tvm.tir.Allocate) # A def test_inplace_rule2(scope_tb="local_TB2", max_bits=1024 * 1024 * 1024): diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 030c01713927..7fff6a804e4a 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -52,7 +52,7 @@ def test_thread_storage_sync(): mod = tvm.IRModule.from_expr(fdevice) cuda_target = tvm.target.Target("cuda") f = tvm.tir.transform.ThreadSync("shared")(mod)["test_kernel0"] - body_list = tvm.tir.stmt_list(f.body.body.body.body) + body_list = tvm.tir.stmt_list(f.body.body.body) assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")) From 05cd4ebe0bfba5a7de338f7148ac568756918431 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 21 Jul 2021 10:03:03 +0900 Subject: [PATCH 2/7] pyformat --- python/tvm/contrib/hexagon.py | 1 - python/tvm/script/scope_handler.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/contrib/hexagon.py b/python/tvm/contrib/hexagon.py index c2197af22d2a..6438c881b7c9 100644 --- a/python/tvm/contrib/hexagon.py +++ b/python/tvm/contrib/hexagon.py @@ -185,7 +185,6 @@ def visit(stmt): if scope == "local.vtcm": vtcm_buffers.append(stmt.node) - def mutate(stmt): """Insert calls to VTCM allocation and deallocation routines.""" if isinstance(stmt, tvm.tir.AttrStmt): diff --git a/python/tvm/script/scope_handler.py b/python/tvm/script/scope_handler.py index 971580343763..bb408f6cdc8f 100644 --- a/python/tvm/script/scope_handler.py +++ b/python/tvm/script/scope_handler.py @@ -110,10 +110,9 @@ def __init__(self): def allocate(extents, dtype, scope, condition=True, span=None): condition = tvm.runtime.convert(condition) scope = tvm.runtime.convert(scope) - body = tvm.tir.Allocate( + return tvm.tir.Allocate( self.buffer_var, dtype, extents, condition, self.body, span=span ) - return body super().__init__(allocate, concise_scope=True, def_symbol=True) self.buffer_var = None From 5a00d9690b37b2f781e4c989827c2985e831da1e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 21 Jul 2021 16:45:30 +0900 Subject: [PATCH 3/7] fixed VTA tests --- .../transforms/lower_device_storage_access_info.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index 0893f02d7443..7443a2449fb2 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -41,24 +41,24 @@ using runtime::StorageScope; class StorageAccessInfoLower : public StmtExprMutator { public: Stmt VisitStmt_(const AllocateNode* op) final { - // Lower allocate to device allocate when needed. - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var)); - ICHECK(storage_info_.find(op->buffer_var.get()) == storage_info_.end()) - << "Double allocation of " << scope.to_string(); if (scope.tag.length() != 0) { auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var)); ICHECK(info.defined()) << "Cannot find memory info of " << scope.to_string(); + ICHECK(storage_info_.find(op->buffer_var.get()) == storage_info_.end()) + << "Double allocation of " << scope.to_string(); storage_info_[op->buffer_var.get()] = info; + // Lower allocate to device allocate when needed. + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); if (info->head_address.defined()) { return LetStmt(op->buffer_var, info->head_address, op->body); } else { return op->body; } } else { - return stmt; + return StmtExprMutator::VisitStmt_(op); } } From 86772739705066bcf9872174ba61513640404a06 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 21 Jul 2021 17:21:39 +0900 Subject: [PATCH 4/7] Update TIR text printer to print storage_scope on allocate --- python/tvm/contrib/hexagon.py | 2 +- src/printer/tir_text_printer.cc | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/contrib/hexagon.py b/python/tvm/contrib/hexagon.py index 6438c881b7c9..6364ef749dd9 100644 --- a/python/tvm/contrib/hexagon.py +++ b/python/tvm/contrib/hexagon.py @@ -183,7 +183,7 @@ def visit(stmt): elif isinstance(stmt, tvm.tir.Allocate): scope = stmt.buffer_var.type_annotation.storage_scope if scope == "local.vtcm": - vtcm_buffers.append(stmt.node) + vtcm_buffers.append(stmt.buffer_var) def mutate(stmt): """Insert calls to VTCM allocation and deallocation routines.""" diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 0f3b89932b68..0692fec27872 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -35,6 +35,7 @@ #include #include +#include "../tir/transforms/ir_utils.h" #include "doc.h" #include "meta_data.h" #include "text_printer.h" @@ -447,8 +448,9 @@ Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) { Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) { Doc doc; + auto scope = GetPtrStorageScope(op->buffer_var); doc << "allocate(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", " - << Print(op->extents) << ")"; + << Print(op->extents) << "), storage_scope = " << scope; if (!is_one(op->condition)) { doc << " if " << Print(op->condition); } From a29019506bd8cb3efbb3836643635e209304a4b0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 22 Jul 2021 10:49:00 +0900 Subject: [PATCH 5/7] print storage scope in AllocateNode ReprPrinter --- src/tir/ir/stmt.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 6fdeb30ec100..9a20f3ec9358 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -360,13 +360,15 @@ TVM_REGISTER_NODE_TYPE(AllocateNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); + const auto* ptr_type = op->buffer_var->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; p->PrintIndent(); p->stream << "allocate " << op->buffer_var << "[" << op->dtype; for (size_t i = 0; i < op->extents.size(); ++i) { p->stream << " * "; p->Print(op->extents[i]); } - p->stream << "]"; + p->stream << "], storage_scope = " << ptr_type->storage_scope; if (!is_one(op->condition)) { p->stream << " if "; p->Print(op->condition); From 9398d96f2d9d3c63dfd81fc37195092d727a8bac Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 27 Jul 2021 23:43:55 +0900 Subject: [PATCH 6/7] Fixed accidently removed scope tag check --- src/tir/transforms/lower_device_storage_access_info.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index 7443a2449fb2..b4ec91ba5012 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -42,7 +42,7 @@ class StorageAccessInfoLower : public StmtExprMutator { public: Stmt VisitStmt_(const AllocateNode* op) final { auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var)); - if (scope.tag.length() != 0) { + if (scope.tag.length() != 0 && scope.tag != ".dyn") { auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var)); ICHECK(info.defined()) << "Cannot find memory info of " << scope.to_string(); ICHECK(storage_info_.find(op->buffer_var.get()) == storage_info_.end()) From 64e1beece833e83c3f1f0fd6347e4343ac63110d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 28 Jul 2021 06:18:26 +0900 Subject: [PATCH 7/7] remove unused function --- src/tir/transforms/inject_copy_intrin.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/tir/transforms/inject_copy_intrin.cc b/src/tir/transforms/inject_copy_intrin.cc index 74538bcb6806..f99cbd5b5a05 100644 --- a/src/tir/transforms/inject_copy_intrin.cc +++ b/src/tir/transforms/inject_copy_intrin.cc @@ -153,10 +153,7 @@ class CopyIntrinInjector : public StmtMutator { ICHECK(out->defined()) << "flower function did not return correct stmt"; return true; } - // Get storage scope - std::string GetStorageScope(const VarNode* var) const { - return GetPtrStorageScope(GetRef(var)); - } + // pragma key std::string pragma_key_; // function to lower copy intrinsics.