Skip to content

Commit

Permalink
Added further merge composite tests
Browse files Browse the repository at this point in the history
Change-Id: Ib1d800409fca4c1834c7fe0cab5a26ab99a26820
  • Loading branch information
mbaret committed Feb 10, 2020
1 parent 7d63fd1 commit f156aab
Showing 1 changed file with 170 additions and 0 deletions.
170 changes: 170 additions & 0 deletions tests/python/relay/test_pass_merge_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,178 @@ def after_C_priority():
assert relay.analysis.alpha_equal(result, expected)


def test_parallel_merge():
"""Tests that parallel patterns relying on the same inputs are correctly merged.
The test graph is difficult to draw out as ascii art. It is essentially two parallel
add-sub-mul units which both consume input_1 and input_2 with their results being multiplied
to give the output. We expect both parallel branches should get merged and both should still
consume the same input variables, input_1 and input_2."""

def before():
input_1 = relay.var('input_1', shape=(10, 10))
input_2 = relay.var('input_2', shape=(10, 10))
branch_1_add = relay.add(input_1, input_2)
branch_1_sub = relay.subtract(input_1, input_2)
branch_1 = relay.multiply(branch_1_add, branch_1_sub)
branch_2_add = relay.add(input_1, input_2)
branch_2_sub = relay.subtract(input_1, input_2)
branch_2 = relay.multiply(branch_2_add, branch_2_sub)
out = relay.multiply(branch_1, branch_2)
return relay.Function([input_1, input_2], out)

def after():
input_1 = relay.var('input_1', shape=(10, 10))
input_2 = relay.var('input_2', shape=(10, 10))
x = relay.var('x')
y = relay.var('y')
branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y))
func_1 = relay.Function([x, y], branch_1)
call_1 = relay.Call(func_1, [input_1, input_2])
x1 = relay.var('x1')
y1 = relay.var('y1')
branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1))
func_2 = relay.Function([x1, y1], branch_2)
call_2 = relay.Call(func_2, [input_1, input_2])
out = relay.multiply(call_1, call_2)
return relay.Function([input_1, input_2], out)

pattern_table = [
("add_sub_mul", make_add_sub_mul_pattern())
]
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)


def test_multiple_input_subgraphs():
"""Test the case when multiple input subgraphs feed into another subgraph.
(1) (2) (3) (4)
add add add add
| | | |
relu relu relu relu
\ / \ /
\ / \ /
add sub
\ /
\ /
\ /
mul
----> When 1=3 and 2=4 (Case 'A')
add_relu add_relu
\ /
\ /
add_sub_mul
----> When 1!=3 and 2!=4 (Case 'B')
add_relu add_relu add_relu add_relu
\ / \ /
\ / \ /
add sub
\ /
-------- -----
\ /
mul
The difference in behaviour comes from the fact that add_sub_mul expects that the
inputs to add and sub are identical (the same two relay expressions). So when you
have 4 independent inputs, the pattern should not be merged.
"""

def before():
before_funcs = {}
inputs = [relay.var('input_' + str(i), shape=(10, 10)) for i in range(8)]
add_relu_1 = relay.add(inputs[0], inputs[1])
add_relu_1 = relay.nn.relu(add_relu_1)
add_relu_2 = relay.add(inputs[2], inputs[3])
add_relu_2 = relay.nn.relu(add_relu_2)
add_relu_3 = relay.add(inputs[4], inputs[5])
add_relu_3 = relay.nn.relu(add_relu_3)
add_relu_4 = relay.add(inputs[6], inputs[7])
add_relu_4 = relay.nn.relu(add_relu_4)
add = relay.add(add_relu_1, add_relu_2)
sub = relay.subtract(add_relu_3, add_relu_4)
out = relay.multiply(add, sub)
before_funcs['B'] = relay.Function(inputs, out)
sub = relay.subtract(add_relu_1, add_relu_2)
out = relay.multiply(add, sub)
before_funcs['A'] = relay.Function(inputs[:4], out)
return before_funcs

def after_A():
inputs = [relay.var('input_' + str(i), shape=(10, 10)) for i in range(4)]
x = relay.var('x')
y = relay.var('y')
add_relu_1 = relay.add(x, y)
add_relu_1 = relay.nn.relu(add_relu_1)
add_relu_1 = relay.Function([x, y], add_relu_1)
add_relu_1 = add_relu_1.set_attribute('Primitive', expr.IntImm('int32', 1))
add_relu_1 = add_relu_1.set_attribute('Composite', expr.StringImm('add_relu'))
add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]])
x1 = relay.var('x1')
y1 = relay.var('y1')
add_relu_2 = relay.add(x1, y1)
add_relu_2 = relay.nn.relu(add_relu_2)
add_relu_2 = relay.Function([x1, y1], add_relu_2)
add_relu_2 = add_relu_2.set_attribute('Primitive', expr.IntImm('int32', 1))
add_relu_2 = add_relu_2.set_attribute('Composite', expr.StringImm('add_relu'))
add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]])
x2 = relay.var('x2')
y2 = relay.var('y2')
add = relay.add(x2, y2)
sub = relay.subtract(x2, y2)
add_sub_mul = relay.multiply(add, sub)
add_sub_mul = relay.Function([x2, y2], add_sub_mul)
add_sub_mul = add_sub_mul.set_attribute('Primitive', expr.IntImm('int32', 1))
add_sub_mul = add_sub_mul.set_attribute('Composite', expr.StringImm('add_sub_mul'))
add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2])
return relay.Function(inputs, add_sub_mul_call)

def after_B():
inputs = [relay.var('input_' + str(i), shape=(10, 10)) for i in range(8)]
add_relu_calls = []
for i in range(4):
x = relay.var('x' + str(i))
y = relay.var('x' + str(i))
add_relu = relay.add(x, y)
add_relu = relay.nn.relu(add_relu)
add_relu = relay.Function([x, y], add_relu)
add_relu = add_relu.set_attribute('Primitive', expr.IntImm('int32', 1))
add_relu = add_relu.set_attribute('Composite', expr.StringImm('add_relu'))
add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]])
add_relu_calls.append(add_relu_call)

add = relay.add(add_relu_calls[0], add_relu_calls[1])
sub = relay.subtract(add_relu_calls[2], add_relu_calls[3])
out = relay.multiply(add, sub)
return relay.Function(inputs, out)

pattern_table = [
("add_sub_mul", make_add_sub_mul_pattern()),
("add_relu", make_add_relu_pattern())
]
# check case 'A'
result = run_opt_pass(before()['A'], relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after_A(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)

# check case 'B'
result = run_opt_pass(before()['B'], relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after_B(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)


if __name__ == "__main__":
test_simple_merge()
test_branch_merge()
test_multiple_patterns()
test_merge_order()
test_parallel_merge()
test_multiple_input_subgraphs()

0 comments on commit f156aab

Please sign in to comment.