-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flash Attention 2] Add flash attention 2 for GPT-J #28295
Conversation
Hi @bytebarde, what is the error message? BTW run |
Hi @susnato, thank you so much for for your attention to this PR! I believe the error originates from two factors: (1) my preliminary implementation of To address these issues, I have reinstated the original transposing operations and reverted the QKV cache concatenation. Additionally, I overwrote Currently, the code has some problems with |
41c9d4a
to
e47ef13
Compare
Hi @younesbelkada, I believe this pull request is now ready for your review. I'd like to highlight a few changes, especially regarding I would really appreciate your guidance on this. If there's a more standard or preferable way to handle such intricate changes, please let me know so I can make the necessary adjustments. Thank you for your time on this! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks clean on my end already ! Would you be happy to address the comment about the copy mechanism that has been removed ?
Also, can you run the benchmarking script here: https://gist.github.com/younesbelkada/02f35734da906cc0f2389ae4f665c58f with a gpt-j checkpoint and see the speedup (the result should give you something similar than: #26414 (review)), I can take care of pushing the images on the Hub and we'll just need to update the docs similarly as: 7d4c688
Hi @younesbelkada, thank you very much for your valuable input and guidance! I apologize for the delayed response. I've addressed the comment regarding the copy mechanism, and the branch successfully passed the Additionally, I've conducted the speed test. However, the observed speedup was not as significant as what we noted with OPT. The test was performed on an Nvidia RTX 4090, utilizing Could you also perform the test on an A100 GPU for comparison? Thank you once again for your time. I look forward to hearing your thoughts on this! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you ! Can you just rebase / merge with main to make sure the CI passes?
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM Thanks for adding flash attention support!
@@ -293,7 +560,11 @@ def __init__(self, config): | |||
super().__init__() | |||
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd | |||
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) | |||
self.attn = GPTJAttention(config) | |||
self.attn = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's define and use !
self.attn = ( | |
GPTJ_ATTENTION_CLASSES = { | |
"eager": GPTJAttention, | |
"flash_attention_2": GPTJFlashAttention, | |
} |
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @bytebarde
Can you address this comment? https://github.com/huggingface/transformers/pull/28295/files#r1470429885
It shouldn't be super hard, you just need to do something similar than what we do in Llama, specifically: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L746 and
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) |
Hi @ArthurZucker and @younesbelkada , Thank you so much for your additional suggestions! I am sorry. I had assumed that I have now added GPTJ_ATTENTION_CLASSES and made the necessary code modifications. Please let me know if there's anything more I can do! |
Good for me merging! 🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Last nits!
@require_torch_gpu | ||
@pytest.mark.flash_attn_test | ||
@slow | ||
def test_flash_attn_2_generate_padding_right(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
requires_bitsandbytes here!
Also let's add the expected text explicitly! to make sure we always have what we want!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Hey @bytebarde could you rebase and add the explicit expected outputs? Or should I do it? 🤗 |
Hi @ArthurZucker , good morning! I have added the Thank you so much! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @bytebarde
Thanks! Can you run the styling checks? make fixup
and / or make fix-copies
after that we can merge
Hi @younesbelkada, Thank you for taking the time to review this! I have run the Please let me know if any further changes are needed. Thank you! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again !
* initial implementation of flash attention for gptj * modify flash attention and overwrite test_flash_attn_2_generate_padding_right * update flash attention support list * remove the copy line in the `CodeGenBlock` * address copy mechanism * Update src/transformers/models/gptj/modeling_gptj.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Add GPTJ attention classes * add expected outputs in the gptj test * Ensure repo consistency with 'make fix-copies' --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
What does this PR do?
Adds Flash Attention 2 for
GPT-J
Fixes #26350
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
cc: @younesbelkada