-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add group kv support and fix past kv from cache #2263
Conversation
hi @siddartha-RE, i was looking into this the other day too (inference support for fa2 llama2, i.e. past_kv/cache), and ended up using |
No there is an issue with alignment. I opened an in flash attention repo. Owner already asked and is planning to adjust behavior. Bump the issue if you want to let him know there is more interest. This PR has a pretty comprehensive check of correctness. |
@siddartha-RE Thanks for the contribution. If this is comprehensively tested, can we make this one the default implementation and replace the old monkey patch with this one? |
The test runs the attention block against
current and new are exactly identical and both differ from HF implementation with tolerance. I also have tested training 70b with this and the loss looks correct so the group kv handling appears to be correct. I could add a test with group kv enabled and compare new against HF to further verify implementation. FYI -- I have an issue open on Flash-attn code base |
One more comment. I tested the proper optimized implementation of use_cache: against a build of flash-attn from this PR: and confirmed that the test I added in the file still passes. |
@siddartha-RE Is this ready for merge? Which one do you prefer?
|
This is ready for merge. It is well tested so I am comfortable updating train hook but suggest keeping both versions for now so that people can test if they see issues. I have a change that further fixes the inference usage but will need to wait till the underlying fix in flash attention is released for handling of the causal flag when q != kv. |
@siddartha-RE Thanks! It is merged. |
): | ||
# [bsz, seq_len] | ||
if past_key_values_length > 0 and attention_mask is not None: | ||
attention_mask = torch.cat( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this concatenation necessary? attention_mask passed in should have shape of (bsz, kv_len), and past_kv_len is already in kv_len
Why are these changes needed?
Primary goal is to enable support for group key-value in Llama2 models, specifically the 70B model.
In addition
Related issue number (if applicable)
Addresses #2229
Checks
format.sh
to lint the changes in this PR.