diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst index 7bb7bdfa6e82..7d5deb2b1ca9 100644 --- a/docs/langref/relay_pattern.rst +++ b/docs/langref/relay_pattern.rst @@ -137,7 +137,7 @@ The next example is matching a pattern of batch_norm -> get(0) -> relu: pat.match(out) The next example is matching a constant node regarding its values. This is useful to check -if a specific parameter in a subgraph has been bind or not. +if a specific parameter in a subgraph has been bound or not. .. code-block:: python @@ -266,10 +266,10 @@ Attribute Pattern Check that the operator matched by the pattern has an attribute with a particular value. -Input -***** +Variable Pattern +**************** -Check that the expression is an input, i.e has no parents and is a variable. +Check that the expression is a relay Variable, and optional provide a name to match to the Variable name. Alternate diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index e8f73ed08f4e..f1d07845c1e0 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -318,15 +318,14 @@ class VarPattern(DFPattern): Parameters ---------- name_hint: str - The name of the variable. - This name only acts as a hint, and is not used - for equality. + The name of the variable. Optional, if not provided, + the pattern will match any VarNode. type_annotation: tvm.relay.Type, optional The type annotation on the variable. """ - def __init__(self, name_hint: str, type_annotation=None): + def __init__(self, name_hint="", type_annotation=None): self.__init_handle_by_constructor__( ffi.VarPattern, name_hint, type_annotation) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index a7e4b3714fc1..70fce2fc49c3 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -562,7 +562,7 @@ class PatternGrouper { auto matches = node_map[node->ref_]; for (auto match : matches) { if (fuzzy_matches.count(match) == 0 && match.as() == nullptr && - match.as() == nullptr) { + match.as() == nullptr && !EmbedConst(match, node->ref_)) { inputs[match] = Var( "FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number), NullValue()); @@ -582,8 +582,8 @@ class PatternGrouper { auto extractor = MatchExtractor(inputs); auto body = extractor.Mutate(expr); - // Verify the pattern still holds, no longer valid if we're not embedding constants in the - // graph, keep here for future debug CHECK(DFPatternMatcher(body).Match(pattern_, body)); + // Verify the pattern still holds + CHECK(DFPatternMatcher(body).Match(pattern_, body)); group.function = Function(params, body, NullValue(), Array()); group.name = extractor.GetName(); // Check to make sure we aren't overlapping with another group @@ -613,6 +613,36 @@ class PatternGrouper { CHECK_EQ(groups_[gid_].gid, gid_); } + /* \brief EmbedConst implements rules for embedding constants into partitioned functions or + * lifting them into the function arguments. + * + * The rules depend on what pattern the ConstantNode matched. + * + * The basic rules are: + * If the constant matches ExprPattern(relay.const(*)) or a ConstantPattern(), embed the constant + * in the partitioned function. If the constant matched an AltPattern, recursively check the + * matched side of the pattern. For any other matching pattern (i.e, wildcard, VarPattern, etc), + * lift the constant into the arguments of the partitioned function. + */ + bool EmbedConst(const Expr& expr, const DFPattern pattern) { + bool embed = false; + if (expr.as()) { + if (pattern.as() != nullptr) { + embed = true; + } else if (auto expr_pat = pattern.as()) { + if (expr_pat->expr.as()) { + embed = true; + } + } else if (auto alt_pat = pattern.as()) { + if (matcher_->Match(alt_pat->left, expr)) { + embed = EmbedConst(expr, alt_pat->left); + } else { + embed = EmbedConst(expr, alt_pat->right); + } + } + } + return embed; + } // Internal State DFPattern pattern_; std::vector groups_; diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 467e30bc769d..89abb2ed4025 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -596,7 +596,7 @@ def __init__(self): self.mean = wildcard() self.beta = wildcard() self.gamma = wildcard() - self.eps = wildcard() + self.eps = ConstantPattern() self.pattern = self.gamma * (self.x - self.mean) / is_op("sqrt")(self.var + self.eps) + \ self.beta @@ -765,7 +765,7 @@ def algebraic_simplify(expr): class ElwiseNullCallback(DFPatternCallback): def callback(self, pre, post, node_map): - return node_map[self.x][0] # pylint: disable=no-member + return node_map[self.x][0] # pylint: disable=no-member class AddCallback(ElwiseNullCallback): def __init__(self): @@ -1001,15 +1001,15 @@ def test_partition_batchnorm(): meanf = relay.var('meanf') betaf = relay.var('betaf') gammaf = relay.var('gammaf') - epsf = relay.var('epsf') # Put the arguments in toplogological order for the reference - f = relay.Function([gammaf, xf, meanf, varf, epsf, betaf], + f = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf, - epsf)).with_attr("PartitionedFromPattern", - "subtract_multiply_add_sqrt_divide_add_") + eps)).with_attr("PartitionedFromPattern", + "subtract_multiply_add_sqrt_divide_add_") partitioned = BatchnormCallback().pattern.partition(BN) - assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, eps, beta)) + reference = f(gamma, x, mean, var, beta) + assert tvm.ir.structural_equal(partitioned, reference) def test_partition_double_batchnorm(): @@ -1028,25 +1028,23 @@ def test_partition_double_batchnorm(): meanf = relay.var('meanf') betaf = relay.var('betaf') gammaf = relay.var('gammaf') - epsf = relay.var('epsf') - f1 = relay.Function([gammaf, xf, meanf, varf, epsf, betaf], + f1 = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf, - epsf)).with_attr("PartitionedFromPattern", - "subtract_multiply_add_sqrt_divide_add_") + eps)).with_attr("PartitionedFromPattern", + "subtract_multiply_add_sqrt_divide_add_") # The partitioner doesn't replace duplicates, so we use two copies of the function xf2 = relay.var('xf2') varf2 = relay.var('varf2') meanf2 = relay.var('meanf2') betaf2 = relay.var('betaf2') gammaf2 = relay.var('gammaf2') - epsf2 = relay.var('epsf2') - f2 = relay.Function([gammaf2, xf2, meanf2, varf2, epsf2, betaf2], + f2 = relay.Function([gammaf2, xf2, meanf2, varf2, betaf2], get_BN(xf2, varf2, meanf2, betaf2, gammaf2, - epsf2)).with_attr("PartitionedFromPattern", - "subtract_multiply_add_sqrt_divide_add_") + eps)).with_attr("PartitionedFromPattern", + "subtract_multiply_add_sqrt_divide_add_") partitioned = BatchnormCallback().pattern.partition(BN2) - reference = f2(gamma, f1(gamma, x, mean, var, eps, beta), mean, var, eps, beta) + reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta) assert tvm.ir.structural_equal(partitioned, reference) @@ -1106,6 +1104,13 @@ def check(pre): assert relu == pattern.partition(relu, check=check) +def conv_bias_relu(x, w, b): + conv2d = relay.op.nn.conv2d(x, w) + bias_add = relay.op.nn.bias_add(conv2d, b) + relu = relay.op.nn.relu(bias_add) + return relu + + def test_partition_option(): x = relay.var('x') w = relay.var('w') @@ -1119,12 +1124,6 @@ def test_partition_option(): bias = is_op('nn.bias_add')(conv2d, wildcard()) pattern2 = bias.optional(lambda x: is_op('nn.relu')(x)) - def conv_bias_relu(x, w, b): - conv2d = relay.op.nn.conv2d(x, w) - bias_add = relay.op.nn.bias_add(conv2d, b) - relu = relay.op.nn.relu(bias_add) - return relu - relu = conv_bias_relu(x, w, b) xf = relay.var('x') @@ -1153,6 +1152,69 @@ def callback(self, pre, post, node_map): out = rewrite(TestRewrite(), mod['tensor_concatenate_int64']) assert tvm.ir.structural_equal(mod['tensor_concatenate_int64'], out) +def test_partition_constant_embedding(): + x = relay.var('x') + w = relay.var('w') + wc = relay.const(1) + b = relay.var('b') + + xf = relay.var('x') + wf = relay.var('w') + bf = relay.var('b') + embeded_func = relay.Function([xf, bf], + conv_bias_relu(xf, wc, + bf)).with_attr("PartitionedFromPattern", + "nn.conv2d_nn.bias_add_nn.relu_") + xf = relay.var('x') + wf = relay.var('w') + bf = relay.var('b') + lifted_func = relay.Function([xf, wf, bf], + conv_bias_relu(xf, wf, + bf)).with_attr("PartitionedFromPattern", + "nn.conv2d_nn.bias_add_nn.relu_") + relu = conv_bias_relu(x, w, b) + reluc = conv_bias_relu(x, wc, b) + + # Check lifting of wildcard matches + pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), wildcard()), + wildcard())) + assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu)) + assert tvm.ir.structural_equal(lifted_func(x, wc, b), pattern.partition(reluc)) + + # Check lifting of input matches + pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_input()), + wildcard())) + assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu)) + assert tvm.ir.structural_equal(reluc, pattern.partition(reluc)) #Constants are not Inputs + + # Check embedding of constant matches + pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), + ConstantPattern()), + wildcard())) + assert tvm.ir.structural_equal(relu, pattern.partition(relu)) + assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc)) + + # Check embedding of constant ExprPatterns + pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), + ExprPattern(wc)), + wildcard())) + assert tvm.ir.structural_equal(relu, pattern.partition(relu)) + assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc)) + + # Check lifting/embedding of Alt matches + pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_input() + | ConstantPattern()), + wildcard())) + assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu)) + assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc)) + + # Check lifting/embedding of Alt matches with the other ordering + pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')( + wildcard(), ConstantPattern() | is_input()), wildcard())) + assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu)) + assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc)) + + if __name__ == "__main__": test_expr_pattern() test_var_pattern() @@ -1209,3 +1271,4 @@ def callback(self, pre, post, node_map): test_partition_check_types() test_partition_option() test_match_match() + test_partition_constant_embedding()