diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index b546f05b46e4e..1a90664415cc7 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -49,10 +49,25 @@ class AnnotateTargetWrapper : public ExprMutator { if (expr->IsInstance()) { Call call = Downcast(expr); auto fannotate = Op::GetAttr("target." + target_); - Op op = Downcast(call->op); - CHECK(op.defined()); - if (fannotate.count(op)) { - return fannotate[op](call->attrs, call->args); + if (call->op->IsInstance()) { + Op op = Downcast(call->op); + CHECK(op.defined()); + if (fannotate.count(op)) { + return fannotate[op](call->attrs, call->args); + } + } + // handle composite functions + else if (call->op->IsInstance()) { + Function func = Downcast(call->op); + CHECK(func.defined()); + auto comp_name = func->GetAttr(attr::kComposite); + if (comp_name.defined()) { + size_t i = comp_name->value.find('.'); + if(i != std::string::npos) { + std::string target = comp_name->value.substr(0, i); + if (target == target_) return true; + } + } } } if (expr->IsInstance()) { @@ -77,7 +92,6 @@ class AnnotateTargetWrapper : public ExprMutator { } Expr VisitExpr_(const CallNode* cn) { - // TODO(@zhiics, @comaniac) Handle composite functions. auto new_e = ExprMutator::VisitExpr_(cn); Call call = Downcast(new_e); @@ -130,13 +144,22 @@ class AnnotateTargetWrapper : public ExprMutator { } } - Expr VisitExpr_(const FunctionNode* op) { - auto new_e = ExprMutator::VisitExpr_(op); + Expr VisitExpr_(const FunctionNode* fn) { + Function func; + Expr new_body; + // don't step into composite functions + if (fn->GetAttr(attr::kComposite).defined()) { + func = GetRef(fn); + new_body = func->body; + } else { + auto new_e = ExprMutator::VisitExpr_(fn); + func = Downcast(new_e); + new_body = InsertEnd(func->body); + } - auto func = Downcast(new_e); return Function( func->params, - InsertEnd(func->body), + new_body, func->ret_type, func->type_params, func->attrs); diff --git a/tests/python/relay/test_annotate_target.py b/tests/python/relay/test_annotate_target.py index 87cf7616e232e..0a2abd73d5eb2 100644 --- a/tests/python/relay/test_annotate_target.py +++ b/tests/python/relay/test_annotate_target.py @@ -219,7 +219,53 @@ def after(): assert tvm.ir.structural_equal(expected, result) +def test_composite_function(): + def before(): + a = relay.var('a', shape=(10, 10)) + b = relay.var('b', shape=(10, 10)) + + # add_relu function + in_1 = relay.var('in_1', shape=(10, 10)) + in_2 = relay.var('in_2', shape=(10, 10)) + add_node = relay.add(in_1, in_2) + relu_node = relay.nn.relu(add_node) + add_relu = relay.Function([in_1, in_2], relu_node) + add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu")) + + # merged function + r = relay.Call(add_relu, [a, b]) + f = relay.Function([a, b], r) + mod = tvm.IRModule.from_expr(f) + return mod + + def after(): + a = relay.var('a', shape=(10, 10)) + b = relay.var('b', shape=(10, 10)) + + # add_relu function + in_1 = relay.var('in_1', shape=(10, 10)) + in_2 = relay.var('in_2', shape=(10, 10)) + add_node = relay.add(in_1, in_2) + relu_node = relay.nn.relu(add_node) + add_relu = relay.Function([in_1, in_2], relu_node) + add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu")) + + # merged function + cb_1 = relay.annotation.compiler_begin(a, "test") + cb_2 = relay.annotation.compiler_begin(b, "test") + r = relay.Call(add_relu, [cb_1, cb_2]) + ce_1 = relay.annotation.compiler_end(r, "test") + f = relay.Function([a, b], ce_1) + mod = tvm.IRModule.from_expr(f) + return mod + + result = transform.AnnotateTarget("test")(before()) + expected = transform.InferType()(after()) + assert tvm.ir.structural_equal(expected, result) + + if __name__ == "__main__": test_multiple_ends() test_extern_dnnl() test_extern_dnnl_mobilenet() + test_composite_function()