Skip to content

Commit

Permalink
[Relay, TOPI] Support dynamic slicing on first few axes, keeping the …
Browse files Browse the repository at this point in the history
…rest static (apache#8068)

* Supporting dynamic slice on first few axes

* fix index normalization

* update dynamic slice tests

* pylint fix

* fix loop index dtype

* fix more dtype issue
  • Loading branch information
masahi authored and trevor-m committed Jun 17, 2021
1 parent 43917ff commit 8d6776b
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 38 deletions.
13 changes: 11 additions & 2 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -569,16 +569,25 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b
std::string tag = topi::kInjective) {
int64_t src_tensor_dim = x->shape.size();
Array<PrimExpr> out_shape;
for (int64_t i = 0; i < src_tensor_dim; ++i) {
const int64_t num_dynamic_axes = begin->shape[0].as<IntImmNode>()->value;
for (int64_t i = 0; i < num_dynamic_axes; ++i) {
out_shape.push_back(tvm::tir::Var("dim"));
}
for (int64_t i = num_dynamic_axes; i < src_tensor_dim; ++i) {
out_shape.push_back(x->shape[i]);
}
return te::compute(
out_shape,
[&](const Array<tvm::tir::Var>& indices) {
Array<PrimExpr> real_indices;
for (int32_t i = 0; i < src_tensor_dim; ++i) {
// dynamic slicing
for (int32_t i = 0; i < num_dynamic_axes; ++i) {
real_indices.push_back(indices[i] * strides(i) + tvm::min(begin(i), x->shape[i] - 1));
}
// keep input dim
for (int32_t i = num_dynamic_axes; i < src_tensor_dim; ++i) {
real_indices.push_back(indices[i]);
}
return x(real_indices);
},
name, tag);
Expand Down
6 changes: 1 addition & 5 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2703,11 +2703,7 @@ def conditionally_squeeze_scalar(x):
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold
)

three = _op.const(np.array([3]), dtype="int64")
begin = _op.const(np.array([0, 0]), dtype="int64")
end = _op.concatenate([nms_out[1], three], axis=0)
strides = _op.const(np.array([1, 1]), dtype="int64")
return _op.strided_slice(nms_out[0], begin, end, strides)
return _op.strided_slice(nms_out[0], _op.const([0], dtype="int64"), nms_out[1])


class ATen(OnnxOpConverter):
Expand Down
11 changes: 5 additions & 6 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,12 +950,11 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"):
end = const(list(end))
if isinstance(strides, (tuple, list)):
strides = const(list(strides))
begin = _make.where(
begin < cast_like(const(0), begin), begin + cast_like(shape_of(data), begin), begin
)
begin = _make.where(
begin >= cast_like(shape_of(data), begin), cast_like(shape_of(data), begin), begin
)

ishape = cast_like(shape_of(data), begin)
ishape_slice = slice_like(ishape, begin)
begin = _make.where(begin < cast_like(const(0), begin), begin + ishape_slice, begin)
begin = _make.where(begin >= ishape_slice, ishape_slice, begin)
return _dyn_make.strided_slice(data, begin, end, strides, slice_mode)
return _make.strided_slice(data, begin, end, strides, slice_mode)

Expand Down
21 changes: 15 additions & 6 deletions src/relay/op/dyn/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -466,12 +466,20 @@ bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attr
auto dshape = data->shape;
int64_t num_axis = dshape.size();

const auto* begin = types[1].as<TensorTypeNode>();
ICHECK(begin);

// calculate output shape
std::vector<IndexExpr> oshape(num_axis);
for (int64_t i = 0; i < num_axis; ++i) {
int64_t num_dynamic_axes = begin->shape[0].as<IntImmNode>()->value;
for (int64_t i = 0; i < num_dynamic_axes; ++i) {
oshape[i] = Any();
}

for (int64_t i = num_dynamic_axes; i < num_axis; ++i) {
oshape[i] = dshape[i];
}

reporter->Assign(types[4], TensorType(oshape, data->dtype));
return true;
}
Expand All @@ -484,11 +492,12 @@ Array<te::Tensor> StridedSliceCompute(const Attrs& attrs, const Array<te::Tensor
te::Tensor strides = inputs[3];
// Dynamic computation
int64_t data_rank = data->shape.size();
ICHECK(begin->shape[0].as<IntImmNode>()->value == data_rank &&
end->shape[0].as<IntImmNode>()->value == data_rank &&
strides->shape[0].as<IntImmNode>()->value == data_rank)
<< "begin, end, and strides are required to have the same length"
<< " if they are dynamic variables.";
int64_t num_dynamic_axes = begin->shape[0].as<IntImmNode>()->value;
ICHECK(end->shape[0].as<IntImmNode>()->value == num_dynamic_axes &&
strides->shape[0].as<IntImmNode>()->value == num_dynamic_axes)
<< "begin, end, strides should have the same length if they are dynamic variables";
ICHECK(num_dynamic_axes <= data_rank)
<< "the number of dynamic axes to slice should be less than or equal to the data rank";
return Array<te::Tensor>{topi::dynamic_strided_slice(data, begin, end, strides)};
}

Expand Down
46 changes: 27 additions & 19 deletions tests/python/relay/dyn/test_dynamic_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,19 @@

@tvm.testing.uses_gpu
def test_dynamic_strided_slice():
def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, dtype="int32"):
def verify(dshape, begin, end, strides, slice_mode="end", test_ref=True, dtype="int32"):
x = relay.var("x", relay.TensorType(dshape, "float32"))
ndim = len(dshape)
slice_dim = len(begin)
begin = begin if begin else [0] * ndim
end = end if end else list(dshape)
end = end if end else list(dshape)[:slice_dim]
if strides:
if len(strides) == 1:
strides = strides * ndim
strides = strides * slice_dim
else:
strides = [1] * ndim
strides = [1] * slice_dim

num_static_axes = len(dshape) - len(begin)

# target numpy result
x_data = np.random.uniform(size=dshape).astype("float32")
Expand All @@ -54,7 +57,10 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True,
func = relay.Function(inputs, z)

func = run_infer_type(func)
text = func.astext()

if num_static_axes > 0:
oshape = run_infer_type(z).checked_type.shape
assert tuple(oshape[-num_static_axes:]) == dshape[-num_static_axes:]

if not test_ref:
return
Expand All @@ -69,22 +75,24 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True,
[0, 20, 20, 0],
[1, 140, 140, 3],
[1, 1, 1, 1],
(1, 120, 120, 3),
dtype="int64",
)
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3), dtype="int16")
verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2))
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3))
verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3))
verify((3, 4, 3), [1, 1, 0], [4, 4, 4], None, (2, 3, 3))
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3))
verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3))
verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3))
verify((20, 10, 5), [20, 10, 4], [0, 0, 1], [-1, -3, -2], (19, 3, 2))
verify(
(3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False
)
verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], (2, 2, 3), slice_mode="size", test_ref=True)
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], dtype="int16")
verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None)
verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None)
verify((3, 4, 3), [1, 1, 0], [4, 4, 4], None)
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None)
verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1])
verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1])
verify((20, 10, 5), [20, 10, 4], [0, 0, 1], [-1, -3, -2])
verify((3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], slice_mode="size", test_ref=False)
verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], slice_mode="size", test_ref=True)

# Slicing along first few axes, where the rest of axes remain static
verify((3, 4, 3), [0], [2], None)
verify((3, 4, 3), [1], [4], [2])
verify((3, 4, 3), [1, 0], [4, 2], [2, 1])


if __name__ == "__main__":
Expand Down

0 comments on commit 8d6776b

Please sign in to comment.