Skip to content

Commit

Permalink
[RELAY] Fixes to MergeCompilerRegions (#5195)
Browse files Browse the repository at this point in the history
* [RELAY] Fixed issues with MergeCompilerRegions

This PR addresses a few outstanding issues with
the implementation of MergeCompilerRegions. In
particular, it now handles TupleGetItem nodes properly
and other minor bugs related to region merging have
been fixed.

Change-Id: I07783afc56183a6f798a510209f23b0a5f252255

* Fixed issue using pre-merged regions

Change-Id: I0a844ac59bda1089ae0c67cef52f0b0c7ab2cbd7

* Removed some debugging logic

Change-Id: Ib6f2eede6f38bbb270073eb8d4c4dc19f60832c6

* Remove default annotations

Change-Id: I9b7696a51c95871491cbea33c40f92ec327e417f

* Annotate default 'if's

Change-Id: I0098bd1bf6788dd6366810dcefa84f1ebbffaab0

* Clang format

Change-Id: I944365cd3080a97a9261f643a8f1efa5a63cf82b

* Use src/dest in merge

Change-Id: Ie43113492bda8f1ce63eaf9615cb645bb9e2ee86

* Fixed partition test

Change-Id: I46f9e349b1a813a9140f7e4f8a2241687e2df73b

* Removed comments

Change-Id: I309afdd1951d7e796e41d13788aa487707e0ac4c
  • Loading branch information
mbaret authored Apr 1, 2020
1 parent 2f41a39 commit 0449966
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 149 deletions.
10 changes: 5 additions & 5 deletions src/relay/analysis/annotated_region_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
regions_.erase(src);
}

void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion region, const Expr& expr) {
auto region2 = GetRegion(expr);
if (region2.defined()) {
MergeRegions(region, region2);
void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr) {
auto src = GetRegion(expr);
if (src.defined()) {
MergeRegions(src, dest);
} else {
region->nodes.insert(expr);
dest->nodes.insert(expr);
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/relay/analysis/annotated_region_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,10 @@ class AnnotatedRegionSetNode : public Object {
/*!
* \brief Add an expression to a region.
*
* \param region The region to add the expression to.
* \param dest The region to add the expression to.
* \param expr The expression.
*/
void AddToRegion(AnnotatedRegion region, const Expr& expr);
void AddToRegion(AnnotatedRegion dest, const Expr& expr);

/*!
* \brief Make a new region.
Expand Down
21 changes: 18 additions & 3 deletions src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ namespace tvm {
namespace relay {
namespace annotate_target {

// Cache compiler_begin op for equivalence check.
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");

// A helper class to insert annotation boundaries for a program region that will
// be handled by a specific compiler.
class AnnotateTargetWrapper : public ExprMutator {
Expand All @@ -52,6 +55,13 @@ class AnnotateTargetWrapper : public ExprMutator {
return fannotate[op](call->attrs, call->args);
}
}
if (expr->IsInstance<TupleGetItemNode>()) {
TupleGetItem get = Downcast<TupleGetItem>(expr);
if (get->tuple->IsInstance<CallNode>() &&
get->tuple.as<CallNode>()->op == compiler_begin_op) {
return true;
}
}
return false;
}

Expand Down Expand Up @@ -110,9 +120,14 @@ class AnnotateTargetWrapper : public ExprMutator {
auto new_e = ExprMutator::VisitExpr_(op);

auto get = Downcast<TupleGetItem>(new_e);
return TupleGetItem(
InsertEnd(get->tuple),
get->index);
if (IsSupported(get->tuple)) {
const auto* begin_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
CHECK(begin_op);
return TupleGetItem((*begin_op)(InsertEnd(get->tuple), target_), get->index);
} else {
return TupleGetItem(InsertEnd(get->tuple), get->index);
}
}

Expr VisitExpr_(const FunctionNode* op) {
Expand Down
Loading

0 comments on commit 0449966

Please sign in to comment.