Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY] Add MergeCompilerRegions pass #5134

Merged
merged 18 commits into from
Mar 30, 2020
11 changes: 11 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,17 @@ def MergeComposite(pattern_table):
return _ffi_api.MergeComposite(pattern_names, patterns)


def MergeCompilerRegions():
"""Merge together compiler regions.
Returns
-------
ret : tvm.relay.Pass
The registered pass that merges compiler regions.
"""
return _ffi_api.MergeCompilerRegions()


def RewriteAnnotatedOps(fallback_device):
"""Rewrite the annotated program where annotation operators, e.g.
`on_deivce`, mark which device an expression should be scheduled to.
Expand Down
6 changes: 5 additions & 1 deletion src/relay/analysis/annotated_region_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,18 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
}
// if any of the outputs of src are inputs of dest, they become internal nodes
// so remove them from outs
std::vector<Expr> ins_to_remove;
for (const auto& input : dest->ins) {
auto call = Downcast<Call>(input);
auto it = std::find(src->outs.begin(), src->outs.end(), call->args[0]);
if (it != src->outs.end()) {
dest->outs.remove(*it);
dest->ins.remove(input);
ins_to_remove.push_back(input);
}
}
for (const auto& input : ins_to_remove) {
dest->ins.remove(input);
}
regions_.erase(src);
}

Expand Down
150 changes: 124 additions & 26 deletions src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,46 +38,144 @@ class AnnotateTargetWrapper : public ExprMutator {
public:
explicit AnnotateTargetWrapper(const std::string& target) : target_(target) {}

Expr Annotate(const Expr& expr) {
return InsertEnd(Mutate(expr));
}

bool IsSupported(const Expr& expr) {
if (expr->IsInstance<CallNode>()) {
Call call = Downcast<Call>(expr);
auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
Op op = Downcast<Op>(call->op);
CHECK(op.defined());
if (fannotate.count(op)) {
return fannotate[op](call->attrs, call->args);
}
}
return false;
}

Expr InsertEnd(const Expr& arg) {
if (IsSupported(arg)) {
const auto *end_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_end");
CHECK(end_op);
Expr end = (*end_op)(arg, target_);
return end;
}
return arg;
}

Expr VisitExpr_(const CallNode* cn) {
// TODO(@zhiics, @comaniac) Handle composite functions.
auto new_e = ExprMutator::VisitExpr_(cn);

Call call = Downcast<Call>(new_e);
auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
Op op = Downcast<Op>(call->op);
CHECK(op.defined());

if (fannotate.count(op)) {
bool external = fannotate[op](call->attrs, call->args);
if (external) {
tvm::Array<tvm::relay::Expr> compiler_begins;
for (const auto& it : call->args) {
const auto* begin_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
CHECK(begin_op);
Expr begin = (*begin_op)(it, target_);
compiler_begins.push_back(begin);
}
Expr update_call = Call(call->op, compiler_begins, call->attrs);
const auto* end_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_end");
CHECK(end_op);
Expr end = (*end_op)(update_call, target_);
return end;

// add end annotations if the args are supported
Array<Expr> compiler_ends;
for (const auto& it : call->args) {
compiler_ends.push_back(InsertEnd(it));
}
call = Call(call->op, compiler_ends, call->attrs);

// add begin annotations if the call node is supported
if (IsSupported(call)) {
tvm::Array<tvm::relay::Expr> compiler_begins;
const auto* begin_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
for (const auto& it : call->args) {
CHECK(begin_op);
Expr begin = (*begin_op)(it, target_);
compiler_begins.push_back(begin);
}
} else {
LOG(WARNING) << op->name << " in " << target_
<< " is not registered. It will be executed on CPU.";
call = Call(call->op, compiler_begins, call->attrs);
}
return new_e;

return std::move(call);
}

Expr VisitExpr_(const TupleNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto tup = Downcast<Tuple>(new_e);
Array<Expr> new_fields;
for (auto field : tup->fields) {
new_fields.push_back(InsertEnd(field));
}
return Tuple(new_fields);
}

Expr VisitExpr_(const TupleGetItemNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto get = Downcast<TupleGetItem>(new_e);
return TupleGetItem(
InsertEnd(get->tuple),
get->index);
}

Expr VisitExpr_(const FunctionNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto func = Downcast<Function>(new_e);
return Function(
func->params,
InsertEnd(func->body),
func->ret_type,
func->type_params,
func->attrs);
}

Expr VisitExpr_(const LetNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto let = Downcast<Let>(new_e);
return Let(
let->var,
InsertEnd(let->value),
InsertEnd(let->body));
}

Expr VisitExpr_(const IfNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto iff = Downcast<If>(new_e);
return If(
InsertEnd(iff->cond),
InsertEnd(iff->true_branch),
InsertEnd(iff->false_branch));
}

Expr VisitExpr_(const RefCreateNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto create = Downcast<RefCreate>(new_e);
return RefCreate(InsertEnd(create->value));
}

Expr VisitExpr_(const RefReadNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto read = Downcast<RefRead>(new_e);
return RefRead(InsertEnd(read->ref));
}

Expr VisitExpr_(const RefWriteNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto write = Downcast<RefWrite>(new_e);
return RefWrite(
InsertEnd(write->ref),
InsertEnd(write->value));
}

private:
std::string target_;
};

Expr AnnotateTarget(const Expr& expr, const std::string& target) {
return AnnotateTargetWrapper(target).Mutate(expr);
return AnnotateTargetWrapper(target).Annotate(expr);
}

} // namespace annotate_target
Expand Down
Loading