-
-
Notifications
You must be signed in to change notification settings - Fork 11.2k
[Qwen3Next] Fixes the cuda graph capture conditions under large batch sizes (#24660) #24667
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Qwen3Next] Fixes the cuda graph capture conditions under large batch sizes (#24660) #24667
Conversation
… sizes. Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
127ad66 to
fb73abd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request attempts to fix a CUDA graph capture condition for large batch sizes in the GDNAttentionBackend. However, the added condition m.num_actual_tokens <= self.decode_cudagraph_max_bs introduces a unit mismatch by comparing a token count with a sequence limit. This makes the check overly restrictive and prevents CUDA graph usage in many valid scenarios. My review provides a detailed explanation of the issue and suggests a more accurate approach to fix the underlying problem.
| if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0 | ||
| and num_spec_decodes <= self.decode_cudagraph_max_bs): | ||
| and num_spec_decodes <= self.decode_cudagraph_max_bs | ||
| and m.num_actual_tokens <= self.decode_cudagraph_max_bs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This condition m.num_actual_tokens <= self.decode_cudagraph_max_bs appears to have a unit mismatch. m.num_actual_tokens is the number of tokens, while self.decode_cudagraph_max_bs is used as a limit on the number of sequences for sizing tensors like spec_state_indices_tensor and spec_sequence_masks.
Comparing tokens to sequences is likely incorrect and makes this check overly restrictive. For instance, with num_spec=7 and decode_cudagraph_max_bs=32, this change limits num_spec_decodes to 4 (since 4 * 8 <= 32), whereas the original code allowed up to 32 sequences.
The underlying issue is that batch_size can exceed self.decode_cudagraph_max_bs due to token padding. The batch_size is calculated as self.vllm_config.pad_for_cudagraph(m.num_actual_tokens) // (self.num_spec + 1).
A more accurate check would be to compute this batch_size and compare it against self.decode_cudagraph_max_bs, while also ensuring m.num_actual_tokens does not exceed self.compilation_config.max_capture_size to prevent errors from pad_for_cudagraph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think keep m.num_actual_tokens <= self.decode_cudagraph_max_bs is good. Then and num_spec_decodes <= self.decode_cudagraph_max_bs seems unnecessary.
|
around 20 lines below there is one more similar place |
LucasWilkinson
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense to me assuming @vadiklyutiy confirms the fix
longer term hopefully: #23789 / #24002 will resolve this more broadly but this makes sense as a temporary fix 👍
Are you referring to: I think thats safe due to |
agree |
… sizes (vllm-project#24660) (vllm-project#24667) Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
…-project#24660) (vllm-project#24667) Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
… sizes (vllm-project#24660) (vllm-project#24667) Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
… sizes (vllm-project#24660) (vllm-project#24667) Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: bbartels <benjamin@bartels.dev>
…ge batch sizes (vllm-project#24660) (vllm-project#24667)" This reverts commit 89da8d9.
…rge batch sizes (vllm-project#24660) (vllm-project#24667)" This reverts commit 02da9a5.
…under large batch sizes (vllm-project#24660) (vllm-project#24667)"" This reverts commit a1124c4.
… under large batch sizes (vllm-project#24660) (vllm-project#24667)"" This reverts commit 3a72536.
…-project#24660) (vllm-project#24667) Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
… sizes (vllm-project#24660) (vllm-project#24667) Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
… sizes (vllm-project#24660) (vllm-project#24667) Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
… sizes (vllm-project#24660) (vllm-project#24667) Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
No description provided.