diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index d767277328423..2fe3e05239c97 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -513,21 +513,47 @@ def Legalize(legalize_map_attr_name="FTVMLegalize"): return _transform.Legalize(legalize_map_attr_name) -def MergeComposite(compiler): - """Merge multiple operators into a single composite relay function. +def AnnotateCompiler(compiler): + """Annotate ops in an experession with a provied compiler and then use it + for codegen. Parameters ---------- compiler : str The compiler used for codegen. + Returns + ------- + ret : tvm.relay.Pass + The annotated pass that wrapps ops with subgraph_start and + subgraph_end. + """ + return _transform.AnnotateCompiler(compiler) + + +def MergeComposite(pattern_table): + """Merge multiple operators into a single composite relay function. + + Parameters + ---------- + pattern_table : list(tuple) + A list of (pattern_name, pattern) tuples. + The order of the patterns in the list will determine the order + of priority in which they are matched. + Returns ------- ret : tvm.relay.Pass The registered pass that merges operators into a single composite relay function. """ - return _transform.MergeComposite(compiler) + pattern_names = [] + patterns = [] + for pattern_name, pattern in pattern_table: + pattern_names.append(pattern_name) + patterns.append(pattern) + + return _transform.MergeComposite(pattern_names, patterns) def RewriteAnnotatedOps(fallback_device): diff --git a/src/relay/pass/merge_composite.cc b/src/relay/pass/merge_composite.cc index 81d50b91a0b7d..d4e445619190e 100644 --- a/src/relay/pass/merge_composite.cc +++ b/src/relay/pass/merge_composite.cc @@ -36,8 +36,8 @@ namespace merge_composite { class MergeCompositeWrapper : public ExprMutator { public: - explicit MergeCompositeWrapper(const tvm::Map& pattern_map) - : pattern_map_(pattern_map) {} + explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern) + : pattern_name_(pattern_name), pattern_(pattern) {} bool MatchPattern(const Call& pattern, const Call& root) { if (!pattern->op->IsInstance() || !root->op->IsInstance()) @@ -135,49 +135,58 @@ class MergeCompositeWrapper : public ExprMutator { Op op = Downcast(call->op); CHECK(op.defined()); - for (const auto& x : pattern_map_) { - Call pattern = Downcast(x.second); - if (Downcast(pattern->op)->name != op->name) - continue; - - if (MatchPattern(pattern, call)) { - Map> args_map; - auto extract = ExtractPattern(pattern, call, &args_map); - auto free_vars = FreeVars(extract); - Function new_func = FunctionNode::make(free_vars, extract, - call->checked_type_, {}, Attrs()); - new_func = FunctionSetAttr(new_func, attr::kComposite, - tir::StringImmNode::make(x.first)); - new_func = FunctionSetAttr(new_func, attr::kPrimitive, - tvm::Integer(1)); - Array args; - for (const auto& free_var : free_vars) { - args.push_back(args_map[free_var->name_hint()][1]); - } - auto new_call = CallNode::make(new_func, args); - return std::move(new_call); + Call pattern = Downcast(pattern_); + if (Downcast(pattern->op)->name != op->name) + return std::move(call); + + if (MatchPattern(pattern, call)) { + Map> args_map; + auto extract = ExtractPattern(pattern, call, &args_map); + auto free_vars = FreeVars(extract); + Function new_func = FunctionNode::make(free_vars, extract, + call->checked_type_, {}, Attrs()); + new_func = FunctionSetAttr(new_func, attr::kComposite, + tir::StringImmNode::make(pattern_name_)); + new_func = FunctionSetAttr(new_func, attr::kPrimitive, + tvm::Integer(1)); + Array args; + for (const auto& free_var : free_vars) { + args.push_back(args_map[free_var->name_hint()][1]); } + auto new_call = CallNode::make(new_func, args); + return std::move(new_call); } return std::move(call); } private: - tvm::Map pattern_map_; + std::string pattern_name_; + Expr pattern_; }; -Expr MergeComposite(const Expr& expr, const tvm::Map& pattern) { - return MergeCompositeWrapper(pattern).Mutate(expr); +Expr MergeComposite(const Expr& expr, + const Array& pattern_names, const Array& patterns) { + CHECK(pattern_names.size() == patterns.size()); + Expr merged_expr = expr; + for (size_t i = 0; i < patterns.size(); i++) { + std::string pattern_name = pattern_names[i]->value; + Expr pattern = patterns[i]; + merged_expr = MergeCompositeWrapper(pattern_name, pattern).Mutate(merged_expr); + } + return merged_expr; } } // namespace merge_composite namespace transform { -Pass MergeComposite(const tvm::Map& pattern) { +Pass MergeComposite(const tvm::Array& pattern_names, + const tvm::Array& patterns) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(relay::merge_composite::MergeComposite(f, pattern)); + return Downcast( + relay::merge_composite::MergeComposite(f, pattern_names, patterns)); }; auto func_pass = CreateFunctionPass(pass_func, 0, "MergeComposite", {}); return func_pass; diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py index 77b8f27902d1a..53e428c04c17e 100644 --- a/tests/python/relay/test_pass_merge_composite.py +++ b/tests/python/relay/test_pass_merge_composite.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Unit tests for merge composite.""" +from tvm import expr from tvm import relay from tvm.relay.testing import run_opt_pass @@ -122,9 +123,9 @@ def test_simple_merge(): relu """ - pattern_table = { - "add_relu": make_add_relu_pattern() - } + pattern_table = [ + ("add_relu", make_add_relu_pattern()) + ] def before(): a = relay.var('a', shape=(10, 10)) @@ -178,9 +179,9 @@ def test_branch_merge(): relu """ - pattern_table = { - "add_sub_mul": make_add_sub_mul_pattern() - } + pattern_table = [ + ("add_sub_mul", make_add_sub_mul_pattern()) + ] def before(): a = relay.var('a', shape=(10, 10)) @@ -244,10 +245,10 @@ def test_multiple_patterns(): | / mul """ - pattern_table = { - "conv2d_bias_relu": make_conv_bias_relu_pattern(), - "add_relu": make_add_relu_pattern() - } + pattern_table = [ + ("conv2d_bias_relu", make_conv_bias_relu_pattern()), + ("add_relu", make_add_relu_pattern()) + ] def before(): data = relay.var('data', shape=(1, 512, 28, 28)) @@ -310,7 +311,129 @@ def expected(): assert relay.analysis.alpha_equal(result, expected) +def test_merge_order(): + """Test that patterns are merged in the order they exist in the pattern table. + + There can be cases where one pattern is a subgraph of another, in which case + it is not clear which match should take priority. The priority should come + from the order in which the patterns are declared in the pattern table. The + first patterns will be merged with highest priority and the last with lowest. + + A: B: C: + add add abs + | | | + abs abs relu + | + relu + + """ + + def pattern_A(): + x = relay.var('x') + y = relay.var('y') + out = relay.add(x, y) + out = relay.abs(out) + out = relay.nn.relu(out) + return out + + def pattern_B(): + x = relay.var('x') + y = relay.var('y') + out = relay.add(x, y) + out = relay.abs(out) + return out + + def pattern_C(): + x = relay.var('x') + out = relay.abs(x) + out = relay.nn.relu(x) + return out + + def before(): + input_1 = relay.var('input_1', shape=(10, 10)) + input_2 = relay.var('input_2', shape=(10, 10)) + out = relay.add(input_1, input_2) + out = relay.abs(out) + out = relay.nn.relu(out) + return relay.Function([input_1, input_2], out) + + def after_A_priority(): + input_1 = relay.var('input_1', shape=(10, 10)) + input_2 = relay.var('input_2', shape=(10, 10)) + x = relay.var('x') + y = relay.var('y') + out = relay.add(x, y) + out = relay.abs(out) + out = relay.nn.relu(out) + merged_func = relay.Function([x, y], out) + merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1)) + merged_func = merged_func.set_attribute('Composite', expr.StringImm('A')) + ret = relay.Call(merged_func, [input_1, input_2]) + return relay.Function([input_1, input_2], ret) + + def after_B_priority(): + input_1 = relay.var('input_1', shape=(10, 10)) + input_2 = relay.var('input_2', shape=(10, 10)) + x = relay.var('x') + y = relay.var('y') + out = relay.add(x, y) + out = relay.abs(out) + merged_func = relay.Function([x, y], out) + merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1)) + merged_func = merged_func.set_attribute('Composite', expr.StringImm('B')) + merged_call = relay.Call(merged_func, [input_1, input_2]) + ret = relay.nn.relu(merged_call) + return relay.Function([input_1, input_2], ret) + + def after_C_priority(): + input_1 = relay.var('input_1', shape=(10, 10)) + input_2 = relay.var('input_2', shape=(10, 10)) + add = relay.add(input_1, input_2) + x = relay.var('x') + out = relay.abs(x) + out = relay.nn.relu(out) + merged_func = relay.Function([x], out) + merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1)) + merged_func = merged_func.set_attribute('Composite', expr.StringImm('C')) + ret = relay.Call(merged_func, [add]) + return relay.Function([input_1, input_2], ret) + + # check A highest priority + pattern_table = [ + ("A", pattern_A()), + ("B", pattern_B()), + ("C", pattern_C()), + ] + result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) + assert not relay.analysis.free_vars(result) + expected = run_opt_pass(after_A_priority(), relay.transform.InferType()) + assert relay.analysis.alpha_equal(result, expected) + + # check B highest priority + pattern_table = [ + ("B", pattern_A()), + ("C", pattern_B()), + ("A", pattern_C()), + ] + result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) + assert not relay.analysis.free_vars(result) + expected = run_opt_pass(after_A_priority(), relay.transform.InferType()) + assert relay.analysis.alpha_equal(result, expected) + + # check C highest priority + pattern_table = [ + ("C", pattern_A()), + ("A", pattern_B()), + ("B", pattern_C()), + ] + result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) + assert not relay.analysis.free_vars(result) + expected = run_opt_pass(after_A_priority(), relay.transform.InferType()) + assert relay.analysis.alpha_equal(result, expected) + + if __name__ == "__main__": test_simple_merge() test_branch_merge() test_multiple_patterns() + test_merge_order()