Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Pass] Fix bug in re-processing call node in MergeComposite pass #4879

Merged
merged 10 commits into from
Feb 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions src/relay/pass/merge_composite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class MergeCompositeWrapper : public ExprMutator {
* a new Relay expression ready to be wrapped into a composite function.
*/
Expr ExtractPattern(const Call& pattern, const Call& root,
Map<std::string, Array<Expr>>* var_map) {
Map<std::string, Array<Expr>>* var_map, Map<Expr, Expr>* call_map) {
// check to make sure both calls are to operators (not functions)
if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>())
return Expr();
Expand All @@ -99,14 +99,20 @@ class MergeCompositeWrapper : public ExprMutator {
for (const auto& arg : pattern->args) {
Expr new_arg;
if (arg->IsInstance<CallNode>()) {
// fail if the root argument is not also a call node
if (!root->args[i]->IsInstance<CallNode>()) {
return Expr();
// if we've already processed this call node, return the previous result
if (call_map->find(arg) != call_map->end()) {
new_arg = (*call_map)[arg];
} else {
// fail if the root argument is not also a call node
if (!root->args[i]->IsInstance<CallNode>()) {
return Expr();
}
// if it's a call node, recursively call this function
new_arg = ExtractPattern(Downcast<Call>(arg),
Downcast<Call>(root->args[i]),
var_map, call_map);
call_map->Set(arg, new_arg);
}
// if it's a call node, recursively call this function
new_arg = ExtractPattern(Downcast<Call>(arg),
Downcast<Call>(root->args[i]),
var_map);
} else if (arg->IsInstance<VarNode>()) {
// if there's a var in the pattern, it must be a free var
// so call the function to update the var_map
Expand Down Expand Up @@ -155,7 +161,8 @@ class MergeCompositeWrapper : public ExprMutator {
Call pattern = Downcast<Call>(pattern_);
CHECK(pattern.defined());
Map<std::string, Array<Expr>> args_map;
auto extract = ExtractPattern(pattern, call, &args_map);
Map<Expr, Expr> call_map;
auto extract = ExtractPattern(pattern, call, &args_map, &call_map);
if (extract.defined()) {
auto free_vars = FreeVars(extract);
// make the composite function
Expand Down
82 changes: 82 additions & 0 deletions tests/python/relay/test_pass_merge_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,26 @@ def make_conv_bias_relu_pattern():
return r


def make_add_add_add_pattern():
"""Create a pattern to match the following graph.
Useful for testing re-using a call node.

x y
/ \ /
| add
\ | \
add |
| /
add
"""
x = relay.var('x')
y = relay.var('y')
add_node = relay.add(x, y)
add_node_1 = relay.add(x, add_node)
r = relay.add(add_node_1, add_node)
return r


def test_simple_merge():
"""Test composite function is correctly produced from simple graph.

Expand Down Expand Up @@ -239,6 +259,67 @@ def expected():
assert relay.analysis.alpha_equal(result, expected)


def test_reuse_call_merge():
"""Test composite function is correctly produced from simple graph
which re-uses call nodes.

We could expect the pattern `make_add_add_add` to be merged
into a single op `add_add_add`.

x y
\ / \
sub | x y
/ | / \ / |
| add ====> sub |
\ | \ | /
add | add_add_add
| /
add

"""
pattern_table = [
("add_add_add", make_add_add_add_pattern())
]

def before():
a = relay.var('a', shape=(10, 10))
b = relay.var('b', shape=(10, 10))
sub_node = relay.subtract(a, b)

# pattern
add_node = relay.add(sub_node, b)
add_node_1 = relay.add(sub_node, add_node)
r = relay.add(add_node_1, add_node)

return relay.Function([a, b], r)

def expected():
a = relay.var('a', shape=(10, 10))
b = relay.var('b', shape=(10, 10))

# add_relu_add function
in_1 = relay.var('in_1', shape=(10, 10))
in_2 = relay.var('in_2', shape=(10, 10))
add_node = relay.add(in_1, in_2)
add_node_1 = relay.add(in_1, add_node)
add_node_2 = relay.add(add_node_1, add_node)
add_add_add = relay.Function([in_1, in_2], add_node_2)
add_add_add = add_add_add.set_attribute("Primitive",
tir.IntImm("int32", 1))
add_add_add = add_add_add.set_attribute("Composite",
tir.StringImm("add_add_add"))

# merged function
sub_node = relay.subtract(a, b)
call = relay.Call(add_add_add, [sub_node, b])
return relay.Function([a, b], call)

result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(expected(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)


def test_multiple_patterns():
"""Test different patterns are merged correctly in the graph.

Expand Down Expand Up @@ -608,3 +689,4 @@ def after_B():
test_merge_order()
test_parallel_merge()
test_multiple_input_subgraphs()
test_reuse_call_merge()