diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index f8be3e2b28ce..ca83e06c6e95 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """The Relay Pattern Language and tooling.""" -from tvm.relay import Expr +from tvm.relay.expr import RelayExpr as Expr import tvm._ffi from ...ir.base import Node from ...ir import make_node diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 71ef430ec9c6..bd3dd83c244d 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -32,8 +32,8 @@ - The other way is to implement the function by themselves to check the attributes of the op and decide if it should be offloaded to DNNL. """ -from ... import expr as _expr from ... import op as _op +from ...dataflow_pattern import wildcard, is_op from .register import register_pattern_table @@ -68,15 +68,15 @@ def _func_wrapper(attrs, args): def make_pattern(with_bias=True): - data = _expr.var("data") - weight = _expr.var("weight") - bias = _expr.var("bias") - conv = _op.nn.conv2d(data, weight) + data = wildcard() + weight = wildcard() + bias = wildcard() + conv = is_op('nn.conv2d')(data, weight) if with_bias: - conv_out = _op.add(conv, bias) + conv_out = is_op('add')(conv, bias) else: conv_out = conv - return _op.nn.relu(conv_out) + return is_op('nn.relu')(conv_out) @register_pattern_table("dnnl") diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 1d17c421b5a6..7222ff26ee49 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -380,7 +380,7 @@ def MergeComposite(pattern_table): Parameters ---------- - pattern_table : list(tuple) + pattern_table : List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Function]] A list of (pattern_name, pattern, check) tuples. The order of the patterns in the list will determine the order of priority in which they are matched. diff --git a/src/relay/transforms/merge_composite.cc b/src/relay/transforms/merge_composite.cc index 027e5123365e..04e41fad0046 100644 --- a/src/relay/transforms/merge_composite.cc +++ b/src/relay/transforms/merge_composite.cc @@ -26,6 +26,7 @@ */ #include +#include #include #include #include @@ -35,191 +36,17 @@ namespace tvm { namespace relay { namespace merge_composite { -class MergeCompositeWrapper : public ExprMutator { - public: - explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern, - const PackedFunc& check) - : pattern_name_(pattern_name), pattern_(pattern), check_(check) {} - - Expr ExtractPattern(const Var& pattern, const Expr& root, - Map>* var_map) { - if (var_map->find(pattern->name_hint()) == var_map->end()) { - // if we haven't encountered this var yet, make a new free var and associate - // it with the value at 'root' - auto free_var = Var(pattern->name_hint(), root->checked_type()); - free_var->checked_type_ = root->checked_type(); - var_map->Set(pattern->name_hint(), Array({free_var, root})); - return std::move(free_var); - } else { - // if we have encountered this var already, return the free var that was created - auto vars = (*var_map)[pattern->name_hint()]; - auto free_var = vars[0]; - auto graph_expr = vars[1]; - // make sure to first check they both map to the same node in the graph - if (graph_expr != root) { - return Expr(); - } - return (*var_map)[pattern->name_hint()][0]; - } - } - - Expr ExtractPattern(const Constant& pattern, const Expr& root, - Map>* var_map) { - return root; - } - - Expr ExtractPattern(const TupleGetItem& pattern, const Expr& root, - Map>* var_map, Map* call_map) { - if (!root->IsInstance()) { - return Expr(); - } - auto root_node = Downcast(root); - if (pattern->index != root_node->index) { - return Expr(); - } - if (pattern->tuple->IsInstance() && root_node->tuple->IsInstance()) { - Expr new_arg; - if (call_map->find(pattern->tuple) != call_map->end()) { - new_arg = (*call_map)[pattern->tuple]; - } else { - new_arg = ExtractPattern(Downcast(pattern->tuple), Downcast(root_node->tuple), - var_map, call_map); - call_map->Set(pattern->tuple, new_arg); - } - return TupleGetItem(new_arg, root_node->index); - } - return Expr(); - } - - /*! - * \brief Try and extract a given pattern from a graph as a subgraph. - * \param pattern The pattern to extract. - * \param root The graph to extract from. - * \param var_map A map between free vars in the subgraph and nodes in the graph. - * \return The extracted subgraph. - * - * \note How does this work? - * - * A pattern consists of Relay expression containing only operator call nodes, constants - * and free variables. The free variables indicate where the pattern can 'attach' in your - * graph. This function takes the final call node of the pattern and the call node currently - * being traversed in the Relay graph. It traverses through the pattern in lockstep with call node - * from the graph (referred to as the 'root' node here) to check they're identical. If at any - * point they differ, an empty expression is returned to signify the extract failed. If a free var - * is reached in the pattern, the corresponding value in the root is associated with the name of - * the free var (via the var_map) so that when we construct the composite function, the inputs - * match up correctly with the rest of the graph. The return value of this function when - * successful is a new Relay expression ready to be wrapped into a composite function. - */ - Expr ExtractPattern(const Call& pattern, const Call& root, Map>* var_map, - Map* call_map) { - // check to make sure both calls are to operators (not functions) - if (!root.defined()) return Expr(); - if (!pattern->op->IsInstance() || !root->op->IsInstance()) return Expr(); - if (pattern->op.as()->name != root->op.as()->name) return Expr(); - - unsigned int i = 0; - Array new_args; - for (const auto& arg : pattern->args) { - Expr new_arg; - if (arg->IsInstance() && root->args[i]->IsInstance()) { - new_arg = - ExtractPattern(Downcast(arg), Downcast(root->args[i]), var_map, call_map); - // if we've already processed this call node, return the previous result - if (call_map->find(arg) != call_map->end() && new_arg.defined()) { - new_arg = (*call_map)[arg]; - } else { - call_map->Set(arg, new_arg); - } - } else if (arg->IsInstance()) { - // if there's a var in the pattern, it must be a free var - // so call the function to update the var_map - new_arg = ExtractPattern(Downcast(arg), root->args[i], var_map); - } else if (arg->IsInstance()) { - // if there's a constant, simply get the corresponding - // value of the constant from the root - new_arg = ExtractPattern(Downcast(arg), root->args[i], var_map); - } else if (arg->IsInstance()) { - new_arg = ExtractPattern(Downcast(arg), root->args[i], var_map, call_map); - } - if (!new_arg.defined()) { - return Expr(); - } - new_args.push_back(new_arg); - i++; - } - Call new_call = Call(root->op, new_args, root->attrs); - new_call->checked_type_ = root->checked_type(); - return std::move(new_call); - } - - Expr VisitExpr_(const CallNode* cn) { - Call call = GetRef(cn); - if (call->op->IsInstance()) { - Function func = Downcast(call->op); - CHECK(func.defined()); - auto name_node = func->GetAttr(attr::kComposite); - // don't step into existing composite functions - if (name_node.defined() && name_node != "") { - tvm::Array new_args; - for (const auto& arg : call->args) { - auto new_e = this->Mutate(arg); - new_args.push_back(new_e); - } - Call new_call = Call(call->op, new_args, call->attrs); - new_call->checked_type_ = call->checked_type(); - return std::move(new_call); - } - } - - Expr expr = ExprMutator::VisitExpr_(cn); - call = Downcast(expr); - call->checked_type_ = cn->checked_type(); - if (!call->op->IsInstance()) return std::move(call); - - // only call patterns are supported - Call pattern = Downcast(pattern_); - CHECK(pattern.defined()); - Map> args_map; - Map call_map; - auto extract = ExtractPattern(pattern, call, &args_map, &call_map); - if (extract.defined() && static_cast(check_(extract))) { - auto free_vars = FreeVars(extract); - // make the composite function - auto f = Function(free_vars, extract, call->checked_type_, {}, DictAttrs()); - f = WithAttr(std::move(f), attr::kComposite, runtime::String(pattern_name_)); - // find the expressions associated with the free vars using the args_map - // this tells us which expressions should be given as inputs to the composite function - Array args; - for (const auto& free_var : free_vars) { - args.push_back(args_map[free_var->name_hint()][1]); - } - auto new_call = Call(f, args); - new_call->checked_type_ = call->checked_type(); - return std::move(new_call); - } - return std::move(call); - } - - private: - /*! \brief The name of the pattern to match */ - std::string pattern_name_; - /*! \brief The pattern to match */ - Expr pattern_; - /*! \brief The function to check whether an extract is supported */ - PackedFunc check_; -}; - -Expr MergeComposite(const Expr& expr, const Array& pattern_names, - const Array& patterns, const std::vector& checks) { +Expr MergeComposite(const Function& func, const Array& pattern_names, + const Array& patterns, const std::vector& checks) { CHECK_EQ(pattern_names.size(), patterns.size()); - Expr merged_expr = expr; + Expr merged_expr = func->body; // merge the patterns one-by-one in order for (size_t i = 0; i < patterns.size(); i++) { - merged_expr = - MergeCompositeWrapper(pattern_names[i], patterns[i], checks[i]).Mutate(merged_expr); + Map attrs; + attrs.Set("Composite", pattern_names[i]); + merged_expr = PartitionPattern(patterns[i], merged_expr, attrs, checks[i]); } - return merged_expr; + return Function(func->params, merged_expr, func->ret_type, func->type_params, func->attrs); } } // namespace merge_composite @@ -227,7 +54,7 @@ Expr MergeComposite(const Expr& expr, const Array& pattern_name namespace transform { Pass MergeComposite(const tvm::Array& pattern_names, - const tvm::Array& patterns, const std::vector& checks) { + const tvm::Array& patterns, const std::vector& checks) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { return Downcast( @@ -239,7 +66,7 @@ Pass MergeComposite(const tvm::Array& pattern_names, TVM_REGISTER_GLOBAL("relay._transform.MergeComposite").set_body([](TVMArgs args, TVMRetValue* rv) { tvm::Array pattern_names = args[0]; - tvm::Array patterns = args[1]; + tvm::Array patterns = args[1]; std::vector checks; for (int i = 2; i < args.size(); i++) { checks.push_back(args[i]); diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py index 3a79f6ad860f..12679c45c7c6 100644 --- a/tests/python/relay/test_pass_merge_composite.py +++ b/tests/python/relay/test_pass_merge_composite.py @@ -16,10 +16,11 @@ # under the License. """Unit tests for merge composite.""" import tvm -from tvm import relay -from tvm import tir +from tvm import relay, tir +from tvm.relay.dataflow_pattern import TupleGetItemPattern, is_op, wildcard from tvm.relay.testing import run_opt_pass + """ The merge composite pass is designed to merge multiple relay operators, that match a given pattern, and combine them into a single relay function. @@ -64,37 +65,32 @@ def make_add_sub_mul_pattern(): - """Create a pattern to match the following graph. + r"""Create a pattern to match the following graph. add sub \ / \ / mul """ - x = relay.var('x') - y = relay.var('y') - add_node = relay.add(x, y) - sub_node = relay.subtract(x, y) - mul_node = relay.multiply(add_node, sub_node) - return mul_node + x = wildcard() + y = wildcard() + return (x + y) * (x - y) def make_add_relu_pattern(): - """Create a pattern to match the following graph. + r"""Create a pattern to match the following graph. add | relu """ - x = relay.var('x') - y = relay.var('y') - add_node = relay.add(x, y) - r = relay.nn.relu(add_node) + add_node = wildcard() + wildcard() + r = is_op('nn.relu')(add_node) return r def make_conv_bias_relu_pattern(): - """Create a pattern to match the following graph. + r"""Create a pattern to match the following graph. conv2d | @@ -102,17 +98,35 @@ def make_conv_bias_relu_pattern(): | relu """ - x = relay.var('x') - y = relay.var('y') - z = relay.var('z') - conv_node = relay.nn.conv2d(x, y) - bias_node = relay.nn.bias_add(conv_node, z) - r = relay.nn.relu(bias_node) + x = wildcard() + y = wildcard() + z = wildcard() + conv_node = is_op('nn.conv2d')(x, y) + bias_node = is_op('nn.bias_add')(conv_node, z) + r = is_op('nn.relu')(bias_node) + return r + + +def make_pattern_with_optional(): + r"""Create a pattern to match the following graph. Note that relu is optinal. + + conv2d + | + bias_add + | + (relu) + """ + x = wildcard() + y = wildcard() + z = wildcard() + conv_node = is_op('nn.conv2d')(x, y) + bias_node = is_op('nn.bias_add')(conv_node, z) + r = bias_node.optional(lambda x: is_op('nn.relu')(x)) return r def make_add_add_add_pattern(): - """Create a pattern to match the following graph. + r"""Create a pattern to match the following graph. Useful for testing re-using a call node. x y @@ -123,15 +137,15 @@ def make_add_add_add_pattern(): | / add """ - x = relay.var('x') - y = relay.var('y') - add_node = relay.add(x, y) - add_node_1 = relay.add(x, add_node) - r = relay.add(add_node_1, add_node) + x = wildcard() + y = wildcard() + add_node = is_op('add')(x, y) + add_node_1 = is_op('add')(x, add_node) + r = is_op('add')(add_node_1, add_node) return r def make_bn_relu_pattern(): - """Create a pattern to match the following graph. + r"""Create a pattern to match the following graph. batch_norm | @@ -139,19 +153,27 @@ def make_bn_relu_pattern(): | relu """ - x = relay.var('x') - gamma = relay.var("gamma") - beta = relay.var("beta") - moving_mean = relay.var("moving_mean") - moving_var = relay.var("moving_var") - bn_node = relay.nn.batch_norm(x, gamma, beta, moving_mean, moving_var) - tuple_get_item_node = bn_node[0] - r = relay.nn.relu(tuple_get_item_node) + x = wildcard() + gamma = wildcard() + beta = wildcard() + moving_mean = wildcard() + moving_var = wildcard() + bn_node = is_op('nn.batch_norm')(x, gamma, beta, moving_mean, moving_var) + tuple_get_item_node = TupleGetItemPattern(bn_node, 0) + r = is_op('nn.relu')(tuple_get_item_node) return r +def check_result(pattern_table, graph, expected_graph): + """Utility function to check merge composite results.""" + result = run_opt_pass(graph, relay.transform.MergeComposite(pattern_table)) + assert not relay.analysis.free_vars(result), \ + "Found free vars in the result graph: {0}".format(str(result)) + expected = run_opt_pass(expected_graph, relay.transform.InferType()) + assert tvm.ir.structural_equal(result, expected, map_free_vars=True), \ + "Graph mismatch: output vs. expected\n{0}\n=====\n{1}".format(str(result), str(expected)) def test_simple_merge(): - """Test composite function is correctly produced from simple graph. + r"""Test composite function is correctly produced from simple graph. We could expect the pattern `make_add_relu_pattern` to be merged into a single op `add_relu`. @@ -185,19 +207,17 @@ def expected(): relu_node = relay.nn.relu(add_node) add_relu = relay.Function([in_1, in_2], relu_node) add_relu = add_relu.with_attr("Composite", "add_relu") + add_relu = add_relu.with_attr("PartitionedFromPattern", "add_nn.relu_") # merged function r = relay.Call(add_relu, [a, b]) return relay.Function([a, b], r) - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) - assert not relay.analysis.free_vars(result) - expected = run_opt_pass(expected(), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + check_result(pattern_table, before(), expected()) def test_branch_merge(): - """Test composite function is correctly produced from branching graph. + r"""Test composite function is correctly produced from branching graph. We would expect the pattern `make_add_sub_mul_pattern` to be merged into a single op `add_sub_mul`. @@ -250,6 +270,7 @@ def expected(): mul_node = relay.multiply(add_node, sub_node) add_sub_mul = relay.Function([in_1, in_2], mul_node) add_sub_mul = add_sub_mul.with_attr("Composite", "add_sub_mul") + add_sub_mul = add_sub_mul.with_attr("PartitionedFromPattern", "add_subtract_multiply_") # add_sub_mul1 function in_3 = relay.var('in_3', shape=(10, 10)) @@ -259,6 +280,7 @@ def expected(): mul_node_1 = relay.multiply(add_node_1, sub_node_1) add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1) add_sub_mul_1 = add_sub_mul_1.with_attr("Composite", "add_sub_mul") + add_sub_mul_1 = add_sub_mul_1.with_attr("PartitionedFromPattern", "add_subtract_multiply_") # merged function m_add_sub_mul_1 = relay.Call(add_sub_mul, [a, b]) @@ -266,14 +288,11 @@ def expected(): r = relay.nn.relu(m_add_sub_mul_2) return relay.Function([a, b, c], r) - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) - assert not relay.analysis.free_vars(result) - expected = run_opt_pass(expected(), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + check_result(pattern_table, before(), expected()) def test_reuse_call_merge(): - """Test composite function is correctly produced from simple graph + r"""Test composite function is correctly produced from simple graph which re-uses call nodes. We could expect the pattern `make_add_add_add` to be merged @@ -318,20 +337,18 @@ def expected(): add_node_2 = relay.add(add_node_1, add_node) add_add_add = relay.Function([in_1, in_2], add_node_2) add_add_add = add_add_add.with_attr("Composite", "add_add_add") + add_add_add = add_add_add.with_attr("PartitionedFromPattern", "add_add_add_") # merged function sub_node = relay.subtract(a, b) call = relay.Call(add_add_add, [sub_node, b]) return relay.Function([a, b], call) - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) - assert not relay.analysis.free_vars(result) - expected = run_opt_pass(expected(), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + check_result(pattern_table, before(), expected()) def test_multiple_patterns(): - """Test different patterns are merged correctly in the graph. + r"""Test different patterns are merged correctly in the graph. We would expect the pattern `make_conv_bias_relu_pattern` to be merged into a single op `conv_bias_relu`. We would also expect `make_add_relu_pattern` @@ -402,6 +419,8 @@ def expected(): conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r) conv_bias_add_relu = conv_bias_add_relu.with_attr("Composite", "conv2d_bias_relu") + conv_bias_add_relu = conv_bias_add_relu.with_attr("PartitionedFromPattern", + "nn.conv2d_nn.bias_add_nn.relu_") # add_relu function in_4 = relay.var('in_4', shape=(1, 256, 28, 28)) @@ -410,6 +429,7 @@ def expected(): r = relay.nn.relu(add_node) add_relu = relay.Function([in_4, in_5], r) add_relu = add_relu.with_attr("Composite", "add_relu") + add_relu = add_relu.with_attr("PartitionedFromPattern", "add_nn.relu_") # merged function conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias]) @@ -417,14 +437,79 @@ def expected(): r = relay.multiply(add_relu_1, b) return relay.Function([data, kernel, bias, a, b], r) - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) - assert not relay.analysis.free_vars(result) - expected = run_opt_pass(expected(), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + check_result(pattern_table, before(), expected()) + + +def test_optional_pattern(): + r"""Test the pattern with optional operators. We can define a pattern with some operators + optional. The merge composite pass will create composite functions for all matched patterns, + but with different "PartitionedFromPattern" attribute. We expect the backend codegen to + analyze that attribute and determine the corresponding action. + + Pattern: Matched Case A: Matched Case B: + + conv2d conv2d conv2d + | | | + bias_add bias_add bias_add + | | + (relu) relu + + In the above example, the composite function for matched case A would have + PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_" while the one for matched case B + woud be "nn.conv2d_nn.bias_add_". + """ + pattern_table = [("layer", make_pattern_with_optional())] + + def before(): + x = relay.var('x', shape=(1, 3, 7, 7)) + w1 = relay.var('w', shape=(3, 3, 1, 1)) + b1 = relay.var('b', shape=(3, )) + w2 = relay.var('w', shape=(3, 3, 1, 1)) + b2 = relay.var('b', shape=(3, )) + conv = relay.nn.conv2d(x, w1, kernel_size=(1, 1)) + bias = relay.nn.bias_add(conv, b1) + relu = relay.nn.relu(bias) + conv = relay.nn.conv2d(relu, w2, kernel_size=(1, 1)) + bias = relay.nn.bias_add(conv, b2) + return relay.Function([x, w1, w2, b1, b2], bias) + + def expected(): + # Matched composite function A + x = relay.var('x') + w = relay.var('w') + b = relay.var('b') + conv = relay.nn.conv2d(x, w, kernel_size=(1, 1)) + bias = relay.nn.bias_add(conv, b) + relu = relay.nn.relu(bias) + func1 = relay.Function([x, w, b], relu) + func1 = func1.with_attr("Composite", "layer") + func1 = func1.with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_") + + # Matched composite function B + x = relay.var('x') + w = relay.var('w') + b = relay.var('b') + conv = relay.nn.conv2d(x, w, kernel_size=(1, 1)) + bias = relay.nn.bias_add(conv, b) + func2 = relay.Function([x, w, b], bias) + func2 = func2.with_attr("Composite", "layer") + func2 = func2.with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_") + + # Main function + x = relay.var('x', shape=(1, 3, 7, 7)) + w1 = relay.var('w', shape=(3, 3, 1, 1)) + b1 = relay.var('b', shape=(3, )) + w2 = relay.var('w', shape=(3, 3, 1, 1)) + b2 = relay.var('b', shape=(3, )) + out1 = func1(x, w1, b1) + out2 = func2(out1, w2, b2) + return relay.Function([x, w1, w2, b1, b2], out2) + + check_result(pattern_table, before(), expected()) def test_merge_order(): - """Test that patterns are merged in the order they exist in the pattern table. + r"""Test that patterns are merged in the order they exist in the pattern table. There can be cases where one pattern is a subgraph of another, in which case it is not clear which match should take priority. The priority should come @@ -441,24 +526,24 @@ def test_merge_order(): """ def pattern_A(): - x = relay.var('x') - y = relay.var('y') - out = relay.add(x, y) - out = relay.abs(out) - out = relay.nn.relu(out) + x = wildcard() + y = wildcard() + out = is_op('add')(x, y) + out = is_op('abs')(out) + out = is_op('nn.relu')(out) return out def pattern_B(): - x = relay.var('x') - y = relay.var('y') - out = relay.add(x, y) - out = relay.abs(out) + x = wildcard() + y = wildcard() + out = is_op('add')(x, y) + out = is_op('abs')(out) return out def pattern_C(): - x = relay.var('x') - out = relay.abs(x) - out = relay.nn.relu(x) + x = wildcard() + out = is_op('abs')(x) + out = is_op('nn.relu')(out) return out def before(): @@ -469,7 +554,7 @@ def before(): out = relay.nn.relu(out) return relay.Function([input_1, input_2], out) - def after_A_priority(composite_name): + def after_A_priority(): input_1 = relay.var('input_1', shape=(10, 10)) input_2 = relay.var('input_2', shape=(10, 10)) x = relay.var('x') @@ -478,46 +563,65 @@ def after_A_priority(composite_name): out = relay.abs(out) out = relay.nn.relu(out) merged_func = relay.Function([x, y], out) - merged_func = merged_func.with_attr('Composite', composite_name) + merged_func = merged_func.with_attr('Composite', 'A') + merged_func = merged_func.with_attr('PartitionedFromPattern', 'add_abs_nn.relu_') ret = relay.Call(merged_func, [input_1, input_2]) return relay.Function([input_1, input_2], ret) + def after_B_priority(): + input_1 = relay.var('input_1', shape=(10, 10)) + input_2 = relay.var('input_2', shape=(10, 10)) + x = relay.var('x') + y = relay.var('y') + out = relay.add(x, y) + out = relay.abs(out) + merged_func = relay.Function([x, y], out) + merged_func = merged_func.with_attr('Composite', 'B') + merged_func = merged_func.with_attr('PartitionedFromPattern', 'add_abs_') + out = relay.Call(merged_func, [input_1, input_2]) + ret = relay.nn.relu(out) + return relay.Function([input_1, input_2], ret) + + def after_C_priority(): + input_1 = relay.var('input_1', shape=(10, 10)) + input_2 = relay.var('input_2', shape=(10, 10)) + x = relay.var('x') + out = relay.abs(x) + out = relay.nn.relu(out) + merged_func = relay.Function([x], out) + merged_func = merged_func.with_attr('Composite', 'C') + merged_func = merged_func.with_attr('PartitionedFromPattern', 'abs_nn.relu_') + out = relay.add(input_1, input_2) + ret = relay.Call(merged_func, [out]) + return relay.Function([input_1, input_2], ret) + # check A highest priority pattern_table = [ ("A", pattern_A()), ("B", pattern_B()), ("C", pattern_C()), ] - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) - assert not relay.analysis.free_vars(result) - expected = run_opt_pass(after_A_priority("A"), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + check_result(pattern_table, before(), after_A_priority()) # check B highest priority pattern_table = [ - ("B", pattern_A()), - ("C", pattern_B()), - ("A", pattern_C()), + ("B", pattern_B()), + ("C", pattern_C()), + ("A", pattern_A()), ] - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) - assert not relay.analysis.free_vars(result) - expected = run_opt_pass(after_A_priority("B"), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + check_result(pattern_table, before(), after_B_priority()) # check C highest priority pattern_table = [ - ("C", pattern_A()), - ("A", pattern_B()), - ("B", pattern_C()), + ("C", pattern_C()), + ("A", pattern_A()), + ("B", pattern_B()), ] - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) - assert not relay.analysis.free_vars(result) - expected = run_opt_pass(after_A_priority("C"), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + check_result(pattern_table, before(), after_C_priority()) def test_parallel_merge(): - """Tests that parallel patterns relying on the same inputs are correctly merged. + r"""Tests that parallel patterns relying on the same inputs are correctly merged. The test graph is difficult to draw out as ascii art. It is essentially two parallel add-sub-mul units which both consume input_1 and input_2 with their results being multiplied @@ -536,7 +640,7 @@ def before(): out = relay.multiply(branch_1, branch_2) return relay.Function([input_1, input_2], out) - def after(): + def expected(): input_1 = relay.var('input_1', shape=(10, 10)) input_2 = relay.var('input_2', shape=(10, 10)) x = relay.var('x') @@ -544,12 +648,14 @@ def after(): branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y)) func_1 = relay.Function([x, y], branch_1) func_1 = func_1.with_attr('Composite', "add_sub_mul") + func_1 = func_1.with_attr('PartitionedFromPattern', "add_subtract_multiply_") call_1 = relay.Call(func_1, [input_1, input_2]) x1 = relay.var('x1') y1 = relay.var('y1') branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1)) func_2 = relay.Function([x1, y1], branch_2) func_2 = func_2.with_attr('Composite', "add_sub_mul") + func_2 = func_2.with_attr('PartitionedFromPattern', "add_subtract_multiply_") call_2 = relay.Call(func_2, [input_1, input_2]) out = relay.multiply(call_1, call_2) return relay.Function([input_1, input_2], out) @@ -557,14 +663,11 @@ def after(): pattern_table = [ ("add_sub_mul", make_add_sub_mul_pattern()) ] - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) - assert not relay.analysis.free_vars(result) - expected = run_opt_pass(after(), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + check_result(pattern_table, before(), expected()) def test_multiple_input_subgraphs(): - """Test the case when multiple input subgraphs feed into another subgraph. + r"""Test the case when multiple input subgraphs feed into another subgraph. (1) (2) (3) (4) add add add add @@ -629,6 +732,7 @@ def after_A(): add_relu_1 = relay.nn.relu(add_relu_1) add_relu_1 = relay.Function([x, y], add_relu_1) add_relu_1 = add_relu_1.with_attr('Composite', 'add_relu') + add_relu_1 = add_relu_1.with_attr('PartitionedFromPattern', 'add_nn.relu_') add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]]) x1 = relay.var('x1') y1 = relay.var('y1') @@ -636,6 +740,7 @@ def after_A(): add_relu_2 = relay.nn.relu(add_relu_2) add_relu_2 = relay.Function([x1, y1], add_relu_2) add_relu_2 = add_relu_2.with_attr('Composite', 'add_relu') + add_relu_2 = add_relu_2.with_attr('PartitionedFromPattern', 'add_nn.relu_') add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]]) x2 = relay.var('x2') y2 = relay.var('y2') @@ -644,6 +749,7 @@ def after_A(): add_sub_mul = relay.multiply(add, sub) add_sub_mul = relay.Function([x2, y2], add_sub_mul) add_sub_mul = add_sub_mul.with_attr('Composite', 'add_sub_mul') + add_sub_mul = add_sub_mul.with_attr('PartitionedFromPattern', 'add_subtract_multiply_') add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2]) return relay.Function(inputs, add_sub_mul_call) @@ -657,6 +763,7 @@ def after_B(): add_relu = relay.nn.relu(add_relu) add_relu = relay.Function([x, y], add_relu) add_relu = add_relu.with_attr('Composite', 'add_relu') + add_relu = add_relu.with_attr('PartitionedFromPattern', 'add_nn.relu_') add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]]) add_relu_calls.append(add_relu_call) @@ -669,17 +776,8 @@ def after_B(): ("add_sub_mul", make_add_sub_mul_pattern()), ("add_relu", make_add_relu_pattern()) ] - # check case 'A' - result = run_opt_pass(before()['A'], relay.transform.MergeComposite(pattern_table)) - assert not relay.analysis.free_vars(result) - expected = run_opt_pass(after_A(), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) - - # check case 'B' - result = run_opt_pass(before()['B'], relay.transform.MergeComposite(pattern_table)) - assert not relay.analysis.free_vars(result) - expected = run_opt_pass(after_B(), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + check_result(pattern_table, before()['A'], after_A()) + check_result(pattern_table, before()['B'], after_B()) def test_tuple_get_item_merge(): @@ -717,15 +815,14 @@ def expected(): relu_node = relay.nn.relu(tuple_get_item_node) bn_relu = relay.Function([in_1, in_2, in_3, in_4, in_5], relu_node) bn_relu = bn_relu.with_attr("Composite", "bn_relu") + bn_relu = bn_relu.with_attr("PartitionedFromPattern", + "nn.batch_norm_TupleGetItem0_nn.relu_") # merged function r = relay.Call(bn_relu, [x, gamma, beta, moving_mean, moving_var]) return relay.Function([x, gamma, beta, moving_mean, moving_var], r) - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) - assert not relay.analysis.free_vars(result) - expected = run_opt_pass(expected(), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + check_result(pattern_table, before(), expected()) def test_pattern_with_check(): @@ -750,23 +847,35 @@ def _check_false(extract): conv = extract.args[0].args[0] return conv.attrs.data_layout == "NCHW" - pattern_table_true = [ - ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_true) - ] + def expected(): + x = relay.var('x') + w = relay.var('w') + b = relay.var('b') + conv = relay.nn.conv2d(x, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC") + bias = relay.nn.bias_add(conv, b) + relu = relay.nn.relu(bias) + func = relay.Function([x, w, b], relu) + func = func.with_attr("Composite", "conv_bias_relu") + func = func.with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_") + + x = relay.var('x', shape=(1, 10, 10, 10)) + w = relay.var('w', shape=(10, 10, 3, 3)) + b = relay.var('b', shape=(8,)) + return relay.Function([x, w, b], func(x, w, b)) + pattern_table_false = [ ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_false) ] + check_result(pattern_table_false, before(), before()) - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_false)) - expected = run_opt_pass(before(), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) - - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_true)) - assert result.body.op.attrs["Composite"] == "conv_bias_relu" + pattern_table_true = [ + ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_true) + ] + check_result(pattern_table_true, before(), expected()) def test_diamond_not_merge(): - """ + r""" The pattern on the left shouldn't match the structure on the right relu relu @@ -779,8 +888,8 @@ def test_diamond_not_merge(): """ def get_pattern(): conv = make_conv_bias_relu_pattern() - clip = relay.op.clip(conv, 0, 255) - return relay.op.multiply(conv, clip) + clip = is_op('clip')(conv, wildcard(), wildcard()) + return is_op('multiply')(conv, clip) def get_net(): data = relay.var('data', shape=(1, 512, 28, 28)) @@ -796,11 +905,9 @@ def get_net(): mul = relay.op.multiply(relu, clip2) return relay.Function(relay.analysis.free_vars(mul), mul) - pat_table = [("pat", get_pattern())] + pattern_table = [("pat", get_pattern())] net = get_net() - result = run_opt_pass(net, relay.transform.MergeComposite(pat_table)) - expected = run_opt_pass(net, relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + check_result(pattern_table, net, net) def test_type_check(): @@ -818,6 +925,23 @@ def before(): relu = relay.nn.relu(bias) return relay.Function([x, w, b], relu) + def expected(): + x = relay.var('x') + w = relay.var('w') + b = relay.var('b') + conv = relay.nn.conv2d(x, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC") + bias = relay.nn.bias_add(conv, b) + relu = relay.nn.relu(bias) + func = relay.Function([x, w, b], relu) + func = func.with_attr("Composite", "conv_bias_relu") + func = func.with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_") + + x = relay.var('x', shape=(1, 10, 10, 10)) + w = relay.var('w', shape=(10, 10, 3, 3)) + b = relay.var('b', shape=(8, )) + return relay.Function([x, w, b], func(x, w, b)) + + def _check_type_true(extract): conv = extract.args[0].args[0] typ = conv.checked_type @@ -828,25 +952,22 @@ def _check_type_false(extract): typ = conv.checked_type return bool(typ.shape[0] != 1) - pattern_table_true = [ - ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_true) - ] pattern_table_false = [ ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_false) ] + check_result(pattern_table_false, before(), before()) - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_false)) - expected = run_opt_pass(before(), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) - - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_true)) - assert result.body.op.attrs["Composite"] == "conv_bias_relu" + pattern_table_true = [ + ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_true) + ] + check_result(pattern_table_true, before(), expected()) if __name__ == "__main__": test_simple_merge() test_branch_merge() test_multiple_patterns() + test_optional_pattern() test_merge_order() test_parallel_merge() test_multiple_input_subgraphs()