From 03891e80e7d3b0195c82a9b4bf16cbf9a8b60214 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 13 Nov 2018 19:50:48 +0800 Subject: [PATCH] Get channels info from type instead of attrs --- src/relay/pass/combine_parallel_conv2d.cc | 54 ++++++++--------- src/relay/pass/pattern_util.h | 13 +++++ .../test_pass_combine_parallel_conv2d.py | 58 +++++++------------ 3 files changed, 63 insertions(+), 62 deletions(-) diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 60f643e436b2d..61dfc632086fd 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -48,9 +48,8 @@ std::tuple TransformWeight(std::vector convolu Array weights; for (const CallNode* n : convolutions) { weights.push_back(n->args[1]); - auto channels = as_const_int(n->attrs.as()->channels); - CHECK(channels); - num_filters += *channels; + auto channels = GetConv2DSuperChannelsDim(n); + num_filters += channels; } auto index = convolutions[0]->attrs.as()->weight_layout.find('O'); CHECK_NE(index, std::string::npos); @@ -60,17 +59,26 @@ std::tuple TransformWeight(std::vector convolu // Two 2d convolutions can be combined if they have the same attributes or only have // different output channels. -bool IsCompatibleConv2D(const Conv2DAttrs& a, const Conv2DAttrs& b) { +bool IsCompatibleConv2D(const CallNode* a, const CallNode* b) { AttrsEqual eq; - return eq(a.strides, b.strides) && - eq(a.padding, b.padding) && - eq(a.dilation, b.dilation) && - eq(a.groups, b.groups) && - eq(a.kernel_size, b.kernel_size) && - eq(a.data_layout, b.data_layout) && - eq(a.weight_layout, b.weight_layout) && - eq(a.out_dtype, b.out_dtype) && - eq(a.out_layout, b.out_layout); + static const Layout kOIHW("OIHW"); + auto attrs_a = a->attrs.as(); + auto attrs_b = b->attrs.as(); + auto tweight_a = a->args[1]->type_as(); + auto tweight_b = b->args[1]->type_as(); + auto shape_a = ConvertLayout(tweight_a->shape, attrs_a->weight_layout, kOIHW); + auto shape_b = ConvertLayout(tweight_b->shape, attrs_b->weight_layout, kOIHW); + + return eq(attrs_a->strides, attrs_b->strides) && + eq(attrs_a->padding, attrs_b->padding) && + eq(attrs_a->dilation, attrs_b->dilation) && + eq(attrs_a->groups, attrs_b->groups) && + eq(attrs_a->data_layout, attrs_b->data_layout) && + eq(attrs_a->weight_layout, attrs_b->weight_layout) && + eq(attrs_a->out_dtype, attrs_b->out_dtype) && + eq(attrs_a->out_layout, attrs_b->out_layout) && + eq(shape_a[2], shape_b[2]) && + eq(shape_a[3], shape_b[3]); } Expr MakeCombinedConv2D(const Expr& data, const std::vector& convolutions) { @@ -113,14 +121,11 @@ Expr CombineParallelConv2D(const Expr& expr) { for (size_t i = 0; i < children.size(); i++) { const CallNode* n = children[i]; - auto args = n->attrs.as(); // assign a group id or create a new group for each conv2d auto it = std::find_if(groups.begin(), groups.end(), [&](const std::vector& group) { - const CallNode* group_root = *(group.begin()); - auto group_args = group_root->attrs.as(); - return IsCompatibleConv2D(*args, *group_args); + return IsCompatibleConv2D(n, group[0]); }); if (it != groups.end()) { @@ -134,22 +139,19 @@ Expr CombineParallelConv2D(const Expr& expr) { } for (const auto& convs : groups) { - if (convs.size() < 2) { - continue; - } - auto new_conv2d = MakeCombinedConv2D(data, convs); + if (convs.size() < 2) continue; + auto new_conv2d = MakeCombinedConv2D(data, convs); int64_t start = 0; // replace original conv2d with slice of output of the new conv2d - for (const auto& conv2d : convs) { + for (const CallNode* conv2d : convs) { auto params = conv2d->attrs.as(); - auto channels = as_const_int(params->channels); - CHECK(channels); - auto indices = MakeConstantArrayFromRange(Int(64), start, start + *channels); + auto channels = GetConv2DSuperChannelsDim(conv2d); + auto indices = MakeConstantArrayFromRange(Int(64), start, start + channels); auto channel_index = params->data_layout.find('C'); CHECK_NE(channel_index, std::string::npos); auto take = MakeTake(new_conv2d, indices, channel_index); - start += *channels; + start += channels; subst_map[GetRef(conv2d)] = take; } } diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index fdcf8ddd0b6c1..cb91b8f285565 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -120,6 +120,19 @@ inline bool IsDepthwiseConv2D(const Call& call, is_const_int(wshape[1], 1); } +/*! + * \brief Get super-dimension of output channels of conv2d + * \param call The conv2d call. + * \return Super-dimension size of output channels of conv2d. + */ +inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) { + auto param = call->attrs.as(); + auto tweight = call->args[1]->type_as(); + auto index = param->weight_layout.find('O'); + CHECK_NE(index, std::string::npos); + auto channels = as_const_int(tweight->shape[index]); + return *channels; +} /*! * \brief Create a Constant with a scalar diff --git a/tests/python/relay/test_pass_combine_parallel_conv2d.py b/tests/python/relay/test_pass_combine_parallel_conv2d.py index ce34f0caed89d..25c788bc9ff2a 100644 --- a/tests/python/relay/test_pass_combine_parallel_conv2d.py +++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py @@ -4,25 +4,13 @@ def test_combine_parallel_conv2d(): """Simple testcase.""" - def before(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4): + def before(x, w1, w2, w3, w4): args = [x, w1, w2, w3, w4] - y1 = relay.nn.conv2d(x, w1, - channels=channels1, - kernel_size=(3, 3), - padding=(1, 1)) - y2 = relay.nn.conv2d(x, w2, - channels=channels2, - kernel_size=(3, 3), - padding=(1, 1)) + y1 = relay.nn.conv2d(x, w1) + y2 = relay.nn.conv2d(x, w2) # y3 is not foldable - y3 = relay.nn.conv2d(x, w3, - channels=channels3, - kernel_size=(1, 1), - padding=(1, 1)) - y4 = relay.nn.conv2d(x, w4, - channels=channels4, - kernel_size=(3, 3), - padding=(1, 1)) + y3 = relay.nn.conv2d(x, w3) + y4 = relay.nn.conv2d(x, w4) y = relay.Tuple((y1, y2, y3, y4)) return relay.Function(args, y) @@ -30,35 +18,33 @@ def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4): # use a fixed order of args so alpha equal check can pass args = [x, w1, w2, w3, w4] w = relay.concatenate((w1, w2, w4), axis=0) - y = relay.nn.conv2d(x, w, - channels=channels1 + channels2 + channels4, - kernel_size=(3, 3), - padding=(1, 1)) + y = relay.nn.conv2d(x, w, channels=channels1 + channels2 + channels4) y1 = relay.take(y, relay.const(np.arange(channels1, dtype='int64')), axis=1) y2 = relay.take(y, relay.const(np.arange(channels1, channels1 + channels2, dtype='int64')), axis=1) - y3 = relay.nn.conv2d(x, w3, - channels=channels3, - kernel_size=(1, 1), - padding=(1, 1)) + y3 = relay.nn.conv2d(x, w3) y4 = relay.take(y, relay.const(np.arange(channels1 + channels2, channels1 + channels2 + channels4, dtype='int64')), axis=1) y = relay.Tuple((y1, y2, y3, y4)) return relay.Function(args, y) - def check(channels1, channels2, channels3, channels4): - x = relay.var("x") - w1 = relay.var("w1") - w2 = relay.var("w2") - w3 = relay.var("w3") - w4 = relay.var("w4") - - y_before = before(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4) - y = relay.ir_pass.combine_parallel_conv2d(y_before) + def check(x_shape, channels1, channels2, channels3, channels4): + x = relay.var("x", shape=x_shape) + in_c = x_shape[1] + w1 = relay.var("w1", shape=(channels1, in_c, 1, 1)) + w2 = relay.var("w2", shape=(channels2, in_c, 1, 1)) + w3 = relay.var("w3", shape=(channels3, in_c, 3, 3)) + w4 = relay.var("w4", shape=(channels4, in_c, 1, 1)) + + y_before = before(x, w1, w2, w3, w4) + y = relay.ir_pass.infer_type(y_before) + y = relay.ir_pass.combine_parallel_conv2d(y) + y = relay.ir_pass.infer_type(y) y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4) + y_expected = relay.ir_pass.infer_type(y_expected) assert relay.ir_pass.alpha_equal(y, y_expected) - check(4, 4, 4, 4) - check(4, 8, 4, 7) + check((1, 4, 16, 16), 4, 4, 4, 4) + check((1, 4, 16, 16), 4, 8, 4, 7) if __name__ == "__main__":