Skip to content

Commit

Permalink
[RELAY] Annotate compiler_ends on each edge
Browse files Browse the repository at this point in the history
This alters the behaviour of the AnnotateTarget
pass to enforce the property that all compiler
annotations exist along a single data flow edge.
Specifically, this means they should have exactly
one parent and one child.

Change-Id: I0e74803a77767f4f377d17755a13a74a30909797
  • Loading branch information
mbaret committed Mar 23, 2020
1 parent 88146cd commit 451a031
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 26 deletions.
150 changes: 124 additions & 26 deletions src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,46 +38,144 @@ class AnnotateTargetWrapper : public ExprMutator {
public:
explicit AnnotateTargetWrapper(const std::string& target) : target_(target) {}

Expr Annotate(const Expr& expr) {
return InsertEnd(Mutate(expr));
}

bool IsSupported(const Expr& expr) {
if (expr->IsInstance<CallNode>()) {
Call call = Downcast<Call>(expr);
auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
Op op = Downcast<Op>(call->op);
CHECK(op.defined());
if (fannotate.count(op)) {
return fannotate[op](call->attrs, call->args);
}
}
return false;
}

Expr InsertEnd(const Expr& arg) {
if (IsSupported(arg)) {
const auto *end_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_end");
CHECK(end_op);
Expr end = (*end_op)(arg, target_);
return end;
}
return arg;
}

Expr VisitExpr_(const CallNode* cn) {
// TODO(@zhiics, @comaniac) Handle composite functions.
auto new_e = ExprMutator::VisitExpr_(cn);

Call call = Downcast<Call>(new_e);
static auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
Op op = Downcast<Op>(call->op);
CHECK(op.defined());

if (fannotate.count(op)) {
bool external = fannotate[op](call->attrs, call->args);
if (external) {
tvm::Array<tvm::relay::Expr> compiler_begins;
for (const auto& it : call->args) {
const auto* begin_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
CHECK(begin_op);
Expr begin = (*begin_op)(it, target_);
compiler_begins.push_back(begin);
}
Expr update_call = CallNode::make(call->op, compiler_begins, call->attrs);
const auto* end_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_end");
CHECK(end_op);
Expr end = (*end_op)(update_call, target_);
return end;

// add end annotations if the args are supported
Array<Expr> compiler_ends;
for (const auto& it : call->args) {
compiler_ends.push_back(InsertEnd(it));
}
call = CallNode::make(call->op, compiler_ends, call->attrs);

// add begin annotations if the call node is supported
if (IsSupported(call)) {
tvm::Array<tvm::relay::Expr> compiler_begins;
for (const auto& it : call->args) {
const auto* begin_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
CHECK(begin_op);
Expr begin = (*begin_op)(it, target_);
compiler_begins.push_back(begin);
}
} else {
LOG(WARNING) << op->name << " in " << target_
<< " is not registered. It will be executed on CPU.";
call = CallNode::make(call->op, compiler_begins, call->attrs);
}
return new_e;

return std::move(call);
}

Expr VisitExpr_(const TupleNode *op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto tup = Downcast<Tuple>(new_e);
Array<Expr> new_fields;
for (auto field : tup->fields) {
new_fields.push_back(InsertEnd(field));
}
return TupleNode::make(new_fields);
}

Expr VisitExpr_(const TupleGetItemNode *op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto get = Downcast<TupleGetItem>(new_e);
return TupleGetItemNode::make(
InsertEnd(get->tuple),
get->index);
}

Expr VisitExpr_(const FunctionNode *op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto func = Downcast<Function>(new_e);
return Function(
func->params,
InsertEnd(func->body),
func->ret_type,
func->type_params,
func->attrs);
}

Expr VisitExpr_(const LetNode *op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto let = Downcast<Let>(new_e);
return LetNode::make(
let->var,
InsertEnd(let->value),
InsertEnd(let->body));
}

Expr VisitExpr_(const IfNode *op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto iff = Downcast<If>(new_e);
return IfNode::make(
InsertEnd(iff->cond),
InsertEnd(iff->true_branch),
InsertEnd(iff->false_branch));
}

Expr VisitExpr_(const RefCreateNode *op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto create = Downcast<RefCreate>(new_e);
return RefCreateNode::make(InsertEnd(create->value));
}

Expr VisitExpr_(const RefReadNode *op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto read = Downcast<RefRead>(new_e);
return RefReadNode::make(InsertEnd(read->ref));
}

Expr VisitExpr_(const RefWriteNode *op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto write = Downcast<RefWrite>(new_e);
return RefWriteNode::make(
InsertEnd(write->ref),
InsertEnd(write->value));
}

private:
std::string target_;
};

Expr AnnotateTarget(const Expr& expr, const std::string& target) {
return AnnotateTargetWrapper(target).Mutate(expr);
return AnnotateTargetWrapper(target).Annotate(expr);
}

} // namespace annotate_target
Expand Down
36 changes: 36 additions & 0 deletions tests/python/relay/test_annotate_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import tvm
import tvm.relay.testing
import tvm.relay.op as reg
import tvm.relay.transform as transform
from tvm import relay
from tvm import runtime
Expand Down Expand Up @@ -183,6 +184,41 @@ def test_extern_dnnl_mobilenet():
(1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)


@reg.register("nn.relu", "target.test")
def relu(attrs, args):
return True


def test_multiple_ends():
def before():
x = relay.var("x", shape=(10, 10))
r = relay.nn.relu(x)
a_1 = relay.abs(r)
a_2 = relay.abs(r)
out = relay.add(a_1, a_2)
f = relay.Function([x], out)
mod = tvm.IRModule.from_expr(f)
return mod

def after():
x = relay.var("x", shape=(10, 10))
cb_1 = relay.annotation.compiler_begin(x, "test")
r = relay.nn.relu(cb_1)
ce_1 = relay.annotation.compiler_end(r, "test")
ce_2 = relay.annotation.compiler_end(r, "test")
a_1 = relay.abs(ce_1)
a_2 = relay.abs(ce_2)
out = relay.add(a_1, a_2)
f = relay.Function([x], out)
mod = tvm.IRModule.from_expr(f)
return mod

result = transform.AnnotateTarget("test")(before())
expected = transform.InferType()(after())
assert relay.analysis.alpha_equal(expected, result)


if __name__ == "__main__":
test_multiple_ends()
test_extern_dnnl()
test_extern_dnnl_mobilenet()

0 comments on commit 451a031

Please sign in to comment.