-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add option to replace attention with flash attention #25
Comments
Some old notes from slack last week: Right now, nothing in torch/huggingface directly can be used to do flash attention. One would need to swap-out the layer, which is possible as this is what gpt-neox repo does. I'll have to look more carefully with this approach to see how to do it, similar to how the other vicuna repo does for llama.
cuda 11.7 required. |
WIP for neox using flash in huggingface transformers, but no work for last 3 months, so probably dead: https://github.com/conceptofmind/flash-gpt |
Amazon thing: https://aws.amazon.com/blogs/machine-learning/new-performance-improvements-in-amazon-sagemaker-model-parallel-library/
So maybe we should use sagemaker. I noticed this before somewhere else I think. 100B parameter GPT-NeoX model on 32 p4d.24xlarge instances |
You can use the same install above to then make llama use flash attention using the wrappers/patches from vicunda model: |
Flash attention has already been integrated into gpt-neox models here: https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py#L215
Can add the swapped model definition as an option to the training and generation scripts and benchmark the speed difference.
Converting Llama and others might be more work. it uses a pretty standard looking attention, but not sure how it differs from the pytorch default. Might just need to remap some layer names https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L160
The text was updated successfully, but these errors were encountered: