Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion vllm/v1/attention/backends/gdn_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ def build( # type: ignore[override]

# prepare tensors for cudagraph
if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0
and num_spec_decodes <= self.decode_cudagraph_max_bs):
and num_spec_decodes <= self.decode_cudagraph_max_bs
and m.num_actual_tokens <= self.decode_cudagraph_max_bs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This condition m.num_actual_tokens <= self.decode_cudagraph_max_bs appears to have a unit mismatch. m.num_actual_tokens is the number of tokens, while self.decode_cudagraph_max_bs is used as a limit on the number of sequences for sizing tensors like spec_state_indices_tensor and spec_sequence_masks.

Comparing tokens to sequences is likely incorrect and makes this check overly restrictive. For instance, with num_spec=7 and decode_cudagraph_max_bs=32, this change limits num_spec_decodes to 4 (since 4 * 8 <= 32), whereas the original code allowed up to 32 sequences.

The underlying issue is that batch_size can exceed self.decode_cudagraph_max_bs due to token padding. The batch_size is calculated as self.vllm_config.pad_for_cudagraph(m.num_actual_tokens) // (self.num_spec + 1).

A more accurate check would be to compute this batch_size and compare it against self.decode_cudagraph_max_bs, while also ensuring m.num_actual_tokens does not exceed self.compilation_config.max_capture_size to prevent errors from pad_for_cudagraph.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think keep m.num_actual_tokens <= self.decode_cudagraph_max_bs is good. Then and num_spec_decodes <= self.decode_cudagraph_max_bs seems unnecessary.

num_total_tokens = self.vllm_config.pad_for_cudagraph(
m.num_actual_tokens)
batch_size = num_total_tokens // (self.num_spec + 1)
Expand Down