Skip to content

Commit

Permalink
Merge composite additional test
Browse files Browse the repository at this point in the history
Change-Id: I9bc7d6053c575e9468ac5abc31214c6ad8507e46
  • Loading branch information
lhutton1 authored and mbaret committed Jan 24, 2020
1 parent d0d7645 commit ccca643
Showing 1 changed file with 163 additions and 5 deletions.
168 changes: 163 additions & 5 deletions tests/python/relay/test_pass_merge_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,48 @@
from tvm import relay
from tvm.relay.testing import run_opt_pass

"""
The merge composite pass is designed to merge multiple relay operators, that
match a given pattern, and combine them into a single relay function.
For example suppose we have the graph:
conv2d
| (merge composite pass)
bias_add ====> conv2d_bias_relu
| (our target)
relu
Our Relay IR before the pass:
fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
%bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
%0 = nn.conv2d(%data, %kernel, kernel_size=[1, 1])
/* ty=Tensor[(1, 256, 28, 28), float32] */;
%1 = nn.bias_add(%0, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */;
nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
}
Our Relay IR after the pass:
fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
%bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
%2 = fn (%x: Tensor[(1, 512, 28, 28), float32], %y: Tensor[(256, 512, 1, 1), float32],
%z: Tensor[(256), float32], Primitive=1, Composite="conv2d_bias_relu") ->
Tensor[(1, 256, 28, 28), float32] {
%0 = nn.conv2d(%x, %y, kernel_size=[1, 1]) /* ty=Tensor[(1, 256, 28, 28), float32] */;
%1 = nn.bias_add(%0, %z) /* ty=Tensor[(1, 256, 28, 28), float32] */;
nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
};
%2(%data, %kernel, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */
}
As you can see in the second relay example, the pattern we specified has been wrapped
in a function. The function is then called, producing the same result as the first relay
example.
One convenient use for this pass is to offload multiple operators to a single external
codegen function.
"""


def make_add_sub_mul_pattern():
"""Create a pattern to match the following graph.
Expand All @@ -40,7 +82,7 @@ def make_add_relu_pattern():
add
|
ReLu
relu
"""
x = relay.var('x')
y = relay.var('y')
Expand All @@ -49,6 +91,24 @@ def make_add_relu_pattern():
return r


def make_conv_bias_relu_pattern():
"""Create a pattern to match the following graph.
conv2d
|
bias_add
|
relu
"""
x = relay.var('x')
y = relay.var('y')
z = relay.var('z')
conv_node = relay.nn.conv2d(x, y)
bias_node = relay.nn.bias_add(conv_node, z)
r = relay.nn.relu(bias_node)
return r


def test_simple_merge():
"""Test composite function is correctly produced from simple graph.
Expand All @@ -59,11 +119,11 @@ def test_simple_merge():
\ / a b
add ====> \ /
| add_relu
ReLu
relu
"""
pattern_table = {
"add_sub_mul": make_add_relu_pattern()
"add_relu": make_add_relu_pattern()
}

def before():
Expand All @@ -89,6 +149,7 @@ def expected():
return relay.Function([a, b], r)

result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(expected(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)

Expand All @@ -109,12 +170,12 @@ def test_branch_merge():
c / c | ====> add_sub_mul
\/ \/ |
add sub |
\ / ReLu
\ / relu
\ /
mul
|
|
ReLu
relu
"""

pattern_table = {
Expand Down Expand Up @@ -154,5 +215,102 @@ def expected():
return relay.Function([a, b, c], r)

result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(expected(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)


def test_multiple_patterns():
"""Test different patterns are merged correctly in the graph.
We would expect the pattern `make_conv_bias_relu_pattern` to be merged
into a single op `conv_bias_relu`. We would also expect `make_add_relu_pattern`
to be merged into a single op `add_relu`.
data kernel
\ /
\ /
conv2d data kernel bias
| \ | /
| bias conv2d_bias_relu
| / |
bias_add ====> | a
| | /
relu a add_relu
\ / |
add | b
| | /
relu b mul
| /
mul
"""
pattern_table = {
"conv2d_bias_relu": make_conv_bias_relu_pattern(),
"add_relu": make_add_relu_pattern()
}

def before():
data = relay.var('data', shape=(1, 512, 28, 28))
kernel = relay.var('kernel', shape=(256, 512, 1, 1))
bias = relay.var('bias', shape=(256,))
a = relay.var('a', shape=(1, 256, 28, 28))
b = relay.var('b', shape=(1, 256, 28, 28))

conv_node = relay.nn.conv2d(data,
kernel,
kernel_size=(1, 1),
padding=(0, 0),
strides=(1, 1))

bias_node = relay.nn.bias_add(conv_node, bias)
relu_node = relay.nn.relu(bias_node)
add_node = relay.add(relu_node, a)
relu_node_2 = relay.nn.relu(add_node)
r = relay.multiply(relu_node_2, b)
return relay.Function([data, kernel, bias], r)

def expected():
data = relay.var('data', shape=(1, 512, 28, 28))
kernel = relay.var('kernel', shape=(256, 512, 1, 1))
bias = relay.var('bias', shape=(256,))
a = relay.var('a', shape=(1, 256, 28, 28))
b = relay.var('b', shape=(1, 256, 28, 28))

# conv_bias_relu function
in_1 = relay.var('in_1', shape=(1, 512, 28, 28))
in_2 = relay.var('in_2', shape=(256, 512, 1, 1))
in_3 = relay.var('in_3', shape=(256,))

conv_node = relay.nn.conv2d(in_1,
in_2,
kernel_size=(1, 1),
padding=(0, 0),
strides=(1, 1))

bias_node = relay.nn.bias_add(conv_node, in_3)
r = relay.nn.relu(bias_node)
conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r)

# add_relu function
in_4 = relay.var('in_4', shape=(1, 256, 28, 28))
in_5 = relay.var('in_5', shape=(1, 256, 28, 28))
add_node = relay.add(in_4, in_5)
r = relay.nn.relu(add_node)
add_relu = relay.Function([in_4, in_5], r)

# merged function
conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias])
add_relu_1 = relay.Call(add_relu, [conv_bias_add_relu_1, a])
r = relay.multiply(add_relu_1, b)
return relay.Function([data, kernel, bias, a, b], r)

result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(expected(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)


if __name__ == "__main__":
test_simple_merge()
test_branch_merge()
test_multiple_patterns()

0 comments on commit ccca643

Please sign in to comment.