Skip to content

TRTLLM-gen decode kernels do not support variable length queries #1832

@benchislett

Description

@benchislett

Benchmarks from vLLM show that the spec-optimized decode kernel scales much better than the prefill kernel for requests with very short query_lens.

# decode fp16
bs      kv_len  q_len   mean    std
1       1024    1       0.032   0.001
4       1024    1       0.032   0.001
8       1024    1       0.033   0.001
16      1024    1       0.034   0.001
32      1024    1       0.045   0.004
1       1024    4       0.032   0.001
4       1024    4       0.034   0.001
8       1024    4       0.039   0.002
16      1024    4       0.048   0.005
32      1024    4       0.064   0.005
1       1024    8       0.033   0.001
4       1024    8       0.039   0.002
8       1024    8       0.047   0.004
16      1024    8       0.063   0.005
32      1024    8       0.095   0.004

prefill fp16
1       1024    1          0.041           0.003
2       1024    1          0.043           0.004
4       1024    1          0.059           0.004
8       1024    1          0.089           0.004
16      1024    1          0.139           0.004
32      1024    1          0.248           0.004
1       1024    4          0.041           0.003
2       1024    4          0.042           0.004
4       1024    4          0.058           0.004
8       1024    4          0.089           0.004
16      1024    4          0.140           0.004
32      1024    4          0.250           0.003
1       1024    8          0.041           0.003
2       1024    8          0.042           0.004
4       1024    8          0.059           0.004
8       1024    8          0.089           0.004
16      1024    8          0.141           0.003
32      1024    8          0.252           0.003

From this benchmark, we can see that the prefill appears to have the same performance for q_len == 1 and q_len == 8. This is not true of the decode kernel, which has worse performance at q_len == 8 than at q_len == 1 for the same batch_size, but it is still consistently faster than the prefill kernel.

For speculative decoding, we do not always have a batch of decodes where all requests have a uniform q_len. Often many requests have spec disabled (q_len == 1) and a smaller subset have a longer q_len (typically between 3 and 8). This leads to poor scaling as the prefill kernel seems to internally apply padding to each request, and in order to use the decode kernel we must do this padding ourselves.

An ideal decode kernel for spec decoding is one which can handle variable query lengths up to a small bound with the same performance scaling as the decode kernel for distinct batches of constant q_len.

In practice, q_len < 16 is essentially 100% true for (chain-drafted) speculative decoding, and q_len <= 8 is true of most speculative decoding (almost always true for draft-model or EAGLE spec, often but not always true for n-gram spec decoding).

Even if the decode kernel could be modified to handle padding in a similar way to the prefill kernel but retain its performance on the padded batch, this would be a win downstream as we currently will have to manually pad and unpad the batch with custom kernels that incur non-trivial overhead.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions