Skip to content

Commit

Permalink
fix InferCorrectLayout
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed Nov 18, 2019
1 parent bcd0ab1 commit cbd3dbf
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
20 changes: 18 additions & 2 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2002,7 +2002,7 @@ Array<Array<Layout> > StridedSliceInferCorrectLayout(

auto layout = old_in_layouts[0];
if (layout.defined() && new_in_layouts.defined()) {
CHECK_EQ(new_in_layouts.size(), 1);
CHECK_GE(new_in_layouts.size(), 1);
auto new_layout = new_in_layouts[0];
auto shape = old_in_shapes[0];

Expand Down Expand Up @@ -2059,9 +2059,25 @@ Array<Array<Layout> > StridedSliceInferCorrectLayout(
new_end.push_back(tvm::Integer(ed / factor));
}
}

layout = new_layout;

DLContext ctx;
ctx.device_type = kDLCPU;
ctx.device_id = 0;
auto begin_ndarray = runtime::NDArray::Empty({int64_t(new_begin.size())}, Type2TVMType(Int(64)), ctx);
auto end_ndarray = runtime::NDArray::Empty({int64_t(new_begin.size())}, Type2TVMType(Int(64)), ctx);
auto strides_ndarray = runtime::NDArray::Empty({int64_t(new_begin.size())}, Type2TVMType(Int(64)), ctx);
int64_t* begin_data = static_cast<int64_t*>(begin_ndarray->data);
int64_t* end_data = static_cast<int64_t*>(end_ndarray->data);
for (size_t i = 0; i < new_begin.size(); ++i) {
begin_data[i] = new_begin[i];
end_data[i] = new_end[i];
}
params->begin = ConstantNode::make(begin_ndarray);
params->end = ConstantNode::make(end_ndarray);
}
return {{layout}, {layout}};
return {{layout, Layout("C"), Layout("C"), Layout("C")}, {layout}};
}

inline Tensor DynamicStridedSlice(const tvm::Tensor& input,
Expand Down
6 changes: 3 additions & 3 deletions src/relay/pass/combine_parallel_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,9 @@ class ParallelConv2DCombiner : public ParallelOpCombiner {
DLContext ctx;
ctx.device_type = kDLCPU;
ctx.device_id = 0;
auto begin_ndarray = runtime::NDArray::Empty({1}, Type2TVMType(Int(64)), ctx);
auto end_ndarray = runtime::NDArray::Empty({1}, Type2TVMType(Int(64)), ctx);
auto strides_ndarray = runtime::NDArray::Empty({1}, Type2TVMType(Int(64)), ctx);
auto begin_ndarray = runtime::NDArray::Empty({int64_t(begin.size())}, Type2TVMType(Int(64)), ctx);
auto end_ndarray = runtime::NDArray::Empty({int64_t(begin.size())}, Type2TVMType(Int(64)), ctx);
auto strides_ndarray = runtime::NDArray::Empty({int64_t(begin.size())}, Type2TVMType(Int(64)), ctx);
int64_t* begin_data = static_cast<int64_t*>(begin_ndarray->data);
int64_t* end_data = static_cast<int64_t*>(end_ndarray->data);

Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def before():
x = relay.var("x", shape=(1, 32, 28, 28))
weight = relay.var('weight', shape=(32, 32, 3, 3))
y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1))
y = relay.strided_slice(y, begin=[0, 16], end=[None, None])
y = relay.strided_slice(y, begin=relay.const([0, 16], "int32"), end=relay.const([1, 32], "int32"))
y = relay.Function(analysis.free_vars(y), y)
return y

Expand All @@ -546,7 +546,7 @@ def expected():
x = relay.layout_transform(x, "NCHW", "NCHW4c")
y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1),
data_layout="NCHW4c")
y = relay.strided_slice(y, begin=[0, 4], end=[None, 8])
y = relay.strided_slice(y, begin=relay.const([0, 4], "int32"), end=relay.const([1, 8], "int32"))
y = relay.layout_transform(y, "NCHW4c", "NCHW")
y = relay.Function(analysis.free_vars(y), y)
return y
Expand Down

0 comments on commit cbd3dbf

Please sign in to comment.