-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend][Tensorflow]add batch_dim support for gatherV2 #7951
Conversation
@comaniac @icemelon9 @yongwww @kevinthesun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM. Just nits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@zxy844288792 I notice you haven't update the gradient for take to accept batch_dims. Could you either update the gradient or add a check that batch_dims = 0 in the gradient? |
I just added a check for batch_dims == 0 in the gradient |
Thanks @zxy844288792 @tkonolige @yongwww |
* add batch_dim support * fix lint * add check for num of arguments for topi.take * fix gpu test cases * add check for batch_dims in take_grad
* add batch_dim support * fix lint * add check for num of arguments for topi.take * fix gpu test cases * add check for batch_dims in take_grad
* add batch_dim support * fix lint * add check for num of arguments for topi.take * fix gpu test cases * add check for batch_dims in take_grad
* add batch_dim support * fix lint * add check for num of arguments for topi.take * fix gpu test cases * add check for batch_dims in take_grad
Encounter a special cases when batch_dims=1 in gather() from centernet_hourglass_512x512_1 from tensorflow hub model zoo.
Implement the batch_dims logic according to tensorflow implementation: https://www.tensorflow.org/api_docs/python/tf/gather
https://github.com/tensorflow/tensorflow/blob/5dcfc51118817f27fad5246812d83e5dccdc5f72/tensorflow/core/kernels/gather_op.cc
I have not added testcases for topi and relay since numpy does not have that attribute, I am open to see any suggestion for that. Test cases for tensorflow frontend parser have been added.
Thanks for contributing to TVM! Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from Reviewers by @ them in the pull request thread.