Skip to content

Conversation

@fhl2000
Copy link
Contributor

@fhl2000 fhl2000 commented Aug 26, 2025

Purpose

This PR has done several things to fix/refactor speculative decoding.

  1. Fix the broken cudagraph of the eagle drafter. (now temporarily fixed by [Spec-Decode] Support piecewise cudagraphs for Eagle head #25109)

    The cudagraph ability of the EAGLE(3) model in V1 has been broken for a while since the launch of [Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer #20059, because for now vllm torch.compile would never implicitly run a cudagraph until we explicitly dispatch it via the cudagraph dispatcher. This PR fixes that by appending a new dispatcher for drafter that is separate from the main model. The reason we don't share a single dispatcher is that for the spec decode scenario, the uniform decode query len for the main model is 1+ num_spec_tokens while for a drafer it is just 1 (generally speaking).

  2. Make "cudagraph_capture_size" of uniform decode divisible by uniform_query_len = 1+ num_spec_decode when enabled spec_decode. This requires a new non-configurable property uniform_cudagraph_capture_sizes to be added to Compilation_config. This is post-computed from and also separated from cudagraph_capture_sizes, which is configurable.

    The reason for doing this is that some attention backends support AttentionCGSupport.UNIFORM_BATCH requires token uniformity strictly, i.e., we couldn't pad to a batch size that is not divisible by uniform_query_len, unlike varlen support backends like FlashAttention(v2/v3) and Triton Attention. Also reference to the "Padded speculation" for drafer proposed in [Performance]: Padded Speculative Decoding #21984 and the example in [Spec Decode] Efficient padded speculation #24539

  3. The cudagraph capturing of the drafter is also separated from the capturing of the main model, since the main model uses "uniform_cudagraph_capture_sizes" for the validation phase (uniform_decode_query_len>1), while the eagle drafter uses chunked "cudagraph_capture_sizes" as it is generally uniform_decode_query_len=1 (after the first draft token).

  4. The cudagraph_dispatcher and batch_descriptor are further refactored to support multiple uniform_decode_query_len together. This is to pave the way for full cudagraph support of "Padded speculation" in [Performance]: Padded Speculative Decoding #21984 and [Spec Decode] Efficient padded speculation #24539, where for a pure spec-decode loop (no prefill tokens), the uniform_decode_query_len can be 1+num_spec_token for the first draft token and be only 1 for the consecutive draft tokens.

    The prototype of batch_descriptor is now:

class BatchDescriptor(NamedTuple):
    num_tokens: int
    uniform_decode: bool = False
    uniform_query_len: int = 0

For non-uniform decode, uniform_query_len is forced to 0 to ensure the uniqueness of a cudagraph batch.

  1. The decision of cudagraphs token padding and cudagraph dispatching could be simplified into a single call of the dispatcher at runtime (typically for drafter model where dbo and dp_padding are ignored), i.e., we can use one-shot planning instead of scatteredly doing that:
cudagraph_runtime_mode, batch_descriptor, num_input_tokens = \
            self.cudagraph_dispatcher.fast_plan(
                num_scheduled_tokens=num_scheduled_tokens,
                num_reqs=self.input_batch.num_reqs,
                max_query_len=max_query_len)

For more details, please see the codes.

Test Plan

To be planned! (call for suggestions, as we can not do much for a full cudagraph of spec-decode currently and most codes are prepared for full cudagraph feature of spec-decode in the near future)

Can confirm that the piecewise cudagraph ability of drafter is fixed!


*Note for its compatibility with DP:

While enabling DP, we have to pad the token sizes to the same size across DP for cudagraph (mainly for the collective communication of MOEs). That makes it hard to hit a cudagraph size to fit across all DP ranks when we both have uniform-decode batches in some DP ranks and non-uniform batches in others, since we have a different list of cudagraph sizes.

To address this:

we only run full-CGs (i.e. uniform) if all ranks are running full-CG; I dont think theres much value in a rank running full-CG if another isn't (since they'll wait for the straggler in full-CG ranks

(Thanks for the contexts from @SageMoore and ideas from @LucasWilkinson, respectively.)

So a uniform-decode batch is potentially padded to the size of a non-uniform batch and falls back into piecewise cudagraph.
The attention part is still on eager execution, so no problem for supporting "Padded speculation"


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: fhl <2410591650@qq.com>
Signed-off-by: fhl <2410591650@qq.com>
Signed-off-by: fhl <2410591650@qq.com>
Signed-off-by: fhl <2410591650@qq.com>
Signed-off-by: fhl <2410591650@qq.com>
Signed-off-by: fhl <2410591650@qq.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request provides a comprehensive fix and refactoring for speculative decoding with CUDA graphs. The changes are well-structured and address several key issues, including the broken CUDA graph for the EAGLE drafter, by introducing a separate dispatcher and a unified planning mechanism. The refactoring of the CudagraphDispatcher to support multiple uniform query lengths and the consolidation of padding and dispatching logic into a single plan method significantly improve code clarity and maintainability. The removal of build_for_cudagraph_capture from attention backends is a good simplification. Overall, the code quality is high, and the changes appear to be correct and robust. I have reviewed the changes in detail and could not find any issues of high or critical severity.

@fhl2000
Copy link
Contributor Author

fhl2000 commented Aug 26, 2025

@ProExpertProg
Copy link
Collaborator

Thanks for doing this, will take a look soon.

Remove build_for_cudagraph_capture from attention_metadata_builder.

This might have issues, currently triton forces seq_lens to 1 (we could ask the builder what seq_lens it wants, or make it optional, not sure) and AITER FA does something with total tokens - although a) that seems unused and b) it was reported broken.

@fhl2000
Copy link
Contributor Author

fhl2000 commented Aug 26, 2025

currently triton forces seq_lens to 1 (we could ask the builder what seq_lens it wants, or make it optional, not sure) and AITER FA does something with total tokens - although a) that seems unused and b) it was #22683 (comment).

I have confirmed that forcing seq_lens to 1 is not necessary for speed up Triton attention capturing. Currently, with the frozen CG feature, the time is almost the same as FA2, and has no influence on the generated contents.

Yeah, the total_tokens for AITER FA is unused currently; they use num_actual_kv_tokens instead. I think fixing that can be done by properly modifying common_metadata, not relying on the build_for_cudagraph_capture

Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
@mergify
Copy link

mergify bot commented Sep 23, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fhl2000.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 23, 2025
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
@nvpohanh
Copy link
Contributor

nvpohanh commented Oct 1, 2025

cc @elvischenv for vis

@mergify
Copy link

mergify bot commented Oct 3, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fhl2000.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 3, 2025
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
@mergify mergify bot removed the needs-rebase label Oct 6, 2025
@mergify
Copy link

mergify bot commented Oct 8, 2025

Documentation preview: https://vllm--23679.org.readthedocs.build/en/23679/

@mergify
Copy link

mergify bot commented Oct 8, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fhl2000.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 8, 2025
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
@mergify mergify bot removed the needs-rebase label Oct 14, 2025
@mergify
Copy link

mergify bot commented Oct 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fhl2000.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 14, 2025
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
@mergify mergify bot removed the needs-rebase label Oct 14, 2025
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
@mergify
Copy link

mergify bot commented Oct 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fhl2000.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 15, 2025
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
@mergify mergify bot removed the needs-rebase label Oct 15, 2025
@fhl2000
Copy link
Contributor Author

fhl2000 commented Oct 15, 2025

Already prepared for the next round of viewing! The main updates are for resolving the potential cudagraph missing issue when DP size >1.

cc @ProExpertProg @LucasWilkinson @SageMoore.

Edit: also add a new flag to disable the uniform_alignment feature introduced in this PR.

@fhl2000 fhl2000 changed the title [Spec-decode] fix and refoctor cudagraphs for spec-decode [Spec-decode] Refoctor cudagraphs for spec-decode;support uniform_alignment of cudagraph sizes. Oct 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants