-
Notifications
You must be signed in to change notification settings - Fork 11.9k
Optimized DeepSeek V2/V3 implementation (MLA + flash attention) #12227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@slaren I found some duplicated code in llama-arch.cpp whilst doing this that probably should be removed in the master. |
Using
I'm unsure if the other backend's FA code will work or not (due to the unusual/large 576 head dimension). I'm unsure if you can change |
I'm running on a much larger input now of 56463 tokens and it seems to be using way more compute than it did before after only 8k tokens in, so will see how it gets on and report back... |
Yeah, the CUDA backend is using an absolutely insane amount of compute for the FA version compared to the non-FA version, so ether it's not optimised for "Multi-Query Attention" (MQA) (ie: 1 K/V for 128 Q) or it doesn't like the 576 head dimension. Hopefully @JohannesGaessler will be able to help with this (I've no idea if the non-CUDA backends will suffer from the same problem with FA and would be interested if somebody could try it!). |
The biggest limitation for CUDA kernels is the register pressure. You have 64k registers per streaming multiprocessor that are shared between all concurrently running threads and each thread can use a maximum of 256 registers. If you write a kernel that uses more registers you get the GPU equivalent of thrashing; registers need to be swapped with VRAM and the performance is terrible. For this reason the CUDA compiler is by default very conservative with how many registers it uses. But this also leaves a lot of potential performance on the table. So a lot of CUDA performance optimization comes down to writing a kernel that is just below the maximum number of registers that are available. For the CUDA FlashAttention kernels register pressure and head size are tightly linked. The larger the head is, the more registers are needed. I wrote and tuned the code with head sizes up to 256 in mind. If you run it for larger head sizes as-is it is expected that the performance will be essentially unusable. That is also the reason why KV types other than FP16 are unsupported for head sizes > 128; the register pressure made the kernel effectively unusable (and swapping also massively increases compilation time). What you can try doing is reducing the number of Q columns for a head size of 576. The more Q columns a CUDA block works on in parallel, the more registers are needed. The downside of doing this is a lower arithmetic intensity (how much work the GPU can do per loaded data value). |
Thanks - I've tried adding in a special case in if (Q->ne[0] == 576) {
constexpr int cols_per_block = 1;
constexpr int parallel_blocks = 1;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
} and will leave it running to see how it compares for the same 56463 token prompt (it still looks like it's using a lot more compute than the non-FA version though). |
Had to kill it as it got about 1/2 way through and was clear it was going to be way slower. Was worth a try anyway :) It will be interesting to see what it does with the CPU backend. I also have another idea of how I might be able to squeeze the dimensions down to 256 or lower using SVD (this time doing everything on the There is some chance that the If this shows promise, then it might be possible to squeeze the dimension(s) even more to get similar storage/accuracy trade-off as we get for quantising the K/V caches currently. |
@jukofyork Just skimming some of these threads (haven't found time to dig in details), but are you confident that the head size of R1 is really 576? This seems very suspicious. It should be something smaller, like 128 / 192. I think last time I checked, the only issue was that the K and V heads had different sizes, which we currently don't support. But I think the sizes were normal. |
Yes, just double-checked again:
Not sure what you mean by head size of 576. |
@ggerganov I think the idea here is that @jukofyork passes (cached) concatenated [kv_compressed | k_pe ] as "K" and "V" vectors to flash attention kernel. kv_compressed has dim kv_lora_rank (512), k_pe has dim n_embd_head_qk_rope (64), hence 576. I seriously felt like a retard for a while when trying to understand how it works (migraine didn't help), but it's equivalent to something like this (translated to variables from my MLA implementation):
So kv_compressed essentially gets multiplied by q_nope2, k_pe by q_pe and both are summed while calculating dot product and we get kq_nope + kq_pe. Then after softmax we multiply [kv_compressed | k_pe] by the result. Because of extra concatenated k_pe in "V" vectors we wasted some compute, but avoided the need for transposed kv_compressed. The idea is brilliant, but I'm not sure about the performance on a CPU. I'm running the benchmark right now, will post some numbers when it finishes. |
I see - I misunderstood the computation here. Ignore my comments. |
Yeah, @fairydreaming already had done 99% of the work and this just came about when I saw a way to remove the permutations and the "Copy Elision" stuff: If we decompress the 512 element cached vectors If we don't decompress it and concatenated (ie: 512 for the latent At first sight this seems like a daft thing to do because we still have to take the same number of dot products for MHA as for MQA, but now we are taking dot-products of 576 elements instead of 192 elements... But what this doesn't consider, and what @fairydreaming's code was already taking into account (I just rearranged it a bit); is the decompression from the 512 element cached |
@jukofyork CPU flash attention implementation doesn't seem to perform well here. But non-fa code performs a bit better than my implementation, that's great! Tested on Epyc 9374F, 384 GB RAM, 32 threads with NPS1 (single NUMA node) settings. Note: each point shows either mean token generation rate of 128 tokens at given context length or mean prompt processing rate of 512 tokens at given context length. |
It will be interesting to see as the CUDA performance was disappointing to say the least :/ Even if the FA stuff doesn't work you might get a boost from not having to make all the batched-matrix-multiplies and just churning through 1 huge 2D x 2D multiply each token. I'm 100% sure I can see a way to compress the latent dimensions down from 512 using SVD, so there is still some hope left that it might improve things enough to work for CUDA, but it will be tomorrow before I can try this (I tried once already but the |
Thanks! It looks like the 576 element flash-attention stuff is working about as well as CUDA 😬 I'm glad the other optimization helped a bit though - for CUDA it makes a very small difference (something like 0.15 tokens/s on top of 4 tokens/s generation), but it does seem to fix the weird numerical problems I was getting and also the problem with the compute buffer only getting allocated 2/3rds of what it needs (neither of which I could figure out the exact cause). I'll try and implement the SVD stuff tomorrow as that has the potential to make a really big difference and may even let the FA stuff start to work properly again. |
@jukofyork have you seen this? ikawrakow/ik_llama.cpp#241 |
Closing this as superseded by #12313 |
I'll leave the flash attention stuff to others as I had hoped it might "just work" but it obviously didn't and to do this properly would require lots of other changes. |
This is a continuation of @fairydreaming's draft-PR #11446.
The main changes are:
KV
.ggml_permute
where possible to make the calculation ofKV
a single 2D x 2D matrix multiply instead of unnecessarily using batch matrix multiply.-fa
to use flash-attention (TESTED ON: CUDA WITH KV-TYPE=F16, and CPU BACKENDS ONLY).Using flash-attention has no real performance gain/loss[see below], but does use a lot less memory:Without
-fa
optionWith
-fa
optionThe fact we can reuse the same "compressed KV-cache" for both K and V, halves the storage for this:
KV self size = 2196.00 MiB, K (f16): 2196.00 MiB, V (f16): 0.00 MiB
and the compute buffers are much less too:
NOTES
576/576 --> 512
...F16
for large head dimensions.It's a bit messy and I will look at trying to tidy it up and possibly also move the logic into the proper KV-cache handling code later (this will require either pre-multiplying/absorbing
wv_b
xwo
and ~3.7x more compute for the multiply, or passingwv_b
to these functions as an extra option likew_b
bias option is passed now, etc).