Skip to content

Commit

Permalink
[BYOC, MergeComposite] Add additional check before re-using the cache…
Browse files Browse the repository at this point in the history
…d match (apache#5552)

* Add additional check before re-using the cached match in merge composite

* clean up ExtractPattern calls
  • Loading branch information
masahi authored and trevor-m committed Jun 18, 2020
1 parent 3de5727 commit 6d54e3d
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 8 deletions.
11 changes: 3 additions & 8 deletions src/relay/transforms/merge_composite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,12 @@ class MergeCompositeWrapper : public ExprMutator {
for (const auto& arg : pattern->args) {
Expr new_arg;
if (arg->IsInstance<CallNode>()) {
new_arg =
ExtractPattern(Downcast<Call>(arg), Downcast<Call>(root->args[i]), var_map, call_map);
// if we've already processed this call node, return the previous result
if (call_map->find(arg) != call_map->end()) {
if (call_map->find(arg) != call_map->end() && new_arg.defined()) {
new_arg = (*call_map)[arg];
} else {
// fail if the root argument is not also a call node
if (!root->args[i]->IsInstance<CallNode>()) {
return Expr();
}
// if it's a call node, recursively call this function
new_arg =
ExtractPattern(Downcast<Call>(arg), Downcast<Call>(root->args[i]), var_map, call_map);
call_map->Set(arg, new_arg);
}
} else if (arg->IsInstance<VarNode>()) {
Expand Down
39 changes: 39 additions & 0 deletions tests/python/relay/test_pass_merge_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,44 @@ def _check_false(extract):
assert result.body.op.attrs["Composite"] == "conv_bias_relu"


def test_diamond_not_merge():
"""
The pattern on the left shouldn't match the structure on the right
relu relu
| \ | \
| clip | add
| / | |
mul | clip
| /
mul
"""
def get_pattern():
conv = make_conv_bias_relu_pattern()
clip = relay.op.clip(conv, 0, 255)
return relay.op.multiply(conv, clip)

def get_net():
data = relay.var('data', shape=(1, 512, 28, 28))
kernel = relay.var('kernel', shape=(256, 512, 1, 1))
conv = relay.nn.conv2d(data, kernel,
kernel_size=(1, 1),
padding=(0, 0),
strides=(1, 1))
bias = relay.nn.bias_add(conv, relay.var('bias', shape=(256,)))
relu = relay.nn.relu(bias)
add = relay.op.add(relu, relay.const(1.0))
clip2 = relay.op.clip(add, 0, 255)
mul = relay.op.multiply(relu, clip2)
return relay.Function(relay.analysis.free_vars(mul), mul)

pat_table = [("pat", get_pattern())]
net = get_net()
result = run_opt_pass(net, relay.transform.MergeComposite(pat_table))
expected = run_opt_pass(net, relay.transform.InferType())
assert tvm.ir.structural_equal(result, expected, map_free_vars=True)


if __name__ == "__main__":
test_simple_merge()
test_branch_merge()
Expand All @@ -775,3 +813,4 @@ def _check_false(extract):
test_reuse_call_merge()
test_tuple_get_item_merge()
test_pattern_with_check()
test_diamond_not_merge()

0 comments on commit 6d54e3d

Please sign in to comment.