Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI][Relay][OP] Add a strided_set operation. #4303

Merged
merged 10 commits into from
Dec 3, 2019
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
_reg.register_schedule("cast_like", schedule_injective)
_reg.register_schedule("reinterpret", schedule_injective)
_reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("strided_set", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective)
_reg.register_schedule("split", schedule_injective)
_reg.register_schedule("take", schedule_injective)
Expand Down Expand Up @@ -304,6 +305,28 @@ def compute_argwhere(attrs, inputs, output_type, _):
new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
return [topi.argwhere(new_output_type, inputs[0])]

@_reg.register_compute("strided_set")
def compute_strided_set(attrs, inputs, output_type, _):
"""Compute definition of strided_set"""
begin = attrs.begin
end = attrs.end
strides = attrs.strides
n = len(inputs[0].shape)
strides = list(strides) + [1] * (n - len(strides))
lb = len(begin)
if lb < n:
begin = list(begin)
for i in range(lb, n):
begin.append(0 if strides[i] >= 0 else inputs.shape[i])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit different because the value to append might be different at each index. It depends on the stride for that index. I'm not sure how to shorten that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s fine, no worries :-)

le = len(end)
if le < n:
end = list(end)
for i in range(le, n):
lim = inputs[0].shape[i] + 1
end.append(lim if strides[i] >= 0 else -lim)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as for begin.


return [topi.strided_set(inputs[0], inputs[1], begin, end, strides)]

@script
def _layout_transform_shape_func(data_shape,
out_layout_len,
Expand Down
30 changes: 30 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,36 @@ def strided_slice(data, begin, end, strides=None):
return _make.strided_slice(data, list(begin), list(end), list(strides))


def strided_set(data, v, begin, end, strides=None):
"""Strided set of an array.

Parameters
----------
data : relay.Expr
The source array to be sliced.

v : relay.Expr
The data to be set.

begin: list of int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding this op. Currently @yongwww is modifying strided_slice(#4312) to support begin, end and strides to be expression instead of just list of int. The reason is that in some DL frameworks, begin, end or strides can be a tensor. Also making it more dynamic can help us when building other ops, such as NMS. Considering this op is similar to strided_slice, should we keep it align with stride_slice, and allow begin, end and strides to be Expr?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could be done, yes. The underlying TOPI op does support TVM expressions so it shouldn't be too hard to do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now done.

The indices to begin with in the slicing.

end: list of int
Indices indicating end of the slice.

strides: list of int, optional
Specifies the stride values, it can be negative in that case,
the input tensor will be reversed in that particular axis.

Returns
-------
ret : relay.Expr
The computed result.
"""
strides = strides or []
return _make.strided_set(data, v, list(begin), list(end), list(strides))


def slice_like(data, shape_like, axes=None):
"""Slice the first input with respect to the second input.

Expand Down
50 changes: 50 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2049,6 +2049,56 @@ Examples::
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", StridedSliceInferCorrectLayout);

// strided_set
bool StridedSetRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
reporter->Assign(types[2], types[0]);
return true;
}

Expr MakeStridedSet(Expr data,
Expr v,
Array<Integer> begin,
Array<Integer> end,
Array<Integer> strides) {
auto attrs = make_node<StridedSliceAttrs>();
attrs->begin = std::move(begin);
attrs->end = std::move(end);
attrs->strides = std::move(strides);
static const Op& op = Op::Get("strided_set");
return CallNode::make(op, {data, v}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op._make.strided_set")
.set_body_typed(MakeStridedSet);


RELAY_REGISTER_OP("strided_set")
.describe(R"code(Strided set of an array.
Example::

x = [[ 1., 4., 7., 10.],
[ 2., 5., 8., 11.],
[ 3., 6., 9., 12.]]

v = [[ 11., 22., 33.]
[ 44., 55., 66.]]

strided_set(x, v, begin=[0, 1], end=[2, 4], stride=[1, 1]) = \
[[ 1., 11., 22., 33.],
[ 2., 44., 55., 66.],
[ 3., 6., 9., 12.]]
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("v", "Tensor", "The data to set.")
.set_support_level(4)
.set_attrs_type<StridedSliceAttrs>()
.set_attr<TOpPattern>("TOpPattern", kInjective)
.add_type_rel("StridedSet", StridedSetRel);

// relay.split
TVM_REGISTER_NODE_TYPE(SplitAttrs);
Expand Down
37 changes: 37 additions & 0 deletions tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,45 @@ def verify(dshape, begin, end, strides, output, test_ref=True):
verify((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3))


def test_strided_set():
def verify(dshape, begin, end, strides, vshape, test_ref=True):
x = relay.var("x", relay.TensorType(dshape, "float32"))
v = relay.var("v", relay.TensorType(vshape, "float32"))
z = relay.strided_set(x, v, begin=begin, end=end, strides=strides)
func = relay.Function([x, v], z)
func = run_infer_type(func)
text = func.astext()
assert "begin=" in text
assert "end=" in text
print(text)
assert func.body.checked_type == relay.ty.TensorType(dshape, "float32")
if not test_ref:
return
x_data = np.random.uniform(size=dshape).astype("float32")
v_data = np.random.uniform(size=vshape).astype("float32")
ref_res = topi.testing.strided_set_python(
x_data, v_data, begin, end, strides)
for target, ctx in ctx_list():
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, v_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)

d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
verify((d1, d2, 3), [None, None, 1], [None, None, 2], None, (d1, d2, 1), False)
verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2))
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3))
verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3))
verify((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2], (1, 2, 2))
verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3))
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3))
verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3))
verify((3, 4, 3), [1, 1, 0], [4, 4], None, (2, 3, 3))
verify((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3))


if __name__ == "__main__":
test_strided_slice()
test_strided_set()
test_binary_op()
test_cmp_type()
test_binary_int_broadcast()
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .lrn_python import lrn_python
from .l2_normalize_python import l2_normalize_python
from .gather_nd_python import gather_nd_python
from .strided_slice_python import strided_slice_python
from .strided_slice_python import strided_slice_python, strided_set_python
from .batch_matmul import batch_matmul
from .slice_axis_python import slice_axis_python
from .sequence_mask_python import sequence_mask
Expand Down
40 changes: 39 additions & 1 deletion topi/python/topi/testing/strided_slice_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""gather_nd in python"""
"""strided_slice/set in python"""


def strided_slice_python(data, begin, end, strides):
"""Python version of strided slice operator.
Expand Down Expand Up @@ -46,3 +47,40 @@ def strided_slice_python(data, begin, end, strides):
end[i] if i < len(end) else None,
strides[i] if i < len(strides) else None))
return data[tuple(slices)]


def strided_set_python(data, v, begin, end, strides):
"""Python version of strided slice operator.

Parameters
----------
data : numpy.ndarray
Input data

v : numpy.ndarray
Value data

begin : list
Begining of the slices.

end : list
End of the slices.

strides : list
The stride of each slice.

Returns
-------
result : numpy.ndarray
The updated result.
"""
strides = [] if strides is None else strides
slices = []
res = data.copy()
for i in range(len(data.shape)):
slices.append(slice(
begin[i] if i < len(begin) else None,
end[i] if i < len(end) else None,
strides[i] if i < len(strides) else None))
res[tuple(slices)] = v
return res
70 changes: 70 additions & 0 deletions topi/python/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import tvm
import topi
from . import cpp
from . import tag
from .util import within_index, make_idx


def expand_dims(a, axis, num_newaxis=1):
Expand Down Expand Up @@ -155,6 +157,74 @@ def strided_slice(a, begin, end, strides=None):
strides = []
return cpp.strided_slice(a, begin, end, strides)

@tvm.tag_scope(tag=tag.INJECTIVE+",strided_set")
def strided_set(a, v, begin, end, strides=None):
"""Set slice of an array.

Parameters
----------
a : tvm.Tensor
The tensor to be sliced.

v : tvm.Tensor
The values to set

begin: list of Expr
The indices to begin with in the slicing.

end: list of Expr
Indicies indicating end of the slice.

strides: list of Expr, optional
Specifies the stride values, it can be negative
in that case, the input tensor will be reversed
in that particular axis.

Returns
-------
ret : tvm.Tensor
"""
n = len(a.shape)
if strides is None:
strides = [1] * n

if len(begin) != n:
raise ValueError("size mismatch")
if len(end) != n:
raise ValueError("size mismatch")
if len(strides) != n:
raise ValueError("size mismatch")

begin = list(map(tvm.convert, begin))
end = list(map(tvm.convert, end))
strides = list(map(tvm.convert, strides))

def _max(a, b):
return tvm.expr.Select(a > b, a, b)

# Convert negative indexes
for i in range(n):
begin[i] = tvm.if_then_else(begin[i] < 0,
begin[i] + a.shape[i],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even begin[i] + a.shape[i] could trigger OOB. I am not sure how to assert the bound...Could anyone help here?

Copy link
Contributor Author

@abergeron abergeron Nov 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By OOB do you mean out of bounds of the indexed array? If yes, that is not a problem because the code will never try to fetch indices that are out the array shape.

This code just tries to handle numpy-style negative indexing (starts from the end of the array) just like strided_slice does.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I don’t have much idea about this either

Copy link
Contributor Author

@abergeron abergeron Nov 17, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if the logical result is very similar to strided_slice the implementation has a difference.

In this operation: a[begin:end:stride] = b

The core kernel loops over all valid indexes for a and check if that index is part of the values selected by the combination of begin, end and stride. If it is, it will compute the corresponding index in b and map the output to that value. Otherwise it will pick up the value from a at that index.

In all cases it doesn't matter if begin, end, or stride doesn't fall within the bounds of a because they are never used to directly or indirectly index into a.

begin[i])
end[i] = tvm.if_then_else(end[i] < 0,
end[i] + a.shape[i],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. OOB is not completely detected.

end[i])

def _select(*indices):
from_val = []
index_tuple = []
for i in range(n):
from_val.append(
within_index(begin[i], end[i], strides[i], indices[i]))
index_tuple.append(
make_idx(begin[i], end[i], strides[i], a.shape[i], indices[i]))
return tvm.if_then_else(tvm.all(*from_val),
v(*index_tuple),
a(*indices))

return tvm.compute(a.shape, _select, name="strided_set")


def reshape(a, newshape):
"""Reshape the array
Expand Down
72 changes: 72 additions & 0 deletions topi/python/topi/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,75 @@ def get_shape(src_shape, src_layout, dst_layout):
tvm.convert([i for i in range(len(src_layout))]))

return get_const_tuple(tuple([src_shape[i.value] for i in dst_indices]))


def within_index(b, e, s, i):
"""Return a boolean value that indicates if i is within the given index.

Parameter
---------
b : Expr
beginning of the index

e : Expr
end of the index

s : Expr
strides of index

i : Expr
array position

Returns
-------
selected: Expr
bool expression that is True is the array position would be selected
by the index and False otherwise
"""
bc = tvm.expr.Select(s < 0, i <= e, i < b)
ec = tvm.expr.Select(s < 0, i > b, i >= e)
ss = tvm.if_then_else(s < 0,
((i - e) + (e % tvm.abs(s)) + 1) % tvm.abs(s),
(i - b) % s)
return tvm.expr.Select(tvm.expr.Or(bc, ec), tvm.const(False), ss.equal(0))


def make_idx(b, e, s, z, i):
"""Return the array position in the selection that corresponds to an
array position in the full array.

The returned value is only meaningful if within_index() returns True
for the same set of parameters.

Parameter
---------
b : Expr
beginning of the index

e : Expr
end of the index

s : Expr
strides of index

z : Expr
size of the indexed dimension

i : Expr
array position

Returns
-------
postion: Expr
int expression that corresponds to an array position in the selection.
"""
bc = tvm.expr.Select(s < 0, i <= e, i < b)
ec = tvm.expr.Select(s < 0, i > b, i >= e)

# Clamp to array size
b = tvm.expr.Select(z < b, z - 1, b)

ss = tvm.if_then_else(s < 0,
(b - i) // tvm.abs(s),
(i - b) // s)
return tvm.if_then_else(tvm.expr.Or(bc, ec), 88, ss)
Loading