Skip to content

Commit

Permalink
fix InferCorrectLayout
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed Jan 8, 2020
1 parent daa3c19 commit 01aa0cb
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 8 deletions.
25 changes: 22 additions & 3 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1996,7 +1996,7 @@ bool StridedSliceRel(const Array<Type>& types,
oshape[i] = make_const(dshape[i].dtype(), (slice_range + step - 1) / step);
}
} else {
for (size_t i = 0; i < num_axis; ++i) {
for (int64_t i = 0; i < num_axis; ++i) {
oshape[i] = Any::make();
}
}
Expand All @@ -2017,7 +2017,7 @@ Array<Array<Layout> > StridedSliceInferCorrectLayout(

auto layout = old_in_layouts[0];
if (layout.defined() && new_in_layouts.defined()) {
CHECK_EQ(new_in_layouts.size(), 1);
CHECK_GE(new_in_layouts.size(), 1);
auto new_layout = new_in_layouts[0];
auto shape = old_in_shapes[0];

Expand Down Expand Up @@ -2074,9 +2074,28 @@ Array<Array<Layout> > StridedSliceInferCorrectLayout(
new_end.push_back(tvm::Integer(ed / factor));
}
}

layout = new_layout;

DLContext ctx;
ctx.device_type = kDLCPU;
ctx.device_id = 0;
auto begin_ndarray = runtime::NDArray::Empty({int64_t(new_begin.size())},
DataType::Int(64), ctx);
auto end_ndarray = runtime::NDArray::Empty({int64_t(new_begin.size())},
DataType::Int(64), ctx);
auto strides_ndarray = runtime::NDArray::Empty({int64_t(new_begin.size())},
DataType::Int(64), ctx);
int64_t* begin_data = static_cast<int64_t*>(begin_ndarray->data);
int64_t* end_data = static_cast<int64_t*>(end_ndarray->data);
for (size_t i = 0; i < new_begin.size(); ++i) {
begin_data[i] = new_begin[i];
end_data[i] = new_end[i];
}
params->begin = ConstantNode::make(begin_ndarray);
params->end = ConstantNode::make(end_ndarray);
}
return {{layout}, {layout}};
return {{layout, Layout("C"), Layout("C"), Layout("C")}, {layout}};
}

inline Tensor DynamicStridedSlice(const tvm::Tensor& input,
Expand Down
9 changes: 6 additions & 3 deletions src/relay/pass/combine_parallel_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,12 @@ class ParallelConv2DCombiner : public ParallelOpCombiner {
DLContext ctx;
ctx.device_type = kDLCPU;
ctx.device_id = 0;
auto begin_ndarray = runtime::NDArray::Empty({1}, DataType::Int(64), ctx);
auto end_ndarray = runtime::NDArray::Empty({1}, DataType::Int(64), ctx);
auto strides_ndarray = runtime::NDArray::Empty({1}, DataType::Int(64), ctx);
auto begin_ndarray = runtime::NDArray::Empty({int64_t(begin.size())},
DataType::Int(64), ctx);
auto end_ndarray = runtime::NDArray::Empty({int64_t(begin.size())},
DataType::Int(64), ctx);
auto strides_ndarray = runtime::NDArray::Empty({int64_t(begin.size())},
DataType::Int(64), ctx);
int64_t* begin_data = static_cast<int64_t*>(begin_ndarray->data);
int64_t* end_data = static_cast<int64_t*>(end_ndarray->data);

Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def before():
x = relay.var("x", shape=(1, 32, 28, 28))
weight = relay.var('weight', shape=(32, 32, 3, 3))
y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1))
y = relay.strided_slice(y, begin=[0, 16], end=[None, None])
y = relay.strided_slice(y, begin=relay.const([0, 16], "int32"), end=relay.const([1, 32], "int32"))
y = relay.Function(analysis.free_vars(y), y)
return y

Expand All @@ -584,7 +584,7 @@ def expected():
x = relay.layout_transform(x, "NCHW", "NCHW4c")
y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1),
data_layout="NCHW4c")
y = relay.strided_slice(y, begin=[0, 4], end=[None, 8])
y = relay.strided_slice(y, begin=relay.const([0, 4], "int32"), end=relay.const([1, 8], "int32"))
y = relay.layout_transform(y, "NCHW4c", "NCHW")
y = relay.Function(analysis.free_vars(y), y)
return y
Expand Down

0 comments on commit 01aa0cb

Please sign in to comment.