Skip to content

Commit

Permalink
Replace take with strided_slice
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Nov 14, 2018
1 parent a273b0e commit d0599e8
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 28 deletions.
24 changes: 16 additions & 8 deletions src/relay/pass/combine_parallel_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,25 @@ Expr CombineParallelConv2D(const Expr& expr) {
if (convs.size() < 2) continue;

auto new_conv2d = MakeCombinedConv2D(data, convs);
int64_t start = 0;
int64_t index = 0;
// replace original conv2d with slice of output of the new conv2d
for (const CallNode* conv2d : convs) {
auto params = conv2d->attrs.as<Conv2DAttrs>();
auto channels = GetConv2DSuperChannelsDim(conv2d);
auto indices = MakeConstantArrayFromRange(Int(64), start, start + channels);
auto channel_index = params->data_layout.find('C');
CHECK_NE(channel_index, std::string::npos);
auto take = MakeTake(new_conv2d, indices, channel_index);
start += channels;
subst_map[GetRef<Call>(conv2d)] = take;
int64_t channels = GetConv2DSuperChannelsDim(conv2d);
size_t channel_pos = params->data_layout.find('C');
CHECK_NE(channel_pos, std::string::npos);
Array<Integer> begin;
Array<Integer> end;
for (int64_t i = 0; i < static_cast<int64_t>(channel_pos); i++) {
begin.push_back(i);
end.push_back(NullValue<Integer>());
}
begin.push_back(index);
index += channels;
end.push_back(index);
auto slice = MakeStridedSlice(new_conv2d, std::move(begin), std::move(end),
Array<Integer>{});
subst_map[GetRef<Call>(conv2d)] = slice;
}
}
}
Expand Down
16 changes: 1 addition & 15 deletions src/relay/pass/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,20 +150,6 @@ inline Constant MakeConstantScalar(DataType dtype, T value) {
return ConstantNode::make(arr);
}

template<typename T, typename = typename std::enable_if<std::is_integral<T>::value>::type>
inline Constant MakeConstantArrayFromRange(DataType dtype, T start, T end, T step = 1) {
CHECK_EQ(sizeof(T) * 8, dtype.bits()) << "data type mismatch";
CHECK(step);
CHECK_GE((end - start) / step, 0);
runtime::NDArray arr = runtime::NDArray::Empty({(int64_t)(end - start) / step},
Type2TVMType(dtype), {kDLCPU, 0});
for (auto *data = static_cast<T*>(arr->data); (step > 0) ? (start < end) : (start > end);
start += step, data++) {
*data = start;
}
return ConstantNode::make(arr);
}


inline Expr Negative(Expr x) {
static const Op& op = Op::Get("negative");
Expand Down Expand Up @@ -202,7 +188,7 @@ inline Expr ReshapeLike(Expr lhs, Expr rhs) {

Expr MakeConcatenate(Expr data, int axis);

Expr MakeTake(Expr data, Expr indices, Integer axis);
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);

} // namespace relay
} // namespace tvm
Expand Down
10 changes: 5 additions & 5 deletions tests/python/relay/test_pass_combine_parallel_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def before(x, w1, w2, w3, w4):
args = [x, w1, w2, w3, w4]
y1 = relay.nn.conv2d(x, w1)
y2 = relay.nn.conv2d(x, w2)
# y3 is not foldable
# y3 cannot be combined
y3 = relay.nn.conv2d(x, w3)
y4 = relay.nn.conv2d(x, w4)
y = relay.Tuple((y1, y2, y3, y4))
Expand All @@ -19,11 +19,11 @@ def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4):
args = [x, w1, w2, w3, w4]
w = relay.concatenate((w1, w2, w4), axis=0)
y = relay.nn.conv2d(x, w, channels=channels1 + channels2 + channels4)
y1 = relay.take(y, relay.const(np.arange(channels1, dtype='int64')), axis=1)
y2 = relay.take(y, relay.const(np.arange(channels1, channels1 + channels2, dtype='int64')), axis=1)
y1 = relay.strided_slice(y, [0, 0], [None, channels1])
y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2])
y3 = relay.nn.conv2d(x, w3)
y4 = relay.take(y, relay.const(np.arange(channels1 + channels2,
channels1 + channels2 + channels4, dtype='int64')), axis=1)
y4 = relay.strided_slice(y, [0, channels1 + channels2],
[None, channels1 + channels2 + channels4])
y = relay.Tuple((y1, y2, y3, y4))
return relay.Function(args, y)

Expand Down

0 comments on commit d0599e8

Please sign in to comment.