From 04496d326553419f6b2970e88b80c60c4dbb174d Mon Sep 17 00:00:00 2001 From: mbaret <55580676+mbaret@users.noreply.github.com> Date: Fri, 12 Jun 2020 16:43:33 +0100 Subject: [PATCH] [BYOC][FIX] Infer types in MergeComposite (#5766) If InferType isn't run between partitioning passes, function calls are inserted which don't have a type. This can result in failures for patterns which want to check types. This works around it simply by running InferType after every partitioning. Change-Id: Ie0887f0564a41eb0913bfe42a362e8effe9681b9 --- src/relay/transforms/merge_composite.cc | 13 +++- .../python/relay/test_pass_merge_composite.py | 64 ++++++++++++++----- 2 files changed, 59 insertions(+), 18 deletions(-) diff --git a/src/relay/transforms/merge_composite.cc b/src/relay/transforms/merge_composite.cc index 34ebca900ada..324b2cb3a1c4 100644 --- a/src/relay/transforms/merge_composite.cc +++ b/src/relay/transforms/merge_composite.cc @@ -36,17 +36,24 @@ namespace tvm { namespace relay { namespace merge_composite { +Function InferType(const Function& expr) { + auto mod = IRModule::FromExpr(expr); + mod = transform::InferType()(mod); + return Downcast(mod->Lookup("main")); +} + Expr MergeComposite(const Function& func, const Array& pattern_names, const Array& patterns, const std::vector& checks) { CHECK_EQ(pattern_names.size(), patterns.size()); - Expr merged_expr = func->body; + Function merged_func = func; // merge the patterns one-by-one in order for (size_t i = 0; i < patterns.size(); i++) { Map attrs; attrs.Set("Composite", pattern_names[i]); - merged_expr = PartitionPattern(patterns[i], merged_expr, attrs, checks[i]); + merged_func = Downcast(PartitionPattern(patterns[i], merged_func, attrs, checks[i])); + merged_func = InferType(merged_func); } - return Function(func->params, merged_expr, func->ret_type, func->type_params, func->attrs); + return std::move(merged_func); } } // namespace merge_composite diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py index 12679c45c7c6..f2d615e9046a 100644 --- a/tests/python/relay/test_pass_merge_composite.py +++ b/tests/python/relay/test_pass_merge_composite.py @@ -916,31 +916,63 @@ def before(): x = relay.var('x', shape=(1, 10, 10, 10)) w = relay.var('w', shape=(10, 10, 3, 3)) b = relay.var('b', shape=(8,)) - conv = relay.nn.conv2d(x, + add = relay.op.add(x, x) + relu = relay.nn.relu(add) + conv = relay.nn.conv2d(relu, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC") bias = relay.nn.bias_add(conv, b) - relu = relay.nn.relu(bias) - return relay.Function([x, w, b], relu) + relu2 = relay.nn.relu(bias) + return run_opt_pass(relay.Function([x, w, b], relu2), relay.transform.InferType()) - def expected(): - x = relay.var('x') - w = relay.var('w') - b = relay.var('b') - conv = relay.nn.conv2d(x, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC") + def expected_false(): + x = relay.var('x', shape=(1, 10, 10, 10)) + w = relay.var('w', shape=(10, 10, 3, 3)) + b = relay.var('b', shape=(8, )) + + x0 = relay.var('x') + y0 = relay.var('y') + + add = relay.op.add(y0, y0) + relu = relay.nn.relu(add) + func = relay.Function([x0, y0], relu) + func = func.with_attr("PartitionedFromPattern", "add_nn.relu_") + func = func.with_attr("Composite", "add_relu") + call = relay.Call(func, [x, x]) + + conv = relay.nn.conv2d(call, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC") bias = relay.nn.bias_add(conv, b) - relu = relay.nn.relu(bias) - func = relay.Function([x, w, b], relu) - func = func.with_attr("Composite", "conv_bias_relu") - func = func.with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_") + relu2 = relay.nn.relu(bias) + return relay.Function([x, w, b], relu2) + def expected_true(): x = relay.var('x', shape=(1, 10, 10, 10)) w = relay.var('w', shape=(10, 10, 3, 3)) b = relay.var('b', shape=(8, )) - return relay.Function([x, w, b], func(x, w, b)) + x0 = relay.var('x') + y0 = relay.var('y') + + add = relay.op.add(y0, y0) + relu = relay.nn.relu(add) + func = relay.Function([x0, y0], relu) + func = func.with_attr("PartitionedFromPattern", "add_nn.relu_") + func = func.with_attr("Composite", "add_relu") + call = relay.Call(func, [x, x]) + + x2 = relay.var('x') + w1 = relay.var('w') + b1 = relay.var('b') + conv = relay.nn.conv2d(x2, w1, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC") + bias = relay.nn.bias_add(conv, b1) + relu2 = relay.nn.relu(bias) + func = relay.Function([x2, w1, b1], relu2) + func = func.with_attr("Composite", "conv_bias_relu") + func = func.with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_") + call = relay.Call(func, [call, w, b]) + return relay.Function([x, w, b], call) def _check_type_true(extract): conv = extract.args[0].args[0] @@ -953,14 +985,16 @@ def _check_type_false(extract): return bool(typ.shape[0] != 1) pattern_table_false = [ + ("add_relu", make_add_relu_pattern()), ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_false) ] - check_result(pattern_table_false, before(), before()) + check_result(pattern_table_false, before(), expected_false()) pattern_table_true = [ + ("add_relu", make_add_relu_pattern()), ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_true) ] - check_result(pattern_table_true, before(), expected()) + check_result(pattern_table_true, before(), expected_true()) if __name__ == "__main__":