From 4e94fd95b6160f421846e1ab7b83a0d92ecd8c0c Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Fri, 1 May 2020 13:27:44 -0700 Subject: [PATCH] [REFACTOR][BOYC] Non recursive partitioning (#5493) * non recursive partitioning * refactor maps * rebase upstream * refactor shared output * address comments Co-authored-by: Cody Yu --- src/relay/transforms/partition_graph.cc | 393 +++++++----------------- 1 file changed, 115 insertions(+), 278 deletions(-) diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 3b0d6bc6f8e0..634434d3a3d0 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -54,39 +54,30 @@ namespace partitioning { static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); -/*! - * \brief The checker that verifies if a Relay program is annotated correctly - * for partitioning. +/*! \brief This struct maintains the required metadata for a region to generate a corresponding + * global function and function call. Global function will be passed to the target specific codegen + * and function call will be used in the transform Relay graph to invoke the function in runtime. */ -class AnnotationChecker : public ExprVisitor { - public: - bool Check() { - if (!found_start_ && !found_end_) { - LOG(WARNING) << "No compiler annotation found"; - } else if (!found_start_) { - LOG(ERROR) << "compiler_begin annotation is missing"; - return false; - } else if (!found_end_) { - LOG(ERROR) << "compiler_end annotation is missing"; - return false; - } - return true; - } +struct RegionFuncMetadata { + /*! \brief The call node of the generated global function for this region. */ + Call func_call; - void VisitExpr_(const CallNode* call) final { - auto op_node = call->op.as(); - if (op_node == nullptr || call->attrs.as() == nullptr) { - return; - } else if (call->op == compiler_begin_op) { - found_start_ = true; - } else if (call->op == compiler_end_op) { - found_end_ = true; - } - } + /*! \brief A list of argument pairs. Each pair includes (var, expr). var is used + * as a function node argument; input expression is used as a function call parameter. + */ + std::vector> args; - private: - bool found_start_{false}; - bool found_end_{false}; + /*! \brief Map from each region output expr (compiler end) node to + * the corresponding function output expr. + */ + std::unordered_map region_func_out; + + /*! \brief Map from each region input expression (compiler begin) to + * the corresponding function input variable. This cache is used to make sure + * a region function will not have duplicated inputs even if it refers to + * the same expr multiple times. + */ + std::unordered_map region_func_in; }; /*! \brief This class partitions the expr labeled with begin and end annotations @@ -124,37 +115,35 @@ class AnnotationChecker : public ExprVisitor { * the compiler name. */ -class Partitioner : public ExprMutator { +class Partitioner : public MixedModeMutator { public: explicit Partitioner(const IRModule& module) : module_(module) { for (auto f : module->functions) { GlobalVar f_var = f.first; BaseFunc f_func = f.second; - // Creating regionset per function in the module + // Creating regionset per function in the module. auto region_set = AnnotatedRegionSet::Create(f_func, partitioning::compiler_begin_op, partitioning::compiler_end_op); regions_sets_[region_set] = f_func; } } - Expr VisitExpr_(const CallNode* call) final { + Expr Rewrite_(const CallNode* call, const Expr& post) final { auto op_node = call->op.as(); if (op_node == nullptr || call->attrs.as() == nullptr) { - return ExprMutator::VisitExpr_(call); + return post; } else if (call->op == compiler_begin_op) { - // The annotation node is inserted on edge so it must have only one - // argument. + // The annotation node is inserted on edge so it must have only one argument. CHECK_EQ(call->args.size(), 1U); // Traverse the rest graph. Expr parent = call->args[0]; - auto input_expr = VisitExpr(parent); + auto input_expr = Downcast(post)->args[0]; // Backtrace the parent to find the first ancestor node that is not a begin or end op while (const auto* parent_call = parent.as()) { - if (parent_call->op == compiler_begin_op || - parent_call->op == compiler_end_op) { + if (parent_call->op == compiler_begin_op || parent_call->op == compiler_end_op) { parent = parent_call->args[0]; } else { break; @@ -165,8 +154,8 @@ class Partitioner : public ExprMutator { int index = GetArgIdx(sg, GetRef(call)); CHECK_NE(index, -1); - if (shared_output_.count(parent) && shared_output_[parent].count(sg)) { - return shared_output_[parent][sg]; + if (region_func_meta_[sg].region_func_in.count(parent)) { + return region_func_meta_[sg].region_func_in[parent]; } else { // The type of the created variable is the same as the compiler_begin // node. @@ -177,11 +166,11 @@ class Partitioner : public ExprMutator { std::pair cand = std::make_pair(var, input_expr); - if (std::find(region_args[sg].begin(), region_args[sg].end(), cand) == - region_args[sg].end()) { - region_args[sg].push_back(cand); + if (std::find(region_func_meta_[sg].args.begin(), region_func_meta_[sg].args.end(), cand) == + region_func_meta_[sg].args.end()) { + region_func_meta_[sg].args.push_back(cand); } - shared_output_[parent][sg] = var; + region_func_meta_[sg].region_func_in[parent] = var; return std::move(var); } } else { @@ -197,114 +186,21 @@ class Partitioner : public ExprMutator { BaseFunc f = GetFunc(GetRef(call)); // Traverse subgraph inputs. - auto input = VisitExpr(call->args[0]); + auto input = Downcast(post)->args[0]; CHECK(region.defined()) << "Region not defined for " << GetRef(call); // functions are created for each annotated regions, // when their first output is encountered. // If multiple outputs are there, a tuple node is inserted at the end. - // region_function_calls is map that maintains - // (each annotated regions) --> created function - if (region_function_calls.find(region) == region_function_calls.end()) { - // First time this region is encountered in the traversal. - // Creating the function. + if (!region_func_meta_[region].func_call.defined()) { + // First time this region is encountered in the traversal. Creating the function. CreateFunction(region, call); } - // Retrieve this particular output of function. - return GetFunctionOutput(region, GetRef(call)); - } - } - - Expr VisitExpr_(const TupleNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Array fields; - for (auto field : op->fields) { - fields.push_back(VisitExpr(field)); - } - return Tuple(fields); - } - } - - Expr VisitExpr_(const TupleGetItemNode* g) final { - auto region = GetRegion(GetRef(g)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(g); - } else { - auto t = VisitExpr(g->tuple); - return TupleGetItem(t, g->index); - } - } - - Expr VisitExpr_(const FunctionNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Array params; - for (auto param : op->params) { - Var new_param = Downcast(VisitExpr(param)); - params.push_back(new_param); - } - auto body = VisitExpr(op->body); - return Function(params, body, op->ret_type, op->type_params, op->attrs); - } - } - - Expr VisitExpr_(const LetNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Var var = Downcast(VisitExpr(op->var)); - auto value = VisitExpr(op->value); - auto body = VisitExpr(op->body); - return Let(var, value, body); - } - } - - Expr VisitExpr_(const IfNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - auto guard = VisitExpr(op->cond); - auto true_b = VisitExpr(op->true_branch); - auto false_b = VisitExpr(op->false_branch); - return If(guard, true_b, false_b); - } - } - - Expr VisitExpr_(const RefCreateNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Expr value = VisitExpr(op->value); - return RefCreate(value); - } - } - Expr VisitExpr_(const RefReadNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Expr ref = VisitExpr(op->ref); - return RefRead(ref); - } - } - - Expr VisitExpr_(const RefWriteNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Expr ref = VisitExpr(op->ref); - Expr value = VisitExpr(op->value); - return RefWrite(ref, value); + // Retrieve this particular output of function. + Expr region_out_expr = Downcast(GetRef(call))->args[0]; + CHECK(region_func_meta_[region].region_func_out.count(region_out_expr)); + return region_func_meta_[region].region_func_out[region_out_expr]; } } @@ -370,24 +266,22 @@ class Partitioner : public ExprMutator { } /*! - * \brief This function is called first time that we encounter a compiler_end - * node to create the function for the subgraph. + * \brief Create a function and its function call for the given region. If the function has + * multiple outputs, a Tuple will be formed to aggregate all outputs, and TupleGetItem nodes + * will be created to serve output consumers. */ - void CreateFunction(AnnotatedRegion region, const CallNode* call) { - // Create fields which is a unique list of outputs. Also populate - // region_return_indices_ map which maps parent of compiler_end node to - // corresponding index in fields. + void CreateFunction(AnnotatedRegion region, const CallNode* end_node) { + // Create fields which is a unique list of outputs. Array fields; - int i = 0; - for (auto ret : region->GetOutputs()) { - auto ret_node = Downcast(ret)->args[0]; + std::unordered_map out_expr_to_idx; + int out_idx = 0; + for (auto region_end_node : region->GetOutputs()) { + auto ret_node = Downcast(region_end_node)->args[0]; // Don't duplicate outputs. - if (!region_return_indices_.count(region) || - !region_return_indices_[region].count(ret_node)) { - auto ret_expr = VisitExpr(ret_node); + if (!out_expr_to_idx.count(ret_node)) { + auto ret_expr = MixedModeMutator::VisitExpr(ret_node); fields.push_back(ret_expr); - region_return_indices_[region][ret_node] = i; - i++; + out_expr_to_idx[ret_node] = out_idx++; } } @@ -396,20 +290,14 @@ class Partitioner : public ExprMutator { Map params_bind; auto IsConstant = [](const Expr& expr) { - if (expr->IsInstance()) - return true; - if (expr->IsInstance()) { - auto tuple = expr.as(); - for (const auto& field : tuple->fields) { - if (!field->IsInstance()) - return false; - } - return true; - } - return false; + if (expr->IsInstance()) return true; + if (!expr->IsInstance()) return false; + const auto* tn = expr.as(); + return std::all_of(tn->fields.begin(), tn->fields.end(), + [](const Expr& e) { return e->IsInstance(); }); }; - for (auto pair : region_args[region]) { + for (auto pair : region_func_meta_[region].args) { params.push_back(pair.first); if (IsConstant(pair.second)) { params_bind.Set(pair.first, pair.second); @@ -422,23 +310,21 @@ class Partitioner : public ExprMutator { if (fields.size() == 1) { // If there are only a single output; no need to add a tuple global_region_func = - Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs()); + Function(params, fields[0], end_node->args[0]->checked_type_, {}, DictAttrs()); } else { auto tuple = Tuple(fields); global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs()); } - std::string target = call->attrs.as()->compiler; + std::string target = end_node->attrs.as()->compiler; std::string name = target + "_" + std::to_string(region->GetID()); - global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol, - runtime::String(name)); global_region_func = - WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1)); - global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler, - tvm::runtime::String(target)); + WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol, runtime::String(name)); + global_region_func = WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1)); global_region_func = - WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1)); + WithAttr(std::move(global_region_func), attr::kCompiler, tvm::runtime::String(target)); + global_region_func = WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1)); // Constant propagation if (!params_bind.empty()) { @@ -446,8 +332,7 @@ class Partitioner : public ExprMutator { } std::string fname = name; - CHECK(!module_->ContainGlobalVar(fname)) - << "Global function " << fname << " already exists"; + CHECK(!module_->ContainGlobalVar(fname)) << "Global function " << fname << " already exists"; // Create a global function and add it to the IRModule for the region. // This way we lift the functions that should be handled by external // codegen to the module scope and rely on the pass manager to prevent @@ -456,129 +341,81 @@ class Partitioner : public ExprMutator { GlobalVar glob_func(fname); module_->Add(glob_func, global_region_func); - // The return type of callnode is the same as the type of the - // compiler_end node. - auto ret = Call(glob_func, param_expr); - region_function_calls[region] = ret; - } + // Create a call node for the function. + auto call = Call(glob_func, param_expr); + region_func_meta_[region].func_call = call; - /*! - * \brief Get the return(output) of the function for compiler end node "end_arg". - * This will return either a Call (for a function with a single output) or a - * TupleGetItem (for a function with multiple outputs). - */ - Expr GetFunctionOutput(AnnotatedRegion region, const Expr& end_arg) { - Expr arg = Downcast(end_arg)->args[0]; - // Function has one output. - if (region_return_indices_[region].size() == 1) { - return region_function_calls[region]; - } - // Function has multiple outputs. - // Use already made TupleGetItem. - if (region_return_tuplegetitem_.count(region) && - region_return_tuplegetitem_[region].count(arg)) { - return region_return_tuplegetitem_[region][arg]; + // Create output expr(s) for the function call. + if (out_expr_to_idx.size() == 1) { + // Single output direcly uses the call node as the output expr. + region_func_meta_[region].region_func_out[out_expr_to_idx.begin()->first] = call; + } else { + // Multiple outptus need to create TupleGetItem nodes as output exprs. + for (auto pair : out_expr_to_idx) { + Expr region_out_expr = pair.first; // The arg of a compiler end node of this region. + int idx = pair.second; // Corresponding function output tuple index. + auto tuple_get_item = TupleGetItem(call, idx); + tuple_get_item->checked_type_ = region_out_expr->checked_type_; + region_func_meta_[region].region_func_out[region_out_expr] = tuple_get_item; + } } - // Create new TupleGetItem. - CHECK(region_return_indices_.count(region) && - region_return_indices_[region].count(arg)); - int index = region_return_indices_[region][arg]; - - auto func_call = region_function_calls[region]; - auto tuple_get_item_ = TupleGetItem(func_call, index); - tuple_get_item_->checked_type_ = arg->checked_type_; - region_return_tuplegetitem_[region][arg] = tuple_get_item_; - return std::move(tuple_get_item_); } - /*! - * \brief This map maintains the already created function calls. - * This is required in the multi-output scenario, to link rest of the outputs - * to call - */ - std::unordered_map region_function_calls; - - /*! - * \brief This map maintains arguments (of region) visits through visitor - * patterns. Those arguement var and expression will be used to when creating - * the function. - */ - std::unordered_map>, ObjectHash, ObjectEqual> - region_args; - - /*! - * \brief This map maintains the index of an output in the subgraph function - * for a given region. If there are multiple entries for a region, then the - * function has a tuple of multiple outputs for its return. - */ - using RegionRetIndexMap = std::unordered_map; - std::unordered_map - region_return_indices_; + /*! \brief Map from each region to its metadata of the generated function. */ + std::unordered_map + region_func_meta_; - /*! - * \brief This map holds already created TupleGetItem nodes for accessing - * outputs of a function. - */ - using RegionRetTupleGetItemMap = std::unordered_map; - std::unordered_map - region_return_tuplegetitem_; - - /*! - * \brief Each region set is associated with a function in the module. + /*! \brief Each region set is associated with a function in the module. * This map maintains the mapping between regionsets and the function it * belongs to */ std::unordered_map regions_sets_; - /*!\brief Cache the output that is shared by different nodes. */ - using RegionOutputMap = std::unordered_map; - std::unordered_map shared_output_; - /*!\brief The IRModule used for partitioning. */ IRModule module_; }; -class DefaultRemover : public ExprMutator { - public: - explicit DefaultRemover(const IRModule& module) : module_(module) {} +IRModule RemoveDefaultAnnotations(IRModule module) { + class DefaultRemover : public ExprRewriter { + public: + DefaultRemover() = default; - IRModule Remove() { - auto glob_funcs = module_->functions; - for (const auto& pair : glob_funcs) { - if (auto* fn = pair.second.as()) { - auto func = GetRef(fn); - func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, - func->attrs); - module_->Update(pair.first, func); + Expr Rewrite_(const CallNode* call, const Expr& post) final { + auto attrs = call->attrs.as(); + if (attrs != nullptr && attrs->compiler == "default") { + return Downcast(post)->args[0]; } + return post; } - return module_; - } + }; - Expr VisitExpr_(const CallNode* call) final { - auto attrs = call->attrs.as(); - if (attrs != nullptr && attrs->compiler == "default") { - return VisitExpr(call->args[0]); + auto glob_funcs = module->functions; + // module is mutable, hence, we make a copy of it. + module.CopyOnWrite(); + for (const auto& pair : glob_funcs) { + if (auto* fn = pair.second.as()) { + auto func = GetRef(fn); + DefaultRemover remover; + auto removed = PostOrderRewrite(func->body, &remover); + func = Function(func->params, removed, func->ret_type, func->type_params, func->attrs); + module->Update(pair.first, func); } - return ExprMutator::VisitExpr_(call); } - - private: - IRModule module_; -}; + return module; +} } // namespace partitioning namespace transform { Pass PartitionGraph() { - runtime::TypedPackedFunc part_func = - [=](IRModule m, PassContext pc) { - // TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute - // by treating them as un-annotated, but we don't have it yet. This workaround pass removes - // all "default" annotations and should be deleted in the future. - auto new_m = partitioning::DefaultRemover(m).Remove(); - return partitioning::Partitioner(new_m).Partition(); + runtime::TypedPackedFunc part_func = [=](IRModule m, + PassContext pc) { + // TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute + // by treating them as un-annotated, but we don't have it yet. This workaround pass removes + // all "default" annotations and should be deleted in the future. + auto new_m = partitioning::RemoveDefaultAnnotations(m); + return partitioning::Partitioner(new_m).Partition(); }; auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {}); return Sequential({partitioned, InferType()});