Skip to content

Commit

Permalink
Get channels info from type instead of attrs
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Nov 13, 2018
1 parent f4b859f commit 03891e8
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 62 deletions.
54 changes: 28 additions & 26 deletions src/relay/pass/combine_parallel_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ std::tuple<Expr, IndexExpr> TransformWeight(std::vector<const CallNode*> convolu
Array<Expr> weights;
for (const CallNode* n : convolutions) {
weights.push_back(n->args[1]);
auto channels = as_const_int(n->attrs.as<Conv2DAttrs>()->channels);
CHECK(channels);
num_filters += *channels;
auto channels = GetConv2DSuperChannelsDim(n);
num_filters += channels;
}
auto index = convolutions[0]->attrs.as<Conv2DAttrs>()->weight_layout.find('O');
CHECK_NE(index, std::string::npos);
Expand All @@ -60,17 +59,26 @@ std::tuple<Expr, IndexExpr> TransformWeight(std::vector<const CallNode*> convolu

// Two 2d convolutions can be combined if they have the same attributes or only have
// different output channels.
bool IsCompatibleConv2D(const Conv2DAttrs& a, const Conv2DAttrs& b) {
bool IsCompatibleConv2D(const CallNode* a, const CallNode* b) {
AttrsEqual eq;
return eq(a.strides, b.strides) &&
eq(a.padding, b.padding) &&
eq(a.dilation, b.dilation) &&
eq(a.groups, b.groups) &&
eq(a.kernel_size, b.kernel_size) &&
eq(a.data_layout, b.data_layout) &&
eq(a.weight_layout, b.weight_layout) &&
eq(a.out_dtype, b.out_dtype) &&
eq(a.out_layout, b.out_layout);
static const Layout kOIHW("OIHW");
auto attrs_a = a->attrs.as<Conv2DAttrs>();
auto attrs_b = b->attrs.as<Conv2DAttrs>();
auto tweight_a = a->args[1]->type_as<TensorTypeNode>();
auto tweight_b = b->args[1]->type_as<TensorTypeNode>();
auto shape_a = ConvertLayout(tweight_a->shape, attrs_a->weight_layout, kOIHW);
auto shape_b = ConvertLayout(tweight_b->shape, attrs_b->weight_layout, kOIHW);

return eq(attrs_a->strides, attrs_b->strides) &&
eq(attrs_a->padding, attrs_b->padding) &&
eq(attrs_a->dilation, attrs_b->dilation) &&
eq(attrs_a->groups, attrs_b->groups) &&
eq(attrs_a->data_layout, attrs_b->data_layout) &&
eq(attrs_a->weight_layout, attrs_b->weight_layout) &&
eq(attrs_a->out_dtype, attrs_b->out_dtype) &&
eq(attrs_a->out_layout, attrs_b->out_layout) &&
eq(shape_a[2], shape_b[2]) &&
eq(shape_a[3], shape_b[3]);
}

Expr MakeCombinedConv2D(const Expr& data, const std::vector<const CallNode*>& convolutions) {
Expand Down Expand Up @@ -113,14 +121,11 @@ Expr CombineParallelConv2D(const Expr& expr) {

for (size_t i = 0; i < children.size(); i++) {
const CallNode* n = children[i];
auto args = n->attrs.as<Conv2DAttrs>();

// assign a group id or create a new group for each conv2d
auto it = std::find_if(groups.begin(), groups.end(),
[&](const std::vector<const CallNode*>& group) {
const CallNode* group_root = *(group.begin());
auto group_args = group_root->attrs.as<Conv2DAttrs>();
return IsCompatibleConv2D(*args, *group_args);
return IsCompatibleConv2D(n, group[0]);
});

if (it != groups.end()) {
Expand All @@ -134,22 +139,19 @@ Expr CombineParallelConv2D(const Expr& expr) {
}

for (const auto& convs : groups) {
if (convs.size() < 2) {
continue;
}
auto new_conv2d = MakeCombinedConv2D(data, convs);
if (convs.size() < 2) continue;

auto new_conv2d = MakeCombinedConv2D(data, convs);
int64_t start = 0;
// replace original conv2d with slice of output of the new conv2d
for (const auto& conv2d : convs) {
for (const CallNode* conv2d : convs) {
auto params = conv2d->attrs.as<Conv2DAttrs>();
auto channels = as_const_int(params->channels);
CHECK(channels);
auto indices = MakeConstantArrayFromRange(Int(64), start, start + *channels);
auto channels = GetConv2DSuperChannelsDim(conv2d);
auto indices = MakeConstantArrayFromRange(Int(64), start, start + channels);
auto channel_index = params->data_layout.find('C');
CHECK_NE(channel_index, std::string::npos);
auto take = MakeTake(new_conv2d, indices, channel_index);
start += *channels;
start += channels;
subst_map[GetRef<Call>(conv2d)] = take;
}
}
Expand Down
13 changes: 13 additions & 0 deletions src/relay/pass/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,19 @@ inline bool IsDepthwiseConv2D(const Call& call,
is_const_int(wshape[1], 1);
}

/*!
* \brief Get super-dimension of output channels of conv2d
* \param call The conv2d call.
* \return Super-dimension size of output channels of conv2d.
*/
inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) {
auto param = call->attrs.as<Conv2DAttrs>();
auto tweight = call->args[1]->type_as<TensorTypeNode>();
auto index = param->weight_layout.find('O');
CHECK_NE(index, std::string::npos);
auto channels = as_const_int(tweight->shape[index]);
return *channels;
}

/*!
* \brief Create a Constant with a scalar
Expand Down
58 changes: 22 additions & 36 deletions tests/python/relay/test_pass_combine_parallel_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,61 +4,47 @@

def test_combine_parallel_conv2d():
"""Simple testcase."""
def before(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4):
def before(x, w1, w2, w3, w4):
args = [x, w1, w2, w3, w4]
y1 = relay.nn.conv2d(x, w1,
channels=channels1,
kernel_size=(3, 3),
padding=(1, 1))
y2 = relay.nn.conv2d(x, w2,
channels=channels2,
kernel_size=(3, 3),
padding=(1, 1))
y1 = relay.nn.conv2d(x, w1)
y2 = relay.nn.conv2d(x, w2)
# y3 is not foldable
y3 = relay.nn.conv2d(x, w3,
channels=channels3,
kernel_size=(1, 1),
padding=(1, 1))
y4 = relay.nn.conv2d(x, w4,
channels=channels4,
kernel_size=(3, 3),
padding=(1, 1))
y3 = relay.nn.conv2d(x, w3)
y4 = relay.nn.conv2d(x, w4)
y = relay.Tuple((y1, y2, y3, y4))
return relay.Function(args, y)

def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4):
# use a fixed order of args so alpha equal check can pass
args = [x, w1, w2, w3, w4]
w = relay.concatenate((w1, w2, w4), axis=0)
y = relay.nn.conv2d(x, w,
channels=channels1 + channels2 + channels4,
kernel_size=(3, 3),
padding=(1, 1))
y = relay.nn.conv2d(x, w, channels=channels1 + channels2 + channels4)
y1 = relay.take(y, relay.const(np.arange(channels1, dtype='int64')), axis=1)
y2 = relay.take(y, relay.const(np.arange(channels1, channels1 + channels2, dtype='int64')), axis=1)
y3 = relay.nn.conv2d(x, w3,
channels=channels3,
kernel_size=(1, 1),
padding=(1, 1))
y3 = relay.nn.conv2d(x, w3)
y4 = relay.take(y, relay.const(np.arange(channels1 + channels2,
channels1 + channels2 + channels4, dtype='int64')), axis=1)
y = relay.Tuple((y1, y2, y3, y4))
return relay.Function(args, y)

def check(channels1, channels2, channels3, channels4):
x = relay.var("x")
w1 = relay.var("w1")
w2 = relay.var("w2")
w3 = relay.var("w3")
w4 = relay.var("w4")

y_before = before(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4)
y = relay.ir_pass.combine_parallel_conv2d(y_before)
def check(x_shape, channels1, channels2, channels3, channels4):
x = relay.var("x", shape=x_shape)
in_c = x_shape[1]
w1 = relay.var("w1", shape=(channels1, in_c, 1, 1))
w2 = relay.var("w2", shape=(channels2, in_c, 1, 1))
w3 = relay.var("w3", shape=(channels3, in_c, 3, 3))
w4 = relay.var("w4", shape=(channels4, in_c, 1, 1))

y_before = before(x, w1, w2, w3, w4)
y = relay.ir_pass.infer_type(y_before)
y = relay.ir_pass.combine_parallel_conv2d(y)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4)
y_expected = relay.ir_pass.infer_type(y_expected)
assert relay.ir_pass.alpha_equal(y, y_expected)

check(4, 4, 4, 4)
check(4, 8, 4, 7)
check((1, 4, 16, 16), 4, 4, 4, 4)
check((1, 4, 16, 16), 4, 8, 4, 7)


if __name__ == "__main__":
Expand Down

0 comments on commit 03891e8

Please sign in to comment.