Skip to content

Commit

Permalink
Add a check Callback to the Pattern Paritioner (apache#5646)
Browse files Browse the repository at this point in the history
* add a check callback to the paritioner

* fix doc string

* fix unit test spelling

* add a test with types
  • Loading branch information
Matthew Brookhart authored and Trevor Morris committed Jun 9, 2020
1 parent 855b02b commit 58a8047
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 18 deletions.
7 changes: 6 additions & 1 deletion include/tvm/relay/dataflow_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/relay/dataflow_pattern.h>
#include <tvm/relay/dataflow_pattern_functor.h>

#include <string>
#include <unordered_map>
#include <utility>

Expand Down Expand Up @@ -87,10 +88,14 @@ Expr RewritePatterns(Array<DFPatternCallback> callbacks, Expr expr);
*
* \param pattern The pattern to match
* \param expr The expression to patition
* \param attrs A set of parameter names and values to apply to the partitioned function
* \param check A callback function for checking more complicated properties of the matched
* expressions, returns true if the match is accepted and false otherwise
*
* \return Return the paritioned Expr.
*/
Expr PartitionPattern(DFPattern pattern, Expr expr);
Expr PartitionPattern(DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs,
PackedFunc check);

} // namespace relay
} // namespace tvm
Expand Down
16 changes: 11 additions & 5 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def match(self, expr: Expr) -> bool:
"""
return match(self, expr)

def partition(self, expr: Expr, attrs=None) -> Expr:
def partition(self, expr: Expr, attrs=None, check=lambda x: True) -> Expr:
"""
Parition the expression into functions defined by this pattern
Expand All @@ -119,13 +119,16 @@ def partition(self, expr: Expr, attrs=None) -> Expr:
The expression to match.
attrs : Optional[Dict[str, Object]]
A dictionary of Attribute name/values to add to the paritioned function
check : Function
A function to perform more complicated checks on the matched expression.
Returns true if partitioning should proceed, false otherwise.
Returns
-------
result : tvm.relay.Expr
The Expression with matched subgraphs replaced by function calls to that subgraph
"""
return partition(self, expr, attrs)
return partition(self, expr, attrs, check)

def dominates(self, parent, path=None):
"""
Expand Down Expand Up @@ -561,7 +564,7 @@ def rewrite(callbacks, expr: Expr) -> Expr:

return ffi.rewrite(tmp, expr)

def partition(pattern: DFPattern, expr: Expr, attrs=None) -> Expr:
def partition(pattern: DFPattern, expr: Expr, attrs=None, check=lambda x: True) -> Expr:
"""
Parition the expression into a series of functions that match the pattern
Expand All @@ -571,12 +574,15 @@ def partition(pattern: DFPattern, expr: Expr, attrs=None) -> Expr:
The pattern to match
expr : tvm.relay.Expr
The expression to split into functions
expr : Optional[Dict[str, Object]]
attrs : Optional[Dict[str, Object]]
A dict of attributes to apply to the partitioned function
check : Function
A function to perform more complicated checks on the matched expression.
Returns true if partitioning should proceed, false otherwise.
Returns
-------
result : tvm.relay.Expr
The Expression with matched subgraphs replaced by function calls to that subgraph
"""
return ffi.partition(pattern, expr, attrs)
return ffi.partition(pattern, expr, attrs, check)
17 changes: 10 additions & 7 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -693,11 +693,12 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.rewrite").set_body_typed(RewritePatt
class PatternPartitioner : protected MixedModeMutator {
public:
Expr Partition(const DFPattern& pattern, const Expr& pre,
const Map<std::string, ObjectRef>& attrs) {
const Map<std::string, ObjectRef>& attrs, PackedFunc check) {
auto grouper = PatternGrouper();
groups_ = grouper.GroupMatches(pattern, pre);
gid_assignments_ = grouper.GetGIDAssignments();
attrs_ = attrs;
check_ = check;
return this->VisitExpr(pre);
}

Expand All @@ -718,7 +719,8 @@ class PatternPartitioner : protected MixedModeMutator {

Expr DispatchVisitExpr(const Expr& pre) override {
auto post = MixedModeMutator::DispatchVisitExpr(pre);
if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node) {
if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node &&
static_cast<bool>(check_(pre))) {
post = RewritePartition(groups_[gid_assignments_[pre]]);
}
return post;
Expand All @@ -727,16 +729,17 @@ class PatternPartitioner : protected MixedModeMutator {
Map<std::string, ObjectRef> attrs_;
std::vector<PatternGrouper::Group> groups_;
std::unordered_map<Expr, int, ObjectHash, ObjectEqual> gid_assignments_;
PackedFunc check_;
};

Expr PartitionPattern(DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs) {
return PatternPartitioner().Partition(pattern, expr, attrs);
Expr PartitionPattern(DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs,
PackedFunc check) {
return PatternPartitioner().Partition(pattern, expr, attrs, check);
}

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.partition")
.set_body_typed([](DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs) {
return PartitionPattern(pattern, expr, attrs);
});
.set_body_typed([](DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs,
PackedFunc check) { return PartitionPattern(pattern, expr, attrs, check); });

} // namespace relay
} // namespace tvm
65 changes: 60 additions & 5 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
from tvm.relay.testing import run_opt_pass
import numpy as np

# NB: 1 corresponds to the C++ enum that specicfies this
Expand Down Expand Up @@ -880,7 +881,7 @@ def nested_diamond(inp, weight):
def get_BN(x, var, mean, beta, gamma, eps = 1e-5):
return gamma * (x - mean)/relay.op.sqrt(var + relay.const(eps)) + beta

def test_parition_batchnorm():
def test_partition_batchnorm():
x = relay.var('x')
var = relay.var('var')
mean = relay.var('mean')
Expand All @@ -900,7 +901,7 @@ def test_parition_batchnorm():
partitioned = BatchnormCallback().pattern.partition(BN)
assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, beta))

def test_parition_double_batchnorm():
def test_partition_double_batchnorm():
x = relay.var('x')
var = relay.var('var')
mean = relay.var('mean')
Expand All @@ -916,7 +917,7 @@ def test_parition_double_batchnorm():
betaf = relay.var('betaf')
gammaf = relay.var('gammaf')
f1 = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")
# The paritioner doesn't replace duplicates, so we use two copies of the function
# The partitioner doesn't replace duplicates, so we use two copies of the function
xf2 = relay.var('xf2')
varf2 = relay.var('varf2')
meanf2 = relay.var('meanf2')
Expand All @@ -928,6 +929,58 @@ def test_parition_double_batchnorm():
reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta)
assert tvm.ir.structural_equal(partitioned, reference)

def test_partition_check():
pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))
def check(pre):
return pre.args[0].attrs.data_layout == "NCHW"

x = relay.var('input')
w = relay.var('weight')
conv2d = relay.op.nn.conv2d(x, w)
relu = relay.op.nn.relu(conv2d)

xf = relay.var('input')
wf = relay.var('weight')
conv2df = relay.op.nn.conv2d(xf, wf)
reluf = relay.op.nn.relu(conv2df)
func = relay.Function([xf, wf], reluf).with_attr("PartitionedFromPattern", "nn.conv2d_nn.relu_")

reference = func(x, w)
partitioned = pattern.partition(relu, check=check)
assert tvm.ir.structural_equal(partitioned, reference)

conv2d = relay.op.nn.conv2d(x, w, data_layout="NHWC")
relu = relay.op.nn.relu(conv2d)
assert relu == pattern.partition(relu, check=check)

def test_partition_check_types():
pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))
def check(pre):
conv = pre.args[0]
return (conv.attrs.data_layout == "NCHW") and bool(conv.checked_type.shape[0] == 1)

x = relay.var('input', shape=(1, 10, 10, 10))
w = relay.var('weight', shape=(10, 10, 3, 3))
conv2d = relay.op.nn.conv2d(x, w)
relu = relay.op.nn.relu(conv2d)
relu = run_opt_pass(relu, relay.transform.InferType())

partitioned = pattern.partition(relu, check=check)
assert partitioned.op.attrs["PartitionedFromPattern"] == "nn.conv2d_nn.relu_"

conv2d = relay.op.nn.conv2d(x, w, data_layout="NHWC")
relu = relay.op.nn.relu(conv2d)
relu = run_opt_pass(relu, relay.transform.InferType())
assert relu == pattern.partition(relu, check=check)

x = relay.var('input', shape=(2, 10, 10, 10))
w = relay.var('weight', shape=(10, 10, 3, 3))
conv2d = relay.op.nn.conv2d(x, w)
relu = relay.op.nn.relu(conv2d)
relu = run_opt_pass(relu, relay.transform.InferType())
assert relu == pattern.partition(relu, check=check)


if __name__ == "__main__":
test_match_op()
test_no_match_op()
Expand Down Expand Up @@ -957,6 +1010,8 @@ def test_parition_double_batchnorm():
test_algebraic_simplify()
test_partition_dominator()
test_quadruple_partition_dominator()
test_parition_batchnorm()
test_parition_double_batchnorm()
test_partition_batchnorm()
test_partition_double_batchnorm()
test_partition_check()
test_partition_check_types()

0 comments on commit 58a8047

Please sign in to comment.