Skip to content

Commit

Permalink
support AnnotateTarget multiple runs
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed Apr 8, 2020
1 parent dbd6301 commit c952722
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 12 deletions.
68 changes: 56 additions & 12 deletions src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@ namespace tvm {
namespace relay {
namespace annotate_target {

const PackedFunc* begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
const PackedFunc* end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end");
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end");

const PackedFunc* make_begin_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end");

// A helper class to insert annotation boundaries for a program region that will
// be handled by a specific compiler.
Expand All @@ -59,26 +63,40 @@ class AnnotateTargetWrapper : public ExprMutator {
std::string ref_target = "";
Array<Expr> compiler_ends;
for (auto arg : args) {
if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
std::string arg_target = op_expr_to_target_[arg];
compiler_ends.push_back(InsertAnnotation(arg, arg_target, end_op));
if (ref_target == "") {
ref_target = arg_target;
} else if (ref_target != arg_target) {
ref_target = "default";
std::string arg_target = "defualt";
const CallNode* call = arg.as<CallNode>();

if (call && call->op == compiler_begin_op) {
// Argument is already compiler begin node meaning that this is not the first time
// running this pass, so we simply remove it and will add a new one later.
CHECK_EQ(call->args.size(), 1U);
const CallNode* end = call->args[0].as<CallNode>();
if (end->op == compiler_end_op) {
arg_target = end->attrs.as<CompilerAttrs>()->compiler;
}
compiler_ends.push_back(call->args[0]);
} else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
arg_target = op_expr_to_target_[arg];
compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op));
} else {
// Input vars.
compiler_ends.push_back(arg);
}

// Maintain reference target in case the target of the current node is unassigned.
if (ref_target == "") {
ref_target = arg_target;
} else if (ref_target != arg_target) {
ref_target = "default";
}
}

// Determine compiler begin target.
std::string op_target = (target == "") ? ref_target : target;

Array<Expr> compiler_begins;
for (const auto& end : compiler_ends) {
compiler_begins.push_back(InsertAnnotation(end, op_target, begin_op));
compiler_begins.push_back(InsertAnnotation(end, op_target, make_begin_op));
}

return {op_target, compiler_begins};
Expand All @@ -94,8 +112,34 @@ class AnnotateTargetWrapper : public ExprMutator {
// Supported targets for this node. The order implies the priority.
std::vector<std::string> supported_targets;

auto op_node = cn->op.as<OpNode>();

// This graph has annotations, meaning that this is not the first time running this pass.
if (op_node && cn->op == compiler_begin_op) {
// Bypass compiler begin due to lack of target information. It will be processed
// when the following op handling arguments.
CHECK_EQ(cn->args.size(), 1U);
return VisitExpr(cn->args[0]);
} else if (op_node && cn->op == compiler_end_op) {
// Override compiler end with the new target.
CHECK_EQ(cn->args.size(), 1U);
auto input_expr = VisitExpr(cn->args[0]);
CHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end());
return InsertAnnotation(input_expr, op_expr_to_target_[input_expr], make_end_op);
}

// Peek the first argument. If it is compiler begin then this node had annotated by
// another target before, so we also consider that target as a supported target.
const CallNode* first_arg_call = cn->args[0].as<CallNode>();
if (first_arg_call && first_arg_call->op == compiler_begin_op) {
std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler;
if (arg_target != "default") {
supported_targets.push_back(arg_target);
}
}

// Check which targets this op can be offloaded.
if (cn->op->IsInstance<OpNode>()) {
if (op_node) {
// TVM operators: Check target specific op checking function and add to supported_targets
// if it is supported.
Op op = Downcast<Op>(cn->op);
Expand Down Expand Up @@ -179,7 +223,7 @@ class AnnotateTargetWrapper : public ExprMutator {
func = Downcast<Function>(new_e);
new_body = func->body;
if (op_expr_to_target_.find(func->body) != op_expr_to_target_.end()) {
new_body = InsertAnnotation(func->body, op_expr_to_target_[func->body], end_op);
new_body = InsertAnnotation(func->body, op_expr_to_target_[func->body], make_end_op);
op_expr_to_target_[new_body] = op_expr_to_target_[func->body];
}
}
Expand Down
27 changes: 27 additions & 0 deletions tests/python/relay/test_pass_annotate_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,37 @@ def after():
assert tvm.ir.structural_equal(expected, result)


def test_multiple_runs():
@reg.register("nn.relu", "target.A")
def relu(attrs, args): # pylint: disable=unused-variable
return True

@reg.register("add", "target.B")
def add(attrs, args): # pylint: disable=unused-variable
return True

def before():
x = relay.var("x", shape=(10, 5))
a_1 = relay.nn.relu(x)
a_2 = relay.abs(a_1)
a_3 = relay.nn.relu(a_1)
out = relay.add(a_2, a_3)

f = relay.Function([x], out)
mod = tvm.IRModule.from_expr(f)
return mod

mod = transform.AnnotateTarget("A")(before())
mod = transform.AnnotateTarget("B")(mod)
expected = transform.AnnotateTarget(["A", "B"])(before())
assert tvm.ir.structural_equal(expected, mod)


if __name__ == "__main__":
test_extern_dnnl()
test_composite_function()
#test_extern_dnnl_mobilenet()
test_multiple_ends()
test_type_propagation()
test_tuple()
test_multiple_runs()

0 comments on commit c952722

Please sign in to comment.