Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support for cached multi-query attention towards speculative decoding #1679

Closed
wants to merge 14 commits into from

Conversation

skrider
Copy link
Contributor

@skrider skrider commented Nov 16, 2023

Initial prototype of cached multi-query attention that takes advantage of implementation details of the single query cached attention kernel to adapt it to the multi-query setting.

Given n sequences with maximum draft length of k to be verified, greedily caches all keys and values, then calls paged_attention on n * k query vectors, "symbolically linking" the KV caches of drafts of the same sequence to the original, masking out "future" tokens by interpolating the sequence_len passed to paged attention kernel from context_len to context_len + draft_len.

While this kernel has support for dynamic draft lengths, this is facilitated somewhat inefficiently by masking rather than by dynamic shape. Potential room for improvement.

Performance has yet to be profiled. The intention behind this PR is to serve as a reference implementation against which a more performant MQA kernel can be developed.

@beginlner
Copy link
Contributor

beginlner commented Nov 23, 2023

I've made a pull request to flash-attention that enables support for blocked KV cache in flash-decoding which supports MQA. The performance is nearly identical to the original. You might want to check it out.
Dao-AILab/flash-attention#678

@skrider
Copy link
Contributor Author

skrider commented Nov 27, 2023

@Lvjinhong
Copy link

@beginlner thanks for the info. Reading https://github.com/microsoft/DeepSpeed-Kernels/blob/main/dskernels/inf_flash_attn/blocked_flash/flash_fwd_kernel.h as well.

So far, is there any progress on enabling speculative decoding for vLLM? Additionally, I'm wondering if the implementation of this kernel might result in increased GPU memory usage.

@Lvjinhong
Copy link

When can this branch be merged? In the version I am currently using, there is:

op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
                (is_hip()) else None,

Is the Flash operation supported only for HIP?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants