From 3ad3d62e53a16cac1555eb207540e04a410edb3e Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Mon, 18 Nov 2019 14:37:09 -0800 Subject: [PATCH] fix InferCorrectLayout --- src/relay/op/tensor/transform.cc | 23 +++++++++++++++++-- src/relay/pass/combine_parallel_conv2d.cc | 9 +++++--- .../python/relay/test_pass_alter_op_layout.py | 4 ++-- 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index bb31a822710af..82aeeadfa42ee 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2017,7 +2017,7 @@ Array > StridedSliceInferCorrectLayout( auto layout = old_in_layouts[0]; if (layout.defined() && new_in_layouts.defined()) { - CHECK_EQ(new_in_layouts.size(), 1); + CHECK_GE(new_in_layouts.size(), 1); auto new_layout = new_in_layouts[0]; auto shape = old_in_shapes[0]; @@ -2074,9 +2074,28 @@ Array > StridedSliceInferCorrectLayout( new_end.push_back(tvm::Integer(ed / factor)); } } + layout = new_layout; + + DLContext ctx; + ctx.device_type = kDLCPU; + ctx.device_id = 0; + auto begin_ndarray = runtime::NDArray::Empty({int64_t(new_begin.size())}, + DataType::Int(64), ctx); + auto end_ndarray = runtime::NDArray::Empty({int64_t(new_begin.size())}, + DataType::Int(64), ctx); + auto strides_ndarray = runtime::NDArray::Empty({int64_t(new_begin.size())}, + DataType::Int(64), ctx); + int64_t* begin_data = static_cast(begin_ndarray->data); + int64_t* end_data = static_cast(end_ndarray->data); + for (size_t i = 0; i < new_begin.size(); ++i) { + begin_data[i] = new_begin[i]; + end_data[i] = new_end[i]; + } + params->begin = ConstantNode::make(begin_ndarray); + params->end = ConstantNode::make(end_ndarray); } - return {{layout}, {layout}}; + return {{layout, Layout("C"), Layout("C"), Layout("C")}, {layout}}; } inline Tensor DynamicStridedSlice(const tvm::Tensor& input, diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 0eec8a576639f..e092216f0822d 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -191,9 +191,12 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { DLContext ctx; ctx.device_type = kDLCPU; ctx.device_id = 0; - auto begin_ndarray = runtime::NDArray::Empty({1}, DataType::Int(64), ctx); - auto end_ndarray = runtime::NDArray::Empty({1}, DataType::Int(64), ctx); - auto strides_ndarray = runtime::NDArray::Empty({1}, DataType::Int(64), ctx); + auto begin_ndarray = runtime::NDArray::Empty({int64_t(begin.size())}, + DataType::Int(64), ctx); + auto end_ndarray = runtime::NDArray::Empty({int64_t(begin.size())}, + DataType::Int(64), ctx); + auto strides_ndarray = runtime::NDArray::Empty({int64_t(begin.size())}, + DataType::Int(64), ctx); int64_t* begin_data = static_cast(begin_ndarray->data); int64_t* end_data = static_cast(end_ndarray->data); diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 3f02e1db625e2..235b216a950e8 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -568,7 +568,7 @@ def before(): x = relay.var("x", shape=(1, 32, 28, 28)) weight = relay.var('weight', shape=(32, 32, 3, 3)) y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1)) - y = relay.strided_slice(y, begin=[0, 16], end=[None, None]) + y = relay.strided_slice(y, begin=relay.const([0, 16], "int32"), end=relay.const([1, 32], "int32")) y = relay.Function(analysis.free_vars(y), y) return y @@ -584,7 +584,7 @@ def expected(): x = relay.layout_transform(x, "NCHW", "NCHW4c") y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW4c") - y = relay.strided_slice(y, begin=[0, 4], end=[None, 8]) + y = relay.strided_slice(y, begin=relay.const([0, 4], "int32"), end=relay.const([1, 8], "int32")) y = relay.layout_transform(y, "NCHW4c", "NCHW") y = relay.Function(analysis.free_vars(y), y) return y