Skip to content

Commit

Permalink
[RELAY] Add 'check' functions to MergeComposite
Browse files Browse the repository at this point in the history
Currently, MergeComposite can only perform structural
matches. This patch introduces the ability to specify
a 'check' function alongside the pattern which can include
custom logic to determine whether an extracted pattern
should be merged.

For example, if you only want to merge 'NHWC' convolutions,
you can specify a 'check' function which queries the
data_layout value of the extracted pattern (see the test).

Change-Id: I9337ce39f10997051a286d888be38ed0d410d340
  • Loading branch information
mbaret committed Apr 7, 2020
1 parent 41b8fd1 commit b18951f
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 11 deletions.
17 changes: 14 additions & 3 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,12 @@ def MergeComposite(pattern_table):
Parameters
----------
pattern_table : list(tuple)
A list of (pattern_name, pattern) tuples.
A list of (pattern_name, pattern, check) tuples.
The order of the patterns in the list will determine the order
of priority in which they are matched.
'check' is a function to check whether an extracted pattern matches.
It can be implemented by pattern writer but if not specified it will
always return True.
Returns
-------
Expand All @@ -390,11 +393,19 @@ def MergeComposite(pattern_table):
"""
pattern_names = []
patterns = []
for pattern_name, pattern in pattern_table:
checks = []
for tup in pattern_table:
if len(tup) == 2:
pattern_name, pattern = tup
check = lambda extract: True
elif len(tup) == 3:
pattern_name, pattern, check = tup

pattern_names.append(pattern_name)
patterns.append(pattern)
checks.append(check)

return _ffi_api.MergeComposite(pattern_names, patterns)
return _ffi_api.MergeComposite(pattern_names, patterns, *checks)


def MergeCompilerRegions():
Expand Down
28 changes: 20 additions & 8 deletions src/relay/transforms/merge_composite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ namespace merge_composite {

class MergeCompositeWrapper : public ExprMutator {
public:
explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern)
: pattern_name_(pattern_name), pattern_(pattern) {}
explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern, const PackedFunc& check)
: pattern_name_(pattern_name), pattern_(pattern), check_(check) {}

Expr ExtractPattern(const Var& pattern, const Expr& root,
Map<std::string, Array<Expr>>* var_map) {
Expand Down Expand Up @@ -193,7 +193,7 @@ class MergeCompositeWrapper : public ExprMutator {
Map<std::string, Array<Expr>> args_map;
Map<Expr, Expr> call_map;
auto extract = ExtractPattern(pattern, call, &args_map, &call_map);
if (extract.defined()) {
if (extract.defined() && static_cast<bool>(check_(extract))) {
auto free_vars = FreeVars(extract);
// make the composite function
auto f = Function(free_vars, extract, call->checked_type_, {}, DictAttrs());
Expand All @@ -215,17 +215,20 @@ class MergeCompositeWrapper : public ExprMutator {
std::string pattern_name_;
/*! \brief The pattern to match */
Expr pattern_;
/*! \brief The function to check whether an extract is supported */
PackedFunc check_;
};

Expr MergeComposite(const Expr& expr,
const Array<tir::StringImm>& pattern_names, const Array<Expr>& patterns) {
const Array<tir::StringImm>& pattern_names, const Array<Expr>& patterns, const std::vector<PackedFunc>& checks) {
CHECK_EQ(pattern_names.size(), patterns.size());
Expr merged_expr = expr;
// merge the patterns one-by-one in order
for (size_t i = 0; i < patterns.size(); i++) {
std::string pattern_name = pattern_names[i]->value;
Expr pattern = patterns[i];
merged_expr = MergeCompositeWrapper(pattern_name, pattern).Mutate(merged_expr);
PackedFunc check = checks[i];
merged_expr = MergeCompositeWrapper(pattern_name, pattern, check).Mutate(merged_expr);
}
return merged_expr;
}
Expand All @@ -235,18 +238,27 @@ Expr MergeComposite(const Expr& expr,
namespace transform {

Pass MergeComposite(const tvm::Array<tir::StringImm>& pattern_names,
const tvm::Array<Expr>& patterns) {
const tvm::Array<Expr>& patterns,
const std::vector<PackedFunc>& checks) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(
relay::merge_composite::MergeComposite(f, pattern_names, patterns));
relay::merge_composite::MergeComposite(f, pattern_names, patterns, checks));
};
auto func_pass = CreateFunctionPass(pass_func, 0, "MergeComposite", {});
return func_pass;
}

TVM_REGISTER_GLOBAL("relay._transform.MergeComposite")
.set_body_typed(MergeComposite);
.set_body([](TVMArgs args, TVMRetValue* rv) {
tvm::Array<tir::StringImm> pattern_names = args[0];
tvm::Array<Expr> patterns = args[1];
std::vector<PackedFunc> checks;
for (int i=2;i<args.size();i++) {
checks.push_back(args[i]);
}
*rv = MergeComposite(pattern_names, patterns, checks);
});

} // namespace transform

Expand Down
38 changes: 38 additions & 0 deletions tests/python/relay/test_pass_merge_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,43 @@ def expected():
assert tvm.ir.structural_equal(result, expected, map_free_vars=True)


def test_pattern_with_check():
def before():
x = relay.var('x', shape=(1, 10, 10, 10))
w = relay.var('w', shape=(10, 10, 3, 3))
b = relay.var('b', shape=(8,))
conv = relay.nn.conv2d(x,
w,
kernel_size=(3, 3),
kernel_layout="OIHW",
data_layout="NHWC")
bias = relay.nn.bias_add(conv, b)
relu = relay.nn.relu(bias)
return relay.Function([x, w, b], relu)

def _check_true(extract):
conv = extract.args[0].args[0]
return conv.attrs.data_layout == "NHWC"

def _check_false(extract):
conv = extract.args[0].args[0]
return conv.attrs.data_layout == "NCHW"

pattern_table_true = [
("conv_bias_relu", make_conv_bias_relu_pattern(), _check_true)
]
pattern_table_false = [
("conv_bias_relu", make_conv_bias_relu_pattern(), _check_false)
]

result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_false))
expected = run_opt_pass(before(), relay.transform.InferType())
assert tvm.ir.structural_equal(result, expected, map_free_vars=True)

result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_true))
assert result.body.op.attrs["Composite"] == "conv_bias_relu"


if __name__ == "__main__":
test_simple_merge()
test_branch_merge()
Expand All @@ -741,3 +778,4 @@ def expected():
test_multiple_input_subgraphs()
test_reuse_call_merge()
test_tuple_get_item_merge()
test_pattern_with_check()

0 comments on commit b18951f

Please sign in to comment.