-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a combine batch_matmul pass #5791
Conversation
Contrary what you might expect, this doesn't share as much code with the combine dense pass as it does with the combine 2d conv pass. This is because it concatenates the "output feature" dimensions.
This might be off the topic. I think it's possible to use the pattern language to replace all "combine parallel XX op" passes. Maybe we could create an issue to check it. |
@t-vi Interesting PR! Thank you for submitting it! I presume the use case for this is in Transformer-like models? Do you see a perf benefit from the rewrite?
I would imagine yes, it would be pretty easy to implement this as a pattern and a rewrite. A pattern solution might have some complication though, I don't think we currently have a way to match something with 2 or 3 or 4 branches in the same pattern, it would require a number of patterns. I'll think about that. |
@mbrookhart Yes, my use-case is transformers. The PyTorch frontend translates the matmul used in HuggingFace I imagine that it would be cool to move the pass to a pattern-matching. I would expect that it would replace the code shared by the combine passes of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
cc @vinx13 please also help to take a look |
* Add a combine batch_matmul pass Contrary what you might expect, this doesn't share as much code with the combine dense pass as it does with the combine 2d conv pass. This is because it concatenates the "output feature" dimensions. * fix docstring
* Add a combine batch_matmul pass Contrary what you might expect, this doesn't share as much code with the combine dense pass as it does with the combine 2d conv pass. This is because it concatenates the "output feature" dimensions. * fix docstring
Contrary what you might expect, this doesn't share as much code with the combine dense pass as it does with the combine 2d conv pass. This is because it concatenates the "output feature" dimensions just like the 2d conv pass concatenates output channels, whereas combine dense stacks the various matmul arguments.
I'm not sure if there is a deeper reason not to concatenate for dense, too, but maybe there is.