diff --git a/src/relay/transforms/merge_composite.cc b/src/relay/transforms/merge_composite.cc index 75d95f0378f1..3e3501a691e5 100644 --- a/src/relay/transforms/merge_composite.cc +++ b/src/relay/transforms/merge_composite.cc @@ -121,17 +121,12 @@ class MergeCompositeWrapper : public ExprMutator { for (const auto& arg : pattern->args) { Expr new_arg; if (arg->IsInstance()) { + new_arg = + ExtractPattern(Downcast(arg), Downcast(root->args[i]), var_map, call_map); // if we've already processed this call node, return the previous result - if (call_map->find(arg) != call_map->end()) { + if (call_map->find(arg) != call_map->end() && new_arg.defined()) { new_arg = (*call_map)[arg]; } else { - // fail if the root argument is not also a call node - if (!root->args[i]->IsInstance()) { - return Expr(); - } - // if it's a call node, recursively call this function - new_arg = - ExtractPattern(Downcast(arg), Downcast(root->args[i]), var_map, call_map); call_map->Set(arg, new_arg); } } else if (arg->IsInstance()) { diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py index e3c8991c8ebc..317bb421477c 100644 --- a/tests/python/relay/test_pass_merge_composite.py +++ b/tests/python/relay/test_pass_merge_composite.py @@ -765,6 +765,44 @@ def _check_false(extract): assert result.body.op.attrs["Composite"] == "conv_bias_relu" +def test_diamond_not_merge(): + """ + The pattern on the left shouldn't match the structure on the right + + relu relu + | \ | \ + | clip | add + | / | | + mul | clip + | / + mul + """ + def get_pattern(): + conv = make_conv_bias_relu_pattern() + clip = relay.op.clip(conv, 0, 255) + return relay.op.multiply(conv, clip) + + def get_net(): + data = relay.var('data', shape=(1, 512, 28, 28)) + kernel = relay.var('kernel', shape=(256, 512, 1, 1)) + conv = relay.nn.conv2d(data, kernel, + kernel_size=(1, 1), + padding=(0, 0), + strides=(1, 1)) + bias = relay.nn.bias_add(conv, relay.var('bias', shape=(256,))) + relu = relay.nn.relu(bias) + add = relay.op.add(relu, relay.const(1.0)) + clip2 = relay.op.clip(add, 0, 255) + mul = relay.op.multiply(relu, clip2) + return relay.Function(relay.analysis.free_vars(mul), mul) + + pat_table = [("pat", get_pattern())] + net = get_net() + result = run_opt_pass(net, relay.transform.MergeComposite(pat_table)) + expected = run_opt_pass(net, relay.transform.InferType()) + assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + + if __name__ == "__main__": test_simple_merge() test_branch_merge() @@ -775,3 +813,4 @@ def _check_false(extract): test_reuse_call_merge() test_tuple_get_item_merge() test_pattern_with_check() + test_diamond_not_merge()