-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOPI][Relay][OP] Add a strided_set operation. #4303
Changes from 7 commits
4e83927
e2577f1
c4fc34e
cc08d52
721086d
ba33a9b
3bc2bf5
77f8462
d5b23fe
3b5f0c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,6 +48,7 @@ | |
_reg.register_schedule("cast_like", schedule_injective) | ||
_reg.register_schedule("reinterpret", schedule_injective) | ||
_reg.register_schedule("strided_slice", schedule_injective) | ||
_reg.register_schedule("strided_set", schedule_injective) | ||
_reg.register_schedule("slice_like", schedule_injective) | ||
_reg.register_schedule("split", schedule_injective) | ||
_reg.register_schedule("take", schedule_injective) | ||
|
@@ -304,6 +305,28 @@ def compute_argwhere(attrs, inputs, output_type, _): | |
new_output_type = tvm.relay.ty.TensorType(output_shape, "int32") | ||
return [topi.argwhere(new_output_type, inputs[0])] | ||
|
||
@_reg.register_compute("strided_set") | ||
def compute_strided_set(attrs, inputs, output_type, _): | ||
"""Compute definition of strided_set""" | ||
begin = attrs.begin | ||
end = attrs.end | ||
strides = attrs.strides | ||
n = len(inputs[0].shape) | ||
strides = list(strides) + [1] * (n - len(strides)) | ||
lb = len(begin) | ||
if lb < n: | ||
begin = list(begin) | ||
for i in range(lb, n): | ||
begin.append(0 if strides[i] >= 0 else inputs.shape[i]) | ||
le = len(end) | ||
if le < n: | ||
end = list(end) | ||
for i in range(le, n): | ||
lim = inputs[0].shape[i] + 1 | ||
end.append(lim if strides[i] >= 0 else -lim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as for begin. |
||
|
||
return [topi.strided_set(inputs[0], inputs[1], begin, end, strides)] | ||
|
||
@script | ||
def _layout_transform_shape_func(data_shape, | ||
out_layout_len, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -631,6 +631,36 @@ def strided_slice(data, begin, end, strides=None): | |
return _make.strided_slice(data, list(begin), list(end), list(strides)) | ||
|
||
|
||
def strided_set(data, v, begin, end, strides=None): | ||
"""Strided set of an array. | ||
|
||
Parameters | ||
---------- | ||
data : relay.Expr | ||
The source array to be sliced. | ||
|
||
v : relay.Expr | ||
The data to be set. | ||
|
||
begin: list of int | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for adding this op. Currently @yongwww is modifying strided_slice(#4312) to support begin, end and strides to be expression instead of just list of int. The reason is that in some DL frameworks, begin, end or strides can be a tensor. Also making it more dynamic can help us when building other ops, such as NMS. Considering this op is similar to strided_slice, should we keep it align with stride_slice, and allow begin, end and strides to be Expr? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That could be done, yes. The underlying TOPI op does support TVM expressions so it shouldn't be too hard to do. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is now done. |
||
The indices to begin with in the slicing. | ||
|
||
end: list of int | ||
Indices indicating end of the slice. | ||
|
||
strides: list of int, optional | ||
Specifies the stride values, it can be negative in that case, | ||
the input tensor will be reversed in that particular axis. | ||
|
||
Returns | ||
------- | ||
ret : relay.Expr | ||
The computed result. | ||
""" | ||
strides = strides or [] | ||
return _make.strided_set(data, v, list(begin), list(end), list(strides)) | ||
|
||
|
||
def slice_like(data, shape_like, axes=None): | ||
"""Slice the first input with respect to the second input. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,8 @@ | |
import tvm | ||
import topi | ||
from . import cpp | ||
from . import tag | ||
from .util import within_index, make_idx | ||
|
||
|
||
def expand_dims(a, axis, num_newaxis=1): | ||
|
@@ -155,6 +157,74 @@ def strided_slice(a, begin, end, strides=None): | |
strides = [] | ||
return cpp.strided_slice(a, begin, end, strides) | ||
|
||
@tvm.tag_scope(tag=tag.INJECTIVE+",strided_set") | ||
def strided_set(a, v, begin, end, strides=None): | ||
"""Set slice of an array. | ||
|
||
Parameters | ||
---------- | ||
a : tvm.Tensor | ||
The tensor to be sliced. | ||
|
||
v : tvm.Tensor | ||
The values to set | ||
|
||
begin: list of Expr | ||
The indices to begin with in the slicing. | ||
|
||
end: list of Expr | ||
Indicies indicating end of the slice. | ||
|
||
strides: list of Expr, optional | ||
Specifies the stride values, it can be negative | ||
in that case, the input tensor will be reversed | ||
in that particular axis. | ||
|
||
Returns | ||
------- | ||
ret : tvm.Tensor | ||
""" | ||
n = len(a.shape) | ||
if strides is None: | ||
strides = [1] * n | ||
|
||
if len(begin) != n: | ||
raise ValueError("size mismatch") | ||
if len(end) != n: | ||
raise ValueError("size mismatch") | ||
if len(strides) != n: | ||
raise ValueError("size mismatch") | ||
|
||
begin = list(map(tvm.convert, begin)) | ||
end = list(map(tvm.convert, end)) | ||
strides = list(map(tvm.convert, strides)) | ||
|
||
def _max(a, b): | ||
return tvm.expr.Select(a > b, a, b) | ||
|
||
# Convert negative indexes | ||
for i in range(n): | ||
begin[i] = tvm.if_then_else(begin[i] < 0, | ||
begin[i] + a.shape[i], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Even There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By OOB do you mean out of bounds of the indexed array? If yes, that is not a problem because the code will never try to fetch indices that are out the array shape. This code just tries to handle numpy-style negative indexing (starts from the end of the array) just like strided_slice does. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, I don’t have much idea about this either There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Even if the logical result is very similar to strided_slice the implementation has a difference. In this operation: The core kernel loops over all valid indexes for In all cases it doesn't matter if |
||
begin[i]) | ||
end[i] = tvm.if_then_else(end[i] < 0, | ||
end[i] + a.shape[i], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. OOB is not completely detected. |
||
end[i]) | ||
|
||
def _select(*indices): | ||
from_val = [] | ||
index_tuple = [] | ||
for i in range(n): | ||
from_val.append( | ||
within_index(begin[i], end[i], strides[i], indices[i])) | ||
index_tuple.append( | ||
make_idx(begin[i], end[i], strides[i], a.shape[i], indices[i])) | ||
return tvm.if_then_else(tvm.all(*from_val), | ||
v(*index_tuple), | ||
a(*indices)) | ||
|
||
return tvm.compute(a.shape, _select, name="strided_set") | ||
|
||
|
||
def reshape(a, newshape): | ||
"""Reshape the array | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit different because the value to append might be different at each index. It depends on the stride for that index. I'm not sure how to shorten that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It’s fine, no worries :-)