-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Torch] Various updates for PyTorch frontend #7348
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
Since torch.sort returns both sorted values and indices while the Relay one doesn't, torch.sort conversion is not efficient, especially for multidimensional input (currently does both sort and argsort!). Suggestions for a better implementation are welcome.
I think gather is the direct equivalent of how you use take on 1d. Maybe it's worth fixing this in this PR.
@t-vi Thanks, I think I tried |
Thanks @siju-samuel @t-vi |
* add conversion for detr * remove explicit broadcast_to before batched matmul * use take with wrap mode * add test for transformer and negative indices * add sort and argsort * add logical_and * support masked_select * add gpu targets to masked_select test * improve sort conversion
* add conversion for detr * remove explicit broadcast_to before batched matmul * use take with wrap mode * add test for transformer and negative indices * add sort and argsort * add logical_and * support masked_select * add gpu targets to masked_select test * improve sort conversion
* add conversion for detr * remove explicit broadcast_to before batched matmul * use take with wrap mode * add test for transformer and negative indices * add sort and argsort * add logical_and * support masked_select * add gpu targets to masked_select test * improve sort conversion
* add conversion for detr * remove explicit broadcast_to before batched matmul * use take with wrap mode * add test for transformer and negative indices * add sort and argsort * add logical_and * support masked_select * add gpu targets to masked_select test * improve sort conversion
This PR adds various updates to the PyTorch frontend, to fully support the recent transformer based object detection model from facebook, DETR https://github.com/facebookresearch/detr. After this PR, DETR runs on TVM and gets the correct results. TVM with auto scheduled GPU conv2d and batched matmul is 1.4x faster than PyTorch in my environment.
Also added various missing ops reported by hummingbird projects. @interesaaat
logical_and
,masked_select
,sort
andargsort
.masked_select
requires VM to run.cumsum
andmasked_fill
op, to enable importing DETR https://github.com/facebookresearch/detr.cumsum
is also requested by hummingbird.mode="wrap"
in Relaytake
op. Without this, the result doesn't match with DETR.broadcast_to
beforebatch_matmul
. Without it, memory usage blows up during constant evaluation of DETR. This is the same issue as [Topi] Allow batch_matmul to broadcast along batch dimension. #6616, please see the explanation there.torch.nn.Transformer
to the tests.Since
torch.sort
returns both sorted values and indices while the Relay one doesn't,torch.sort
conversion is not efficient,especially for multidimensional input (currently does both sort and argsort!). Suggestions for a better implementation are welcome.UPDATE: fixedplease review @siju-samuel @jwfromm @kevinthesun @t-vi