Skip to content

Conversation

@benchislett
Copy link
Collaborator

@benchislett benchislett commented Jul 29, 2025

Purpose

At large batch sizes, fill_bitmask becomes a bottleneck and should be parallelized. Since the libraries we use (xGrammar, Guidance, Outlines-core) make low-level calls, they release the GIL and can be trivially parallelized using threads. This PR adds shortcut logic to dispatch the fill_bitmask calls to a thread pool when:

  1. the number of requests in the batch using structured outputs exceeds 128. In my testing, this is the threshold at which the threading overhead pays for itself.
  2. speculative decoding is disabled. It is still possible to parallelize this case in theory, but it would be much more complex and we rarely run speculative decoding at large batch sizes anyways.

This PR also caches the grammar.is_terminated() value in the xGrammar backend as it is accessed more often than it is changed.

Also, the xGrammar apply_token_bitmask_inplace_torch_compile function is not being compiled in some cases because the input indices are a List[int]. This is unavoidable in some cases, but when all requests in the batch are using structured outputs then we can skip this and use a slightly simpler path which will always compile properly.

Test Plan

No additional testing is added as no new functionality is introduced.

Perf Results

Benchmarks indicate strong performance gain with small models and large batch sizes. Internal experiments indicate 1.5 - 1.8 speedup. Because of the fallback condition, other workloads should not experience any disruption.

Below are reference outputs using:

python3 benchmarks/benchmark_serving_structured_output.py --backend vllm --model Qwen/Qwen3-1.7B --dataset json --structured-output-ratio 1.0 --request-rate 1000 --num-prompts 2000

with thinking disabled.

Benchmark TL;DR:

  • Baseline: 92 req/s (1x)
  • This PR, serial: 105 req/s (1.14x)
  • This PR, parallel: 148 req/s (1.6x)

This PR

============ Serving Benchmark Result ============
Successful requests:                     2000      
Benchmark duration (s):                  13.54     
Total input tokens:                      244000    
Total generated tokens:                  190000    
Request throughput (req/s):              147.69    
Output token throughput (tok/s):         14030.77  
Total Token throughput (tok/s):          32049.24  
---------------Time to First Token----------------
Mean TTFT (ms):                          3095.11   
Median TTFT (ms):                        4343.39   
P99 TTFT (ms):                           5420.21   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          43.06     
Median TPOT (ms):                        42.31     
P99 TPOT (ms):                           53.82     
---------------Inter-token Latency----------------
Mean ITL (ms):                           42.61     
Median ITL (ms):                         40.37     
P99 ITL (ms):                            153.23    
==================================================
correct_rate(%) 100.0

This PR, parallel disabled

============ Serving Benchmark Result ============
Successful requests:                     2000      
Benchmark duration (s):                  18.94     
Total input tokens:                      244000    
Total generated tokens:                  190000    
Request throughput (req/s):              105.61    
Output token throughput (tok/s):         10033.13  
Total Token throughput (tok/s):          22917.79  
---------------Time to First Token----------------
Mean TTFT (ms):                          3923.97   
Median TTFT (ms):                        2487.88   
P99 TTFT (ms):                           9240.62   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          83.71     
Median TPOT (ms):                        88.58     
P99 TPOT (ms):                           96.88     
---------------Inter-token Latency----------------
Mean ITL (ms):                           82.83     
Median ITL (ms):                         86.90     
P99 ITL (ms):                            294.66    
=================================================
correct_rate(%) 100.0

Previous main

============ Serving Benchmark Result ============
Successful requests:                     2000      
Benchmark duration (s):                  21.77     
Total input tokens:                      244000    
Total generated tokens:                  190000    
Request throughput (req/s):              91.87     
Output token throughput (tok/s):         8728.04   
Total Token throughput (tok/s):          19936.68  
---------------Time to First Token----------------
Mean TTFT (ms):                          4449.79   
Median TTFT (ms):                        2004.81   
P99 TTFT (ms):                           10171.69  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          97.09     
Median TPOT (ms):                        99.62     
P99 TPOT (ms):                           111.67    
---------------Inter-token Latency----------------
Mean ITL (ms):                           96.07     
Median ITL (ms):                         98.41     
P99 ITL (ms):                            328.47    
==================================================

Guided decoding disabled

============ Serving Benchmark Result ============
Successful requests:                     2000      
Benchmark duration (s):                  12.52     
Total input tokens:                      244000    
Total generated tokens:                  256000    
Request throughput (req/s):              159.78    
Output token throughput (tok/s):         20451.67  
Total Token throughput (tok/s):          39944.67  
---------------Time to First Token----------------
Mean TTFT (ms):                          2650.65   
Median TTFT (ms):                        3277.64   
P99 TTFT (ms):                           4164.70   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          24.24     
Median TPOT (ms):                        28.72     
P99 TPOT (ms):                           41.58     
---------------Inter-token Latency----------------
Mean ITL (ms):                           24.88     
Median ITL (ms):                         34.51     
P99 ITL (ms):                            108.56    
==================================================
correct_rate(%) 99.6

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = torch.from_numpy(grammar_bitmask)
grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just for sanity, could someone confirm that the above numpy-indexing logic with torch.from_numpy will always create a contiguous array and not a row-wise view?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaik torch.from_numpy creates a view, so as long as the original numpy array is contiguous, it should be ok?

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces performance improvements for guided decoding at high throughput by parallelizing fill_bitmask. The changes also include caching grammar.is_terminated() and optimizing an xGrammar function call. The performance gains demonstrated in the description are impressive.

My review has identified two critical correctness issues that need to be addressed:

  1. The logic for advancing the FSM state during speculative decoding appears to be incorrect, which could lead to invalid bitmasks.
  2. The newly introduced caching for the grammar's terminated state is not correctly updated when the FSM state is rolled back, which can cause incorrect behavior.

Addressing these issues is crucial for the stability and correctness of the new implementation.

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Copy link
Collaborator

@aarnphm aarnphm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One tiny comment about np, but otherwise we might also want to do this with guidance as well.

@benchislett
Copy link
Collaborator Author

@aarnphm The parallelization code is independent of backend. I tested that it works for xGrammar and Guidance, contributing a significant speedup to both.

The termination caching trick is only for xGrammar as I noticed it being slow there, I assume the other backends have this built-in in one way or another. We can add in a follow-up if it still presents as a slowdown.

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 1, 2025
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM and the optimization is compelling. I do wish we had a throughput test for structured output, I don't think we flex that anywhere in CI

Copy link
Member

@russellb russellb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!

@russellb russellb enabled auto-merge (squash) August 4, 2025 19:03
Co-authored-by: Russell Bryant <rbryant@redhat.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
auto-merge was automatically disabled August 5, 2025 14:41

Head branch was pushed to by a user without write access

@benchislett benchislett force-pushed the ben/guided-speedups branch from 2a6c2ff to 19fd946 Compare August 5, 2025 14:41
@russellb russellb enabled auto-merge (squash) August 5, 2025 14:58
@vllm-bot vllm-bot merged commit 7e6544c into vllm-project:main Aug 6, 2025
64 of 66 checks passed
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
…decoding (vllm-project#21862)

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
myselvess pushed a commit to myselvess/vllm that referenced this pull request Aug 7, 2025
…decoding (vllm-project#21862)

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
…decoding (vllm-project#21862)

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
noamgat pushed a commit to noamgat/vllm that referenced this pull request Aug 9, 2025
…decoding (vllm-project#21862)

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Noam Gat <noamgat@gmail.com>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
…decoding (vllm-project#21862)

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Paul Pak <paulpak58@gmail.com>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
…decoding (vllm-project#21862)

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
…decoding (vllm-project#21862)

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
…decoding (vllm-project#21862)

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Xiao Yu <xiao.yu@amd.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
…decoding (vllm-project#21862)

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed structured-output v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

6 participants