Skip to content

Commit

Permalink
[PatternLang]Conditionally Embedding Constants in Partitioned Functio…
Browse files Browse the repository at this point in the history
…ns (#5693)

* Embed constants in the partition function if the pattern explicity requests constants

fix rst

fix pylint

* improve comments based on Cody's feedback
  • Loading branch information
Matthew Brookhart authored May 30, 2020
1 parent 1ae7162 commit 2cd5117
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 33 deletions.
8 changes: 4 additions & 4 deletions docs/langref/relay_pattern.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ The next example is matching a pattern of batch_norm -> get(0) -> relu:
pat.match(out)
The next example is matching a constant node regarding its values. This is useful to check
if a specific parameter in a subgraph has been bind or not.
if a specific parameter in a subgraph has been bound or not.

.. code-block:: python
Expand Down Expand Up @@ -266,10 +266,10 @@ Attribute Pattern

Check that the operator matched by the pattern has an attribute with a particular value.

Input
*****
Variable Pattern
****************

Check that the expression is an input, i.e has no parents and is a variable.
Check that the expression is a relay Variable, and optional provide a name to match to the Variable name.


Alternate
Expand Down
7 changes: 3 additions & 4 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,15 +318,14 @@ class VarPattern(DFPattern):
Parameters
----------
name_hint: str
The name of the variable.
This name only acts as a hint, and is not used
for equality.
The name of the variable. Optional, if not provided,
the pattern will match any VarNode.
type_annotation: tvm.relay.Type, optional
The type annotation on the variable.
"""

def __init__(self, name_hint: str, type_annotation=None):
def __init__(self, name_hint="", type_annotation=None):
self.__init_handle_by_constructor__(
ffi.VarPattern, name_hint, type_annotation)

Expand Down
36 changes: 33 additions & 3 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ class PatternGrouper {
auto matches = node_map[node->ref_];
for (auto match : matches) {
if (fuzzy_matches.count(match) == 0 && match.as<OpNode>() == nullptr &&
match.as<FunctionNode>() == nullptr) {
match.as<FunctionNode>() == nullptr && !EmbedConst(match, node->ref_)) {
inputs[match] = Var(
"FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number),
NullValue<Type>());
Expand All @@ -582,8 +582,8 @@ class PatternGrouper {
auto extractor = MatchExtractor(inputs);
auto body = extractor.Mutate(expr);

// Verify the pattern still holds, no longer valid if we're not embedding constants in the
// graph, keep here for future debug CHECK(DFPatternMatcher(body).Match(pattern_, body));
// Verify the pattern still holds
CHECK(DFPatternMatcher(body).Match(pattern_, body));
group.function = Function(params, body, NullValue<Type>(), Array<TypeVar>());
group.name = extractor.GetName();
// Check to make sure we aren't overlapping with another group
Expand Down Expand Up @@ -613,6 +613,36 @@ class PatternGrouper {
CHECK_EQ(groups_[gid_].gid, gid_);
}

/* \brief EmbedConst implements rules for embedding constants into partitioned functions or
* lifting them into the function arguments.
*
* The rules depend on what pattern the ConstantNode matched.
*
* The basic rules are:
* If the constant matches ExprPattern(relay.const(*)) or a ConstantPattern(), embed the constant
* in the partitioned function. If the constant matched an AltPattern, recursively check the
* matched side of the pattern. For any other matching pattern (i.e, wildcard, VarPattern, etc),
* lift the constant into the arguments of the partitioned function.
*/
bool EmbedConst(const Expr& expr, const DFPattern pattern) {
bool embed = false;
if (expr.as<ConstantNode>()) {
if (pattern.as<ConstantPatternNode>() != nullptr) {
embed = true;
} else if (auto expr_pat = pattern.as<ExprPatternNode>()) {
if (expr_pat->expr.as<ConstantNode>()) {
embed = true;
}
} else if (auto alt_pat = pattern.as<AltPatternNode>()) {
if (matcher_->Match(alt_pat->left, expr)) {
embed = EmbedConst(expr, alt_pat->left);
} else {
embed = EmbedConst(expr, alt_pat->right);
}
}
}
return embed;
}
// Internal State
DFPattern pattern_;
std::vector<Group> groups_;
Expand Down
107 changes: 85 additions & 22 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def __init__(self):
self.mean = wildcard()
self.beta = wildcard()
self.gamma = wildcard()
self.eps = wildcard()
self.eps = ConstantPattern()

self.pattern = self.gamma * (self.x - self.mean) / is_op("sqrt")(self.var + self.eps) + \
self.beta
Expand Down Expand Up @@ -765,7 +765,7 @@ def algebraic_simplify(expr):

class ElwiseNullCallback(DFPatternCallback):
def callback(self, pre, post, node_map):
return node_map[self.x][0] # pylint: disable=no-member
return node_map[self.x][0] # pylint: disable=no-member

class AddCallback(ElwiseNullCallback):
def __init__(self):
Expand Down Expand Up @@ -1001,15 +1001,15 @@ def test_partition_batchnorm():
meanf = relay.var('meanf')
betaf = relay.var('betaf')
gammaf = relay.var('gammaf')
epsf = relay.var('epsf')
# Put the arguments in toplogological order for the reference
f = relay.Function([gammaf, xf, meanf, varf, epsf, betaf],
f = relay.Function([gammaf, xf, meanf, varf, betaf],
get_BN(xf, varf, meanf, betaf, gammaf,
epsf)).with_attr("PartitionedFromPattern",
"subtract_multiply_add_sqrt_divide_add_")
eps)).with_attr("PartitionedFromPattern",
"subtract_multiply_add_sqrt_divide_add_")

partitioned = BatchnormCallback().pattern.partition(BN)
assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, eps, beta))
reference = f(gamma, x, mean, var, beta)
assert tvm.ir.structural_equal(partitioned, reference)


def test_partition_double_batchnorm():
Expand All @@ -1028,25 +1028,23 @@ def test_partition_double_batchnorm():
meanf = relay.var('meanf')
betaf = relay.var('betaf')
gammaf = relay.var('gammaf')
epsf = relay.var('epsf')
f1 = relay.Function([gammaf, xf, meanf, varf, epsf, betaf],
f1 = relay.Function([gammaf, xf, meanf, varf, betaf],
get_BN(xf, varf, meanf, betaf, gammaf,
epsf)).with_attr("PartitionedFromPattern",
"subtract_multiply_add_sqrt_divide_add_")
eps)).with_attr("PartitionedFromPattern",
"subtract_multiply_add_sqrt_divide_add_")
# The partitioner doesn't replace duplicates, so we use two copies of the function
xf2 = relay.var('xf2')
varf2 = relay.var('varf2')
meanf2 = relay.var('meanf2')
betaf2 = relay.var('betaf2')
gammaf2 = relay.var('gammaf2')
epsf2 = relay.var('epsf2')
f2 = relay.Function([gammaf2, xf2, meanf2, varf2, epsf2, betaf2],
f2 = relay.Function([gammaf2, xf2, meanf2, varf2, betaf2],
get_BN(xf2, varf2, meanf2, betaf2, gammaf2,
epsf2)).with_attr("PartitionedFromPattern",
"subtract_multiply_add_sqrt_divide_add_")
eps)).with_attr("PartitionedFromPattern",
"subtract_multiply_add_sqrt_divide_add_")

partitioned = BatchnormCallback().pattern.partition(BN2)
reference = f2(gamma, f1(gamma, x, mean, var, eps, beta), mean, var, eps, beta)
reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta)
assert tvm.ir.structural_equal(partitioned, reference)


Expand Down Expand Up @@ -1106,6 +1104,13 @@ def check(pre):
assert relu == pattern.partition(relu, check=check)


def conv_bias_relu(x, w, b):
conv2d = relay.op.nn.conv2d(x, w)
bias_add = relay.op.nn.bias_add(conv2d, b)
relu = relay.op.nn.relu(bias_add)
return relu


def test_partition_option():
x = relay.var('x')
w = relay.var('w')
Expand All @@ -1119,12 +1124,6 @@ def test_partition_option():
bias = is_op('nn.bias_add')(conv2d, wildcard())
pattern2 = bias.optional(lambda x: is_op('nn.relu')(x))

def conv_bias_relu(x, w, b):
conv2d = relay.op.nn.conv2d(x, w)
bias_add = relay.op.nn.bias_add(conv2d, b)
relu = relay.op.nn.relu(bias_add)
return relu

relu = conv_bias_relu(x, w, b)

xf = relay.var('x')
Expand Down Expand Up @@ -1153,6 +1152,69 @@ def callback(self, pre, post, node_map):
out = rewrite(TestRewrite(), mod['tensor_concatenate_int64'])
assert tvm.ir.structural_equal(mod['tensor_concatenate_int64'], out)

def test_partition_constant_embedding():
x = relay.var('x')
w = relay.var('w')
wc = relay.const(1)
b = relay.var('b')

xf = relay.var('x')
wf = relay.var('w')
bf = relay.var('b')
embeded_func = relay.Function([xf, bf],
conv_bias_relu(xf, wc,
bf)).with_attr("PartitionedFromPattern",
"nn.conv2d_nn.bias_add_nn.relu_")
xf = relay.var('x')
wf = relay.var('w')
bf = relay.var('b')
lifted_func = relay.Function([xf, wf, bf],
conv_bias_relu(xf, wf,
bf)).with_attr("PartitionedFromPattern",
"nn.conv2d_nn.bias_add_nn.relu_")
relu = conv_bias_relu(x, w, b)
reluc = conv_bias_relu(x, wc, b)

# Check lifting of wildcard matches
pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), wildcard()),
wildcard()))
assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
assert tvm.ir.structural_equal(lifted_func(x, wc, b), pattern.partition(reluc))

# Check lifting of input matches
pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_input()),
wildcard()))
assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
assert tvm.ir.structural_equal(reluc, pattern.partition(reluc)) #Constants are not Inputs

# Check embedding of constant matches
pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(),
ConstantPattern()),
wildcard()))
assert tvm.ir.structural_equal(relu, pattern.partition(relu))
assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))

# Check embedding of constant ExprPatterns
pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(),
ExprPattern(wc)),
wildcard()))
assert tvm.ir.structural_equal(relu, pattern.partition(relu))
assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))

# Check lifting/embedding of Alt matches
pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_input()
| ConstantPattern()),
wildcard()))
assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))

# Check lifting/embedding of Alt matches with the other ordering
pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(
wildcard(), ConstantPattern() | is_input()), wildcard()))
assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))


if __name__ == "__main__":
test_expr_pattern()
test_var_pattern()
Expand Down Expand Up @@ -1209,3 +1271,4 @@ def callback(self, pre, post, node_map):
test_partition_check_types()
test_partition_option()
test_match_match()
test_partition_constant_embedding()

0 comments on commit 2cd5117

Please sign in to comment.