diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 12c10e5455ffc..5af9b99ae0ed8 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -33,8 +33,12 @@ namespace tvm { namespace relay { namespace annotate_target { -const PackedFunc* begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); -const PackedFunc* end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end"); +static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); +static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); + +const PackedFunc* make_begin_op = + runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); +const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end"); // A helper class to insert annotation boundaries for a program region that will // be handled by a specific compiler. @@ -59,18 +63,32 @@ class AnnotateTargetWrapper : public ExprMutator { std::string ref_target = ""; Array compiler_ends; for (auto arg : args) { - if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) { - std::string arg_target = op_expr_to_target_[arg]; - compiler_ends.push_back(InsertAnnotation(arg, arg_target, end_op)); - if (ref_target == "") { - ref_target = arg_target; - } else if (ref_target != arg_target) { - ref_target = "default"; + std::string arg_target = "defualt"; + const CallNode* call = arg.as(); + + if (call && call->op == compiler_begin_op) { + // Argument is already compiler begin node meaning that this is not the first time + // running this pass, so we simply remove it and will add a new one later. + CHECK_EQ(call->args.size(), 1U); + const CallNode* end = call->args[0].as(); + if (end->op == compiler_end_op) { + arg_target = end->attrs.as()->compiler; } + compiler_ends.push_back(call->args[0]); + } else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) { + arg_target = op_expr_to_target_[arg]; + compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op)); } else { // Input vars. compiler_ends.push_back(arg); } + + // Maintain reference target in case the target of the current node is unassigned. + if (ref_target == "") { + ref_target = arg_target; + } else if (ref_target != arg_target) { + ref_target = "default"; + } } // Determine compiler begin target. @@ -78,7 +96,7 @@ class AnnotateTargetWrapper : public ExprMutator { Array compiler_begins; for (const auto& end : compiler_ends) { - compiler_begins.push_back(InsertAnnotation(end, op_target, begin_op)); + compiler_begins.push_back(InsertAnnotation(end, op_target, make_begin_op)); } return {op_target, compiler_begins}; @@ -94,8 +112,34 @@ class AnnotateTargetWrapper : public ExprMutator { // Supported targets for this node. The order implies the priority. std::vector supported_targets; + auto op_node = cn->op.as(); + + // This graph has annotations, meaning that this is not the first time running this pass. + if (op_node && cn->op == compiler_begin_op) { + // Bypass compiler begin due to lack of target information. It will be processed + // when the following op handling arguments. + CHECK_EQ(cn->args.size(), 1U); + return VisitExpr(cn->args[0]); + } else if (op_node && cn->op == compiler_end_op) { + // Override compiler end with the new target. + CHECK_EQ(cn->args.size(), 1U); + auto input_expr = VisitExpr(cn->args[0]); + CHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end()); + return InsertAnnotation(input_expr, op_expr_to_target_[input_expr], make_end_op); + } + + // Peek the first argument. If it is compiler begin then this node had annotated by + // another target before, so we also consider that target as a supported target. + const CallNode* first_arg_call = cn->args[0].as(); + if (first_arg_call && first_arg_call->op == compiler_begin_op) { + std::string arg_target = first_arg_call->attrs.as()->compiler; + if (arg_target != "default") { + supported_targets.push_back(arg_target); + } + } + // Check which targets this op can be offloaded. - if (cn->op->IsInstance()) { + if (op_node) { // TVM operators: Check target specific op checking function and add to supported_targets // if it is supported. Op op = Downcast(cn->op); @@ -179,7 +223,7 @@ class AnnotateTargetWrapper : public ExprMutator { func = Downcast(new_e); new_body = func->body; if (op_expr_to_target_.find(func->body) != op_expr_to_target_.end()) { - new_body = InsertAnnotation(func->body, op_expr_to_target_[func->body], end_op); + new_body = InsertAnnotation(func->body, op_expr_to_target_[func->body], make_end_op); op_expr_to_target_[new_body] = op_expr_to_target_[func->body]; } } diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index ca7381754cc2d..11eb3d8859abe 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -338,6 +338,32 @@ def after(): assert tvm.ir.structural_equal(expected, result) +def test_multiple_runs(): + @reg.register("nn.relu", "target.A") + def relu(attrs, args): # pylint: disable=unused-variable + return True + + @reg.register("add", "target.B") + def add(attrs, args): # pylint: disable=unused-variable + return True + + def before(): + x = relay.var("x", shape=(10, 5)) + a_1 = relay.nn.relu(x) + a_2 = relay.abs(a_1) + a_3 = relay.nn.relu(a_1) + out = relay.add(a_2, a_3) + + f = relay.Function([x], out) + mod = tvm.IRModule.from_expr(f) + return mod + + mod = transform.AnnotateTarget("A")(before()) + mod = transform.AnnotateTarget("B")(mod) + expected = transform.AnnotateTarget(["A", "B"])(before()) + assert tvm.ir.structural_equal(expected, mod) + + if __name__ == "__main__": test_extern_dnnl() test_composite_function() @@ -345,3 +371,4 @@ def after(): test_multiple_ends() test_type_propagation() test_tuple() + test_multiple_runs()