Skip to content

Conversation

h-brenoskuk
Copy link
Contributor

@h-brenoskuk h-brenoskuk commented Aug 18, 2025

Purpose

Generate synthetic image inputs alongside random text prompts to stress-test vision models without external datasets.

Notes:

  • Works only with the OpenAI Chat-compatible backend (--backend openai-chat) and endpoint /v1/chat/completions.
  • Set --limit-mm-per-prompt on the server to match your model config.
  • Video sampling is not yet implemented. If specifying videos in the bucket config, set their probability to 0.

Vary the number of items per request and use multiple image buckets:

  --random-mm-base-items-per-request 2 \
  --random-mm-num-mm-items-range-ratio 0.5 \
  --random-mm-limit-mm-per-prompt '{"image": 4, "video": 0}' \
  --random-mm-bucket-config '{(256, 256, 1): 0.7, (720, 1280, 1): 0.3}' \

Flags specific to random-mm:

  • --random-mm-base-items-per-request: base number of multimodal items per request.
  • --random-mm-num-mm-items-range-ratio: vary item count uniformly in the closed integer range [floor(n·(1−r)), ceil(n·(1+r))]. Set r=0 to keep it fixed; r=1 allows 0 items.
  • --random-mm-limit-mm-per-prompt: per-modality hard caps, e.g. '{"image": 3, "video": 0}'.
  • --random-mm-bucket-config: dict mapping (H, W, T) → probability. Entries with probability 0 are removed; remaining probabilities are renormalized to sum to 1. Use T=1 for images. Set any T>1 (videos) probability to 0 (video sampling not yet supported).

Behavioral notes:

  • If the requested base item count cannot be satisfied under the provided per-prompt limits, the tool raises an error rather than silently clamping.

How sampling works:

  • Determine per-request item count k by sampling uniformly from the integer range defined by --random-mm-base-items-per-request and --random-mm-num-mm-items-range-ratio, then clamp k to at most the sum of per-modality limits.
  • For each of the k items, sample a bucket (H, W, T) according to the normalized probabilities in --random-mm-bucket-config, while tracking how many items of each modality have been added.
  • If a modality (e.g., image) reaches its limit from --random-mm-limit-mm-per-prompt, all buckets of that modality are excluded and the remaining bucket probabilities are renormalized before continuing.
    This should be seen as an edge case, and if this behavior can be avoided by setting --random-mm-limit-mm-per-prompt to a large number. Note that this might result in errors due to engine config --limit-mm-per-prompt.
  • The resulting request contains synthetic image data in multi_modal_data (OpenAI Chat format). When random-mm is used with the OpenAI Chat backend, prompts remain text and MM content is attached via multi_modal_data.

Test Plan

Start server

vllm serve Qwen/Qwen2.5-VL-3B-Instruct \
    --tensor-parallel-size 1 \
    --pipeline-parallel-size 1 \
    --dtype bfloat16 \
    --gpu-memory-utilization 0.9 \
    --max-model-len 16384 \
    --limit-mm-per-prompt "image=3,video=0" \
    --mm-processor-kwargs max_pixels=1003520 \
    --guided-decoding-backend "xgrammar"
  1. Test RandomDataset refactor and compare with previous implementation:

For experiments here we fix:

    --max-concurrency 10 \
    --random-prefix-len 25 \
    --random-input-len 300 \
    --random-output-len 40 \
    --random-range-ratio 0.2 \
    --request-rate inf \
    --ignore-eos \
    --endpoint-type openai-chat \
    --endpoint "/v1/chat/completions" \
    --seed 42
  1. Test RandomMultiModalDataset:

We use args above with addition of multimodal args:

  --random-range-ratio 0.2 \
  --random-mm-base-items-per-request 2 \
  --random-mm-num-mm-items-range-ratio 0.5 \
  --random-mm-limit-mm-per-prompt '{"image":3,"video":0}' \
  --random-mm-bucket-config '{(512, 512, 1): 0.25, (720, 1280, 1): 0.75}' \

On the benchmark front we test the tree cases below:

a. With no mm-data as sanity check (must align with RandomDataset tests)
b. With a fixed number of mm-items and dimensions
c. With a variable number dimensions
d. With a variable number dimensions and mm-items per request 

Test Results (ran on a H100)

  1. Test RandomDataset refactor and compare with previous implementation:
vllm bench serve \
    --model Qwen/Qwen2.5-VL-3B-Instruct \
    --dataset-name random \
    --num-prompts 100 \
    --max-concurrency 10 \
    --random-prefix-len 25 \
    --random-input-len 300 \
    --random-output-len 40 \
    --random-range-ratio 0.2 \
    --request-rate inf \
    --ignore-eos \
    --endpoint-type openai-chat \
    --endpoint "/v1/chat/completions" \
    --seed 42

Before refactor:

============ Serving Benchmark Result ============
Successful requests:                     100       
Maximum request concurrency:             10        
Benchmark duration (s):                  2.49      
Total input tokens:                      32230     
Total generated tokens:                  3911      
Request throughput (req/s):              40.24     
Output token throughput (tok/s):         1573.67   
Total Token throughput (tok/s):          14542.03  
---------------Time to First Token----------------
Mean TTFT (ms):                          21.01     
Median TTFT (ms):                        17.60     
P99 TTFT (ms):                           47.88     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          5.66      
Median TPOT (ms):                        5.63      
P99 TPOT (ms):                           5.97      
---------------Inter-token Latency----------------
Mean ITL (ms):                           5.51      
Median ITL (ms):                         5.59      
P99 ITL (ms):                            7.17      
==================================================

After refactor:

============ Serving Benchmark Result ============
Successful requests:                     100       
Maximum request concurrency:             10        
Benchmark duration (s):                  2.55      
Total input tokens:                      32805     
Total generated tokens:                  3982      
Request throughput (req/s):              39.19     
Output token throughput (tok/s):         1560.73   
Total Token throughput (tok/s):          14418.54  
---------------Time to First Token----------------
Mean TTFT (ms):                          20.27     
Median TTFT (ms):                        17.80     
P99 TTFT (ms):                           48.02     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          5.76      
Median TPOT (ms):                        5.79      
P99 TPOT (ms):                           5.85      
---------------Inter-token Latency----------------
Mean ITL (ms):                           5.62      
Median ITL (ms):                         5.71      
P99 ITL (ms):                            7.01      
==================================================
  1. Test RandomMultiModalDataset:

(a) No images:

vllm bench serve \
    --model Qwen/Qwen2.5-VL-3B-Instruct \
    --dataset-name random-mm \
    --num-prompts 100 \
    --max-concurrency 10 \
    --random-prefix-len 25 \
    --random-input-len 300 \
    --random-output-len 40 \
    --random-range-ratio 0.2 \
    --random-mm-base-items-per-request 0 \
    --random-mm-num-mm-items-range-ratio 0 \
    --random-mm-limit-mm-per-prompt '{"image":3,"video":0}' \
    --random-mm-bucket-config '{(256, 256, 1): 0.25, (720, 1280, 1): 0.75}' \
    --request-rate inf \
    --ignore-eos \
    --endpoint-type openai-chat \
    --endpoint "/v1/chat/completions" \
    --seed 42 
============ Serving Benchmark Result ============
Successful requests:                     100       
Maximum request concurrency:             10        
Benchmark duration (s):                  2.51      
Total input tokens:                      32805     
Total generated tokens:                  3982      
Request throughput (req/s):              39.90     
Output token throughput (tok/s):         1588.91   
Total Token throughput (tok/s):          14678.88  
---------------Time to First Token----------------
Mean TTFT (ms):                          21.50     
Median TTFT (ms):                        17.53     
P99 TTFT (ms):                           48.50     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          5.63      
Median TPOT (ms):                        5.62      
P99 TPOT (ms):                           5.81      
---------------Inter-token Latency----------------
Mean ITL (ms):                           5.49      
Median ITL (ms):                         5.58      
P99 ITL (ms):                            7.48      
==================================================

(b) With fixed number of images and dimension

vllm bench serve \
    --model Qwen/Qwen2.5-VL-3B-Instruct \
    --dataset-name random-mm \
    --num-prompts 100 \
    --max-concurrency 10 \
    --random-prefix-len 25 \
    --random-input-len 300 \
    --random-output-len 40 \
    --random-range-ratio 0.2 \
    --random-mm-base-items-per-request 2 \
    --random-mm-num-mm-items-range-ratio 0 \
    --random-mm-limit-mm-per-prompt '{"image":3,"video":0}' \
    --random-mm-bucket-config '{(256, 256, 1): 0.0, (720, 1280, 1): 1.0}' \
    --request-rate inf \
    --ignore-eos \
    --endpoint-type openai-chat \
    --endpoint "/v1/chat/completions" \
    --seed 42 
============ Serving Benchmark Result ============
Successful requests:                     100       
Maximum request concurrency:             10        
Benchmark duration (s):                  17.83     
Total input tokens:                      32805     
Total generated tokens:                  3982      
Request throughput (req/s):              5.61      
Output token throughput (tok/s):         223.37    
Total Token throughput (tok/s):          2063.57   
---------------Time to First Token----------------
Mean TTFT (ms):                          1535.79   
Median TTFT (ms):                        1546.03   
P99 TTFT (ms):                           2127.64   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          4.78      
Median TPOT (ms):                        3.61      
P99 TPOT (ms):                           17.40     
---------------Inter-token Latency----------------
Mean ITL (ms):                           69.52     
Median ITL (ms):                         5.17      
P99 ITL (ms):                            633.67    
==================================================

(c) With variable number images dimensions

vllm bench serve \
    --model Qwen/Qwen2.5-VL-3B-Instruct \
    --dataset-name random-mm \
    --num-prompts 100 \
    --max-concurrency 10 \
    --random-prefix-len 25 \
    --random-input-len 300 \
    --random-output-len 40 \
    --random-range-ratio 0.2 \
    --random-mm-base-items-per-request 2 \
    --random-mm-num-mm-items-range-ratio 0 \
    --random-mm-limit-mm-per-prompt '{"image":3,"video":0}' \
    --random-mm-bucket-config '{(256, 256, 1): 0.5, (720, 1280, 1): 0.5}' \
    --request-rate inf \
    --ignore-eos \
    --endpoint-type openai-chat \
    --endpoint "/v1/chat/completions" \
    --seed 42 
============ Serving Benchmark Result ============
Successful requests:                     100       
Maximum request concurrency:             10        
Benchmark duration (s):                  9.70      
Total input tokens:                      32805     
Total generated tokens:                  3982      
Request throughput (req/s):              10.31     
Output token throughput (tok/s):         410.42    
Total Token throughput (tok/s):          3791.57   
---------------Time to First Token----------------
Mean TTFT (ms):                          653.59    
Median TTFT (ms):                        625.24    
P99 TTFT (ms):                           1204.22   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.46      
Median TPOT (ms):                        7.08      
P99 TPOT (ms):                           15.95     
---------------Inter-token Latency----------------
Mean ITL (ms):                           51.21     
Median ITL (ms):                         5.66      
P99 ITL (ms):                            351.66    
==================================================

(d) With a variable number of image dimensions and images per request

vllm bench serve \
    --model Qwen/Qwen2.5-VL-3B-Instruct \
    --dataset-name random-mm \
    --num-prompts 100 \
    --max-concurrency 10 \
    --random-prefix-len 25 \
    --random-input-len 300 \
    --random-output-len 40 \
    --random-range-ratio 0.2 \
    --random-mm-base-items-per-request 2 \
    --random-mm-num-mm-items-range-ratio 0.5 \
    --random-mm-limit-mm-per-prompt '{"image":3,"video":0}' \
    --random-mm-bucket-config '{(256, 256, 1): 0.5, (720, 1280, 1): 0.5}' \
    --request-rate inf \
    --ignore-eos \
    --endpoint-type openai-chat \
    --endpoint "/v1/chat/completions" \
    --seed 42 
============ Serving Benchmark Result ============
Successful requests:                     100       
Maximum request concurrency:             10        
Benchmark duration (s):                  9.71      
Total input tokens:                      32805     
Total generated tokens:                  3982      
Request throughput (req/s):              10.30     
Output token throughput (tok/s):         410.05    
Total Token throughput (tok/s):          3788.18   
---------------Time to First Token----------------
Mean TTFT (ms):                          678.44    
Median TTFT (ms):                        686.19    
P99 TTFT (ms):                           1057.61   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.04      
Median TPOT (ms):                        6.91      
P99 TPOT (ms):                           16.72     
---------------Inter-token Latency----------------
Mean ITL (ms):                           44.87     
Median ITL (ms):                         5.30      
P99 ITL (ms):                            364.75    
==================================================

RNG isolation test

We also introduce a small test to verify the robustness of the RNG and reproducibility of the new RandomDataset implementation. The test asserts that global RNG should not interfere with the RNG of the Dataset.

The previous implementation of RandomDataset fails this while the new one passes.

To do next:

  1. Improve Serving Benchmark Result with multimodal related data, following this issue: [Feature]: Multimodal Benchmarking Support (MMLM) #21887
  2. Integrate with vllm/benchmarks/throughput.py

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the performance Performance-related issues label Aug 18, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new synthetic multimodal benchmark dataset, RandomMultiModalDataset, and refactors the existing RandomDataset to improve modularity and code reuse. The changes are well-organized, with new functionalities encapsulated in separate methods and corresponding command-line arguments added for configuration. My review identifies a potential runtime error where invalid user-provided arguments for image sampling could cause a crash. I've suggested replacing an assertion with a more informative ValueError to handle this case gracefully, aligning with the error handling practices elsewhere in the file.

Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
@h-brenoskuk h-brenoskuk force-pushed the Feature/Benchmark/RandomMMData/Images branch from 307c069 to 1323d9d Compare August 19, 2025 09:20
Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
@DarkLight1337 DarkLight1337 requested review from mgoin and ywang96 August 19, 2025 11:32
Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
@h-brenoskuk h-brenoskuk force-pushed the Feature/Benchmark/RandomMMData/Images branch from 67ede37 to 5edd4b6 Compare August 19, 2025 13:55
Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
@h-brenoskuk h-brenoskuk force-pushed the Feature/Benchmark/RandomMMData/Images branch from 57691a6 to adfd529 Compare August 19, 2025 14:59
…e for old implementation of RandomDataset

Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I think the code looks good. Can you also perform a benchmark run with real-sized images (around the range of 1024x1024) to see the difference? I usually see much worse throughput and TTFT when using VisionArena dataset, so I am interested in what settings I should use to get similar results using a random dataset.

@h-brenoskuk
Copy link
Contributor Author

h-brenoskuk commented Aug 19, 2025

Thanks, I think the code looks good. Can you also perform a benchmark run with real-sized images (around the range of 1024x1024) to see the difference? I usually see much worse throughput and TTFT when using VisionArena dataset, so I am interested in what settings I should use to get similar results using a random dataset.
@DarkLight1337
Added the results with 1024x1024 in the PR description!

@DarkLight1337
Copy link
Member

Thanks, yeah that req/s looks more reasonable with larger number of images. LGTM!

I'll have @ywang96 @mgoin conduct a second review though

…andomMMData/Images

Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
…andomMMData/Images

Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
…_items_per_request

Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM - but can you fix the formatting issue on README?

Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
…andomMMData/Images

Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
@h-brenoskuk h-brenoskuk requested a review from ywang96 August 25, 2025 07:59
Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 🚢

@ywang96 ywang96 merged commit 0cb7b06 into vllm-project:main Aug 25, 2025
37 checks passed
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
Signed-off-by: Xiao Yu <xiao.yu@amd.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Sep 3, 2025
Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
ekagra-ranjan pushed a commit to ekagra-ranjan/vllm that referenced this pull request Sep 4, 2025
Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
@shawn9977
Copy link

shawn9977 commented Sep 11, 2025

@h-brenoskuk @DarkLight1337 May I ask which vLLM Docker release has the feature? Thank you!

FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
sducouedic pushed a commit to sducouedic/vllm that referenced this pull request Oct 16, 2025
Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

5 participants