diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 41aa04095277..ce4ac79a88d0 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -378,9 +378,12 @@ def MergeComposite(pattern_table): Parameters ---------- pattern_table : list(tuple) - A list of (pattern_name, pattern) tuples. + A list of (pattern_name, pattern, check) tuples. The order of the patterns in the list will determine the order of priority in which they are matched. + 'check' is a function to check whether an extracted pattern matches. + It can be implemented by pattern writer but if not specified it will + always return True. Returns ------- @@ -390,11 +393,19 @@ def MergeComposite(pattern_table): """ pattern_names = [] patterns = [] - for pattern_name, pattern in pattern_table: + checks = [] + for tup in pattern_table: + if len(tup) == 2: + pattern_name, pattern = tup + check = lambda extract: True + elif len(tup) == 3: + pattern_name, pattern, check = tup + pattern_names.append(pattern_name) patterns.append(pattern) + checks.append(check) - return _ffi_api.MergeComposite(pattern_names, patterns) + return _ffi_api.MergeComposite(pattern_names, patterns, *checks) def MergeCompilerRegions(): diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index b546f05b46e4..c3d34cb9ab7c 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -49,10 +49,24 @@ class AnnotateTargetWrapper : public ExprMutator { if (expr->IsInstance()) { Call call = Downcast(expr); auto fannotate = Op::GetAttr("target." + target_); - Op op = Downcast(call->op); - CHECK(op.defined()); - if (fannotate.count(op)) { - return fannotate[op](call->attrs, call->args); + if (call->op->IsInstance()) { + Op op = Downcast(call->op); + CHECK(op.defined()); + if (fannotate.count(op)) { + return fannotate[op](call->attrs, call->args); + } + } else if (call->op->IsInstance()) { + // handle composite functions + Function func = Downcast(call->op); + CHECK(func.defined()); + auto comp_name = func->GetAttr(attr::kComposite); + if (comp_name.defined()) { + size_t i = comp_name->value.find('.'); + if (i != std::string::npos) { + std::string target = comp_name->value.substr(0, i); + if (target == target_) return true; + } + } } } if (expr->IsInstance()) { @@ -77,7 +91,6 @@ class AnnotateTargetWrapper : public ExprMutator { } Expr VisitExpr_(const CallNode* cn) { - // TODO(@zhiics, @comaniac) Handle composite functions. auto new_e = ExprMutator::VisitExpr_(cn); Call call = Downcast(new_e); @@ -130,13 +143,22 @@ class AnnotateTargetWrapper : public ExprMutator { } } - Expr VisitExpr_(const FunctionNode* op) { - auto new_e = ExprMutator::VisitExpr_(op); + Expr VisitExpr_(const FunctionNode* fn) { + Function func; + Expr new_body; + // don't step into composite functions + if (fn->GetAttr(attr::kComposite).defined()) { + func = GetRef(fn); + new_body = func->body; + } else { + auto new_e = ExprMutator::VisitExpr_(fn); + func = Downcast(new_e); + new_body = InsertEnd(func->body); + } - auto func = Downcast(new_e); return Function( func->params, - InsertEnd(func->body), + new_body, func->ret_type, func->type_params, func->attrs); diff --git a/src/relay/transforms/merge_composite.cc b/src/relay/transforms/merge_composite.cc index e26ff402c3cd..35b93dced90d 100644 --- a/src/relay/transforms/merge_composite.cc +++ b/src/relay/transforms/merge_composite.cc @@ -25,11 +25,11 @@ * Relay operators map to a single external operator. */ -#include #include #include #include #include +#include namespace tvm { namespace relay { @@ -37,11 +37,12 @@ namespace merge_composite { class MergeCompositeWrapper : public ExprMutator { public: - explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern) - : pattern_name_(pattern_name), pattern_(pattern) {} + explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern, + const PackedFunc& check) + : pattern_name_(pattern_name), pattern_(pattern), check_(check) {} Expr ExtractPattern(const Var& pattern, const Expr& root, - Map>* var_map) { + Map>* var_map) { if (var_map->find(pattern->name_hint()) == var_map->end()) { // if we haven't encountered this var yet, make a new free var and associate // it with the value at 'root' @@ -62,12 +63,12 @@ class MergeCompositeWrapper : public ExprMutator { } Expr ExtractPattern(const Constant& pattern, const Expr& root, - Map>* var_map) { + Map>* var_map) { return root; } Expr ExtractPattern(const TupleGetItem& pattern, const Expr& root, - Map>* var_map, Map* call_map) { + Map>* var_map, Map* call_map) { if (!root->IsInstance()) { return Expr(); } @@ -75,14 +76,12 @@ class MergeCompositeWrapper : public ExprMutator { if (pattern->index != root_node->index) { return Expr(); } - if (pattern->tuple->IsInstance() && - root_node->tuple->IsInstance()) { + if (pattern->tuple->IsInstance() && root_node->tuple->IsInstance()) { Expr new_arg; if (call_map->find(pattern->tuple) != call_map->end()) { new_arg = (*call_map)[pattern->tuple]; } else { - new_arg = ExtractPattern(Downcast(pattern->tuple), - Downcast(root_node->tuple), + new_arg = ExtractPattern(Downcast(pattern->tuple), Downcast(root_node->tuple), var_map, call_map); call_map->Set(pattern->tuple, new_arg); } @@ -104,20 +103,18 @@ class MergeCompositeWrapper : public ExprMutator { * and free variables. The free variables indicate where the pattern can 'attach' in your * graph. This function takes the final call node of the pattern and the call node currently * being traversed in the Relay graph. It traverses through the pattern in lockstep with call node - * from the graph (referred to as the 'root' node here) to check they're identical. If at any point - * they differ, an empty expression is returned to signify the extract failed. If a free var is - * reached in the pattern, the corresponding value in the root is associated with the name of the - * free var (via the var_map) so that when we construct the composite function, the inputs match - * up correctly with the rest of the graph. The return value of this function when successful is - * a new Relay expression ready to be wrapped into a composite function. + * from the graph (referred to as the 'root' node here) to check they're identical. If at any + * point they differ, an empty expression is returned to signify the extract failed. If a free var + * is reached in the pattern, the corresponding value in the root is associated with the name of + * the free var (via the var_map) so that when we construct the composite function, the inputs + * match up correctly with the rest of the graph. The return value of this function when + * successful is a new Relay expression ready to be wrapped into a composite function. */ - Expr ExtractPattern(const Call& pattern, const Call& root, - Map>* var_map, Map* call_map) { + Expr ExtractPattern(const Call& pattern, const Call& root, Map>* var_map, + Map* call_map) { // check to make sure both calls are to operators (not functions) - if (!pattern->op->IsInstance() || !root->op->IsInstance()) - return Expr(); - if (pattern->op.as()->name != root->op.as()->name) - return Expr(); + if (!pattern->op->IsInstance() || !root->op->IsInstance()) return Expr(); + if (pattern->op.as()->name != root->op.as()->name) return Expr(); unsigned int i = 0; Array new_args; @@ -133,27 +130,20 @@ class MergeCompositeWrapper : public ExprMutator { return Expr(); } // if it's a call node, recursively call this function - new_arg = ExtractPattern(Downcast(arg), - Downcast(root->args[i]), - var_map, call_map); + new_arg = + ExtractPattern(Downcast(arg), Downcast(root->args[i]), var_map, call_map); call_map->Set(arg, new_arg); } } else if (arg->IsInstance()) { // if there's a var in the pattern, it must be a free var // so call the function to update the var_map - new_arg = ExtractPattern(Downcast(arg), - root->args[i], - var_map); + new_arg = ExtractPattern(Downcast(arg), root->args[i], var_map); } else if (arg->IsInstance()) { // if there's a constant, simply get the corresponding // value of the constant from the root - new_arg = ExtractPattern(Downcast(arg), - root->args[i], - var_map); + new_arg = ExtractPattern(Downcast(arg), root->args[i], var_map); } else if (arg->IsInstance()) { - new_arg = ExtractPattern(Downcast(arg), - root->args[i], - var_map, call_map); + new_arg = ExtractPattern(Downcast(arg), root->args[i], var_map, call_map); } if (!new_arg.defined()) { return Expr(); @@ -169,8 +159,7 @@ class MergeCompositeWrapper : public ExprMutator { if (call->op->IsInstance()) { Function func = Downcast(call->op); CHECK(func.defined()); - const auto name_node = - func->GetAttr(attr::kComposite); + const auto name_node = func->GetAttr(attr::kComposite); // don't step into existing composite functions if (name_node.defined() && name_node->value != "") { tvm::Array new_args; @@ -184,8 +173,7 @@ class MergeCompositeWrapper : public ExprMutator { Expr expr = ExprMutator::VisitExpr_(cn); call = Downcast(expr); - if (!call->op->IsInstance()) - return std::move(call); + if (!call->op->IsInstance()) return std::move(call); // only call patterns are supported Call pattern = Downcast(pattern_); @@ -193,7 +181,7 @@ class MergeCompositeWrapper : public ExprMutator { Map> args_map; Map call_map; auto extract = ExtractPattern(pattern, call, &args_map, &call_map); - if (extract.defined()) { + if (extract.defined() && static_cast(check_(extract))) { auto free_vars = FreeVars(extract); // make the composite function auto f = Function(free_vars, extract, call->checked_type_, {}, DictAttrs()); @@ -215,17 +203,20 @@ class MergeCompositeWrapper : public ExprMutator { std::string pattern_name_; /*! \brief The pattern to match */ Expr pattern_; + /*! \brief The function to check whether an extract is supported */ + PackedFunc check_; }; -Expr MergeComposite(const Expr& expr, - const Array& pattern_names, const Array& patterns) { +Expr MergeComposite(const Expr& expr, const Array& pattern_names, + const Array& patterns, const std::vector& checks) { CHECK_EQ(pattern_names.size(), patterns.size()); Expr merged_expr = expr; // merge the patterns one-by-one in order for (size_t i = 0; i < patterns.size(); i++) { std::string pattern_name = pattern_names[i]->value; Expr pattern = patterns[i]; - merged_expr = MergeCompositeWrapper(pattern_name, pattern).Mutate(merged_expr); + PackedFunc check = checks[i]; + merged_expr = MergeCompositeWrapper(pattern_name, pattern, check).Mutate(merged_expr); } return merged_expr; } @@ -235,18 +226,25 @@ Expr MergeComposite(const Expr& expr, namespace transform { Pass MergeComposite(const tvm::Array& pattern_names, - const tvm::Array& patterns) { + const tvm::Array& patterns, const std::vector& checks) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { return Downcast( - relay::merge_composite::MergeComposite(f, pattern_names, patterns)); + relay::merge_composite::MergeComposite(f, pattern_names, patterns, checks)); }; auto func_pass = CreateFunctionPass(pass_func, 0, "MergeComposite", {}); return func_pass; } -TVM_REGISTER_GLOBAL("relay._transform.MergeComposite") -.set_body_typed(MergeComposite); +TVM_REGISTER_GLOBAL("relay._transform.MergeComposite").set_body([](TVMArgs args, TVMRetValue* rv) { + tvm::Array pattern_names = args[0]; + tvm::Array patterns = args[1]; + std::vector checks; + for (int i = 2; i < args.size(); i++) { + checks.push_back(args[i]); + } + *rv = MergeComposite(pattern_names, patterns, checks); +}); } // namespace transform diff --git a/tests/python/relay/test_annotate_target.py b/tests/python/relay/test_annotate_target.py index 87cf7616e232..0a2abd73d5eb 100644 --- a/tests/python/relay/test_annotate_target.py +++ b/tests/python/relay/test_annotate_target.py @@ -219,7 +219,53 @@ def after(): assert tvm.ir.structural_equal(expected, result) +def test_composite_function(): + def before(): + a = relay.var('a', shape=(10, 10)) + b = relay.var('b', shape=(10, 10)) + + # add_relu function + in_1 = relay.var('in_1', shape=(10, 10)) + in_2 = relay.var('in_2', shape=(10, 10)) + add_node = relay.add(in_1, in_2) + relu_node = relay.nn.relu(add_node) + add_relu = relay.Function([in_1, in_2], relu_node) + add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu")) + + # merged function + r = relay.Call(add_relu, [a, b]) + f = relay.Function([a, b], r) + mod = tvm.IRModule.from_expr(f) + return mod + + def after(): + a = relay.var('a', shape=(10, 10)) + b = relay.var('b', shape=(10, 10)) + + # add_relu function + in_1 = relay.var('in_1', shape=(10, 10)) + in_2 = relay.var('in_2', shape=(10, 10)) + add_node = relay.add(in_1, in_2) + relu_node = relay.nn.relu(add_node) + add_relu = relay.Function([in_1, in_2], relu_node) + add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu")) + + # merged function + cb_1 = relay.annotation.compiler_begin(a, "test") + cb_2 = relay.annotation.compiler_begin(b, "test") + r = relay.Call(add_relu, [cb_1, cb_2]) + ce_1 = relay.annotation.compiler_end(r, "test") + f = relay.Function([a, b], ce_1) + mod = tvm.IRModule.from_expr(f) + return mod + + result = transform.AnnotateTarget("test")(before()) + expected = transform.InferType()(after()) + assert tvm.ir.structural_equal(expected, result) + + if __name__ == "__main__": test_multiple_ends() test_extern_dnnl() test_extern_dnnl_mobilenet() + test_composite_function() diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py index 3c70cf237c94..110d855216e4 100644 --- a/tests/python/relay/test_pass_merge_composite.py +++ b/tests/python/relay/test_pass_merge_composite.py @@ -732,6 +732,43 @@ def expected(): assert tvm.ir.structural_equal(result, expected, map_free_vars=True) +def test_pattern_with_check(): + def before(): + x = relay.var('x', shape=(1, 10, 10, 10)) + w = relay.var('w', shape=(10, 10, 3, 3)) + b = relay.var('b', shape=(8,)) + conv = relay.nn.conv2d(x, + w, + kernel_size=(3, 3), + kernel_layout="OIHW", + data_layout="NHWC") + bias = relay.nn.bias_add(conv, b) + relu = relay.nn.relu(bias) + return relay.Function([x, w, b], relu) + + def _check_true(extract): + conv = extract.args[0].args[0] + return conv.attrs.data_layout == "NHWC" + + def _check_false(extract): + conv = extract.args[0].args[0] + return conv.attrs.data_layout == "NCHW" + + pattern_table_true = [ + ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_true) + ] + pattern_table_false = [ + ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_false) + ] + + result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_false)) + expected = run_opt_pass(before(), relay.transform.InferType()) + assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + + result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_true)) + assert result.body.op.attrs["Composite"] == "conv_bias_relu" + + if __name__ == "__main__": test_simple_merge() test_branch_merge() @@ -741,3 +778,4 @@ def expected(): test_multiple_input_subgraphs() test_reuse_call_merge() test_tuple_get_item_merge() + test_pattern_with_check()