-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay, TOPI] Complete rewrite of where op to support broadcasting #6759
Conversation
Nice work @masahi! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good!
0fe545e
to
5423c70
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @masahi, looks good. ❤️ moving this to the full broadcasting.
@t-vi In d8c5076, I ported c++ Now, invalid shapes are catched at runtime and result in an error. I added one test that intentionally tries to run a compiled model with invalid combination of shapes, and verifies that invaild broadcast are properly caught. Please have a look! |
Looks awesome to me. Thank you @masahi! |
@kevinthesun I've updated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks @masahi @junrushao1994 @t-vi |
…pache#6759) * where type rel with broadcast * add tests for where with broadcast * clean up tests * uncomment other tests * add more tests * update doc * CHECK -> ICHECK * add where any test * fix format * remove useless detections for one * set manual seed * ported shape broadcast helper func to hybridscript * remove shape function helper from cpp Co-authored-by: masa <masa@pop-os.localdomain>
…pache#6759) * where type rel with broadcast * add tests for where with broadcast * clean up tests * uncomment other tests * add more tests * update doc * CHECK -> ICHECK * add where any test * fix format * remove useless detections for one * set manual seed * ported shape broadcast helper func to hybridscript * remove shape function helper from cpp Co-authored-by: masa <masa@pop-os.localdomain>
…pache#6759) * where type rel with broadcast * add tests for where with broadcast * clean up tests * uncomment other tests * add more tests * update doc * CHECK -> ICHECK * add where any test * fix format * remove useless detections for one * set manual seed * ported shape broadcast helper func to hybridscript * remove shape function helper from cpp Co-authored-by: masa <masa@pop-os.localdomain>
…pache#6759) * where type rel with broadcast * add tests for where with broadcast * clean up tests * uncomment other tests * add more tests * update doc * CHECK -> ICHECK * add where any test * fix format * remove useless detections for one * set manual seed * ported shape broadcast helper func to hybridscript * remove shape function helper from cpp Co-authored-by: masa <masa@pop-os.localdomain>
…pache#6759) * where type rel with broadcast * add tests for where with broadcast * clean up tests * uncomment other tests * add more tests * update doc * CHECK -> ICHECK * add where any test * fix format * remove useless detections for one * set manual seed * ported shape broadcast helper func to hybridscript * remove shape function helper from cpp Co-authored-by: masa <masa@pop-os.localdomain>
This is a follow up to the discussion in #6383.
Recently I've hit a model (
gpt2
from transformers) that has awhere
op with following shapes:It is not possible to support such combinations of shapes by yet another ad hoc band-aid, so I decided to roll a complete support for numpy style broadcasting. This turned out to be easy thanks to existing broadcasting utilities in relay/topi. I think this is desirable feature to align with other frameworks that follow numpy style where op (pytorch, onnx etc).
I added tests that replicate examples in numpy where op doc https://numpy.org/doc/stable/reference/generated/numpy.where.html. I think my implementation is correct, but since this is my first attempt to implement an op with broadcasting semantics, reviews by people who are familiar with broadcast issues are highly appreciated.
With this PR, I can load
gpt2
model from transformers using PyTorch frontend, compile and get correct outputs compared to PyTorch.please review @zhiics @junrushao1994 @t-vi @kevinthesun @jwfromm @tqchen