From 7e68b4d413dfc0ec0042634a40d94e9bab59dbde Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Sat, 3 Apr 2021 03:12:13 -0600 Subject: [PATCH] [PatternMatcher] Support matching tuples, call nodes, and functions with variable numbers of inputs (#7754) * Allow TuplePattern to have null fields and match any tuple * support matching functions and call nodes with variable numbers of parameters * remove development code that was commented out * add docs for fuzzy matching --- docs/langref/relay_pattern.rst | 16 +++ python/tvm/relay/dataflow_pattern/__init__.py | 5 +- src/relay/ir/dataflow_matcher.cc | 107 ++++++++++++------ src/relay/ir/dataflow_pattern_functor.cc | 18 ++- src/relay/ir/indexed_graph.cc | 18 ++- tests/python/relay/test_dataflow_pattern.py | 80 ++++++++++++- 6 files changed, 192 insertions(+), 52 deletions(-) diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst index d77a51980f23..efb98045480c 100644 --- a/docs/langref/relay_pattern.rst +++ b/docs/langref/relay_pattern.rst @@ -307,6 +307,22 @@ The final example is matching diamonds with a post-dominator relationship. We em assert diamond.match(out) +Matching Fuzzy Patterns +======================= + +The Dominator analysis above lets one match a subgraph of Relay AST that doesn't correspond to a set of patterns nodes exactly 1-to-1. There are a few other places where we support such "fuzzy" matching. + +Tuples, Functions, and Call nodes with any number of inputs can be matched by passing `None` as the argument value, i.e.:: + + tuple_pattern = is_tuple(None) + func_pattern = FunctionPattern(None, wildcard() + wildcard()) + call_pattern = func_pattern(None) + +These patterns allow matching more generic classes patterns by constraining the use of the arguments rather than the number of arguments. + +Additionally, we support matching Functions with fuzzy bodies, i.e., a function body that is under constrained by the pattern. The pattern `FunctionPattern([is_var(), is_var()], wildcard() + wildcard()])` will match `relay.Function([x, y], x + y)`, but it will also match `relay.Function([x, y], x * x + y)`. In the second case, the pattern doesn't perfectly constrain the body of the function, so the resulting match is fuzzy. + + Pattern Language Design ======================= diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index d4a8481d106e..b368f4e5175e 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -47,7 +47,10 @@ class DFPattern(Node): """Base class of all Patterns.""" def __call__(self, *args): - return CallPattern(self, list(args)) + args = list(args) + if len(args) == 1 and args[0] is None: + args = None + return CallPattern(self, args) def __or__(self, other): return AltPattern(self, other) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 43a6473fb632..6ed24d5053c4 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -242,6 +242,7 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex } return false; }; + // logic auto watermark = matched_nodes_.size(); if (const auto* call_node = expr.as()) { @@ -253,13 +254,15 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex const Array expr_args) { bool matches = true; size_t i = 0; - if (pattern_args.size() == expr_args.size()) { - while (matches && i < pattern_args.size()) { - matches &= VisitDFPattern(pattern_args[i], expr_args[i]); - ++i; + if (pattern_args.defined()) { + if (pattern_args.size() == expr_args.size()) { + while (matches && i < pattern_args.size()) { + matches &= VisitDFPattern(pattern_args[i], expr_args[i]); + ++i; + } + } else { + matches = false; } - } else { - matches = false; } if (!matches) { ClearMap(watermark2); @@ -381,14 +384,16 @@ bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const Expr bool matches = false; if (const auto* func = expr.as()) { matches = true; - size_t i = 0; - if (op->params.size() == func->params.size()) { - while (matches && i < op->params.size()) { - matches &= VisitDFPattern(op->params[i], func->params[i]); - ++i; + if (op->params.defined()) { + size_t i = 0; + if (op->params.size() == func->params.size()) { + while (matches && i < op->params.size()) { + matches &= VisitDFPattern(op->params[i], func->params[i]); + ++i; + } + } else { + matches = false; } - } else { - matches = false; } if (matches) { matches &= VisitDFPattern(op->body, func->body); @@ -409,12 +414,16 @@ bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) { bool matches = false; if (const auto* tuple_node = expr.as()) { - if (op->fields.size() == tuple_node->fields.size()) { - matches = true; - size_t i = 0; - while (matches && i < op->fields.size()) { - matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]); - ++i; + matches = true; + if (op->fields.defined()) { + if (op->fields.size() == tuple_node->fields.size()) { + size_t i = 0; + while (matches && i < op->fields.size()) { + matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]); + ++i; + } + } else { + matches = false; } } } @@ -657,7 +666,6 @@ class PatternGrouper { int var_number = 0; auto node_map = matcher_->GetMemo(); - // Get fuzzy patterns std::unordered_set fuzzy_matches; for (auto node : pattern_graph_.topological_order_) { @@ -669,11 +677,13 @@ class PatternGrouper { } } } - // Don't treat Function params as input variables for partition - if (auto op = node->ref_.as()) { - for (auto fuzzy_op : op->params) { - for (auto match : node_map[fuzzy_op]) { - fuzzy_matches.insert(match); + // Don't treat Function params or body as input variables for partition + if (node->ref_.as()) { + auto matches = node_map[node->ref_]; + for (auto match : matches) { + auto graph = CreateIndexedGraph(match.as()->body); + for (auto node : graph.topological_order_) { + fuzzy_matches.insert(node->ref_); } } } @@ -686,22 +696,46 @@ class PatternGrouper { std::unordered_map inputs; Array params; + for (auto node : pattern_graph_.topological_order_) { - if (node->inputs_.size() == 0) { + auto make_input = [&](const Expr& input) { + if (fuzzy_matches.count(input) == 0 && input.as() == nullptr && + input.as() == nullptr && !EmbedConst(input, node->ref_)) { + inputs[input] = + Var("FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number), + NullValue()); + group.args.push_back(input); + params.push_back(inputs[input]); + var_number++; + } + }; + auto tuple = node->ref_.as(); + auto call = node->ref_.as(); + if (tuple && !tuple->fields.defined()) { if (node_map.count(node->ref_)) { auto matches = node_map[node->ref_]; for (auto match : matches) { - if (fuzzy_matches.count(match) == 0 && match.as() == nullptr && - match.as() == nullptr && !EmbedConst(match, node->ref_)) { - inputs[match] = Var( - "FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number), - NullValue()); - group.args.push_back(match); - params.push_back(inputs[match]); - var_number++; + for (auto input : match.as()->fields) { + make_input(input); } } } + } else if (call && !call->args.defined()) { + if (node_map.count(node->ref_)) { + auto matches = node_map[node->ref_]; + for (auto match : matches) { + for (auto input : match.as()->args) { + make_input(input); + } + } + } + } else if (node->inputs_.size() == 0) { + if (node_map.count(node->ref_)) { + auto matches = node_map[node->ref_]; + for (auto match : matches) { + make_input(match); + } + } } } @@ -898,6 +932,11 @@ class PatternPartitioner : protected MixedModeMutator { public: Expr Partition(const DFPattern& pattern, const Expr& pre, const Map& attrs, PackedFunc check) { + if (pattern.as()) { + LOG(WARNING) << "Partioning a Function that isn't called doesn't make sense, skipping" + << pattern; + return pre; + } auto grouper = PatternGrouper(); groups_ = grouper.GroupMatches(pattern, pre); gid_assignments_ = grouper.GetGIDAssignments(); diff --git a/src/relay/ir/dataflow_pattern_functor.cc b/src/relay/ir/dataflow_pattern_functor.cc index 828e867b332c..290f72df1deb 100644 --- a/src/relay/ir/dataflow_pattern_functor.cc +++ b/src/relay/ir/dataflow_pattern_functor.cc @@ -45,8 +45,10 @@ void DFPatternVisitor::VisitDFPattern_(const AttrPatternNode* op) { VisitDFPatte void DFPatternVisitor::VisitDFPattern_(const CallPatternNode* op) { VisitDFPattern(op->op); - for (auto arg : op->args) { - VisitDFPattern(arg); + if (op->args.defined()) { + for (auto arg : op->args) { + VisitDFPattern(arg); + } } } @@ -63,8 +65,10 @@ void DFPatternVisitor::VisitDFPattern_(const DominatorPatternNode* op) { void DFPatternVisitor::VisitDFPattern_(const ExprPatternNode* op) {} void DFPatternVisitor::VisitDFPattern_(const FunctionPatternNode* op) { - for (auto param : op->params) { - VisitDFPattern(param); + if (op->params.defined()) { + for (auto param : op->params) { + VisitDFPattern(param); + } } VisitDFPattern(op->body); } @@ -76,8 +80,10 @@ void DFPatternVisitor::VisitDFPattern_(const TupleGetItemPatternNode* op) { } void DFPatternVisitor::VisitDFPattern_(const TuplePatternNode* op) { - for (auto field : op->fields) { - VisitDFPattern(field); + if (op->fields.defined()) { + for (auto field : op->fields) { + VisitDFPattern(field); + } } } diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc index 36789e6f808a..e4d9585470a6 100644 --- a/src/relay/ir/indexed_graph.cc +++ b/src/relay/ir/indexed_graph.cc @@ -242,8 +242,10 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { void VisitDFPattern_(const CallPatternNode* op, NodePtr parent) override { VisitDFPattern(op->op, graph_.node_map_[GetRef(op)]); - for (auto arg : op->args) { - VisitDFPattern(arg, graph_.node_map_[GetRef(op)]); + if (op->args.defined()) { + for (auto arg : op->args) { + VisitDFPattern(arg, graph_.node_map_[GetRef(op)]); + } } } @@ -262,8 +264,10 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { void VisitDFPattern_(const ExprPatternNode* op, NodePtr parent) override {} void VisitDFPattern_(const FunctionPatternNode* op, NodePtr parent) override { - for (auto param : op->params) { - VisitDFPattern(param, graph_.node_map_[GetRef(op)]); + if (op->params.defined()) { + for (auto param : op->params) { + VisitDFPattern(param, graph_.node_map_[GetRef(op)]); + } } VisitDFPattern(op->body, graph_.node_map_[GetRef(op)]); } @@ -277,8 +281,10 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { } void VisitDFPattern_(const TuplePatternNode* op, NodePtr parent) override { - for (auto field : op->fields) { - VisitDFPattern(field, graph_.node_map_[GetRef(op)]); + if (op->fields.defined()) { + for (auto field : op->fields) { + VisitDFPattern(field, graph_.node_map_[GetRef(op)]); + } } } diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index a8e4b65f1bc6..8e2c74ab44b8 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -196,6 +196,11 @@ def test_match_call(): add_pattern = is_op("add")(wildcard(), wildcard()) assert add_pattern.match(x + y) + # Match call with any number of inputs + call_pattern = wildcard()(None) + assert call_pattern.match(relay.op.nn.relu(x)) + assert call_pattern.match(relay.op.add(x, y)) + def test_no_match_call(): x = relay.var("x") @@ -212,6 +217,11 @@ def test_match_func(): func_pattern = FunctionPattern([wc1, wc2], wc1 + wc2) assert func_pattern.match(relay.Function([x, y], x + y)) + # Match Function with any number of inputs + func_pattern = FunctionPattern(None, wildcard()) + assert func_pattern.match(relay.Function([x], x)) + assert func_pattern.match(relay.Function([x, y], x + y)) + def test_no_match_func(): x = relay.var("x") @@ -369,6 +379,13 @@ def test_match_tuple(): assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 1)) assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 2)) + # Match tuple with any inputs + tuple_pattern = is_tuple(None) + concat_pattern = is_op("concatenate")(tuple_pattern) + assert concat_pattern.match(relay.op.concatenate(relay.expr.Tuple((x,)), axis=0)) + assert concat_pattern.match(relay.op.concatenate(relay.expr.Tuple((x, y)), axis=0)) + assert concat_pattern.match(relay.op.concatenate(relay.expr.Tuple((x, y, z)), axis=0)) + def test_no_match_tuple(): x = relay.var("x") @@ -1375,6 +1392,63 @@ def test_partition_overused(): assert pattern.partition(out) == out +def test_partition_fuzzy_tuple(): + x = relay.var("x") + y = relay.var("y") + z = x + y + tuple_pattern = is_tuple(None) + concat_pattern = is_op("concatenate")(tuple_pattern) + + xp = relay.var("xp") + yp = relay.var("yp") + zp = relay.var("zp") + + def create_func(args, body): + return relay.Function(args, body).with_attr("PartitionedFromPattern", "Tuple_concatenate_") + + def concat(*args): + return relay.op.concatenate(relay.expr.Tuple(args), axis=0) + + one = concat_pattern.partition(concat(x)) + assert tvm.ir.structural_equal(one, create_func([xp], concat(xp))(x)) + two = concat_pattern.partition(concat(x, y)) + assert tvm.ir.structural_equal(two, create_func([xp, yp], concat(xp, yp))(x, y)) + three = concat_pattern.partition(concat(x, y, z)) + assert tvm.ir.structural_equal(three, create_func([xp, yp, zp], concat(xp, yp, zp))(x, y, z)) + + +def test_partition_fuzzy_function_args(): + + func_pattern = FunctionPattern(None, wildcard() + wildcard())(None) + wildcard() + x = relay.var("x") + y = relay.var("y") + z = relay.var("z") + b = relay.var("b") + xp = relay.var("xp") + yp = relay.var("yp") + zp = relay.var("zp") + + def create_func(call): + N = len(call.op.params) + new_params = [relay.var(str(i)) for i in range(N + 1)] + label = "add_FunctionCall_add_" + if N == 3: + label = "add_" + label + return relay.Function( + new_params, relay.Call(call.op, (new_params[0:-1])) + new_params[-1] + ).with_attr("PartitionedFromPattern", label)(*([x, y, z][0:N] + [b])) + + f1 = relay.Function([xp], xp + xp)(x) + one = func_pattern.partition(f1 + b) + assert tvm.ir.structural_equal(one, create_func(f1)) + f2 = relay.Function([xp, yp], xp + yp)(x, y) + two = func_pattern.partition(f2 + b) + assert tvm.ir.structural_equal(two, create_func(f2)) + f3 = relay.Function([xp, yp, zp], xp + yp + zp)(x, y, z) + three = func_pattern.partition(f3 + b) + assert tvm.ir.structural_equal(three, create_func(f3)) + + def test_partition_check(): pattern = is_op("nn.relu")(is_op("nn.conv2d")(is_var("input"), wildcard())) @@ -1529,10 +1603,6 @@ def callback(self, pre, post, node_map): assert tvm.ir.structural_equal(x + w, x + w) -@pytest.mark.skip( - """TODO(mbrookhart): The current partitioner can't properly handle - the partitioned inputs on the fuzzy body""" -) def test_partition_function_with_fuzzy_body(): """ Allow Rewriting a function with a fuzzy body via dominator analysis @@ -1560,7 +1630,7 @@ def test_partition_function_with_fuzzy_body(): w2 = relay.var("w2") b2 = relay.var("b2") func2 = relay.Function([x2, w2, b2], func(x2, w2) + b2).with_attr( - "PartitionedFromPattern", "FunctionCall_add_" + "PartitionedFromPattern", "nn.conv2d_FunctionCall_add_" ) expr2 = func2(x, w, b) + b assert tvm.ir.structural_equal(pattern.partition(expr), expr2)