From 2563f1fbdfd889f2752ff24dba3a1cbeb5cc09eb Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Mon, 30 Mar 2020 18:41:42 +0100 Subject: [PATCH 1/9] [RELAY] Fixed issues with MergeCompilerRegions This PR addresses a few outstanding issues with the implementation of MergeCompilerRegions. In particular, it now handles TupleGetItem nodes properly and other minor bugs related to region merging have been fixed. Change-Id: I07783afc56183a6f798a510209f23b0a5f252255 --- src/relay/transforms/annotate_target.cc | 21 ++- .../transforms/merge_compiler_regions.cc | 141 ++++++++++-------- 2 files changed, 98 insertions(+), 64 deletions(-) diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index c2f7b804cb6a..b546f05b46e4 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -32,6 +32,9 @@ namespace tvm { namespace relay { namespace annotate_target { +// Cache compiler_begin op for equivalence check. +static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); + // A helper class to insert annotation boundaries for a program region that will // be handled by a specific compiler. class AnnotateTargetWrapper : public ExprMutator { @@ -52,6 +55,13 @@ class AnnotateTargetWrapper : public ExprMutator { return fannotate[op](call->attrs, call->args); } } + if (expr->IsInstance()) { + TupleGetItem get = Downcast(expr); + if (get->tuple->IsInstance() && + get->tuple.as()->op == compiler_begin_op) { + return true; + } + } return false; } @@ -110,9 +120,14 @@ class AnnotateTargetWrapper : public ExprMutator { auto new_e = ExprMutator::VisitExpr_(op); auto get = Downcast(new_e); - return TupleGetItem( - InsertEnd(get->tuple), - get->index); + if (IsSupported(get->tuple)) { + const auto* begin_op = + runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); + CHECK(begin_op); + return TupleGetItem((*begin_op)(InsertEnd(get->tuple), target_), get->index); + } else { + return TupleGetItem(InsertEnd(get->tuple), get->index); + } } Expr VisitExpr_(const FunctionNode* op) { diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc index e6ec93aecd42..03be8648ba74 100644 --- a/src/relay/transforms/merge_compiler_regions.cc +++ b/src/relay/transforms/merge_compiler_regions.cc @@ -71,13 +71,13 @@ class AnnotateRestDefault : public ExprMutator { func_ = Downcast(expr); // Corner Case CC1 : If the last node does not belong - // to a region nede to add a compiler_end + // to a region node to add a compiler_end auto region = regions_->GetRegion(func_->body); auto mutated_expr = this->VisitExpr(expr); if (!region.defined()) { func_ = Downcast(mutated_expr); // CC1 : add that compiler end after mutation - auto body = AddCompilerEnd_(func_->body); + auto body = InsertEnd(func_->body); func_ = Function(func_->params, body, body->checked_type_, {}, DictAttrs()); return Downcast(func_); @@ -86,82 +86,100 @@ class AnnotateRestDefault : public ExprMutator { } /*! \brief This function adds compiler ends to nodes that - * have a region AND they should not be arguments of the - * original function + * don't belong to a region already (default). * \param expr The expression to add a compiler end to. * \return expr The expression with or without a compiler end added. */ - Expr AddCompilerEnd(const Expr& expr) { - auto region = regions_->GetRegion(expr); - auto visited_expr = VisitExpr(expr); - - // The compiler ends are added to nodes that does have a region - // AND they should not be arguments of the original function - if (!region.defined() && - std::find(func_->params.begin(), - func_->params.end(), visited_expr) - == func_->params.end()) { - return AddCompilerEnd_(visited_expr); + Expr InsertEnd(const Expr& expr) { + if (annotated_nodes_.find(expr) == annotated_nodes_.end() && + !expr->IsInstance() && !expr->IsInstance()) { + const auto *end_op = + runtime::Registry::Get("relay.op.annotation._make.compiler_end"); + CHECK(end_op); + Expr end = (*end_op)(expr, target_); + return end; } - return visited_expr; + return expr; } - Expr AddCompilerEnd_(const Expr& expr) { - const auto* end_op = - runtime::Registry::Get("relay.op.annotation._make.compiler_end"); - CHECK(end_op); - Expr end = (*end_op)(expr, target_); - return end; + /*! \brief This function adds compiler begins to nodes that + * don't belong to a region already (default). + * \param expr The expression to add a compiler begin to. + * \return expr The expression with or without a compiler begin added. + */ + Expr InsertBegin(const Expr& expr) { + const auto *begin_op = + runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); + CHECK(begin_op); + Expr begin = (*begin_op)(expr, target_); + annotated_nodes_.insert(begin); + return begin; } - Expr VisitExpr_(const CallNode* call) final { - auto op_node = call->op.as(); - auto ret = GetRef(call); + Expr VisitExpr_(const CallNode* cn) final { + auto region = regions_->GetRegion(GetRef(cn)); + auto new_e = ExprMutator::VisitExpr_(cn); + Call call = Downcast(new_e); - Array args; - // Add compiler ends if the parent is supported + // Add compiler ends if the parent isn't annotated + Array args; for (auto arg : call->args) { - args.push_back(AddCompilerEnd(arg)); + args.push_back(InsertEnd(arg)); } - if (op_node == nullptr || call->attrs.as() == nullptr) { - // Skip annotatation ops, only add default compiler to actual compute nodes - - auto region = regions_->GetRegion(ret); - if (!region.defined()) { - // if the current node does not belong to annotated region - // annotate the all incoming edges (args) - // with "default" compile_begin and compiler_end annotations. - tvm::Array compiler_begins; - for (auto arg : args) { - const auto* begin_op = - runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); - CHECK(begin_op); - Expr begin = (*begin_op)(arg, target_); - compiler_begins.push_back(begin); - } - Expr update_call = Call(call->op, compiler_begins, call->attrs); - return update_call; + Expr updated_call = Call(call->op, args, call->attrs); + if (!region.defined()) { + // if the current node does not belong to annotated region + // annotate the all incoming edges (args) + // with "default" compiler_begin annotations. + Array compiler_begins; + for (auto arg : args) { + compiler_begins.push_back(InsertBegin(arg)); } + updated_call = Call(call->op, compiler_begins, call->attrs); + } else { + annotated_nodes_.insert(updated_call); } - return Call(call->op, args, call->attrs); + return updated_call; }; Expr VisitExpr_(const TupleNode *op) { + auto region = regions_->GetRegion(GetRef(op)); auto new_e = ExprMutator::VisitExpr_(op); - auto tup = Downcast(new_e); - Array new_fields; + Tuple tup = Downcast(new_e); + + Array fields; for (auto field : tup->fields) { - new_fields.push_back(AddCompilerEnd(field)); + fields.push_back(InsertEnd(field)); } - return Tuple(new_fields); + + Expr updated_tuple = Tuple(fields); + if (!region.defined()) { + Array compiler_begins; + for (const auto& field : fields) { + compiler_begins.push_back(InsertBegin(field)); + } + updated_tuple = Tuple(compiler_begins); + } else { + annotated_nodes_.insert(updated_tuple); + } + return updated_tuple; } Expr VisitExpr_(const TupleGetItemNode *op) { + auto region = regions_->GetRegion(GetRef(op)); auto new_e = ExprMutator::VisitExpr_(op); auto get = Downcast(new_e); - return TupleGetItem(AddCompilerEnd(get->tuple), get->index); + + auto updated_tuple = InsertEnd(get->tuple); + Expr updated_get = TupleGetItem(updated_tuple, get->index); + if (!region.defined()) { + updated_get = TupleGetItem(InsertBegin(updated_tuple), get->index); + } else { + annotated_nodes_.insert(updated_get); + } + return updated_get; } Expr VisitExpr_(const LetNode *op) { @@ -169,43 +187,44 @@ class AnnotateRestDefault : public ExprMutator { auto let = Downcast(new_e); return Let( let->var, - AddCompilerEnd(let->value), - AddCompilerEnd(let->body)); + InsertEnd(let->value), + InsertEnd(let->body)); } Expr VisitExpr_(const IfNode *op) { auto new_e = ExprMutator::VisitExpr_(op); auto iff = Downcast(new_e); return If( - AddCompilerEnd(iff->cond), - AddCompilerEnd(iff->true_branch), - AddCompilerEnd(iff->false_branch)); + InsertEnd(iff->cond), + InsertEnd(iff->true_branch), + InsertEnd(iff->false_branch)); } Expr VisitExpr_(const RefCreateNode *op) { auto new_e = ExprMutator::VisitExpr_(op); auto create = Downcast(new_e); - return RefCreate(AddCompilerEnd(create->value)); + return RefCreate(InsertEnd(create->value)); } Expr VisitExpr_(const RefReadNode *op) { auto new_e = ExprMutator::VisitExpr_(op); auto read = Downcast(new_e); - return RefRead(AddCompilerEnd(read->ref)); + return RefRead(InsertEnd(read->ref)); } Expr VisitExpr_(const RefWriteNode *op) { auto new_e = ExprMutator::VisitExpr_(op); auto write = Downcast(new_e); return RefWrite( - AddCompilerEnd(write->ref), - AddCompilerEnd(write->value)); + InsertEnd(write->ref), + InsertEnd(write->value)); } private: AnnotatedRegionSet regions_; const std::string target_ = "default"; Function func_; + std::unordered_set annotated_nodes_; }; class MergeAnnotations : public ExprMutator { From 3166c8dca82244c2de5a204be69a784c77aeb1d3 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Tue, 31 Mar 2020 18:19:11 +0100 Subject: [PATCH 2/9] Fixed issue using pre-merged regions Change-Id: I0a844ac59bda1089ae0c67cef52f0b0c7ab2cbd7 --- src/relay/analysis/annotated_region_set.cc | 2 +- src/relay/transforms/merge_compiler_regions.cc | 14 +++++++++++++- tests/python/relay/test_annotate_target.py | 13 +++++++------ 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc index df2eb9643f61..12085975f776 100644 --- a/src/relay/analysis/annotated_region_set.cc +++ b/src/relay/analysis/annotated_region_set.cc @@ -73,7 +73,7 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion region, const Expr& expr) { auto region2 = GetRegion(expr); if (region2.defined()) { - MergeRegions(region, region2); + MergeRegions(region2, region); } else { region->nodes.insert(expr); } diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc index 03be8648ba74..a73c35703078 100644 --- a/src/relay/transforms/merge_compiler_regions.cc +++ b/src/relay/transforms/merge_compiler_regions.cc @@ -258,10 +258,14 @@ class RegionMerger : public ExprVisitor { void VisitExpr_(const CallNode* call) final { if (call->op == compiler_end_op) { auto region = regions_->GetRegion(GetRef(call)); + auto node = (*region->GetOutputs().begin()).as(); + std::string name = ""; + if (node->args[0]->IsInstance()) { + name = node->args[0].as()->op.as()->name; + } // set the region target auto compiler_attrs = call->attrs.as(); region_targets_[region->GetID()] = compiler_attrs->compiler; - std::vector mergeable_regions; // first look at the region args to determine the parent regions for (const auto& arg : region->GetInputs()) { // all args should be begin annotations @@ -275,6 +279,14 @@ class RegionMerger : public ExprVisitor { if (merged_regions_.find(parent_region->GetID()) == merged_regions_.end()) { VisitExpr(begin->args[0]); } + } + // get the mergeable regions now all the parents have been visited + std::vector mergeable_regions; + for (const auto& arg : region->GetInputs()) { + auto begin = Downcast(arg); + CHECK_EQ(begin->op, compiler_begin_op); + auto parent_region = regions_->GetRegion(begin->args[0]); + if (!parent_region.defined()) continue; mergeable_regions.push_back(parent_region); } auto& region_restrictions = region_restrictions_[region->GetID()]; diff --git a/tests/python/relay/test_annotate_target.py b/tests/python/relay/test_annotate_target.py index 98de3f5e9b06..87cf7616e232 100644 --- a/tests/python/relay/test_annotate_target.py +++ b/tests/python/relay/test_annotate_target.py @@ -113,7 +113,8 @@ def expected(dtype, ishape, w1shape): padding=(1, 1), groups=32) end0 = relay.annotation.compiler_end(depthwise_conv2d_1, "dnnl") - begin2 = relay.annotation.compiler_begin(end0, "dnnl") + end1 = relay.annotation.compiler_end(depthwise_conv2d_1, "dnnl") + begin2 = relay.annotation.compiler_begin(end1, "dnnl") begin3 = relay.annotation.compiler_begin(end0, "dnnl") begin4 = relay.annotation.compiler_begin(weight1, "dnnl") depthwise_conv2d_2 = relay.nn.conv2d(begin3, @@ -121,11 +122,11 @@ def expected(dtype, ishape, w1shape): kernel_size=(3, 3), padding=(1, 1), groups=32) - end1 = relay.annotation.compiler_end(depthwise_conv2d_2, "dnnl") - begin5 = relay.annotation.compiler_begin(end1, "dnnl") + end2 = relay.annotation.compiler_end(depthwise_conv2d_2, "dnnl") + begin5 = relay.annotation.compiler_begin(end2, "dnnl") out = relay.add(begin2, begin5) - end2 = relay.annotation.compiler_end(out, "dnnl") - f = relay.Function([data, weight1], end2) + end3 = relay.annotation.compiler_end(out, "dnnl") + f = relay.Function([data, weight1], end3) mod = tvm.IRModule.from_expr(f) return mod @@ -137,7 +138,7 @@ def test_annotate(): mod = annotated(dtype, ishape, w1shape) mod = transform.AnnotateTarget("dnnl")(mod) ref_mod = expected(dtype, ishape, w1shape) - # tvm.ir.assert_structural_equal(mod, ref_mod) + tvm.ir.assert_structural_equal(mod, ref_mod) def test_run(): if not tvm.get_global_func("relay.ext.dnnl", True): From c7ca0c17a922dddf3402af894ec88702030e1bc3 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Wed, 1 Apr 2020 10:07:03 +0100 Subject: [PATCH 3/9] Removed some debugging logic Change-Id: Ib6f2eede6f38bbb270073eb8d4c4dc19f60832c6 --- src/relay/transforms/merge_compiler_regions.cc | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc index a73c35703078..a3776bbf193f 100644 --- a/src/relay/transforms/merge_compiler_regions.cc +++ b/src/relay/transforms/merge_compiler_regions.cc @@ -258,11 +258,6 @@ class RegionMerger : public ExprVisitor { void VisitExpr_(const CallNode* call) final { if (call->op == compiler_end_op) { auto region = regions_->GetRegion(GetRef(call)); - auto node = (*region->GetOutputs().begin()).as(); - std::string name = ""; - if (node->args[0]->IsInstance()) { - name = node->args[0].as()->op.as()->name; - } // set the region target auto compiler_attrs = call->attrs.as(); region_targets_[region->GetID()] = compiler_attrs->compiler; From c7a51eca7c1fd7c6aaa1e979e7aea3c207ab144c Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Wed, 1 Apr 2020 12:05:56 +0100 Subject: [PATCH 4/9] Remove default annotations Change-Id: I9b7696a51c95871491cbea33c40f92ec327e417f --- .../transforms/merge_compiler_regions.cc | 10 +++- .../relay/test_pass_merge_compiler_regions.py | 48 ++++++++----------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc index a3776bbf193f..1f500b70787c 100644 --- a/src/relay/transforms/merge_compiler_regions.cc +++ b/src/relay/transforms/merge_compiler_regions.cc @@ -232,6 +232,14 @@ class MergeAnnotations : public ExprMutator { explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {} Expr VisitExpr_(const CallNode* call) final { + // remove 'default' annotations + auto attrs = call->attrs.as(); + if (attrs != nullptr && attrs->compiler == "default") { + return VisitExpr(call->args[0]); + } + // Merge annotations which are now internal to a region. + // This happens if we see a compiler begin next to a + // compiler end and they're both in the same region. if (call->op == compiler_begin_op) { if (call->args[0]->IsInstance()) { auto arg = Downcast(call->args[0]); @@ -239,7 +247,7 @@ class MergeAnnotations : public ExprMutator { auto region1 = regions_->GetRegion(GetRef(call)); auto region2 = regions_->GetRegion(arg); if (region1 == region2) { - return ExprMutator::VisitExpr(arg->args[0]); + return VisitExpr(arg->args[0]); } } } diff --git a/tests/python/relay/test_pass_merge_compiler_regions.py b/tests/python/relay/test_pass_merge_compiler_regions.py index 364973f0ce8a..31d121acbe6b 100644 --- a/tests/python/relay/test_pass_merge_compiler_regions.py +++ b/tests/python/relay/test_pass_merge_compiler_regions.py @@ -66,13 +66,13 @@ def expected(): O_2 = relay.nn.relu(O_1) ce_3 = compiler_end(O_2, "test") - cb_x = compiler_begin(ce_2, "default") - X = relay.tanh(cb_x) - ce_x1 = compiler_end(X, "default") - ce_x2 = compiler_end(X, "default") + # cb_x = compiler_begin(ce_2, "default") + X = relay.tanh(ce_2) + # ce_x1 = compiler_end(X, "default") + # ce_x2 = compiler_end(X, "default") cb_3 = compiler_begin(ce_3, "test") - cb_4 = compiler_begin(ce_x1, "test") + cb_4 = compiler_begin(X, "test") O_3 = relay.add(cb_3, cb_4) ce_4 = compiler_end(O_3, "test") @@ -162,36 +162,28 @@ def expected(): node1 = relay.add(begin2, begin3) node2 = relay.add(node0, node1) - begin4 = compiler_begin(in_5, "default") - begin5 = compiler_begin(in_6, "default") - begin6 = compiler_begin(in_7, "default") - node3 = relay.subtract(begin4, begin5) - node4 = relay.subtract(begin6, node3) - end0 = compiler_end(node4, "default") - - begin7 = compiler_begin(end0, "test") - begin8 = compiler_begin(in_9, "test") + node3 = relay.subtract(in_5, in_6) + node4 = relay.subtract(in_7, node3) - node5 = relay.add(node2, begin7) + begin4 = compiler_begin(node4, "test") + begin5 = compiler_begin(in_9, "test") + node5 = relay.add(node2, begin4) end1 = compiler_end(node5, "test") - begin9 = compiler_begin(end1, "default") - begin10 = compiler_begin(in_8, "default") - node6 = relay.subtract(begin10, begin9) - end2 = compiler_end(node6, "default") + node6 = relay.subtract(in_8, end1) - node7 = relay.add(begin8, node5) - end3 = compiler_end(node7, "test") - begin11 = compiler_begin(end3, "test") - begin12 = compiler_begin(end2, "test") + node7 = relay.add(begin5, node5) + end2 = compiler_end(node7, "test") + begin6 = compiler_begin(end2, "test") + begin7 = compiler_begin(node6, "test") - node8 = relay.add(begin12, begin11) + node8 = relay.add(begin7, begin6) - begin13 = compiler_begin(in_10, "test") - node9 = relay.add(begin13, node8) - end4 = compiler_end(node9, "test") + begin8 = compiler_begin(in_10, "test") + node9 = relay.add(begin8, node8) + end3 = compiler_end(node9, "test") - f = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end4) + f = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end3) mod = tvm.IRModule.from_expr(f) return mod From f8e378e8431220ec132a3c05abf34df86cfdf5e5 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Wed, 1 Apr 2020 12:19:59 +0100 Subject: [PATCH 5/9] Annotate default 'if's Change-Id: I0098bd1bf6788dd6366810dcefa84f1ebbffaab0 --- .../transforms/merge_compiler_regions.cc | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc index 1f500b70787c..0b7fe5485709 100644 --- a/src/relay/transforms/merge_compiler_regions.cc +++ b/src/relay/transforms/merge_compiler_regions.cc @@ -182,6 +182,26 @@ class AnnotateRestDefault : public ExprMutator { return updated_get; } + Expr VisitExpr_(const IfNode *op) { + auto region = regions_->GetRegion(GetRef(op)); + auto new_e = ExprMutator::VisitExpr_(op); + auto iff = Downcast(new_e); + + if (!region.defined()) { + return If( + InsertBegin(InsertEnd(iff->cond)), + InsertBegin(InsertEnd(iff->true_branch)), + InsertBegin(InsertEnd(iff->false_branch))); + } else { + Expr updated_iff = If( + InsertEnd(iff->cond), + InsertEnd(iff->true_branch), + InsertEnd(iff->false_branch)); + annotated_nodes_.insert(updated_iff); + return updated_iff; + } + } + Expr VisitExpr_(const LetNode *op) { auto new_e = ExprMutator::VisitExpr_(op); auto let = Downcast(new_e); @@ -191,15 +211,6 @@ class AnnotateRestDefault : public ExprMutator { InsertEnd(let->body)); } - Expr VisitExpr_(const IfNode *op) { - auto new_e = ExprMutator::VisitExpr_(op); - auto iff = Downcast(new_e); - return If( - InsertEnd(iff->cond), - InsertEnd(iff->true_branch), - InsertEnd(iff->false_branch)); - } - Expr VisitExpr_(const RefCreateNode *op) { auto new_e = ExprMutator::VisitExpr_(op); auto create = Downcast(new_e); From 79b9c13bd216b78260c15e0f2c3cdf70916eb00d Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Wed, 1 Apr 2020 15:23:59 +0100 Subject: [PATCH 6/9] Clang format Change-Id: I944365cd3080a97a9261f643a8f1efa5a63cf82b --- .../transforms/merge_compiler_regions.cc | 89 ++++++++----------- 1 file changed, 37 insertions(+), 52 deletions(-) diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc index 0b7fe5485709..4a8ff64a24b8 100644 --- a/src/relay/transforms/merge_compiler_regions.cc +++ b/src/relay/transforms/merge_compiler_regions.cc @@ -30,10 +30,10 @@ * as external functions. */ +#include #include #include #include -#include #include #include @@ -44,7 +44,6 @@ #include "../analysis/annotated_region_set.h" - namespace tvm { namespace relay { namespace partitioning { @@ -63,7 +62,7 @@ static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); class AnnotateRestDefault : public ExprMutator { public: explicit AnnotateRestDefault(const Expr& expr) { - regions_ = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op); + regions_ = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op); } Expr Annotate(const Expr& expr) { @@ -78,8 +77,7 @@ class AnnotateRestDefault : public ExprMutator { func_ = Downcast(mutated_expr); // CC1 : add that compiler end after mutation auto body = InsertEnd(func_->body); - func_ = Function(func_->params, body, - body->checked_type_, {}, DictAttrs()); + func_ = Function(func_->params, body, body->checked_type_, {}, DictAttrs()); return Downcast(func_); } return mutated_expr; @@ -91,10 +89,9 @@ class AnnotateRestDefault : public ExprMutator { * \return expr The expression with or without a compiler end added. */ Expr InsertEnd(const Expr& expr) { - if (annotated_nodes_.find(expr) == annotated_nodes_.end() && - !expr->IsInstance() && !expr->IsInstance()) { - const auto *end_op = - runtime::Registry::Get("relay.op.annotation._make.compiler_end"); + if (annotated_nodes_.find(expr) == annotated_nodes_.end() && !expr->IsInstance() && + !expr->IsInstance()) { + const auto* end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end"); CHECK(end_op); Expr end = (*end_op)(expr, target_); return end; @@ -103,13 +100,12 @@ class AnnotateRestDefault : public ExprMutator { } /*! \brief This function adds compiler begins to nodes that - * don't belong to a region already (default). - * \param expr The expression to add a compiler begin to. - * \return expr The expression with or without a compiler begin added. - */ + * don't belong to a region already (default). + * \param expr The expression to add a compiler begin to. + * \return expr The expression with or without a compiler begin added. + */ Expr InsertBegin(const Expr& expr) { - const auto *begin_op = - runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); + const auto* begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); CHECK(begin_op); Expr begin = (*begin_op)(expr, target_); annotated_nodes_.insert(begin); @@ -121,7 +117,6 @@ class AnnotateRestDefault : public ExprMutator { auto new_e = ExprMutator::VisitExpr_(cn); Call call = Downcast(new_e); - // Add compiler ends if the parent isn't annotated Array args; for (auto arg : call->args) { @@ -144,7 +139,7 @@ class AnnotateRestDefault : public ExprMutator { return updated_call; }; - Expr VisitExpr_(const TupleNode *op) { + Expr VisitExpr_(const TupleNode* op) { auto region = regions_->GetRegion(GetRef(op)); auto new_e = ExprMutator::VisitExpr_(op); Tuple tup = Downcast(new_e); @@ -167,7 +162,7 @@ class AnnotateRestDefault : public ExprMutator { return updated_tuple; } - Expr VisitExpr_(const TupleGetItemNode *op) { + Expr VisitExpr_(const TupleGetItemNode* op) { auto region = regions_->GetRegion(GetRef(op)); auto new_e = ExprMutator::VisitExpr_(op); auto get = Downcast(new_e); @@ -182,60 +177,51 @@ class AnnotateRestDefault : public ExprMutator { return updated_get; } - Expr VisitExpr_(const IfNode *op) { + Expr VisitExpr_(const IfNode* op) { auto region = regions_->GetRegion(GetRef(op)); auto new_e = ExprMutator::VisitExpr_(op); auto iff = Downcast(new_e); if (!region.defined()) { - return If( - InsertBegin(InsertEnd(iff->cond)), - InsertBegin(InsertEnd(iff->true_branch)), - InsertBegin(InsertEnd(iff->false_branch))); + return If(InsertBegin(InsertEnd(iff->cond)), InsertBegin(InsertEnd(iff->true_branch)), + InsertBegin(InsertEnd(iff->false_branch))); } else { - Expr updated_iff = If( - InsertEnd(iff->cond), - InsertEnd(iff->true_branch), - InsertEnd(iff->false_branch)); + Expr updated_iff = + If(InsertEnd(iff->cond), InsertEnd(iff->true_branch), InsertEnd(iff->false_branch)); annotated_nodes_.insert(updated_iff); return updated_iff; } } - Expr VisitExpr_(const LetNode *op) { + Expr VisitExpr_(const LetNode* op) { auto new_e = ExprMutator::VisitExpr_(op); auto let = Downcast(new_e); - return Let( - let->var, - InsertEnd(let->value), - InsertEnd(let->body)); + return Let(let->var, InsertEnd(let->value), InsertEnd(let->body)); } - Expr VisitExpr_(const RefCreateNode *op) { + Expr VisitExpr_(const RefCreateNode* op) { auto new_e = ExprMutator::VisitExpr_(op); auto create = Downcast(new_e); return RefCreate(InsertEnd(create->value)); } - Expr VisitExpr_(const RefReadNode *op) { + Expr VisitExpr_(const RefReadNode* op) { auto new_e = ExprMutator::VisitExpr_(op); auto read = Downcast(new_e); return RefRead(InsertEnd(read->ref)); } - Expr VisitExpr_(const RefWriteNode *op) { + Expr VisitExpr_(const RefWriteNode* op) { auto new_e = ExprMutator::VisitExpr_(op); auto write = Downcast(new_e); - return RefWrite( - InsertEnd(write->ref), - InsertEnd(write->value)); + return RefWrite(InsertEnd(write->ref), InsertEnd(write->value)); } private: - AnnotatedRegionSet regions_; - const std::string target_ = "default"; - Function func_; - std::unordered_set annotated_nodes_; + AnnotatedRegionSet regions_; + const std::string target_ = "default"; + Function func_; + std::unordered_set annotated_nodes_; }; class MergeAnnotations : public ExprMutator { @@ -307,8 +293,7 @@ class RegionMerger : public ExprVisitor { for (const auto& parent_region : mergeable_regions) { // add all the parent restrictions to the current region auto parent_restrictions = region_restrictions_[parent_region->GetID()]; - region_restrictions.insert(parent_restrictions.begin(), - parent_restrictions.end()); + region_restrictions.insert(parent_restrictions.begin(), parent_restrictions.end()); } for (const auto& parent_region : mergeable_regions) { bool merged = false; @@ -318,7 +303,8 @@ class RegionMerger : public ExprVisitor { if (region_restrictions.find(parent_region->GetID()) == region_restrictions.end()) { // merge the parent region into the current region regions_->MergeRegions(parent_region, region); - // update the restrictions of all other regions to reflect the change in id + // update the restrictions of all other regions to reflect the + // change in id for (const auto& r : regions_) { auto& restrictions = region_restrictions_[r->GetID()]; if (restrictions.find(parent_region->GetID()) != restrictions.end()) { @@ -329,9 +315,9 @@ class RegionMerger : public ExprVisitor { merged = true; } } - // if the parent wasn't merged, add it as a restriction to the current region - if (!merged) - region_restrictions.insert(parent_region->GetID()); + // if the parent wasn't merged, add it as a restriction to the current + // region + if (!merged) region_restrictions.insert(parent_region->GetID()); } merged_regions_.insert(region->GetID()); } @@ -345,15 +331,14 @@ class RegionMerger : public ExprVisitor { std::map region_targets_; }; - Expr MergeCompilerRegions(const Expr& expr) { // Annotate all the nodes that aren't annotated as 'default'. AnnotateRestDefault anno_default(expr); auto expr_all_annotated = anno_default.Annotate(expr); // Create regions using the annotations. - AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr_all_annotated, - compiler_begin_op, compiler_end_op); + AnnotatedRegionSet regions = + AnnotatedRegionSet::Create(expr_all_annotated, compiler_begin_op, compiler_end_op); // By now, all the nodes have some sort of annotation. // Region merger is an ExprVisitor that will update the @@ -381,7 +366,7 @@ Pass MergeCompilerRegions() { } TVM_REGISTER_GLOBAL("relay._transform.MergeCompilerRegions") -.set_body_typed(transform::MergeCompilerRegions); + .set_body_typed(transform::MergeCompilerRegions); } // namespace transform From a29399749ef8004eae1940251e4eabd794c64001 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Wed, 1 Apr 2020 15:49:22 +0100 Subject: [PATCH 7/9] Use src/dest in merge Change-Id: Ie43113492bda8f1ce63eaf9615cb645bb9e2ee86 --- src/relay/analysis/annotated_region_set.cc | 10 +++++----- src/relay/analysis/annotated_region_set.h | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc index 12085975f776..f7b9b421a767 100644 --- a/src/relay/analysis/annotated_region_set.cc +++ b/src/relay/analysis/annotated_region_set.cc @@ -70,12 +70,12 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, regions_.erase(src); } -void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion region, const Expr& expr) { - auto region2 = GetRegion(expr); - if (region2.defined()) { - MergeRegions(region2, region); +void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr) { + auto src = GetRegion(expr); + if (src.defined()) { + MergeRegions(src, dest); } else { - region->nodes.insert(expr); + dest->nodes.insert(expr); } } diff --git a/src/relay/analysis/annotated_region_set.h b/src/relay/analysis/annotated_region_set.h index c5db2cc3d202..0b9301133d1c 100644 --- a/src/relay/analysis/annotated_region_set.h +++ b/src/relay/analysis/annotated_region_set.h @@ -178,10 +178,10 @@ class AnnotatedRegionSetNode : public Object { /*! * \brief Add an expression to a region. * - * \param region The region to add the expression to. + * \param dest The region to add the expression to. * \param expr The expression. */ - void AddToRegion(AnnotatedRegion region, const Expr& expr); + void AddToRegion(AnnotatedRegion dest, const Expr& expr); /*! * \brief Make a new region. From 763bca62cd71115a56bfdf838fe54cd0029b757b Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Wed, 1 Apr 2020 16:55:02 +0100 Subject: [PATCH 8/9] Fixed partition test Change-Id: I46f9e349b1a813a9140f7e4f8a2241687e2df73b --- .../python/relay/test_pass_partition_graph.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 0dfc89d469ca..9d4d71179fd7 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -725,12 +725,12 @@ def expected(): mod = tvm.IRModule() # function 0 - data = relay.var("test_target_0_i0", relay.TensorType((1, 3, 224, 224), "float32")) - weight = relay.var("test_target_0_i1", relay.TensorType((16, 3, 3, 3), "float32")) - bn_gamma = relay.var("test_target_0_i2", relay.TensorType((16, ), "float32")) - bn_beta = relay.var("test_target_0_i3", relay.TensorType((16, ), "float32")) - bn_mean = relay.var("test_target_0_i4", relay.TensorType((16, ), "float32")) - bn_var = relay.var("test_target_0_i5", relay.TensorType((16, ), "float32")) + data = relay.var("test_target_2_i0", relay.TensorType((1, 3, 224, 224), "float32")) + weight = relay.var("test_target_2_i1", relay.TensorType((16, 3, 3, 3), "float32")) + bn_gamma = relay.var("test_target_2_i2", relay.TensorType((16, ), "float32")) + bn_beta = relay.var("test_target_2_i3", relay.TensorType((16, ), "float32")) + bn_mean = relay.var("test_target_2_i4", relay.TensorType((16, ), "float32")) + bn_var = relay.var("test_target_2_i5", relay.TensorType((16, ), "float32")) conv_o = relay.nn.conv2d( data=data, @@ -743,7 +743,7 @@ def expected(): bn_var) relu_o = relay.nn.relu(bn_o[0]) - tuple_o = relay.Tuple((relu_o, bn_o[1], bn_o[2])) + tuple_o = relay.Tuple((bn_o[2], bn_o[1], relu_o)) func0 = relay.Function([data, weight, bn_gamma, bn_beta, bn_mean, bn_var], tuple_o) @@ -752,8 +752,8 @@ def expected(): func0 = func0.with_attr("Compiler", tvm.tir.StringImm("test_target")) func0 = func0.with_attr("ExternalSymbol", - tvm.tir.StringImm("test_target_0")) - gv0 = relay.GlobalVar("test_target_0") + tvm.tir.StringImm("test_target_2")) + gv0 = relay.GlobalVar("test_target_2") mod[gv0] = func0 # body @@ -765,9 +765,9 @@ def expected(): bn_var = relay.var("bn_var", relay.TensorType((16, ), "float32")) f0_o = gv0(data, weight, bn_gamma, bn_beta, bn_mean, bn_var) - f0_relu_o = relay.TupleGetItem(f0_o, 0) + f0_relu_o = relay.TupleGetItem(f0_o, 2) f0_mean_o = relay.TupleGetItem(f0_o, 1) - f0_var_o = relay.TupleGetItem(f0_o, 2) + f0_var_o = relay.TupleGetItem(f0_o, 0) f0_mean_abs = relay.abs(f0_mean_o) f0_var_abs = relay.abs(f0_var_o) From 6c02abecda867d116e3c76a217a4a6b567657cc3 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Wed, 1 Apr 2020 17:16:00 +0100 Subject: [PATCH 9/9] Removed comments Change-Id: I309afdd1951d7e796e41d13788aa487707e0ac4c --- tests/python/relay/test_pass_merge_compiler_regions.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/python/relay/test_pass_merge_compiler_regions.py b/tests/python/relay/test_pass_merge_compiler_regions.py index 31d121acbe6b..f316a41a88da 100644 --- a/tests/python/relay/test_pass_merge_compiler_regions.py +++ b/tests/python/relay/test_pass_merge_compiler_regions.py @@ -66,10 +66,7 @@ def expected(): O_2 = relay.nn.relu(O_1) ce_3 = compiler_end(O_2, "test") - # cb_x = compiler_begin(ce_2, "default") X = relay.tanh(ce_2) - # ce_x1 = compiler_end(X, "default") - # ce_x2 = compiler_end(X, "default") cb_3 = compiler_begin(ce_3, "test") cb_4 = compiler_begin(X, "test")