Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][ONNX] Batch_matmul to dense optimization #8440

Merged
merged 2 commits into from
Jul 13, 2021

Conversation

ymwangg
Copy link
Contributor

@ymwangg ymwangg commented Jul 9, 2021

This PR copied the Pytorch matmul implementation that adopts the batch_matmul to dense optimization, which significantly improves the performance of some matmul ops in bert models such as [4, 128, 768] x [768, 768] when using cublas or tensorrt.

Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM. Do we need a test case for this?

python/tvm/relay/frontend/onnx.py Outdated Show resolved Hide resolved
python/tvm/relay/frontend/onnx.py Outdated Show resolved Hide resolved
python/tvm/relay/frontend/onnx.py Outdated Show resolved Hide resolved
@ymwangg
Copy link
Contributor Author

ymwangg commented Jul 10, 2021

@comaniac Thanks for the valuable comments. I found the Pytorch implementation https://github.com/apache/tvm/blob/main/python/tvm/relay/frontend/pytorch.py#L1624 cannot handle dynamic shapes. So I added this optimization based on the ONNX implementation. For some reason, the weight matrix of nn.dense must be static for the tvm codegen to work correctly, though using libs like mkl works just fine. I created an issue for this #8441.

Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Just a nit.

python/tvm/relay/frontend/onnx.py Show resolved Hide resolved
Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@comaniac comaniac merged commit 136f218 into apache:main Jul 13, 2021
@comaniac
Copy link
Contributor

Thanks @ymwangg

ylc pushed a commit to ylc/tvm that referenced this pull request Sep 29, 2021
* [ONNX]Add batch_matmul to dense optimization

* Add extra check to avoid unnecessary reshape

Co-authored-by: Ubuntu <ubuntu@ip-172-31-14-16.us-west-2.compute.internal>
zxy844288792 pushed a commit to zxy844288792/tvm that referenced this pull request Mar 4, 2022
* [ONNX]Add batch_matmul to dense optimization

* Add extra check to avoid unnecessary reshape

Co-authored-by: Ubuntu <ubuntu@ip-172-31-14-16.us-west-2.compute.internal>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants