-
-
Couldn't load subscription status.
- Fork 10.8k
[Perf] Enable full CUDA graphs for spec decoding with FlashInfer #26937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
|
Documentation preview: https://vllm--26937.org.readthedocs.build/en/26937/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request enables full CUDA graphs for speculative decoding with FlashInfer when TRT-LLM attention kernels are available, which is a valuable performance enhancement. The implementation correctly updates the cudagraph_support attribute in FlashInferMetadataBuilder at runtime based on whether TRT-LLM attention can be used. The change from a class variable to an instance variable for cudagraph_support is appropriate for this dynamic behavior. The documentation has also been updated to reflect these changes. The logic appears sound and the provided test results indicate that correctness is maintained while enabling this optimization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
|
Regarding performance improvement. |
|
cc @LucasWilkinson @ProExpertProg regarding updating AttentionCGSupport dynamically |
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Dynamically updating it should be fine since we only call it here on instances here vllm/vllm/v1/worker/gpu_model_runner.py Line 3968 in a5464dc
|
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
…ables Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM; there are some nits that should be addressed (specifically for the CPU backend I think we should still keep the reorder_batch_threshold = 1)
it is a bit harder to see where cudagraph_support is set now :/ I guess the alternative would be use a function; i.e. add a get_cudagraph_support() function in the base class (I think the current implementation is better but im also flip-flopping haha)
|
|
||
| class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): | ||
| reorder_batch_threshold: int = 1 | ||
| reorder_batch_threshold: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I still think this needs to be set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is set in the constructor: _init_reorder_batch_threshold(1, False)
The type annotation is left to indicate that it will never be "None" on this class and its subclasses. This is a common pattern in the changes in this PR
| AttentionMetadataBuilder[XFormersAttentionMetadata] | ||
| ): | ||
| reorder_batch_threshold: int = 1 | ||
| reorder_batch_threshold: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: is this still needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, as type annotation, see previous comment
| ) | ||
|
|
||
| reorder_batch_threshold: int = 1 | ||
| reorder_batch_threshold: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: is this still needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, as type annotation, see previous comment
Purpose
TRTLLM-gen kernels support full cuda graphs, but are only used with FlashInfer on Blackwell under certain conditions.
It might not be safe to change FlashInfer's
cudagraph_supporttoUNIFORM_BATCHalways, but we can still set it when we know TRTLLM-gen backend will be used.Also update the docs to reflect the FlashInfer and FlashInferMLA cuda graph compatibility
FIX #26856
Test Plan
Ran Llama 3.1 8B-Instruct with EAGLE3 and confirmed that lm_eval-gsm8k is unchanged compared to main, and when TRTLLM attention is force disabled. Confirmed via torch profile that full graphs are now issued for verification when TRTLLM attention is enabled
Test Result
TRTLLM on:
TRTLLM off:
Benchmarks
MT-Bench at concurrency 1 sees a minimal speedup (~2%)
Before:
After: