-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes for FlashAttention #2126
Fixes for FlashAttention #2126
Conversation
@tmm1 Thanks! This is merged. |
This code looks quite similar to the patch in my blog We know noticed that 70B models with GQA are not supported. Have you seen the same issue? |
Hi @philschmid, yes you could refer to this approach: |
Yeah saw the commit as well. Just wanted to share here that you are aware its not working for 70B atm. |
@philschmid I'm looking at this again now to add 70B support. Did you end up doing any more work in this area? I'm also interested in making the patch work for forward-pass with past_key_value support, which is something that's still not quite working. |
Why are these changes needed?
We need to mark LlamaRMSNorm layers as bf16 after peft conversion to fix errors from flash-attn about dtype support
I also updated the flash attention patch
I am able to run training on llama2 with this change.
Related issue number (if applicable)
Closes #1828
Checks
format.sh
to lint the changes in this PR.