Skip to content

Conversation

david6666666
Copy link
Contributor

@david6666666 david6666666 commented Aug 21, 2025

Purpose

Add option to run MiniCPM-V-4 vision encoder in data parallel manner while the main model is in TP. Can be enabled by flag: mm_encoder_tp_mode="data"

Test Plan & Result

lm_eval:

TP8(baseline):

lm_eval \
  --model vllm-vlm \
  --model_args "pretrained=openbmb/MiniCPM-V-4,tensor_parallel_size=8,gpu_memory_utilization=0.9,trust_remote_code=True" \
  --tasks chartqa \
  --limit 100 \
  --batch_size auto
| Tasks |Version|Filter|n-shot|     Metric      |   |Value |   |Stderr|
|-------|------:|------|-----:|-----------------|---|-----:|---|-----:|
|chartqa|      0|none  |     0|anywhere_accuracy|↑  |0.5640|±  |0.0099|
|       |       |none  |     0|exact_match      |↑  |0.4792|±  |0.0100|
|       |       |none  |     0|relaxed_accuracy |↑  |0.5640|±  |0.0099|

DP8:

lm_eval \
  --model vllm-vlm \
  --model_args "pretrained=openbmb/MiniCPM-V-4,mm_encoder_tp_mode=data,tensor_parallel_size=8,gpu_memory_utilization=0.9,trust_remote_code=True" \
  --tasks chartqa \
  --limit 100 \
  --batch_size auto
| Tasks |Version|Filter|n-shot|     Metric      |   |Value |   |Stderr|
|-------|------:|------|-----:|-----------------|---|-----:|---|-----:|
|chartqa|      0|none  |     0|anywhere_accuracy|↑  |0.5604|±  |0.0099|
|       |       |none  |     0|exact_match      |↑  |0.4740|±  |0.0100|
|       |       |none  |     0|relaxed_accuracy |↑  |0.5684|±  |0.0099|

Note: The accuracy is lower than the official standard, perhaps because chat_template was not applied, as MiniCPM did not provide it in the uploaded HF model.

performance:

TP8(baseline):

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

vllm serve openbmb/MiniCPM-V-4 \
    --gpu-memory-utilization 0.9 \
    --trust-remote-code \
    --tensor-parallel-size 8 \
    --host 0.0.0.0 \
    --port 20001

DP8:

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

vllm serve openbmb/MiniCPM-V-4 \
    --gpu-memory-utilization 0.9 \
    --mm_encoder_tp_mode data \
    --trust-remote-code \
    --tensor-parallel-size 8 \
    --host 0.0.0.0 \
    --port 20001

run benchmark_serving.py:

python3 benchmarks/benchmark_serving.py  \
    --backend openai-chat   \
    --model openbmb/MiniCPM-V-4 \
    --endpoint /v1/chat/completions   \
    --trust_remote_code \
    --dataset-name hf   \
    --dataset-path lmarena-ai/VisionArena-Chat   \
    --hf-split train   \
    --num-prompts 100 \
    --host localhost \
    --port 20001

result:

TP8(baseline):
============ Serving Benchmark Result ============
Successful requests:                     100       
Benchmark duration (s):                  26.45     
Total input tokens:                      6975      
Total generated tokens:                  9078      
Request throughput (req/s):              3.78      
Output token throughput (tok/s):         343.22    
Total Token throughput (tok/s):          606.93    
---------------Time to First Token----------------
Mean TTFT (ms):                          8593.27   
Median TTFT (ms):                        9483.75   
P99 TTFT (ms):                           21059.12  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          438.66    
Median TPOT (ms):                        170.73    
P99 TPOT (ms):                           4539.28   
---------------Inter-token Latency----------------
Mean ITL (ms):                           181.50    
Median ITL (ms):                         36.19     
P99 ITL (ms):                            3375.19   
==================================================

DP8:
============ Serving Benchmark Result ============
Successful requests:                     100       
Benchmark duration (s):                  17.58     
Total input tokens:                      6975      
Total generated tokens:                  8920      
Request throughput (req/s):              5.69      
Output token throughput (tok/s):         507.49    
Total Token throughput (tok/s):          904.32    
---------------Time to First Token----------------
Mean TTFT (ms):                          6977.63   
Median TTFT (ms):                        6616.90   
P99 TTFT (ms):                           13008.27  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          180.00    
Median TPOT (ms):                        107.37    
P99 TPOT (ms):                           832.16    
---------------Inter-token Latency----------------
Mean ITL (ms):                           101.05    
Median ITL (ms):                         35.87     
P99 ITL (ms):                            946.84    
==================================================

(Optional) Documentation Update


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Signed-off-by: ycyaw66 <497410282@qq.com>
Signed-off-by: ycyaw66 <497410282@qq.com>
@david6666666 david6666666 changed the title support DP for ViT on MiniCPM-V-4 [Model] Support DP for ViT on MiniCPM-V-4 Aug 21, 2025
@mergify mergify bot added the multi-modality Related to multi-modality (#4194) label Aug 21, 2025
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@david6666666 david6666666 marked this pull request as ready for review August 22, 2025 03:12
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, can you update the doc page with this model?

Signed-off-by: ycyaw66 <497410282@qq.com>
@CarrotShoo CarrotShoo requested a review from hmellor as a code owner August 22, 2025 03:31
@mergify mergify bot added the documentation Improvements or additions to documentation label Aug 22, 2025
Signed-off-by: ycyaw66 <497410282@qq.com>
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM

@DarkLight1337
Copy link
Member

cc @jio-H you may consider using this feature

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) August 22, 2025 03:39
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 22, 2025
@DarkLight1337 DarkLight1337 merged commit 23c939f into vllm-project:main Aug 23, 2025
42 of 44 checks passed
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: ycyaw66 <497410282@qq.com>
Co-authored-by: ycyaw66 <497410282@qq.com>
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: ycyaw66 <497410282@qq.com>
Co-authored-by: ycyaw66 <497410282@qq.com>
Signed-off-by: Xiao Yu <xiao.yu@amd.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: ycyaw66 <497410282@qq.com>
Co-authored-by: ycyaw66 <497410282@qq.com>
self.q_size = self.num_heads * self.head_dim
self.qkv_proj = ReplicatedLinear(
self.embed_dim,
3 * self.q_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to bother you, but I was wondering why it's 3 * self.q_size instead of self.head_dim?

Copy link
Contributor

@CarrotShoo CarrotShoo Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps you misunderstood the meaning of the parameters? In ReplicatedRinear, 3 * self.q_size is the output_size, self.head_dim in QKVParallelLinear means head_size. Hope my answer has been helpful to you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining! Got it now.

mengxingkongzhouhan pushed a commit to mengxingkongzhouhan/vllm that referenced this pull request Aug 30, 2025
Signed-off-by: ycyaw66 <497410282@qq.com>
Co-authored-by: ycyaw66 <497410282@qq.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Sep 3, 2025
Signed-off-by: ycyaw66 <497410282@qq.com>
Co-authored-by: ycyaw66 <497410282@qq.com>
ekagra-ranjan pushed a commit to ekagra-ranjan/vllm that referenced this pull request Sep 4, 2025
Signed-off-by: ycyaw66 <497410282@qq.com>
Co-authored-by: ycyaw66 <497410282@qq.com>
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Signed-off-by: ycyaw66 <497410282@qq.com>
Co-authored-by: ycyaw66 <497410282@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants