-
Notifications
You must be signed in to change notification settings - Fork 13.9k
ggml : add Flash Attention #5021
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Since we are doing this from scratch, wouldn't it be better to remove the custom attention mask entirely and pass a list of KV cells used in each sequence? Considering our implementation of batching, I think we should be looking at implementing something closer to paged attention rather than flash attention. I suppose it is possible to convert the mask to a list of sequences in the kernels, but it would be less efficient. |
|
Yes, we can pass list instead of mask. I am not sure of the format though - if each list has different length I feel it will hinder the GPU performance. Edit: I just got an idea - we can pass both the |
|
We could use a vector with dimension |
|
It seems that vLLM has added a new version of paged attention since it looked into the implementation (vllm-project/vllm#1348). I am not sure what are the changes, but I think it is worth looking into what they are doing. The kernel is in https://github.com/vllm-project/vllm/blob/main/csrc/attention/attention_kernels.cu |
|
Alibi could also be done in this kernel. |
|
Regarding the Alibi, I feel reinterpreting it as a It remains to be seen though if the Will take a look at the vLLM code and I've updated the description with some of the things from this discussion |
|
@ggerganov @slaren Together with @JohannesGaessler and @FSSRepo we are working on the same thing over at Pints-AI#1 which we intend to do a pull to llamacpp once work is done. However, I think we will converge into this one. Given the amount of work here, @ggerganov @slaren how do you want to organise this? The 3 of us are in a temporary discord group actually to work this out, perhaps we can use that? What are your thoughts? |
|
Discord is not an option for me - I prefer to communicate over Github issues / discussions / e-mail. Happy to see you have started work on the CUDA implementation. Please take into account the proposed API here - note that it is still a WIP and can change. I can review the implementation that you have when you think it is in a good state. Would prefer PR's that are compatible with this branch so we can verify correctness using |
|
@ggerganov Got it. Let us work on a plan to converge with this PR. |
|
|
|
Any performance numbers? |
e0ba0da to
52ae085
Compare
|
Metal and other GPU backends with full offload only uses one thread, however in Metal the number of threads is also used as the number of command buffers. |
|
@slaren ah ok, thanks for the explanation! I'm not seeing any effect of |
It's currently disabled, yes. |
|
Sadly I'm not seeing any benefit from this. No reduction in VRAM usage, no speedup, even when fully offloading. Infact, I'm only seeing slower speeds when using partial offloading. |
|
For me (Windows, CUDA, 24GB VRAM) the difference is definitely there, but it depends on the model and I have best results with a large amount of context data. The most pronounced for me is Mixtral-8x7B-Instruct-v0.1-requant-imat-IQ3_XS which I can fully offload. It Edit: I saw the below "old timings" across at least 4x runs each last night, but today w/o FA is hitting close to 39-40 t/s, so must have been an edge case there, but FA seemed to help with it. With FA: Without FA: (updated) old w/o timingsOther models are less remarkable, but I'm able to store a lot more context. New tests: Llamabench with -p 512,1024 is less dramatic but measurable, TG ~46 -> ~50:The differences are more obvious at -p 8096, 16192, 32384: From PP 819 -> 1005 @ 16K, and OOM -> 879 @ 32K. |
|
Performance on Macbook Air M2, 24GB using latest llama.cpp, before and after using the Without Flash Attention: With Flash Attention: TL;DR: Generation speed increases from 8.70 t/s to 9.69 t/s, memory usage decreases slightly, prompt processing is not tested in this case. |
|
Hi is server has flash attention yet ? Or is it automatically using flash attention ? edit: just add -fa too in server got it |
|
Hi, I am having issues building this on CUDA 11.4 now after this PR. Notably, I am getting This is not the first time this has happened, previously we added |
|
@LostRuins can you check whether this fix #7019 works? |
It seems this only applies to a low context like 4K. Testing a very small LLM on my system with a context size of 13.000 Tokens and no GQA, the difference is massive. VRAM savings from 2.8 to 1.2 GB, Text Generation from 37 to 71 token/s, pp from 1300 token/s to 2300 token/s. Great work! |
|
From the dialogue above, I think I understand that the support for -fa needs to be coded per backend. Can someone confirm that? Not having much luck using -fa for the vulkan backend. I do not expect said support to materialize either, just want to clarify. |
|
It does need to be implemented per backend. |
|
Why metal test |
@LukeLIN-web because you're compiling with LLAMA_CUBLAS (which is deprecated by the way, use LLAMA_CUDA). You can't use CUDA on a MacBook |
|
Any updates on context shift compability? |
Implemented ggml_sycl_op_soft_max() F16 src1(mask) support for which a pragma deprecation warning was added during #5021. To do this, had to decouple it from ggml_sycl_op_flatten which always considered src1 to be of fp32 type(many OP functions are dependent on it). * SYCL: SOFTMAX F16 mask support and other fixes * test-backend-ops: Add F16 mask test cases
Implemented ggml_sycl_op_soft_max() F16 src1(mask) support for which a pragma deprecation warning was added during ggml-org#5021. To do this, had to decouple it from ggml_sycl_op_flatten which always considered src1 to be of fp32 type(many OP functions are dependent on it). * SYCL: SOFTMAX F16 mask support and other fixes * test-backend-ops: Add F16 mask test cases
ref #3365
Setting up what's needed for Flash Attention support in
ggmlandllama.cppThe proposed operator performs:
Suggestions and comments for the API are welcome.
Looking for help in implementing efficient GPU kernels - please open PR to this branch if you have proposals
ggmlAPI:ggml_flash_attn_ext()llama.cppuse inllm_build_kqv()test-backend-opstestGGML_PREC_F32support (CUDA) (CUDA: faster FlashAttention for batch sizes > 1 #6646)GGML_PREC_F32support (Metal)Changes to
ggml/llamaGGML_OP_FLASH_ATTN_EXTandggml_flash_attn_ext()call(before merging we can consider reusing the old
GGML_OP_FLASH_ATTNand removing the legacy code)masktype to F16 forggml_soft_max_ext()and require that it is padded toGGML_KQ_MASK_PAD 32n_kvdenoting the number of computed tokens from the KV cache is now padded to 128 (from 32) to support larger FA blocks without making out-of-bounds accessllama_context_params.n_batchthat can be used isGGML_KQ_MASK_PAD 32to avoid out-of-bounds access in the FA kernels for small batch sizeVtensor is no longer transposed when storing it in the KV cacheThings to consider
ggml_add()? (low-prio)Testing
main,server: add-fallama-bench: add-fa 1Benchmark
Baseline:
FA kernel:
Text-generation after long prompt:
References