Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for FlashAttention #2126

Merged
merged 4 commits into from
Aug 8, 2023
Merged

Fixes for FlashAttention #2126

merged 4 commits into from
Aug 8, 2023

Conversation

tmm1
Copy link
Contributor

@tmm1 tmm1 commented Aug 1, 2023

Why are these changes needed?

We need to mark LlamaRMSNorm layers as bf16 after peft conversion to fix errors from flash-attn about dtype support

I also updated the flash attention patch

I am able to run training on llama2 with this change.

Related issue number (if applicable)

Closes #1828

Checks

  • I've run format.sh to lint the changes in this PR.
  • I've included any doc changes needed.
  • I've made sure the relevant tests are passing (if applicable).

@tmm1 tmm1 marked this pull request as ready for review August 1, 2023 07:17
@merrymercy merrymercy merged commit 060c9f1 into lm-sys:main Aug 8, 2023
@merrymercy
Copy link
Member

@tmm1 Thanks! This is merged.

@philschmid
Copy link

This code looks quite similar to the patch in my blog We know noticed that 70B models with GQA are not supported.

Have you seen the same issue?

@tmm1
Copy link
Contributor Author

tmm1 commented Aug 9, 2023

Hi @philschmid, yes you could refer to this approach:

LAION-AI/Open-Assistant@3c8f93e

@philschmid
Copy link

Yeah saw the commit as well. Just wanted to share here that you are aware its not working for 70B atm.

@tmm1
Copy link
Contributor Author

tmm1 commented Aug 12, 2023

@philschmid I'm looking at this again now to add 70B support. Did you end up doing any more work in this area?

I'm also interested in making the patch work for forward-pass with past_key_value support, which is something that's still not quite working.

@merrymercy
Copy link
Member

@tmm1 Hi, I did some minor style cleanup in PR #2212.
I also found the current implementation does not support generative inference or 70B.
Could you take a look and fix it? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

LoRA Fine Tuning Crash at FlashAttention Issue
3 participants