From cfa991adc7605d793fdfdefed2b619f851a19f78 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Tue, 25 May 2021 16:28:35 +0100 Subject: [PATCH] Address comments Change-Id: Id69ae51bad05b85e95e35132ea43434a17c7a89e --- src/relay/backend/aot_executor_codegen.cc | 133 +++++++++++----------- 1 file changed, 69 insertions(+), 64 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index b4578c93f501d..6fe0ed8f05c31 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -46,22 +46,22 @@ namespace relay { namespace backend { /** - * Struct to contain information about intermediate variables in the + * Struct to contain information about the intermediate tensors in the * runner function */ struct StorageInfo { - /*! \brief unique integer identifier of the particular intermediate variable */ - std::vector ids; + /*! \brief storage integer identifier of the particular intermediate buffer */ + int sid; /*! \brief exact size of the temporary */ - std::vector sizes_bytes; - /*! \brief device type of the temporary variable */ - std::vector dev_types; + int size_bytes; + /*! \brief device type of the intermediate tensor */ + int dev_type; }; using IntegerArray = Array; using TargetsMap = std::unordered_map; -using StorageMap = - std::unordered_map; +using StorageMap = std::unordered_map, runtime::ObjectPtrHash, + runtime::ObjectPtrEqual>; /** * This is an on demand allocator for AOT. A new temporary @@ -74,10 +74,10 @@ class AOTOnDemandAllocator : public ExprVisitor { node_device_map_ = CollectDeviceInfo(func); for (Expr param : func->params) { - CreateSid(param.operator->()); + CreateStorage(param.operator->()); } - GetSid(func->body); + GetStorage(func->body); } std::vector GetReturnIds() const { return return_ids_; } @@ -85,15 +85,15 @@ class AOTOnDemandAllocator : public ExprVisitor { StorageMap GetStorageMap() const { return storage_device_map_; } void VisitExpr_(const ConstantNode* op) final { - CreateSid(op); + CreateStorage(op); AssignReturnSid(GetRef(op)); } void VisitExpr_(const CallNode* op) final { // create token for the call node. - CreateSid(op); + CreateStorage(op); for (Expr arg : op->args) { - GetSid(arg); + GetStorage(arg); } AssignReturnSid(GetRef(op)); } @@ -116,28 +116,22 @@ class AOTOnDemandAllocator : public ExprVisitor { } void VisitExpr_(const TupleNode* op) final { - StorageInfo field_sid; + std::vector field_sids; Expr expr = GetRef(op); for (Expr field : op->fields) { - auto sid = GetSid(field); - field_sid.ids.insert(field_sid.ids.end(), sid.ids.begin(), sid.ids.end()); - field_sid.dev_types.insert(field_sid.dev_types.end(), sid.dev_types.begin(), - sid.dev_types.end()); - field_sid.sizes_bytes.insert(field_sid.sizes_bytes.end(), sid.sizes_bytes.begin(), - sid.sizes_bytes.end()); + auto sid = GetStorage(field); + field_sids.insert(field_sids.end(), sid.begin(), sid.end()); } - storage_device_map_[expr] = field_sid; + storage_device_map_[expr] = field_sids; AssignReturnSid(expr); } void VisitExpr_(const TupleGetItemNode* op) final { Expr expr = GetRef(op); - const auto& sid = GetSid(op->tuple); - ICHECK_LT(static_cast(op->index), sid.ids.size()); - storage_device_map_[expr].ids = {sid.ids[op->index]}; - storage_device_map_[expr].sizes_bytes = {sid.sizes_bytes[op->index]}; - storage_device_map_[expr].dev_types = {sid.dev_types[op->index]}; + const auto& sids = GetStorage(op->tuple); + ICHECK_LT(static_cast(op->index), sids.size()); + storage_device_map_[expr] = {sids[op->index]}; AssignReturnSid(expr); } @@ -147,9 +141,13 @@ class AOTOnDemandAllocator : public ExprVisitor { private: void AssignReturnSid(Expr e) { - auto iter = storage_device_map_.find(e); - if (iter != storage_device_map_.end()) { - return_ids_ = (*iter).second.ids; + if (storage_device_map_.find(e) != storage_device_map_.end()) { + auto buffers = storage_device_map_[e]; + std::vector return_ids; + for (auto buffer : buffers) { + return_ids.push_back(buffer.sid); + } + return_ids_ = return_ids; } } /*! @@ -178,37 +176,45 @@ class AOTOnDemandAllocator : public ExprVisitor { return size; } /*! - * \brief Get the necessary token. + * \brief Get the necessary storage for the expression. * \param expr The expression. * \return The corresponding token. */ - StorageInfo GetSid(const Expr& expr) { + std::vector GetStorage(const Expr& expr) { this->VisitExpr(expr); auto it = storage_device_map_.find(expr); ICHECK(it != storage_device_map_.end()); return it->second; } - void CreateSid(const ExprNode* op) { - StorageInfo sid; + /*! + * \brief Create storage for the expression. + * \param expr The expression. + */ + void CreateStorage(const ExprNode* op) { + std::vector buffers; Expr expr = GetRef(op); int device_type = node_device_map_.count(GetRef(op)) ? node_device_map_[expr]->value : 0; if (const auto* tuple_type = op->checked_type().as()) { for (Type t : tuple_type->fields) { const auto* ttype = t.as(); ICHECK(ttype); - sid.ids.push_back(sid_++); - sid.dev_types.push_back(device_type); - sid.sizes_bytes.push_back(GetMemorySize(ttype)); + StorageInfo buffer; + buffer.sid = sid_++; + buffer.size_bytes = GetMemorySize(ttype); + buffer.dev_type = device_type; + buffers.push_back(buffer); } } else { const auto* ttype = op->checked_type().as(); ICHECK(ttype); - sid.ids.push_back(sid_++); - sid.dev_types.push_back(device_type); - sid.sizes_bytes.push_back(GetMemorySize(ttype)); + StorageInfo buffer; + buffer.sid = sid_++; + buffer.size_bytes = GetMemorySize(ttype); + buffer.dev_type = device_type; + buffers.push_back(buffer); } - storage_device_map_[expr] = sid; + storage_device_map_[expr] = buffers; } /*! \brief mapping of expression -> storageInfo*/ StorageMap storage_device_map_; @@ -216,7 +222,7 @@ class AOTOnDemandAllocator : public ExprVisitor { Map node_device_map_; /*! \brief current id of the temporary allocated*/ int sid_{0}; - /*! \brief the set of identifiers that are return variables */ + /*! \brief the set of intermediate tensors that are return variables */ std::vector return_ids_; }; @@ -248,18 +254,18 @@ class AOTExecutorCodegen : public ExprVisitor { * \brief Return a vector of variables that represents the sids for the given Relay Expr */ std::vector PackSid(Expr expr) { - auto sids = storage_device_map_[expr]; - std::vector sid_vars; + auto buffers = storage_device_map_[expr]; + std::vector buffer_vars; // Note that an expression can have multiple sids associated with it // e.g., returning multiple values from a function - for (const auto& sid : sids.ids) { + for (const auto& buffer : buffers) { // Determine if an sid is an output buffer - int sid_int = sid; - auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid_int); + int sid = buffer.sid; + auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid); if (output_iter != return_sid_.end()) { int output_index = std::distance(return_sid_.begin(), output_iter); - sid_vars.push_back(main_signature_[input_vars_.size() + output_index]); + buffer_vars.push_back(main_signature_[input_vars_.size() + output_index]); continue; } // Pack the sid inside the TVMValue @@ -269,9 +275,9 @@ class AOTExecutorCodegen : public ExprVisitor { tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), {sid_array, 0, tir::builtin::kArrData, sid_value}); stmts_.push_back(tir::LetStmt(sid_array, StackAlloca("array", 1), tir::Evaluate(set_tensor))); - sid_vars.push_back(sid_array); + buffer_vars.push_back(sid_array); } - return sid_vars; + return buffer_vars; } /*! @@ -518,8 +524,7 @@ class AOTExecutorCodegen : public ExprVisitor { } ICHECK_GE(storage_device_map_.count(expr), 0); - auto& device_type = storage_device_map_[expr].dev_types; - auto call_dev_type = device_type[0]; + auto call_dev_type = storage_device_map_[expr][0].dev_type; // Normal Relay Function if (targets_.size() == 1) { // homogeneous execution. @@ -556,14 +561,14 @@ class AOTExecutorCodegen : public ExprVisitor { // If the Var node is an output node we need to copy the content of the variable to the output // It's safe to check the SID here because Var StorageToken are never reallocated - auto sids = storage_device_map_[expr]; + auto buffers = storage_device_map_[expr]; - auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sids.ids[0]); + auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), buffers[0].sid); if (output_iter != return_sid_.end()) { int output_index = std::distance(return_sid_.begin(), output_iter); auto var_expr = FindExpr(expr); CopyToOutput(main_signature_[input_vars_.size() + output_index], var_expr[0], - sids.sizes_bytes[0]); + buffers[0].size_bytes); } } @@ -572,18 +577,18 @@ class AOTExecutorCodegen : public ExprVisitor { size_t index = params_.size(); std::string name = "p" + std::to_string(index); - param_storage_ids_[name] = storage_device_map_[expr].ids[0]; + param_storage_ids_[name] = storage_device_map_[expr][0].sid; params_[name] = op->data; params_by_expr_.Set(expr, name); // If the Constant node is an output node we need to copy the content of the parameter to the // output A Var node can only produce a single output - auto sids = storage_device_map_[expr]; - auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sids.ids[0]); + auto buffers = storage_device_map_[expr]; + auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), buffers[0].sid); if (output_iter != return_sid_.end()) { int output_index = std::distance(return_sid_.begin(), output_iter); CopyToOutput(main_signature_[input_vars_.size() + output_index], PackParam(expr), - sids.sizes_bytes[0]); + buffers[0].size_bytes); } } @@ -639,9 +644,9 @@ class AOTExecutorCodegen : public ExprVisitor { continue; } - for (unsigned int i = 0; i < kv.second.ids.size(); i++) { - int size = kv.second.sizes_bytes[i]; - int sid = kv.second.ids[i]; + for (unsigned int i = 0; i < kv.second.size(); i++) { + int size = kv.second[i].size_bytes; + int sid = kv.second[i].sid; if (std::find(return_sid_.begin(), return_sid_.end(), sid) != return_sid_.end()) { continue; @@ -733,9 +738,9 @@ class AOTExecutorCodegen : public ExprVisitor { // Define the storage allocator ids for (auto kv : storage_device_map_) { - for (const auto& sid : kv.second.ids) { - te::Var sid_var(MakeString("sid_", sid), PointerType(PrimType(DataType::Int(8)))); - sids_table_[sid] = sid_var; + for (const auto& buffer : kv.second) { + te::Var buffer_var(MakeString("sid_", buffer.sid), PointerType(PrimType(DataType::Int(8)))); + sids_table_[buffer.sid] = buffer_var; } }