Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay, TOPI] Complete rewrite of where op to support broadcasting #6759

Merged
merged 13 commits into from
Oct 28, 2020

Conversation

masahi
Copy link
Member

@masahi masahi commented Oct 25, 2020

This is a follow up to the discussion in #6383.

Recently I've hit a model (gpt2 from transformers) that has a where op with following shapes:

  • condition: (1, 1, 12, 12)
  • x (true case): (1, 8, 12, 12)
  • y (false case): scalar

It is not possible to support such combinations of shapes by yet another ad hoc band-aid, so I decided to roll a complete support for numpy style broadcasting. This turned out to be easy thanks to existing broadcasting utilities in relay/topi. I think this is desirable feature to align with other frameworks that follow numpy style where op (pytorch, onnx etc).

I added tests that replicate examples in numpy where op doc https://numpy.org/doc/stable/reference/generated/numpy.where.html. I think my implementation is correct, but since this is my first attempt to implement an op with broadcasting semantics, reviews by people who are familiar with broadcast issues are highly appreciated.

With this PR, I can load gpt2 model from transformers using PyTorch frontend, compile and get correct outputs compared to PyTorch.

please review @zhiics @junrushao1994 @t-vi @kevinthesun @jwfromm @tqchen

@junrushao
Copy link
Member

Nice work @masahi!

Copy link
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good!

src/relay/op/tensor/transform.cc Outdated Show resolved Hide resolved
Copy link
Contributor

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @masahi, looks good. ❤️ moving this to the full broadcasting.

include/tvm/topi/broadcast.h Outdated Show resolved Hide resolved
@masahi
Copy link
Member Author

masahi commented Oct 27, 2020

@t-vi In d8c5076, I ported c++ broadcast_shape_tensors function to hybridscript, to be able to generate runtime assertion for broadcasting validity check.

Now, invalid shapes are catched at runtime and result in an error. I added one test that intentionally tries to run a compiled model with invalid combination of shapes, and verifies that invaild broadcast are properly caught. Please have a look!

python/tvm/topi/broadcast.py Outdated Show resolved Hide resolved
@t-vi
Copy link
Contributor

t-vi commented Oct 27, 2020

Looks awesome to me. Thank you @masahi!

@masahi
Copy link
Member Author

masahi commented Oct 27, 2020

@kevinthesun I've updated where shape function, I think it is ready to merge. Can you take a final look?

Copy link
Contributor

@kevinthesun kevinthesun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kevinthesun kevinthesun merged commit f092e1d into apache:main Oct 28, 2020
@kevinthesun
Copy link
Contributor

Thanks @masahi @junrushao1994 @t-vi

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Oct 29, 2020
…pache#6759)

* where type rel with broadcast

* add tests for where with broadcast

* clean up tests

* uncomment other tests

* add more tests

* update doc

* CHECK -> ICHECK

* add where any test

* fix format

* remove useless detections for one

* set manual seed

* ported shape broadcast helper func to hybridscript

* remove shape function helper from cpp

Co-authored-by: masa <masa@pop-os.localdomain>
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Dec 2, 2020
…pache#6759)

* where type rel with broadcast

* add tests for where with broadcast

* clean up tests

* uncomment other tests

* add more tests

* update doc

* CHECK -> ICHECK

* add where any test

* fix format

* remove useless detections for one

* set manual seed

* ported shape broadcast helper func to hybridscript

* remove shape function helper from cpp

Co-authored-by: masa <masa@pop-os.localdomain>
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Dec 4, 2020
…pache#6759)

* where type rel with broadcast

* add tests for where with broadcast

* clean up tests

* uncomment other tests

* add more tests

* update doc

* CHECK -> ICHECK

* add where any test

* fix format

* remove useless detections for one

* set manual seed

* ported shape broadcast helper func to hybridscript

* remove shape function helper from cpp

Co-authored-by: masa <masa@pop-os.localdomain>
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Dec 4, 2020
…pache#6759)

* where type rel with broadcast

* add tests for where with broadcast

* clean up tests

* uncomment other tests

* add more tests

* update doc

* CHECK -> ICHECK

* add where any test

* fix format

* remove useless detections for one

* set manual seed

* ported shape broadcast helper func to hybridscript

* remove shape function helper from cpp

Co-authored-by: masa <masa@pop-os.localdomain>
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Mar 31, 2021
…pache#6759)

* where type rel with broadcast

* add tests for where with broadcast

* clean up tests

* uncomment other tests

* add more tests

* update doc

* CHECK -> ICHECK

* add where any test

* fix format

* remove useless detections for one

* set manual seed

* ported shape broadcast helper func to hybridscript

* remove shape function helper from cpp

Co-authored-by: masa <masa@pop-os.localdomain>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants