-
Notifications
You must be signed in to change notification settings - Fork 581
Description
Benchmarks from vLLM show that the spec-optimized decode kernel scales much better than the prefill kernel for requests with very short query_lens.
# decode fp16
bs kv_len q_len mean std
1 1024 1 0.032 0.001
4 1024 1 0.032 0.001
8 1024 1 0.033 0.001
16 1024 1 0.034 0.001
32 1024 1 0.045 0.004
1 1024 4 0.032 0.001
4 1024 4 0.034 0.001
8 1024 4 0.039 0.002
16 1024 4 0.048 0.005
32 1024 4 0.064 0.005
1 1024 8 0.033 0.001
4 1024 8 0.039 0.002
8 1024 8 0.047 0.004
16 1024 8 0.063 0.005
32 1024 8 0.095 0.004
prefill fp16
1 1024 1 0.041 0.003
2 1024 1 0.043 0.004
4 1024 1 0.059 0.004
8 1024 1 0.089 0.004
16 1024 1 0.139 0.004
32 1024 1 0.248 0.004
1 1024 4 0.041 0.003
2 1024 4 0.042 0.004
4 1024 4 0.058 0.004
8 1024 4 0.089 0.004
16 1024 4 0.140 0.004
32 1024 4 0.250 0.003
1 1024 8 0.041 0.003
2 1024 8 0.042 0.004
4 1024 8 0.059 0.004
8 1024 8 0.089 0.004
16 1024 8 0.141 0.003
32 1024 8 0.252 0.003
From this benchmark, we can see that the prefill appears to have the same performance for q_len == 1 and q_len == 8. This is not true of the decode kernel, which has worse performance at q_len == 8 than at q_len == 1 for the same batch_size, but it is still consistently faster than the prefill kernel.
For speculative decoding, we do not always have a batch of decodes where all requests have a uniform q_len. Often many requests have spec disabled (q_len == 1) and a smaller subset have a longer q_len (typically between 3 and 8). This leads to poor scaling as the prefill kernel seems to internally apply padding to each request, and in order to use the decode kernel we must do this padding ourselves.
An ideal decode kernel for spec decoding is one which can handle variable query lengths up to a small bound with the same performance scaling as the decode kernel for distinct batches of constant q_len.
In practice, q_len < 16 is essentially 100% true for (chain-drafted) speculative decoding, and q_len <= 8 is true of most speculative decoding (almost always true for draft-model or EAGLE spec, often but not always true for n-gram spec decoding).
Even if the decode kernel could be modified to handle padding in a similar way to the prefill kernel but retain its performance on the padded batch, this would be a win downstream as we currently will have to manually pad and unpad the batch with custom kernels that incur non-trivial overhead.