diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 9e66301dfcbb2..fa8a778e80536 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2002,7 +2002,7 @@ Array > StridedSliceInferCorrectLayout( auto layout = old_in_layouts[0]; if (layout.defined() && new_in_layouts.defined()) { - CHECK_EQ(new_in_layouts.size(), 1); + CHECK_GE(new_in_layouts.size(), 1); auto new_layout = new_in_layouts[0]; auto shape = old_in_shapes[0]; @@ -2059,9 +2059,25 @@ Array > StridedSliceInferCorrectLayout( new_end.push_back(tvm::Integer(ed / factor)); } } + layout = new_layout; + + DLContext ctx; + ctx.device_type = kDLCPU; + ctx.device_id = 0; + auto begin_ndarray = runtime::NDArray::Empty({int64_t(new_begin.size())}, Type2TVMType(Int(64)), ctx); + auto end_ndarray = runtime::NDArray::Empty({int64_t(new_begin.size())}, Type2TVMType(Int(64)), ctx); + auto strides_ndarray = runtime::NDArray::Empty({int64_t(new_begin.size())}, Type2TVMType(Int(64)), ctx); + int64_t* begin_data = static_cast(begin_ndarray->data); + int64_t* end_data = static_cast(end_ndarray->data); + for (size_t i = 0; i < new_begin.size(); ++i) { + begin_data[i] = new_begin[i]; + end_data[i] = new_end[i]; + } + params->begin = ConstantNode::make(begin_ndarray); + params->end = ConstantNode::make(end_ndarray); } - return {{layout}, {layout}}; + return {{layout, Layout("C"), Layout("C"), Layout("C")}, {layout}}; } inline Tensor DynamicStridedSlice(const tvm::Tensor& input, diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 8562a1ffa454a..763857fd02d98 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -192,9 +192,9 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { DLContext ctx; ctx.device_type = kDLCPU; ctx.device_id = 0; - auto begin_ndarray = runtime::NDArray::Empty({1}, Type2TVMType(Int(64)), ctx); - auto end_ndarray = runtime::NDArray::Empty({1}, Type2TVMType(Int(64)), ctx); - auto strides_ndarray = runtime::NDArray::Empty({1}, Type2TVMType(Int(64)), ctx); + auto begin_ndarray = runtime::NDArray::Empty({int64_t(begin.size())}, Type2TVMType(Int(64)), ctx); + auto end_ndarray = runtime::NDArray::Empty({int64_t(begin.size())}, Type2TVMType(Int(64)), ctx); + auto strides_ndarray = runtime::NDArray::Empty({int64_t(begin.size())}, Type2TVMType(Int(64)), ctx); int64_t* begin_data = static_cast(begin_ndarray->data); int64_t* end_data = static_cast(end_ndarray->data); diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 2738690025df6..8d93e5327c136 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -528,7 +528,7 @@ def before(): x = relay.var("x", shape=(1, 32, 28, 28)) weight = relay.var('weight', shape=(32, 32, 3, 3)) y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1)) - y = relay.strided_slice(y, begin=[0, 16], end=[None, None]) + y = relay.strided_slice(y, begin=relay.const([0, 16], "int32"), end=relay.const([1, 32], "int32")) y = relay.Function(analysis.free_vars(y), y) return y @@ -546,7 +546,7 @@ def expected(): x = relay.layout_transform(x, "NCHW", "NCHW4c") y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW4c") - y = relay.strided_slice(y, begin=[0, 4], end=[None, 8]) + y = relay.strided_slice(y, begin=relay.const([0, 4], "int32"), end=relay.const([1, 8], "int32")) y = relay.layout_transform(y, "NCHW4c", "NCHW") y = relay.Function(analysis.free_vars(y), y) return y