-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rewrite rules implementation for LLaMA-2/ LLaMA-3 #1811
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1811 +/- ##
==========================================
- Coverage 75.95% 73.50% -2.45%
==========================================
Files 228 248 +20
Lines 24246 26893 +2647
Branches 4201 4915 +714
==========================================
+ Hits 18416 19768 +1352
- Misses 5035 6161 +1126
- Partials 795 964 +169 ☔ View full report in Codecov by Sentry. |
initializer.py
Outdated
@@ -0,0 +1,231 @@ | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest excluding this file for now. We can focus on the rewriter rules for this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lintrunner found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.
09bb592
to
2771a40
Compare
7894534
to
03ba9e9
Compare
Congrats on your first PR! 🎉 For autofix-able lint errors, you can follow https://github.com/microsoft/onnxscript#coding-style to run the autofix. |
Summary
This PR introduces the implementation of LLaMA 3 and LLaMA 2 rewrite rules for the MLP and LLaMAAttention layers in transformers. The rules are designed to work with transformer versions 4.39 to 4.42, and they handle the optimization and fusion operations.
Key Changes
MLP RewriteRule:
A new rewrite rule for optimizing the LLaMA MLP layer (LlamaMLP) in transformer versions 4.39 to 4.42.
The optimization includes handling different input sizes (5 or 6) and performing matrix multiplication and activation operations to produce an optimized output.
GQA Llama RewriteRule:
Introduces a rewrite rule for the LLaMAAttention layer as well as the first attention (LlamaAttention) with support for specified number of inputs.
Two methods are implemented for handling 2D and 4D cache configurations during the Group Query Attention (GQA) process, enabling optimized matrix multiplication and attention operations.