Skip to content

Conversation

@benchislett
Copy link
Collaborator

@benchislett benchislett commented Oct 15, 2025

Purpose

TRTLLM-gen kernels support full cuda graphs, but are only used with FlashInfer on Blackwell under certain conditions.
It might not be safe to change FlashInfer's cudagraph_support to UNIFORM_BATCH always, but we can still set it when we know TRTLLM-gen backend will be used.

Also update the docs to reflect the FlashInfer and FlashInferMLA cuda graph compatibility

FIX #26856

Test Plan

Ran Llama 3.1 8B-Instruct with EAGLE3 and confirmed that lm_eval-gsm8k is unchanged compared to main, and when TRTLLM attention is force disabled. Confirmed via torch profile that full graphs are now issued for verification when TRTLLM attention is enabled

Test Result

TRTLLM on:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7726|±  |0.0115|
|     |       |strict-match    |     5|exact_match|↑  |0.7013|±  |0.0126|

TRTLLM off:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7726|±  |0.0115|
|     |       |strict-match    |     5|exact_match|↑  |0.7013|±  |0.0126|

Benchmarks

MT-Bench at concurrency 1 sees a minimal speedup (~2%)

vllm serve meta-llama/Llama-3.1-8B-Instruct --speculative-config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 3}' &

vllm bench serve --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --max-concurrency 1 --model meta-llama/Llama-3.1-8B-Instruct --base-url http://0.0.0.0:8049

Before:

============ Serving Benchmark Result ============
Successful requests:                     80        
Maximum request concurrency:             1         
Benchmark duration (s):                  42.58     
Total input tokens:                      8133      
Total generated tokens:                  16955     
Request throughput (req/s):              1.88      
Output token throughput (tok/s):         398.23    
Peak output token throughput (tok/s):    186.00    
Peak concurrent requests:                4.00      
Total Token throughput (tok/s):          589.25    
---------------Time to First Token----------------
Mean TTFT (ms):                          12.11     
Median TTFT (ms):                        11.88     
P99 TTFT (ms):                           14.29     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          2.45      
Median TPOT (ms):                        2.45      
P99 TPOT (ms):                           3.33      
---------------Inter-token Latency----------------
Mean ITL (ms):                           5.36      
Median ITL (ms):                         5.36      
P99 ITL (ms):                            5.61      
==================================================

After:

============ Serving Benchmark Result ============
Successful requests:                     80        
Maximum request concurrency:             1         
Benchmark duration (s):                  41.73     
Total input tokens:                      8133      
Total generated tokens:                  16795     
Request throughput (req/s):              1.92      
Output token throughput (tok/s):         402.47    
Peak output token throughput (tok/s):    190.00    
Peak concurrent requests:                4.00      
Total Token throughput (tok/s):          597.37    
---------------Time to First Token----------------
Mean TTFT (ms):                          11.86     
Median TTFT (ms):                        11.75     
P99 TTFT (ms):                           14.93     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          2.43      
Median TPOT (ms):                        2.37      
P99 TPOT (ms):                           3.39      
---------------Inter-token Latency----------------
Mean ITL (ms):                           5.26      
Median ITL (ms):                         5.25      
P99 ITL (ms):                            5.48      
==================================================

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@benchislett benchislett requested a review from mgoin as a code owner October 15, 2025 19:12
@mergify
Copy link

mergify bot commented Oct 15, 2025

Documentation preview: https://vllm--26937.org.readthedocs.build/en/26937/

@mergify mergify bot added documentation Improvements or additions to documentation v1 labels Oct 15, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables full CUDA graphs for speculative decoding with FlashInfer when TRT-LLM attention kernels are available, which is a valuable performance enhancement. The implementation correctly updates the cudagraph_support attribute in FlashInferMetadataBuilder at runtime based on whether TRT-LLM attention can be used. The change from a class variable to an instance variable for cudagraph_support is appropriate for this dynamic behavior. The documentation has also been updated to reflect these changes. The logic appears sound and the provided test results indicate that correctness is maintained while enabling this optimization.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

@vadiklyutiy
Copy link
Contributor

Regarding performance improvement.
I did try on Qwen3-next with 2 prediction tokens.
With batch=1 it improves from 92 toks/s -> 222 toks/s

@mgoin
Copy link
Member

mgoin commented Oct 15, 2025

cc @LucasWilkinson @ProExpertProg regarding updating AttentionCGSupport dynamically

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@LucasWilkinson
Copy link
Collaborator

LucasWilkinson commented Oct 16, 2025

cc @LucasWilkinson @ProExpertProg regarding updating AttentionCGSupport dynamically

Dynamically updating it should be fine since we only call it here on instances here

if builder.cudagraph_support.value < min_cg_support.value:
. But if we are going to dynamically update it I think we should make it an instance property instead of a class variable just to avoid confusion and future bugs.

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
…ables

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@mergify mergify bot added the rocm Related to AMD ROCm label Oct 21, 2025
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM; there are some nits that should be addressed (specifically for the CPU backend I think we should still keep the reorder_batch_threshold = 1)

it is a bit harder to see where cudagraph_support is set now :/ I guess the alternative would be use a function; i.e. add a get_cudagraph_support() function in the base class (I think the current implementation is better but im also flip-flopping haha)


class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
reorder_batch_threshold: int = 1
reorder_batch_threshold: int
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I still think this needs to be set?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is set in the constructor: _init_reorder_batch_threshold(1, False)

The type annotation is left to indicate that it will never be "None" on this class and its subclasses. This is a common pattern in the changes in this PR

AttentionMetadataBuilder[XFormersAttentionMetadata]
):
reorder_batch_threshold: int = 1
reorder_batch_threshold: int
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is this still needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, as type annotation, see previous comment

)

reorder_batch_threshold: int = 1
reorder_batch_threshold: int
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is this still needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, as type annotation, see previous comment

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation rocm Related to AMD ROCm v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Performance]: FalshInfer attn backend. Use dynamic AttentionCGSupport

4 participants