-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Support for cached multi-query attention towards speculative decoding #1679
Conversation
I've made a pull request to flash-attention that enables support for blocked KV cache in flash-decoding which supports MQA. The performance is nearly identical to the original. You might want to check it out. |
@beginlner thanks for the info. Reading https://github.com/microsoft/DeepSpeed-Kernels/blob/main/dskernels/inf_flash_attn/blocked_flash/flash_fwd_kernel.h as well. |
So far, is there any progress on enabling speculative decoding for vLLM? Additionally, I'm wondering if the implementation of this kernel might result in increased GPU memory usage. |
When can this branch be merged? In the version I am currently using, there is:
Is the Flash operation supported only for HIP? |
Initial prototype of cached multi-query attention that takes advantage of implementation details of the single query cached attention kernel to adapt it to the multi-query setting.
Given n sequences with maximum draft length of k to be verified, greedily caches all keys and values, then calls paged_attention on n * k query vectors, "symbolically linking" the KV caches of drafts of the same sequence to the original, masking out "future" tokens by interpolating the sequence_len passed to paged attention kernel from context_len to context_len + draft_len.
While this kernel has support for dynamic draft lengths, this is facilitated somewhat inefficiently by masking rather than by dynamic shape. Potential room for improvement.
Performance has yet to be profiled. The intention behind this PR is to serve as a reference implementation against which a more performant MQA kernel can be developed.