diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 379234a5c65a..b2132b75fab9 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -569,16 +569,25 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b std::string tag = topi::kInjective) { int64_t src_tensor_dim = x->shape.size(); Array out_shape; - for (int64_t i = 0; i < src_tensor_dim; ++i) { + const int64_t num_dynamic_axes = begin->shape[0].as()->value; + for (int64_t i = 0; i < num_dynamic_axes; ++i) { out_shape.push_back(tvm::tir::Var("dim")); } + for (int64_t i = num_dynamic_axes; i < src_tensor_dim; ++i) { + out_shape.push_back(x->shape[i]); + } return te::compute( out_shape, [&](const Array& indices) { Array real_indices; - for (int32_t i = 0; i < src_tensor_dim; ++i) { + // dynamic slicing + for (int32_t i = 0; i < num_dynamic_axes; ++i) { real_indices.push_back(indices[i] * strides(i) + tvm::min(begin(i), x->shape[i] - 1)); } + // keep input dim + for (int32_t i = num_dynamic_axes; i < src_tensor_dim; ++i) { + real_indices.push_back(indices[i]); + } return x(real_indices); }, name, tag); diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 2a57cba53cd2..b9fabdebb330 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2703,11 +2703,7 @@ def conditionally_squeeze_scalar(x): boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold ) - three = _op.const(np.array([3]), dtype="int64") - begin = _op.const(np.array([0, 0]), dtype="int64") - end = _op.concatenate([nms_out[1], three], axis=0) - strides = _op.const(np.array([1, 1]), dtype="int64") - return _op.strided_slice(nms_out[0], begin, end, strides) + return _op.strided_slice(nms_out[0], _op.const([0], dtype="int64"), nms_out[1]) class ATen(OnnxOpConverter): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 05963a043f18..2501bef488a8 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -950,12 +950,11 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"): end = const(list(end)) if isinstance(strides, (tuple, list)): strides = const(list(strides)) - begin = _make.where( - begin < cast_like(const(0), begin), begin + cast_like(shape_of(data), begin), begin - ) - begin = _make.where( - begin >= cast_like(shape_of(data), begin), cast_like(shape_of(data), begin), begin - ) + + ishape = cast_like(shape_of(data), begin) + ishape_slice = slice_like(ishape, begin) + begin = _make.where(begin < cast_like(const(0), begin), begin + ishape_slice, begin) + begin = _make.where(begin >= ishape_slice, ishape_slice, begin) return _dyn_make.strided_slice(data, begin, end, strides, slice_mode) return _make.strided_slice(data, begin, end, strides, slice_mode) diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index cf8f3689b045..d8ee1c84a99c 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -466,12 +466,20 @@ bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attr auto dshape = data->shape; int64_t num_axis = dshape.size(); + const auto* begin = types[1].as(); + ICHECK(begin); + // calculate output shape std::vector oshape(num_axis); - for (int64_t i = 0; i < num_axis; ++i) { + int64_t num_dynamic_axes = begin->shape[0].as()->value; + for (int64_t i = 0; i < num_dynamic_axes; ++i) { oshape[i] = Any(); } + for (int64_t i = num_dynamic_axes; i < num_axis; ++i) { + oshape[i] = dshape[i]; + } + reporter->Assign(types[4], TensorType(oshape, data->dtype)); return true; } @@ -484,11 +492,12 @@ Array StridedSliceCompute(const Attrs& attrs, const Arrayshape.size(); - ICHECK(begin->shape[0].as()->value == data_rank && - end->shape[0].as()->value == data_rank && - strides->shape[0].as()->value == data_rank) - << "begin, end, and strides are required to have the same length" - << " if they are dynamic variables."; + int64_t num_dynamic_axes = begin->shape[0].as()->value; + ICHECK(end->shape[0].as()->value == num_dynamic_axes && + strides->shape[0].as()->value == num_dynamic_axes) + << "begin, end, strides should have the same length if they are dynamic variables"; + ICHECK(num_dynamic_axes <= data_rank) + << "the number of dynamic axes to slice should be less than or equal to the data rank"; return Array{topi::dynamic_strided_slice(data, begin, end, strides)}; } diff --git a/tests/python/relay/dyn/test_dynamic_op_level4.py b/tests/python/relay/dyn/test_dynamic_op_level4.py index 43e5beba199f..01e5056c72cb 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level4.py +++ b/tests/python/relay/dyn/test_dynamic_op_level4.py @@ -25,16 +25,19 @@ @tvm.testing.uses_gpu def test_dynamic_strided_slice(): - def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, dtype="int32"): + def verify(dshape, begin, end, strides, slice_mode="end", test_ref=True, dtype="int32"): x = relay.var("x", relay.TensorType(dshape, "float32")) ndim = len(dshape) + slice_dim = len(begin) begin = begin if begin else [0] * ndim - end = end if end else list(dshape) + end = end if end else list(dshape)[:slice_dim] if strides: if len(strides) == 1: - strides = strides * ndim + strides = strides * slice_dim else: - strides = [1] * ndim + strides = [1] * slice_dim + + num_static_axes = len(dshape) - len(begin) # target numpy result x_data = np.random.uniform(size=dshape).astype("float32") @@ -54,7 +57,10 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, func = relay.Function(inputs, z) func = run_infer_type(func) - text = func.astext() + + if num_static_axes > 0: + oshape = run_infer_type(z).checked_type.shape + assert tuple(oshape[-num_static_axes:]) == dshape[-num_static_axes:] if not test_ref: return @@ -69,22 +75,24 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, [0, 20, 20, 0], [1, 140, 140, 3], [1, 1, 1, 1], - (1, 120, 120, 3), dtype="int64", ) - verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3), dtype="int16") - verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2)) - verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3)) - verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3)) - verify((3, 4, 3), [1, 1, 0], [4, 4, 4], None, (2, 3, 3)) - verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3)) - verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) - verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) - verify((20, 10, 5), [20, 10, 4], [0, 0, 1], [-1, -3, -2], (19, 3, 2)) - verify( - (3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False - ) - verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], (2, 2, 3), slice_mode="size", test_ref=True) + verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], dtype="int16") + verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2]) + verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None) + verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None) + verify((3, 4, 3), [1, 1, 0], [4, 4, 4], None) + verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None) + verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1]) + verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1]) + verify((20, 10, 5), [20, 10, 4], [0, 0, 1], [-1, -3, -2]) + verify((3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], slice_mode="size", test_ref=False) + verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], slice_mode="size", test_ref=True) + + # Slicing along first few axes, where the rest of axes remain static + verify((3, 4, 3), [0], [2], None) + verify((3, 4, 3), [1], [4], [2]) + verify((3, 4, 3), [1, 0], [4, 2], [2, 1]) if __name__ == "__main__":