-
-
Notifications
You must be signed in to change notification settings - Fork 4.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Performance] Use optimized kernels for MQA/GQA #1880
Comments
This is a feature that is high on my list for performance reasons. I have scavenged other people's benchmarks and found an interesting one from TensorRT-LLM that also uses PagedAttention: Llama 2 7B with 1x A100 80GB:
Latency benchmark:
Conclusion: Due to MHA taking up KV cache, latency increases as you must use more memory bandwidth. This is especially seen in scenarios where we aim for high throughput, i.e. when handling many requests at once. This means we can see a significant improvement in throughput if we spend less memory on the KV cache because of the grouping in GQA. Reference: NVIDIA/TensorRT-LLM#404 |
Seems it is prosible to migrate https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h to fullfill vLLM's requirement. |
sorry. my bad . |
@whitelok Oh, not really. While I'm not sure at the moment, it seems there is some code that we can leverage. I will look into it soon. Thanks again for sharing! |
seems entrance point is here: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.5.0/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h#L1015 all vllm need is just modify kvcache buffer. |
FYI, I have modify the FlashAttention kernel to support paged KV cache with a restriction that block_size must match kBlockN in the kernel. |
It seems they made some changes to GQA/MQA in the new version 0.6.1. Diff: https://www.diffchecker.com/lixZr3Aj/ |
Great work! As it is very important for performance, could you plan to submit the feature? I am glad to test it. |
@zhaoyang-star See the following: |
I've tidied up the code a bit, you could test using the following two branches: |
Thanks for the info. I will share the latency comparison after benchmark. |
@beginlner I found you write a new function |
I think it will be a little complicated. |
These are great results. I hope that decoding can get a speed up as well as this is likely to yield a substantial improvement. @zhaoyang-star make sure to fall back to xformers attention since it supports older GPUs as well. Flash attention supports Ampere and newer. @beginlner Do you know what it takes for Tri Dao to accept |
@beginlner It seems the blocked flash attn unittest failed. |
@zhaoyang-star Thanks for the reminder, the error tolerance should be relaxed. |
@beginlner The greastest releative difference is 9.48. Is this a little larger than expected? FYI, the original |
Base on my experience, the greatest absolute difference of 0.008 and the greatest relative difference of 9.48 are acceptable. I have updated a more reliable test to the branch. Additionally, how the original test_flash_attn.py failed? |
@beginlner Sorry I didnot store the log. I suggested you take a test. |
It failed on some tests only by OOM on A100 40G because someone else is also using the GPU. |
Good news. |
The |
@beginlner The unittest has passed. From the kernel benchmark and e2e benchmark we can see there is no speedup compared with paged attention
|
I would not expect a speedup for implementing GQA on prefilling, only on decoding (which is much harder because of PagedAttention). |
Hi, here is my kernel benchmark result on a SXM A100 40GB. I have updated the code on https://github.com/beginlner/vllm/tree/blocked_flash_attn. Note that the shape and the block size of the KV cache are different from vllm's paged attention. |
Yes, GQA is computation-bounded on prefilling, and is memory-bounded on decoding. So there is a speedup only on decoding. |
@beginlner Yes. The shape and block_size are different from paged attention. I used https://github.com/zhaoyang-star/vllm/tree/blocked_flash_attn_based_on_beginIner, which add some options based on your branch blocked_flash_attn. Results using Env:
Starcoder (MQA) e2e latency with In/Out length=512:
|
@zhaoyang-star It's as expected that the speedup can hardly seen in e2e latency benchmark when batch size is small. Because when the batch size is small, loading parameters is the bottleneck; when the batch size is large, loading the KV cache is the bottleneck to performance. |
@zhaoyang-star Blocked KV cache was added to flash attention in 2.5.0. I wonder if the newer implementation gives any performance boost? Either way, it’s now in flash attention which makes it easy to use in vLLM |
Great! @beginlner is the core contributor of this feature. I have not benchmark it under FA 2.5.0. Bute I think the results will be close to data we had before. The main question is FA only supports We are looking for any contributions to deliver this feature :) |
|
This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you! |
In theory, MQA/GQA can reduce memory bandwidth for reading KV cache and enable using TensorCore for the dot products in attention mechanism. However, this benefit can be only realized when using optimized kernels that vLLM does not have at the moment.
vllm/vllm/model_executor/layers/attention.py
Lines 121 to 128 in e5452dd
The text was updated successfully, but these errors were encountered: