Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flash Attention 2] Add flash attention 2 for GPT-J #28295

Merged
merged 11 commits into from
Mar 13, 2024

Conversation

bytebarde
Copy link
Contributor

What does this PR do?

Adds Flash Attention 2 for GPT-J
Fixes #26350

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

cc: @younesbelkada

@bytebarde
Copy link
Contributor Author

bytebarde commented Jan 1, 2024

Current progress with running flash_attn_test. Will dive deeper to fix the error.

2023-12-31 8 43 06

@susnato
Copy link
Contributor

susnato commented Jan 1, 2024

Hi @bytebarde, what is the error message?
If it is something like - "IndexError: tensors used as ...", then updating CUDA could solve the error (At least it was for my case in OPT).

BTW run make fixup to make the CI green!

@bytebarde
Copy link
Contributor Author

bytebarde commented Jan 2, 2024

Hi @susnato, thank you so much for for your attention to this PR!

I believe the error originates from two factors: (1) my preliminary implementation of GPTJFlashAttention2, which aimed to eliminate "redundant" transposing of the key and query, and (2) the execution of test_flash_attn_2_generate_padding_right using the testing configuration.

To address these issues, I have reinstated the original transposing operations and reverted the QKV cache concatenation. Additionally, I overwrote test_flash_attn_2_generate_padding_right by using the actual checkpoint and passed all eight tests, similar to what @younesbelkada and you did for llama2 and phi2.

Currently, the code has some problems with make fixup. Will work on this for the next step.

2024-01-01 6 50 52

@bytebarde bytebarde changed the title [Flash Attention 2] [WIP] Add flash attention 2 for GPT-J [Flash Attention 2] Add flash attention 2 for GPT-J Jan 4, 2024
@bytebarde
Copy link
Contributor Author

Hi @younesbelkada,

I believe this pull request is now ready for your review.

I'd like to highlight a few changes, especially regarding check_copies.py, that I'm not entirely confident about. To ensure the branch passes the make fixup check, I removed the "copies" lines before both modeling_codegen.CodeGenBlock and test_modeling_gptj.test_flash_attn_2_generate_padding_right. This was done because the changes involved are somehow complex.

I would really appreciate your guidance on this. If there's a more standard or preferable way to handle such intricate changes, please let me know so I can make the necessary adjustments.

Thank you for your time on this!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks clean on my end already ! Would you be happy to address the comment about the copy mechanism that has been removed ?
Also, can you run the benchmarking script here: https://gist.github.com/younesbelkada/02f35734da906cc0f2389ae4f665c58f with a gpt-j checkpoint and see the speedup (the result should give you something similar than: #26414 (review)), I can take care of pushing the images on the Hub and we'll just need to update the docs similarly as: 7d4c688

tests/models/gptj/test_modeling_gptj.py Outdated Show resolved Hide resolved
@bytebarde
Copy link
Contributor Author

Hi @younesbelkada, thank you very much for your valuable input and guidance! I apologize for the delayed response.

I've addressed the comment regarding the copy mechanism, and the branch successfully passed the make fixup test.

Additionally, I've conducted the speed test. However, the observed speedup was not as significant as what we noted with OPT. The test was performed on an Nvidia RTX 4090, utilizing max-batch-size=8 and max-seqlen=32 to conserve memory. The model checkpoint used was EleutherAI/gpt-j-6b with the revision set to "float16". I've attached the speedup graph below for your review.

2024-01-27 9 51 37

Could you also perform the test on an A100 GPU for comparison?

Thank you once again for your time. I look forward to hearing your thoughts on this!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you ! Can you just rebase / merge with main to make sure the CI passes?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM Thanks for adding flash attention support!

src/transformers/models/gptj/modeling_gptj.py Outdated Show resolved Hide resolved
@@ -293,7 +560,11 @@ def __init__(self, config):
super().__init__()
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPTJAttention(config)
self.attn = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's define and use !

Suggested change
self.attn = (
GPTJ_ATTENTION_CLASSES = {
"eager": GPTJAttention,
"flash_attention_2": GPTJFlashAttention,
}

younesbelkada and others added 2 commits January 30, 2024 02:54
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bytebarde
Can you address this comment? https://github.com/huggingface/transformers/pull/28295/files#r1470429885
It shouldn't be super hard, you just need to do something similar than what we do in Llama, specifically: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L746 and

self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)

@bytebarde
Copy link
Contributor Author

Hi @ArthurZucker and @younesbelkada ,

Thank you so much for your additional suggestions!

I am sorry. I had assumed that GPTJ_ATTENTION_CLASSES had already been introduced by @ArthurZucker previously...

I have now added GPTJ_ATTENTION_CLASSES and made the necessary code modifications.
Furthermore, I re-ran the test suite and successfully passed all the tests.

Please let me know if there's anything more I can do!
Thank you so much!

@ArthurZucker
Copy link
Collaborator

Good for me merging! 🤗

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last nits!

@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

requires_bitsandbytes here!
Also let's add the expected text explicitly! to make sure we always have what we want!

Copy link

github-actions bot commented Mar 7, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ArthurZucker
Copy link
Collaborator

Hey @bytebarde could you rebase and add the explicit expected outputs? Or should I do it? 🤗

@bytebarde
Copy link
Contributor Author

bytebarde commented Mar 8, 2024

Hi @ArthurZucker , good morning!

I have added the @require_bitsandbytes and expected outputs in the test function.
Please let me know if there is anything needed to be addressed!

Thank you so much!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bytebarde
Thanks! Can you run the styling checks? make fixup and / or make fix-copies after that we can merge

@bytebarde
Copy link
Contributor Author

Hi @younesbelkada,

Thank you for taking the time to review this!

I have run the make fix-copies and believe that the previous issues regarding consistency have been addressed.

Please let me know if any further changes are needed. Thank you!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again !

@younesbelkada younesbelkada merged commit be3fd8a into huggingface:main Mar 13, 2024
19 checks passed
itazap pushed a commit that referenced this pull request May 14, 2024
* initial implementation of flash attention for gptj

* modify flash attention and overwrite test_flash_attn_2_generate_padding_right

* update flash attention support list

* remove the copy line in the `CodeGenBlock`

* address copy mechanism

* Update src/transformers/models/gptj/modeling_gptj.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Add GPTJ attention classes

* add expected outputs in the gptj test

* Ensure repo consistency with 'make fix-copies'

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants