-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOPI][Relay][OP] Add a strided_set operation. #4303
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
4e83927
Add the strided_set function in TOPI
abergeron e2577f1
Add the c++ parts for strided_set in Relay
abergeron c4fc34e
Add the python parts of strided_set
abergeron cc08d52
WIP tests relay
abergeron 721086d
Fixes relay
abergeron ba33a9b
Trigger CI
abergeron 3bc2bf5
Fixes from review
abergeron 77f8462
Change to use dynamic indexes
abergeron d5b23fe
Make the indexes fully dynamic
abergeron 3b5f0c2
Fix lint
abergeron File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,8 @@ | |
import tvm | ||
import topi | ||
from . import cpp | ||
from . import tag | ||
from .util import within_index, make_idx | ||
|
||
|
||
def expand_dims(a, axis, num_newaxis=1): | ||
|
@@ -155,6 +157,97 @@ def strided_slice(a, begin, end, strides=None): | |
strides = [] | ||
return cpp.strided_slice(a, begin, end, strides) | ||
|
||
@tvm.tag_scope(tag=tag.INJECTIVE+",strided_set") | ||
def strided_set(a, v, begin, end, strides=None): | ||
"""Set slice of an array. | ||
|
||
Parameters | ||
---------- | ||
a : tvm.Tensor | ||
The tensor to be sliced. | ||
|
||
v : tvm.Tensor | ||
The values to set | ||
|
||
begin: tvm.Tensor | ||
The indices to begin with in the slicing. | ||
|
||
end: tvm.Tensor | ||
Indicies indicating end of the slice. | ||
|
||
strides: tvm.Tensor, optional | ||
Specifies the stride values, it can be negative | ||
in that case, the input tensor will be reversed | ||
in that particular axis. | ||
|
||
Returns | ||
------- | ||
ret : tvm.Tensor | ||
""" | ||
n = len(a.shape) | ||
|
||
if len(begin.shape) != 1: | ||
raise ValueError("begin should be a vector") | ||
if not begin.dtype == 'int32': | ||
raise TypeError("begin should be int32") | ||
if len(end.shape) != 1: | ||
raise ValueError("end should be a vector") | ||
if not end.dtype == 'int32': | ||
raise TypeError("end should be int32") | ||
if strides is not None: | ||
if len(strides.shape) != 1: | ||
raise ValueError("strides should be a vector") | ||
if not strides.dtype == 'int32': | ||
raise TypeError("strides should be int32") | ||
|
||
def _max(a, b): | ||
return tvm.expr.Select(a > b, a, b) | ||
|
||
if strides is None: | ||
strides = [tvm.const(1, 'int32')] * n | ||
else: | ||
strides = [tvm.if_then_else(strides.shape[0] > i, | ||
strides[i], | ||
tvm.const(1, 'int32')) | ||
for i in range(n)] | ||
|
||
begin = [tvm.if_then_else(begin.shape[0] > i, | ||
begin[i], | ||
tvm.expr.Select(strides[i] > 0, | ||
tvm.const(0, 'int32'), | ||
a.shape[i])) | ||
for i in range(n)] | ||
end = [tvm.if_then_else(end.shape[0] > i, | ||
end[i], | ||
tvm.expr.Select(strides[i] > 0, | ||
a.shape[i] + 1, | ||
-(a.shape[i] + 1))) | ||
for i in range(n)] | ||
|
||
|
||
# Convert negative indexes | ||
for i in range(n): | ||
begin[i] = tvm.if_then_else(begin[i] < 0, | ||
begin[i] + a.shape[i], | ||
begin[i]) | ||
end[i] = tvm.if_then_else(end[i] < 0, | ||
end[i] + a.shape[i], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. OOB is not completely detected. |
||
end[i]) | ||
|
||
def _select(*indices): | ||
from_val = [] | ||
index_tuple = [] | ||
for i in range(n): | ||
from_val.append( | ||
within_index(begin[i], end[i], strides[i], indices[i])) | ||
index_tuple.append( | ||
make_idx(begin[i], end[i], strides[i], a.shape[i], indices[i])) | ||
return tvm.if_then_else(tvm.all(*from_val), | ||
v(*index_tuple), | ||
a(*indices)) | ||
|
||
return tvm.compute(a.shape, _select, name="strided_set") | ||
|
||
|
||
def reshape(a, newshape): | ||
"""Reshape the array | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even
begin[i] + a.shape[i]
could trigger OOB. I am not sure how to assert the bound...Could anyone help here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By OOB do you mean out of bounds of the indexed array? If yes, that is not a problem because the code will never try to fetch indices that are out the array shape.
This code just tries to handle numpy-style negative indexing (starts from the end of the array) just like strided_slice does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, I don’t have much idea about this either
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even if the logical result is very similar to strided_slice the implementation has a difference.
In this operation:
a[begin:end:stride] = b
The core kernel loops over all valid indexes for
a
and check if that index is part of the values selected by the combination ofbegin
,end
andstride
. If it is, it will compute the corresponding index inb
and map the output to that value. Otherwise it will pick up the value froma
at that index.In all cases it doesn't matter if
begin
,end
, orstride
doesn't fall within the bounds ofa
because they are never used to directly or indirectly index intoa
.