Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay, TOPI] Support dynamic slicing on first few axes, keeping the rest static #8068

Merged
merged 6 commits into from
May 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -569,16 +569,25 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b
std::string tag = topi::kInjective) {
int64_t src_tensor_dim = x->shape.size();
Array<PrimExpr> out_shape;
for (int64_t i = 0; i < src_tensor_dim; ++i) {
const int64_t num_dynamic_axes = begin->shape[0].as<IntImmNode>()->value;
for (int64_t i = 0; i < num_dynamic_axes; ++i) {
out_shape.push_back(tvm::tir::Var("dim"));
}
for (int64_t i = num_dynamic_axes; i < src_tensor_dim; ++i) {
out_shape.push_back(x->shape[i]);
}
return te::compute(
out_shape,
[&](const Array<tvm::tir::Var>& indices) {
Array<PrimExpr> real_indices;
for (int32_t i = 0; i < src_tensor_dim; ++i) {
// dynamic slicing
for (int32_t i = 0; i < num_dynamic_axes; ++i) {
real_indices.push_back(indices[i] * strides(i) + tvm::min(begin(i), x->shape[i] - 1));
}
// keep input dim
for (int32_t i = num_dynamic_axes; i < src_tensor_dim; ++i) {
real_indices.push_back(indices[i]);
}
return x(real_indices);
},
name, tag);
Expand Down
6 changes: 1 addition & 5 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2703,11 +2703,7 @@ def conditionally_squeeze_scalar(x):
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold
)

three = _op.const(np.array([3]), dtype="int64")
begin = _op.const(np.array([0, 0]), dtype="int64")
end = _op.concatenate([nms_out[1], three], axis=0)
strides = _op.const(np.array([1, 1]), dtype="int64")
return _op.strided_slice(nms_out[0], begin, end, strides)
return _op.strided_slice(nms_out[0], _op.const([0], dtype="int64"), nms_out[1])


class ATen(OnnxOpConverter):
Expand Down
11 changes: 5 additions & 6 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,12 +911,11 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"):
end = const(list(end))
if isinstance(strides, (tuple, list)):
strides = const(list(strides))
begin = _make.where(
begin < cast_like(const(0), begin), begin + cast_like(shape_of(data), begin), begin
)
begin = _make.where(
begin >= cast_like(shape_of(data), begin), cast_like(shape_of(data), begin), begin
)

ishape = cast_like(shape_of(data), begin)
ishape_slice = slice_like(ishape, begin)
begin = _make.where(begin < cast_like(const(0), begin), begin + ishape_slice, begin)
begin = _make.where(begin >= ishape_slice, ishape_slice, begin)
return _dyn_make.strided_slice(data, begin, end, strides, slice_mode)
return _make.strided_slice(data, begin, end, strides, slice_mode)

Expand Down
21 changes: 15 additions & 6 deletions src/relay/op/dyn/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -466,12 +466,20 @@ bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attr
auto dshape = data->shape;
int64_t num_axis = dshape.size();

const auto* begin = types[1].as<TensorTypeNode>();
ICHECK(begin);

// calculate output shape
std::vector<IndexExpr> oshape(num_axis);
for (int64_t i = 0; i < num_axis; ++i) {
int64_t num_dynamic_axes = begin->shape[0].as<IntImmNode>()->value;
for (int64_t i = 0; i < num_dynamic_axes; ++i) {
oshape[i] = Any();
}

for (int64_t i = num_dynamic_axes; i < num_axis; ++i) {
oshape[i] = dshape[i];
}

reporter->Assign(types[4], TensorType(oshape, data->dtype));
return true;
}
Expand All @@ -484,11 +492,12 @@ Array<te::Tensor> StridedSliceCompute(const Attrs& attrs, const Array<te::Tensor
te::Tensor strides = inputs[3];
// Dynamic computation
int64_t data_rank = data->shape.size();
ICHECK(begin->shape[0].as<IntImmNode>()->value == data_rank &&
end->shape[0].as<IntImmNode>()->value == data_rank &&
strides->shape[0].as<IntImmNode>()->value == data_rank)
<< "begin, end, and strides are required to have the same length"
<< " if they are dynamic variables.";
int64_t num_dynamic_axes = begin->shape[0].as<IntImmNode>()->value;
ICHECK(end->shape[0].as<IntImmNode>()->value == num_dynamic_axes &&
strides->shape[0].as<IntImmNode>()->value == num_dynamic_axes)
<< "begin, end, strides should have the same length if they are dynamic variables";
ICHECK(num_dynamic_axes <= data_rank)
<< "the number of dynamic axes to slice should be less than or equal to the data rank";
return Array<te::Tensor>{topi::dynamic_strided_slice(data, begin, end, strides)};
}

Expand Down
46 changes: 27 additions & 19 deletions tests/python/relay/dyn/test_dynamic_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,19 @@

@tvm.testing.uses_gpu
def test_dynamic_strided_slice():
def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, dtype="int32"):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output is not used in the test, so I removed it.

def verify(dshape, begin, end, strides, slice_mode="end", test_ref=True, dtype="int32"):
x = relay.var("x", relay.TensorType(dshape, "float32"))
ndim = len(dshape)
slice_dim = len(begin)
begin = begin if begin else [0] * ndim
end = end if end else list(dshape)
end = end if end else list(dshape)[:slice_dim]
if strides:
if len(strides) == 1:
strides = strides * ndim
strides = strides * slice_dim
else:
strides = [1] * ndim
strides = [1] * slice_dim

num_static_axes = len(dshape) - len(begin)

# target numpy result
x_data = np.random.uniform(size=dshape).astype("float32")
Expand All @@ -54,7 +57,10 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True,
func = relay.Function(inputs, z)

func = run_infer_type(func)
text = func.astext()

if num_static_axes > 0:
oshape = run_infer_type(z).checked_type.shape
assert tuple(oshape[-num_static_axes:]) == dshape[-num_static_axes:]

if not test_ref:
return
Expand All @@ -69,22 +75,24 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True,
[0, 20, 20, 0],
[1, 140, 140, 3],
[1, 1, 1, 1],
(1, 120, 120, 3),
dtype="int64",
)
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3), dtype="int16")
verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2))
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3))
verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3))
verify((3, 4, 3), [1, 1, 0], [4, 4, 4], None, (2, 3, 3))
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3))
verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3))
verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3))
verify((20, 10, 5), [20, 10, 4], [0, 0, 1], [-1, -3, -2], (19, 3, 2))
verify(
(3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False
)
verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], (2, 2, 3), slice_mode="size", test_ref=True)
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], dtype="int16")
verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None)
verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None)
verify((3, 4, 3), [1, 1, 0], [4, 4, 4], None)
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None)
verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1])
verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1])
verify((20, 10, 5), [20, 10, 4], [0, 0, 1], [-1, -3, -2])
verify((3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], slice_mode="size", test_ref=False)
verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], slice_mode="size", test_ref=True)

# Slicing along first few axes, where the rest of axes remain static
verify((3, 4, 3), [0], [2], None)
verify((3, 4, 3), [1], [4], [2])
verify((3, 4, 3), [1, 0], [4, 2], [2, 1])


if __name__ == "__main__":
Expand Down