From ed684e5a43e842d4a9924f2052b9753304a39e9c Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Tue, 28 Apr 2020 18:10:55 +0000 Subject: [PATCH 1/5] non recursive partitioning --- src/relay/transforms/partition_graph.cc | 210 +++++------------------- 1 file changed, 41 insertions(+), 169 deletions(-) diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 3b0d6bc6f8e0..12b425e02942 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -54,41 +54,6 @@ namespace partitioning { static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); -/*! - * \brief The checker that verifies if a Relay program is annotated correctly - * for partitioning. - */ -class AnnotationChecker : public ExprVisitor { - public: - bool Check() { - if (!found_start_ && !found_end_) { - LOG(WARNING) << "No compiler annotation found"; - } else if (!found_start_) { - LOG(ERROR) << "compiler_begin annotation is missing"; - return false; - } else if (!found_end_) { - LOG(ERROR) << "compiler_end annotation is missing"; - return false; - } - return true; - } - - void VisitExpr_(const CallNode* call) final { - auto op_node = call->op.as(); - if (op_node == nullptr || call->attrs.as() == nullptr) { - return; - } else if (call->op == compiler_begin_op) { - found_start_ = true; - } else if (call->op == compiler_end_op) { - found_end_ = true; - } - } - - private: - bool found_start_{false}; - bool found_end_{false}; -}; - /*! \brief This class partitions the expr labeled with begin and end annotations * into function containing multiple regions. Each region is labeled with * a compiler attribute so that it will be handled by any compilers that are not @@ -124,7 +89,7 @@ class AnnotationChecker : public ExprVisitor { * the compiler name. */ -class Partitioner : public ExprMutator { +class Partitioner : public MixedModeMutator { public: explicit Partitioner(const IRModule& module) : module_(module) { for (auto f : module->functions) { @@ -138,10 +103,10 @@ class Partitioner : public ExprMutator { } } - Expr VisitExpr_(const CallNode* call) final { + Expr Rewrite_(const CallNode* call, const Expr& post) final { auto op_node = call->op.as(); if (op_node == nullptr || call->attrs.as() == nullptr) { - return ExprMutator::VisitExpr_(call); + return post; } else if (call->op == compiler_begin_op) { // The annotation node is inserted on edge so it must have only one // argument. @@ -149,7 +114,7 @@ class Partitioner : public ExprMutator { // Traverse the rest graph. Expr parent = call->args[0]; - auto input_expr = VisitExpr(parent); + auto input_expr = Downcast(post)->args[0]; // Backtrace the parent to find the first ancestor node that is not a begin or end op while (const auto* parent_call = parent.as()) { @@ -177,9 +142,9 @@ class Partitioner : public ExprMutator { std::pair cand = std::make_pair(var, input_expr); - if (std::find(region_args[sg].begin(), region_args[sg].end(), cand) == - region_args[sg].end()) { - region_args[sg].push_back(cand); + if (std::find(region_args_[sg].begin(), region_args_[sg].end(), cand) == + region_args_[sg].end()) { + region_args_[sg].push_back(cand); } shared_output_[parent][sg] = var; return std::move(var); @@ -197,15 +162,15 @@ class Partitioner : public ExprMutator { BaseFunc f = GetFunc(GetRef(call)); // Traverse subgraph inputs. - auto input = VisitExpr(call->args[0]); + auto input = Downcast(post)->args[0]; CHECK(region.defined()) << "Region not defined for " << GetRef(call); // functions are created for each annotated regions, // when their first output is encountered. // If multiple outputs are there, a tuple node is inserted at the end. - // region_function_calls is map that maintains + // region_function_calls_ is map that maintains // (each annotated regions) --> created function - if (region_function_calls.find(region) == region_function_calls.end()) { + if (region_function_calls_.find(region) == region_function_calls_.end()) { // First time this region is encountered in the traversal. // Creating the function. CreateFunction(region, call); @@ -215,99 +180,6 @@ class Partitioner : public ExprMutator { } } - Expr VisitExpr_(const TupleNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Array fields; - for (auto field : op->fields) { - fields.push_back(VisitExpr(field)); - } - return Tuple(fields); - } - } - - Expr VisitExpr_(const TupleGetItemNode* g) final { - auto region = GetRegion(GetRef(g)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(g); - } else { - auto t = VisitExpr(g->tuple); - return TupleGetItem(t, g->index); - } - } - - Expr VisitExpr_(const FunctionNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Array params; - for (auto param : op->params) { - Var new_param = Downcast(VisitExpr(param)); - params.push_back(new_param); - } - auto body = VisitExpr(op->body); - return Function(params, body, op->ret_type, op->type_params, op->attrs); - } - } - - Expr VisitExpr_(const LetNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Var var = Downcast(VisitExpr(op->var)); - auto value = VisitExpr(op->value); - auto body = VisitExpr(op->body); - return Let(var, value, body); - } - } - - Expr VisitExpr_(const IfNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - auto guard = VisitExpr(op->cond); - auto true_b = VisitExpr(op->true_branch); - auto false_b = VisitExpr(op->false_branch); - return If(guard, true_b, false_b); - } - } - - Expr VisitExpr_(const RefCreateNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Expr value = VisitExpr(op->value); - return RefCreate(value); - } - } - - Expr VisitExpr_(const RefReadNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Expr ref = VisitExpr(op->ref); - return RefRead(ref); - } - } - - Expr VisitExpr_(const RefWriteNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Expr ref = VisitExpr(op->ref); - Expr value = VisitExpr(op->value); - return RefWrite(ref, value); - } - } - IRModule Partition() { auto glob_funcs = module_->functions; for (const auto& pair : glob_funcs) { @@ -384,7 +256,7 @@ class Partitioner : public ExprMutator { // Don't duplicate outputs. if (!region_return_indices_.count(region) || !region_return_indices_[region].count(ret_node)) { - auto ret_expr = VisitExpr(ret_node); + auto ret_expr = MixedModeMutator::VisitExpr(ret_node); fields.push_back(ret_expr); region_return_indices_[region][ret_node] = i; i++; @@ -409,7 +281,7 @@ class Partitioner : public ExprMutator { return false; }; - for (auto pair : region_args[region]) { + for (auto pair : region_args_[region]) { params.push_back(pair.first); if (IsConstant(pair.second)) { params_bind.Set(pair.first, pair.second); @@ -459,7 +331,7 @@ class Partitioner : public ExprMutator { // The return type of callnode is the same as the type of the // compiler_end node. auto ret = Call(glob_func, param_expr); - region_function_calls[region] = ret; + region_function_calls_[region] = ret; } /*! @@ -471,7 +343,7 @@ class Partitioner : public ExprMutator { Expr arg = Downcast(end_arg)->args[0]; // Function has one output. if (region_return_indices_[region].size() == 1) { - return region_function_calls[region]; + return region_function_calls_[region]; } // Function has multiple outputs. // Use already made TupleGetItem. @@ -484,7 +356,7 @@ class Partitioner : public ExprMutator { region_return_indices_[region].count(arg)); int index = region_return_indices_[region][arg]; - auto func_call = region_function_calls[region]; + auto func_call = region_function_calls_[region]; auto tuple_get_item_ = TupleGetItem(func_call, index); tuple_get_item_->checked_type_ = arg->checked_type_; region_return_tuplegetitem_[region][arg] = tuple_get_item_; @@ -496,15 +368,15 @@ class Partitioner : public ExprMutator { * This is required in the multi-output scenario, to link rest of the outputs * to call */ - std::unordered_map region_function_calls; + std::unordered_map region_function_calls_; /*! * \brief This map maintains arguments (of region) visits through visitor - * patterns. Those arguement var and expression will be used to when creating + * patterns. Those argument var and expression will be used to when creating * the function. */ std::unordered_map>, ObjectHash, ObjectEqual> - region_args; + region_args_; /*! * \brief This map maintains the index of an output in the subgraph function @@ -538,34 +410,34 @@ class Partitioner : public ExprMutator { IRModule module_; }; -class DefaultRemover : public ExprMutator { - public: - explicit DefaultRemover(const IRModule& module) : module_(module) {} +IRModule Remove(IRModule module) { + class DefaultRemover : public ExprRewriter { + public: + DefaultRemover() = default; - IRModule Remove() { - auto glob_funcs = module_->functions; - for (const auto& pair : glob_funcs) { - if (auto* fn = pair.second.as()) { - auto func = GetRef(fn); - func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, - func->attrs); - module_->Update(pair.first, func); + Expr Rewrite_(const CallNode* call, const Expr& post) final { + auto attrs = call->attrs.as(); + if (attrs != nullptr && attrs->compiler == "default") { + return Downcast(post)->args[0]; } + return post; } - return module_; - } + }; - Expr VisitExpr_(const CallNode* call) final { - auto attrs = call->attrs.as(); - if (attrs != nullptr && attrs->compiler == "default") { - return VisitExpr(call->args[0]); + auto glob_funcs = module->functions; + // module is mutable, hence, we make a copy of it. + module.CopyOnWrite(); + for (const auto& pair : glob_funcs) { + if (auto* fn = pair.second.as()) { + auto func = GetRef(fn); + DefaultRemover remover; + auto removed = PostOrderRewrite(func->body, &remover); + func = Function(func->params, removed, func->ret_type, func->type_params, func->attrs); + module->Update(pair.first, func); } - return ExprMutator::VisitExpr_(call); } - - private: - IRModule module_; -}; + return module; +} } // namespace partitioning @@ -577,7 +449,7 @@ Pass PartitionGraph() { // TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute // by treating them as un-annotated, but we don't have it yet. This workaround pass removes // all "default" annotations and should be deleted in the future. - auto new_m = partitioning::DefaultRemover(m).Remove(); + auto new_m = partitioning::Remove(m); return partitioning::Partitioner(new_m).Partition(); }; auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {}); From b8bfc59a5d84858d1a9ee378b944260afa74c8d0 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 30 Apr 2020 21:07:05 +0000 Subject: [PATCH 2/5] refactor maps --- src/relay/transforms/partition_graph.cc | 145 +++++++++++------------- 1 file changed, 65 insertions(+), 80 deletions(-) diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 12b425e02942..24dcb6150c57 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -54,6 +54,23 @@ namespace partitioning { static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); +/*! \brief This struct maintains the required metadata for a region to generate a corresponding + * global function and function call. Global function will be passed to the target specific codegen + * and function call will be used in the transform Relay graph to invoke the function in runtime. + */ +struct RegionFuncMetadata { + /*! \brief The call node of the generated global function for this region. */ + Call func_call; + + /*! \brief A list of argument pairs. Each pair includes (var, expr). var is used + * as a function node argument; input expression is used as a function call parameter. + */ + std::vector> args; + + /*! \brief Map from each region output expr node to its output index and TupleGetItem node. */ + std::unordered_map, ObjectHash, ObjectEqual> out_expr_indices; +}; + /*! \brief This class partitions the expr labeled with begin and end annotations * into function containing multiple regions. Each region is labeled with * a compiler attribute so that it will be handled by any compilers that are not @@ -96,10 +113,15 @@ class Partitioner : public MixedModeMutator { GlobalVar f_var = f.first; BaseFunc f_func = f.second; - // Creating regionset per function in the module + // Creating regionset per function in the module. auto region_set = AnnotatedRegionSet::Create(f_func, partitioning::compiler_begin_op, partitioning::compiler_end_op); regions_sets_[region_set] = f_func; + + // Initial region function metadata. + for (auto region : region_set) { + region_func_meta_[region] = RegionFuncMetadata(); + } } } @@ -108,8 +130,7 @@ class Partitioner : public MixedModeMutator { if (op_node == nullptr || call->attrs.as() == nullptr) { return post; } else if (call->op == compiler_begin_op) { - // The annotation node is inserted on edge so it must have only one - // argument. + // The annotation node is inserted on edge so it must have only one argument. CHECK_EQ(call->args.size(), 1U); // Traverse the rest graph. @@ -118,8 +139,7 @@ class Partitioner : public MixedModeMutator { // Backtrace the parent to find the first ancestor node that is not a begin or end op while (const auto* parent_call = parent.as()) { - if (parent_call->op == compiler_begin_op || - parent_call->op == compiler_end_op) { + if (parent_call->op == compiler_begin_op || parent_call->op == compiler_end_op) { parent = parent_call->args[0]; } else { break; @@ -142,9 +162,9 @@ class Partitioner : public MixedModeMutator { std::pair cand = std::make_pair(var, input_expr); - if (std::find(region_args_[sg].begin(), region_args_[sg].end(), cand) == - region_args_[sg].end()) { - region_args_[sg].push_back(cand); + if (std::find(region_func_meta_[sg].args.begin(), region_func_meta_[sg].args.end(), cand) == + region_func_meta_[sg].args.end()) { + region_func_meta_[sg].args.push_back(cand); } shared_output_[parent][sg] = var; return std::move(var); @@ -167,10 +187,8 @@ class Partitioner : public MixedModeMutator { // functions are created for each annotated regions, // when their first output is encountered. // If multiple outputs are there, a tuple node is inserted at the end. - // region_function_calls_ is map that maintains - // (each annotated regions) --> created function - if (region_function_calls_.find(region) == region_function_calls_.end()) { + if (!region_func_meta_[region].func_call.defined()) { // First time this region is encountered in the traversal. // Creating the function. CreateFunction(region, call); @@ -246,19 +264,16 @@ class Partitioner : public MixedModeMutator { * node to create the function for the subgraph. */ void CreateFunction(AnnotatedRegion region, const CallNode* call) { - // Create fields which is a unique list of outputs. Also populate - // region_return_indices_ map which maps parent of compiler_end node to - // corresponding index in fields. + // Create fields which is a unique list of outputs. Array fields; int i = 0; for (auto ret : region->GetOutputs()) { auto ret_node = Downcast(ret)->args[0]; // Don't duplicate outputs. - if (!region_return_indices_.count(region) || - !region_return_indices_[region].count(ret_node)) { + if (!region_func_meta_[region].out_expr_indices.count(ret_node)) { auto ret_expr = MixedModeMutator::VisitExpr(ret_node); fields.push_back(ret_expr); - region_return_indices_[region][ret_node] = i; + region_func_meta_[region].out_expr_indices[ret_node] = {i, TupleGetItem()}; i++; } } @@ -281,7 +296,7 @@ class Partitioner : public MixedModeMutator { return false; }; - for (auto pair : region_args_[region]) { + for (auto pair : region_func_meta_[region].args) { params.push_back(pair.first); if (IsConstant(pair.second)) { params_bind.Set(pair.first, pair.second); @@ -303,14 +318,12 @@ class Partitioner : public MixedModeMutator { std::string target = call->attrs.as()->compiler; std::string name = target + "_" + std::to_string(region->GetID()); - global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol, - runtime::String(name)); global_region_func = - WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1)); - global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler, - tvm::runtime::String(target)); + WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol, runtime::String(name)); + global_region_func = WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1)); global_region_func = - WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1)); + WithAttr(std::move(global_region_func), attr::kCompiler, tvm::runtime::String(target)); + global_region_func = WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1)); // Constant propagation if (!params_bind.empty()) { @@ -318,8 +331,7 @@ class Partitioner : public MixedModeMutator { } std::string fname = name; - CHECK(!module_->ContainGlobalVar(fname)) - << "Global function " << fname << " already exists"; + CHECK(!module_->ContainGlobalVar(fname)) << "Global function " << fname << " already exists"; // Create a global function and add it to the IRModule for the region. // This way we lift the functions that should be handled by external // codegen to the module scope and rely on the pass manager to prevent @@ -328,10 +340,9 @@ class Partitioner : public MixedModeMutator { GlobalVar glob_func(fname); module_->Add(glob_func, global_region_func); - // The return type of callnode is the same as the type of the - // compiler_end node. + // The return type of callnode is the same as the type of the compiler_end node. auto ret = Call(glob_func, param_expr); - region_function_calls_[region] = ret; + region_func_meta_[region].func_call = ret; } /*! @@ -342,67 +353,41 @@ class Partitioner : public MixedModeMutator { Expr GetFunctionOutput(AnnotatedRegion region, const Expr& end_arg) { Expr arg = Downcast(end_arg)->args[0]; // Function has one output. - if (region_return_indices_[region].size() == 1) { - return region_function_calls_[region]; + if (region_func_meta_[region].out_expr_indices.size() == 1) { + return region_func_meta_[region].func_call; } + // Function has multiple outputs. // Use already made TupleGetItem. - if (region_return_tuplegetitem_.count(region) && - region_return_tuplegetitem_[region].count(arg)) { - return region_return_tuplegetitem_[region][arg]; + if (region_func_meta_[region].out_expr_indices.count(arg) && + region_func_meta_[region].out_expr_indices[arg].second.defined()) { + return region_func_meta_[region].out_expr_indices[arg].second; } // Create new TupleGetItem. - CHECK(region_return_indices_.count(region) && - region_return_indices_[region].count(arg)); - int index = region_return_indices_[region][arg]; + CHECK(region_func_meta_[region].out_expr_indices.count(arg)); + int index = region_func_meta_[region].out_expr_indices[arg].first; - auto func_call = region_function_calls_[region]; + auto func_call = region_func_meta_[region].func_call; auto tuple_get_item_ = TupleGetItem(func_call, index); tuple_get_item_->checked_type_ = arg->checked_type_; - region_return_tuplegetitem_[region][arg] = tuple_get_item_; + region_func_meta_[region].out_expr_indices[arg].second = tuple_get_item_; return std::move(tuple_get_item_); } - /*! - * \brief This map maintains the already created function calls. - * This is required in the multi-output scenario, to link rest of the outputs - * to call - */ - std::unordered_map region_function_calls_; - - /*! - * \brief This map maintains arguments (of region) visits through visitor - * patterns. Those argument var and expression will be used to when creating - * the function. - */ - std::unordered_map>, ObjectHash, ObjectEqual> - region_args_; - - /*! - * \brief This map maintains the index of an output in the subgraph function - * for a given region. If there are multiple entries for a region, then the - * function has a tuple of multiple outputs for its return. - */ - using RegionRetIndexMap = std::unordered_map; - std::unordered_map - region_return_indices_; - - /*! - * \brief This map holds already created TupleGetItem nodes for accessing - * outputs of a function. - */ - using RegionRetTupleGetItemMap = std::unordered_map; - std::unordered_map - region_return_tuplegetitem_; + /*! \brief Map from each region to its metadata of the generated function. */ + std::unordered_map + region_func_meta_; - /*! - * \brief Each region set is associated with a function in the module. + /*! \brief Each region set is associated with a function in the module. * This map maintains the mapping between regionsets and the function it * belongs to */ std::unordered_map regions_sets_; - /*!\brief Cache the output that is shared by different nodes. */ + /*! \brief Map from region output exprs to descendant region inputs (vars). This cache is used + * to make sure descendant region will not have duplicated inputs even it refers the same expr + * multuple times. + */ using RegionOutputMap = std::unordered_map; std::unordered_map shared_output_; @@ -444,13 +429,13 @@ IRModule Remove(IRModule module) { namespace transform { Pass PartitionGraph() { - runtime::TypedPackedFunc part_func = - [=](IRModule m, PassContext pc) { - // TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute - // by treating them as un-annotated, but we don't have it yet. This workaround pass removes - // all "default" annotations and should be deleted in the future. - auto new_m = partitioning::Remove(m); - return partitioning::Partitioner(new_m).Partition(); + runtime::TypedPackedFunc part_func = [=](IRModule m, + PassContext pc) { + // TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute + // by treating them as un-annotated, but we don't have it yet. This workaround pass removes + // all "default" annotations and should be deleted in the future. + auto new_m = partitioning::Remove(m); + return partitioning::Partitioner(new_m).Partition(); }; auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {}); return Sequential({partitioned, InferType()}); From aaed2b48a8a46b4dd668d218b6a666a671b89aef Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Thu, 30 Apr 2020 21:30:27 +0000 Subject: [PATCH 3/5] rebase upstream --- src/relay/transforms/partition_graph.cc | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 24dcb6150c57..6d52af0a8204 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -283,17 +283,11 @@ class Partitioner : public MixedModeMutator { Map params_bind; auto IsConstant = [](const Expr& expr) { - if (expr->IsInstance()) - return true; - if (expr->IsInstance()) { - auto tuple = expr.as(); - for (const auto& field : tuple->fields) { - if (!field->IsInstance()) - return false; - } - return true; - } - return false; + if (expr->IsInstance()) return true; + if (!expr->IsInstance()) return false; + const auto* tn = expr.as(); + return std::all_of(tn->fields.begin(), tn->fields.end(), + [](const Expr& e) { return e->IsInstance(); }); }; for (auto pair : region_func_meta_[region].args) { From e8ffa8985687908039fa89d825c2b67891cdabad Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 30 Apr 2020 21:32:28 +0000 Subject: [PATCH 4/5] refactor shared output --- src/relay/transforms/partition_graph.cc | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 6d52af0a8204..89c07febc183 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -69,6 +69,12 @@ struct RegionFuncMetadata { /*! \brief Map from each region output expr node to its output index and TupleGetItem node. */ std::unordered_map, ObjectHash, ObjectEqual> out_expr_indices; + + /*! \brief Map from each input expression to the corresponding input variable of this region. + * This cache is used to make sure a region function will not have duplicated inputs even + * it refers the same expr multuple times. + */ + std::unordered_map in_expr_vars; }; /*! \brief This class partitions the expr labeled with begin and end annotations @@ -120,7 +126,7 @@ class Partitioner : public MixedModeMutator { // Initial region function metadata. for (auto region : region_set) { - region_func_meta_[region] = RegionFuncMetadata(); + region_func_meta_[region]; } } } @@ -150,8 +156,8 @@ class Partitioner : public MixedModeMutator { int index = GetArgIdx(sg, GetRef(call)); CHECK_NE(index, -1); - if (shared_output_.count(parent) && shared_output_[parent].count(sg)) { - return shared_output_[parent][sg]; + if (region_func_meta_[sg].in_expr_vars.count(parent)) { + return region_func_meta_[sg].in_expr_vars[parent]; } else { // The type of the created variable is the same as the compiler_begin // node. @@ -166,7 +172,7 @@ class Partitioner : public MixedModeMutator { region_func_meta_[sg].args.end()) { region_func_meta_[sg].args.push_back(cand); } - shared_output_[parent][sg] = var; + region_func_meta_[sg].in_expr_vars[parent] = var; return std::move(var); } } else { @@ -378,13 +384,6 @@ class Partitioner : public MixedModeMutator { */ std::unordered_map regions_sets_; - /*! \brief Map from region output exprs to descendant region inputs (vars). This cache is used - * to make sure descendant region will not have duplicated inputs even it refers the same expr - * multuple times. - */ - using RegionOutputMap = std::unordered_map; - std::unordered_map shared_output_; - /*!\brief The IRModule used for partitioning. */ IRModule module_; }; From 26404c14d74ea0e9b6abb12ee4d942dd019748e6 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 1 May 2020 18:01:50 +0000 Subject: [PATCH 5/5] address comments --- src/relay/transforms/partition_graph.cc | 107 +++++++++++------------- 1 file changed, 47 insertions(+), 60 deletions(-) diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 89c07febc183..634434d3a3d0 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -67,14 +67,17 @@ struct RegionFuncMetadata { */ std::vector> args; - /*! \brief Map from each region output expr node to its output index and TupleGetItem node. */ - std::unordered_map, ObjectHash, ObjectEqual> out_expr_indices; + /*! \brief Map from each region output expr (compiler end) node to + * the corresponding function output expr. + */ + std::unordered_map region_func_out; - /*! \brief Map from each input expression to the corresponding input variable of this region. - * This cache is used to make sure a region function will not have duplicated inputs even - * it refers the same expr multuple times. + /*! \brief Map from each region input expression (compiler begin) to + * the corresponding function input variable. This cache is used to make sure + * a region function will not have duplicated inputs even if it refers to + * the same expr multiple times. */ - std::unordered_map in_expr_vars; + std::unordered_map region_func_in; }; /*! \brief This class partitions the expr labeled with begin and end annotations @@ -123,11 +126,6 @@ class Partitioner : public MixedModeMutator { auto region_set = AnnotatedRegionSet::Create(f_func, partitioning::compiler_begin_op, partitioning::compiler_end_op); regions_sets_[region_set] = f_func; - - // Initial region function metadata. - for (auto region : region_set) { - region_func_meta_[region]; - } } } @@ -156,8 +154,8 @@ class Partitioner : public MixedModeMutator { int index = GetArgIdx(sg, GetRef(call)); CHECK_NE(index, -1); - if (region_func_meta_[sg].in_expr_vars.count(parent)) { - return region_func_meta_[sg].in_expr_vars[parent]; + if (region_func_meta_[sg].region_func_in.count(parent)) { + return region_func_meta_[sg].region_func_in[parent]; } else { // The type of the created variable is the same as the compiler_begin // node. @@ -172,7 +170,7 @@ class Partitioner : public MixedModeMutator { region_func_meta_[sg].args.end()) { region_func_meta_[sg].args.push_back(cand); } - region_func_meta_[sg].in_expr_vars[parent] = var; + region_func_meta_[sg].region_func_in[parent] = var; return std::move(var); } } else { @@ -195,12 +193,14 @@ class Partitioner : public MixedModeMutator { // If multiple outputs are there, a tuple node is inserted at the end. if (!region_func_meta_[region].func_call.defined()) { - // First time this region is encountered in the traversal. - // Creating the function. + // First time this region is encountered in the traversal. Creating the function. CreateFunction(region, call); } + // Retrieve this particular output of function. - return GetFunctionOutput(region, GetRef(call)); + Expr region_out_expr = Downcast(GetRef(call))->args[0]; + CHECK(region_func_meta_[region].region_func_out.count(region_out_expr)); + return region_func_meta_[region].region_func_out[region_out_expr]; } } @@ -266,21 +266,22 @@ class Partitioner : public MixedModeMutator { } /*! - * \brief This function is called first time that we encounter a compiler_end - * node to create the function for the subgraph. + * \brief Create a function and its function call for the given region. If the function has + * multiple outputs, a Tuple will be formed to aggregate all outputs, and TupleGetItem nodes + * will be created to serve output consumers. */ - void CreateFunction(AnnotatedRegion region, const CallNode* call) { + void CreateFunction(AnnotatedRegion region, const CallNode* end_node) { // Create fields which is a unique list of outputs. Array fields; - int i = 0; - for (auto ret : region->GetOutputs()) { - auto ret_node = Downcast(ret)->args[0]; + std::unordered_map out_expr_to_idx; + int out_idx = 0; + for (auto region_end_node : region->GetOutputs()) { + auto ret_node = Downcast(region_end_node)->args[0]; // Don't duplicate outputs. - if (!region_func_meta_[region].out_expr_indices.count(ret_node)) { + if (!out_expr_to_idx.count(ret_node)) { auto ret_expr = MixedModeMutator::VisitExpr(ret_node); fields.push_back(ret_expr); - region_func_meta_[region].out_expr_indices[ret_node] = {i, TupleGetItem()}; - i++; + out_expr_to_idx[ret_node] = out_idx++; } } @@ -309,13 +310,13 @@ class Partitioner : public MixedModeMutator { if (fields.size() == 1) { // If there are only a single output; no need to add a tuple global_region_func = - Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs()); + Function(params, fields[0], end_node->args[0]->checked_type_, {}, DictAttrs()); } else { auto tuple = Tuple(fields); global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs()); } - std::string target = call->attrs.as()->compiler; + std::string target = end_node->attrs.as()->compiler; std::string name = target + "_" + std::to_string(region->GetID()); global_region_func = @@ -340,38 +341,24 @@ class Partitioner : public MixedModeMutator { GlobalVar glob_func(fname); module_->Add(glob_func, global_region_func); - // The return type of callnode is the same as the type of the compiler_end node. - auto ret = Call(glob_func, param_expr); - region_func_meta_[region].func_call = ret; - } + // Create a call node for the function. + auto call = Call(glob_func, param_expr); + region_func_meta_[region].func_call = call; - /*! - * \brief Get the return(output) of the function for compiler end node "end_arg". - * This will return either a Call (for a function with a single output) or a - * TupleGetItem (for a function with multiple outputs). - */ - Expr GetFunctionOutput(AnnotatedRegion region, const Expr& end_arg) { - Expr arg = Downcast(end_arg)->args[0]; - // Function has one output. - if (region_func_meta_[region].out_expr_indices.size() == 1) { - return region_func_meta_[region].func_call; - } - - // Function has multiple outputs. - // Use already made TupleGetItem. - if (region_func_meta_[region].out_expr_indices.count(arg) && - region_func_meta_[region].out_expr_indices[arg].second.defined()) { - return region_func_meta_[region].out_expr_indices[arg].second; + // Create output expr(s) for the function call. + if (out_expr_to_idx.size() == 1) { + // Single output direcly uses the call node as the output expr. + region_func_meta_[region].region_func_out[out_expr_to_idx.begin()->first] = call; + } else { + // Multiple outptus need to create TupleGetItem nodes as output exprs. + for (auto pair : out_expr_to_idx) { + Expr region_out_expr = pair.first; // The arg of a compiler end node of this region. + int idx = pair.second; // Corresponding function output tuple index. + auto tuple_get_item = TupleGetItem(call, idx); + tuple_get_item->checked_type_ = region_out_expr->checked_type_; + region_func_meta_[region].region_func_out[region_out_expr] = tuple_get_item; + } } - // Create new TupleGetItem. - CHECK(region_func_meta_[region].out_expr_indices.count(arg)); - int index = region_func_meta_[region].out_expr_indices[arg].first; - - auto func_call = region_func_meta_[region].func_call; - auto tuple_get_item_ = TupleGetItem(func_call, index); - tuple_get_item_->checked_type_ = arg->checked_type_; - region_func_meta_[region].out_expr_indices[arg].second = tuple_get_item_; - return std::move(tuple_get_item_); } /*! \brief Map from each region to its metadata of the generated function. */ @@ -388,7 +375,7 @@ class Partitioner : public MixedModeMutator { IRModule module_; }; -IRModule Remove(IRModule module) { +IRModule RemoveDefaultAnnotations(IRModule module) { class DefaultRemover : public ExprRewriter { public: DefaultRemover() = default; @@ -427,7 +414,7 @@ Pass PartitionGraph() { // TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute // by treating them as un-annotated, but we don't have it yet. This workaround pass removes // all "default" annotations and should be deleted in the future. - auto new_m = partitioning::Remove(m); + auto new_m = partitioning::RemoveDefaultAnnotations(m); return partitioning::Partitioner(new_m).Partition(); }; auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});