Skip to content

Commit

Permalink
Support priority order in merge_composite
Browse files Browse the repository at this point in the history
The order in which the patterns are matched
was currently random as an unordered_map was
used to store the pattern table. This uses
arrays instead so that a distinct priority
order of matching can be defined. Additional
tests have also been added to verify this
behaviour.

Change-Id: Ief347df4262639138d5d9d7c8cee7ef233af7b56
  • Loading branch information
mbaret committed Feb 3, 2020
1 parent 60581ff commit 8f764a6
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 41 deletions.
32 changes: 29 additions & 3 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,21 +513,47 @@ def Legalize(legalize_map_attr_name="FTVMLegalize"):
return _transform.Legalize(legalize_map_attr_name)


def MergeComposite(compiler):
"""Merge multiple operators into a single composite relay function.
def AnnotateCompiler(compiler):
"""Annotate ops in an experession with a provied compiler and then use it
for codegen.
Parameters
----------
compiler : str
The compiler used for codegen.
Returns
-------
ret : tvm.relay.Pass
The annotated pass that wrapps ops with subgraph_start and
subgraph_end.
"""
return _transform.AnnotateCompiler(compiler)


def MergeComposite(pattern_table):
"""Merge multiple operators into a single composite relay function.
Parameters
----------
pattern_table : list(tuple)
A list of (pattern_name, pattern) tuples.
The order of the patterns in the list will determine the order
of priority in which they are matched.
Returns
-------
ret : tvm.relay.Pass
The registered pass that merges operators into a single composite
relay function.
"""
return _transform.MergeComposite(compiler)
pattern_names = []
patterns = []
for pattern_name, pattern in pattern_table:
pattern_names.append(pattern_name)
patterns.append(pattern)

return _transform.MergeComposite(pattern_names, patterns)


def RewriteAnnotatedOps(fallback_device):
Expand Down
65 changes: 37 additions & 28 deletions src/relay/pass/merge_composite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ namespace merge_composite {

class MergeCompositeWrapper : public ExprMutator {
public:
explicit MergeCompositeWrapper(const tvm::Map<std::string, Expr>& pattern_map)
: pattern_map_(pattern_map) {}
explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern)
: pattern_name_(pattern_name), pattern_(pattern) {}

bool MatchPattern(const Call& pattern, const Call& root) {
if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>())
Expand Down Expand Up @@ -135,49 +135,58 @@ class MergeCompositeWrapper : public ExprMutator {

Op op = Downcast<Op>(call->op);
CHECK(op.defined());
for (const auto& x : pattern_map_) {
Call pattern = Downcast<Call>(x.second);
if (Downcast<Op>(pattern->op)->name != op->name)
continue;

if (MatchPattern(pattern, call)) {
Map<std::string, Array<Expr>> args_map;
auto extract = ExtractPattern(pattern, call, &args_map);
auto free_vars = FreeVars(extract);
Function new_func = FunctionNode::make(free_vars, extract,
call->checked_type_, {}, Attrs());
new_func = FunctionSetAttr(new_func, attr::kComposite,
tir::StringImmNode::make(x.first));
new_func = FunctionSetAttr(new_func, attr::kPrimitive,
tvm::Integer(1));
Array<Expr> args;
for (const auto& free_var : free_vars) {
args.push_back(args_map[free_var->name_hint()][1]);
}
auto new_call = CallNode::make(new_func, args);
return std::move(new_call);
Call pattern = Downcast<Call>(pattern_);
if (Downcast<Op>(pattern->op)->name != op->name)
return std::move(call);

if (MatchPattern(pattern, call)) {
Map<std::string, Array<Expr>> args_map;
auto extract = ExtractPattern(pattern, call, &args_map);
auto free_vars = FreeVars(extract);
Function new_func = FunctionNode::make(free_vars, extract,
call->checked_type_, {}, Attrs());
new_func = FunctionSetAttr(new_func, attr::kComposite,
tir::StringImmNode::make(pattern_name_));
new_func = FunctionSetAttr(new_func, attr::kPrimitive,
tvm::Integer(1));
Array<Expr> args;
for (const auto& free_var : free_vars) {
args.push_back(args_map[free_var->name_hint()][1]);
}
auto new_call = CallNode::make(new_func, args);
return std::move(new_call);
}

return std::move(call);
}

private:
tvm::Map<std::string, Expr> pattern_map_;
std::string pattern_name_;
Expr pattern_;
};

Expr MergeComposite(const Expr& expr, const tvm::Map<std::string, Expr>& pattern) {
return MergeCompositeWrapper(pattern).Mutate(expr);
Expr MergeComposite(const Expr& expr,
const Array<tir::StringImm>& pattern_names, const Array<Expr>& patterns) {
CHECK(pattern_names.size() == patterns.size());
Expr merged_expr = expr;
for (size_t i = 0; i < patterns.size(); i++) {
std::string pattern_name = pattern_names[i]->value;
Expr pattern = patterns[i];
merged_expr = MergeCompositeWrapper(pattern_name, pattern).Mutate(merged_expr);
}
return merged_expr;
}

} // namespace merge_composite

namespace transform {

Pass MergeComposite(const tvm::Map<std::string, Expr>& pattern) {
Pass MergeComposite(const tvm::Array<tir::StringImm>& pattern_names,
const tvm::Array<Expr>& patterns) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::merge_composite::MergeComposite(f, pattern));
return Downcast<Function>(
relay::merge_composite::MergeComposite(f, pattern_names, patterns));
};
auto func_pass = CreateFunctionPass(pass_func, 0, "MergeComposite", {});
return func_pass;
Expand Down
143 changes: 133 additions & 10 deletions tests/python/relay/test_pass_merge_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Unit tests for merge composite."""
from tvm import expr
from tvm import relay
from tvm.relay.testing import run_opt_pass

Expand Down Expand Up @@ -122,9 +123,9 @@ def test_simple_merge():
relu
"""
pattern_table = {
"add_relu": make_add_relu_pattern()
}
pattern_table = [
("add_relu", make_add_relu_pattern())
]

def before():
a = relay.var('a', shape=(10, 10))
Expand Down Expand Up @@ -178,9 +179,9 @@ def test_branch_merge():
relu
"""

pattern_table = {
"add_sub_mul": make_add_sub_mul_pattern()
}
pattern_table = [
("add_sub_mul", make_add_sub_mul_pattern())
]

def before():
a = relay.var('a', shape=(10, 10))
Expand Down Expand Up @@ -244,10 +245,10 @@ def test_multiple_patterns():
| /
mul
"""
pattern_table = {
"conv2d_bias_relu": make_conv_bias_relu_pattern(),
"add_relu": make_add_relu_pattern()
}
pattern_table = [
("conv2d_bias_relu", make_conv_bias_relu_pattern()),
("add_relu", make_add_relu_pattern())
]

def before():
data = relay.var('data', shape=(1, 512, 28, 28))
Expand Down Expand Up @@ -310,7 +311,129 @@ def expected():
assert relay.analysis.alpha_equal(result, expected)


def test_merge_order():
"""Test that patterns are merged in the order they exist in the pattern table.
There can be cases where one pattern is a subgraph of another, in which case
it is not clear which match should take priority. The priority should come
from the order in which the patterns are declared in the pattern table. The
first patterns will be merged with highest priority and the last with lowest.
A: B: C:
add add abs
| | |
abs abs relu
|
relu
"""

def pattern_A():
x = relay.var('x')
y = relay.var('y')
out = relay.add(x, y)
out = relay.abs(out)
out = relay.nn.relu(out)
return out

def pattern_B():
x = relay.var('x')
y = relay.var('y')
out = relay.add(x, y)
out = relay.abs(out)
return out

def pattern_C():
x = relay.var('x')
out = relay.abs(x)
out = relay.nn.relu(x)
return out

def before():
input_1 = relay.var('input_1', shape=(10, 10))
input_2 = relay.var('input_2', shape=(10, 10))
out = relay.add(input_1, input_2)
out = relay.abs(out)
out = relay.nn.relu(out)
return relay.Function([input_1, input_2], out)

def after_A_priority():
input_1 = relay.var('input_1', shape=(10, 10))
input_2 = relay.var('input_2', shape=(10, 10))
x = relay.var('x')
y = relay.var('y')
out = relay.add(x, y)
out = relay.abs(out)
out = relay.nn.relu(out)
merged_func = relay.Function([x, y], out)
merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1))
merged_func = merged_func.set_attribute('Composite', expr.StringImm('A'))
ret = relay.Call(merged_func, [input_1, input_2])
return relay.Function([input_1, input_2], ret)

def after_B_priority():
input_1 = relay.var('input_1', shape=(10, 10))
input_2 = relay.var('input_2', shape=(10, 10))
x = relay.var('x')
y = relay.var('y')
out = relay.add(x, y)
out = relay.abs(out)
merged_func = relay.Function([x, y], out)
merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1))
merged_func = merged_func.set_attribute('Composite', expr.StringImm('B'))
merged_call = relay.Call(merged_func, [input_1, input_2])
ret = relay.nn.relu(merged_call)
return relay.Function([input_1, input_2], ret)

def after_C_priority():
input_1 = relay.var('input_1', shape=(10, 10))
input_2 = relay.var('input_2', shape=(10, 10))
add = relay.add(input_1, input_2)
x = relay.var('x')
out = relay.abs(x)
out = relay.nn.relu(out)
merged_func = relay.Function([x], out)
merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1))
merged_func = merged_func.set_attribute('Composite', expr.StringImm('C'))
ret = relay.Call(merged_func, [add])
return relay.Function([input_1, input_2], ret)

# check A highest priority
pattern_table = [
("A", pattern_A()),
("B", pattern_B()),
("C", pattern_C()),
]
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after_A_priority(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)

# check B highest priority
pattern_table = [
("B", pattern_A()),
("C", pattern_B()),
("A", pattern_C()),
]
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after_A_priority(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)

# check C highest priority
pattern_table = [
("C", pattern_A()),
("A", pattern_B()),
("B", pattern_C()),
]
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after_A_priority(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)


if __name__ == "__main__":
test_simple_merge()
test_branch_merge()
test_multiple_patterns()
test_merge_order()

0 comments on commit 8f764a6

Please sign in to comment.