-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay, TOPI] Add searchsorted op #9184
Conversation
8bb70f2
to
2e178d0
Compare
I'll try to get to this monday. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few comments, correctness looks good to me mostly but might need another pair of eyes 👀
TVM_ATTR_FIELD(side).set_default("left").describe( | ||
"Controls which index is returned if a value lands exactly on one of sorted values."); | ||
TVM_ATTR_FIELD(dtype) | ||
.set_default(DataType::Int(32)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, just curious if there is any convention on the dtype of indices, there is a lot of index code with dyn gather I believe has all the indices in Int(64). Int(64) might be a better default.
The other attributes in this file have NullValue<DataType>()
as the default value which is interesting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, other ops have NullValue<DataType>()
as the default here, but if we look at the python definition at https://github.com/apache/tvm/blob/main/python/tvm/relay/op/algorithm.py#L47, they say the default is int32. So I thought we should make that explicit in attrs/algorithm.h
as well.
python/tvm/topi/searchsorted.py
Outdated
|
||
|
||
def binary_search( | ||
ib, sequence_offset, search_range, index, sorted_sequence, values, out_indices, side, out_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might want to provide a brief docstring on these variables since they are not immediately obvious to me and this uses mutation which is kind of odd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some description, let me know if things are not clear
python/tvm/topi/searchsorted.py
Outdated
with ib.else_scope(): | ||
hi[0] = mid | ||
|
||
out_indices[index] = lo[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just return lo[0] and set out_indices[index] = ... below? Might be more reusable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Due to some peculiarity in the IR builder, that doesn't work on the vulkan target. It works fine on llvm, but on vulkan I get all zero output:
Mismatched elements: 149857 / 150000 (99.9%)
Max absolute difference: 1024
Max relative difference: 1.
x: array([[[[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]],...
y: array([[[[ 116, 195, 291, ..., 338, 196, 890],
[ 609, 93, 659, ..., 977, 563, 693],
[ 675, 922, 53, ..., 1019, 429, 486]],...
I'll test on the cuda target too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah this is very interesting, let's just add a comment and move on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad, when I tested the change above, I forget to update the GPU definition in topi/cuda/searchsorted.py
to assign the returned index to the output buffer. No wonder I got all zero output!
I fixed my mistake and it works on vulkan as well. Not only it removed in-place mutation, it also removed some arguments from binary_search
. Things look much cleaner now, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
pinging for review. |
Can somebody take a look at this PR? I believe it should be ready to go. |
* Add relay definition * 1D cpu test working * multi dim working * gpu version working * check shape in type rel * support side * use target specfic max threads * add relay boilerplate * relay test working * cleanup topi test * fix test * add torch converter * handle other cases * more topi test * support torch bucketize * update doc * fix tests * fix lint * rebase fix * make the test case smaller * add tests for edge cases * replace "side" attribute with boolean "right" * add more descrition to binear_search IR gen params * return index from binary_search rather than update inplace * remove unused argument * format fix
* Add relay definition * 1D cpu test working * multi dim working * gpu version working * check shape in type rel * support side * use target specfic max threads * add relay boilerplate * relay test working * cleanup topi test * fix test * add torch converter * handle other cases * more topi test * support torch bucketize * update doc * fix tests * fix lint * rebase fix * make the test case smaller * add tests for edge cases * replace "side" attribute with boolean "right" * add more descrition to binear_search IR gen params * return index from binary_search rather than update inplace * remove unused argument * format fix
This adds
searchsorted
op supported by numpy, PyTorch, and TF. This is a simple op that runs many binary search in parallel. The same op can be used for numpydigitize
and PTbucketize
op.Numpy
searchsorted
op only supports 1Dsorted_sequence
, but both PT and TF generalize it for N-D. Our op also supports N-Dsorted_sequence
.References:
https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html
https://numpy.org/doc/stable/reference/generated/numpy.digitize.html
https://pytorch.org/docs/stable/generated/torch.searchsorted.html
https://www.tensorflow.org/api_docs/python/tf/searchsorted
cc @mbrookhart @AndrewZhaoLuo @interesaaat