Skip to content

Conversation

tjtanaa
Copy link
Contributor

@tjtanaa tjtanaa commented Aug 12, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

This PR enables DP for ViT (LLM will be in TP)

A load balancing logic has also been implemented.

Test Plan

Run lm_eval on ChartQA dataset

Evaluate the Performance Gain

Add unit test case to evaluate the function get_load_balance_assignment and run_dp_sharded_mrope_vision_model.

Test Result

lm_eval ChartQA Dataset of model Qwen/Qwen2.5VL-72B-Instruct

TP8 Baseline
================================================================================
Metrics:
{
    "explicit_prompt_relaxed_correctness": 0.8864,
    "anywhere_in_answer_relaxed_correctness": 0.8908
}
================================================================================

DP8+TP8 This PR
================================================================================
Metrics:
{
    "explicit_prompt_relaxed_correctness": 0.8844,
    "anywhere_in_answer_relaxed_correctness": 0.8884
}
================================================================================

Performance Gain:

Server command:

HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
MIOPEN_USER_DB_PATH=/app/vl/miopen \
MIOPEN_FIND_MODE=FAST \
VLLM_USE_V1=1 \
VLLM_ROCM_USE_AITER=1 \
SAFETENSORS_FAST_GPU=1 \
vllm serve Qwen/Qwen2.5-VL-72B-Instruct \
--tensor_parallel_size=8 \
--trust_remote_code \
--port 7899 \
--enable_multimodal_encoder_data_parallel

Client command:

python3 benchmarks/benchmark_serving.py  \
--backend openai-chat   \
--model Qwen/Qwen2.5-VL-72B-Instruct   \
--endpoint /v1/chat/completions   \
--dataset-name hf   \
--dataset-path lmarena-ai/VisionArena-Chat   \
--hf-split train   \
--num-prompts 1000 \
--max-concurrency 64 \
--port 7899
Metric ViT TP ViT DP DP vs TP Improvement
Throughput
Request throughput (req/s) 1.79 2.63 +47%
Output token throughput (tok/s) 206.66 302.89 +47%
Total token throughput (tok/s) 375.53 551.31 +47%
Latency
Benchmark duration (s) 558.57 379.70 -32% (faster)
Mean TTFT (ms) 4,584.50 2,062.35 -55% (faster)
Median TTFT (ms) 2,280.17 1,513.51 -34% (faster)
P99 TTFT (ms) 25,097.62 10,313.05 -59% (faster)
Mean TPOT (ms) 285.83 198.03 -31% (faster)
Median TPOT (ms) 283.86 203.93 -28% (faster)
P99 TPOT (ms) 515.32 274.71 -47% (faster)
Mean ITL (ms) 337.49 249.15 -26% (faster)
Median ITL (ms) 33.11 131.26 +296% (slower)
P99 ITL (ms) 3,699.28 1,291.46 -65% (faster)

Most of the improvement comes from DP-ing Conv3d.

(Optional) Documentation Update

Trace of Qwen2.5VL-72B-Instruct with 16 concurrency prompts

Before enabling DP (ViT in TP mode)

image

After enabling DP (ViT in DP mode)
image

Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the qwen Related to Qwen models label Aug 12, 2025
Copy link

mergify bot commented Aug 12, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tjtanaa.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 12, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces data parallelism for the Vision Transformer in the Qwen2.5VL model, which results in significant performance improvements as shown by the benchmarks. The implementation includes a load balancing mechanism to distribute images across GPUs. The code is well-structured, but I've identified a potential issue in the load balancing logic that could lead to imbalanced workloads under certain conditions, even though it doesn't affect the current usage in this PR. I've provided a detailed comment with a suggested fix for this.

Comment on lines 275 to 291
# Assign minimum samples to each GPU
# (round-robin with smallest samples first)
small_to_large_indices = torch.argsort(sizes, descending=False)

for gpu_id in range(num_gpus):
samples_assigned = 0
for idx in small_to_large_indices:
if idx.item(
) not in used_indices and samples_assigned < min_samples_per_gpu:
gpu_assignments[gpu_id].append(idx.item())
gpu_loads[gpu_id] += sizes[idx]
used_indices.add(idx.item())
samples_assigned += 1

if samples_assigned >= min_samples_per_gpu:
break

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation for Phase 1 of load balancing does not perform a round-robin assignment as the comment suggests. Instead, it assigns blocks of the smallest samples to each GPU sequentially. This can lead to significant load imbalance if min_samples_per_gpu > 1.

For example, with sizes = [1, 2, 100, 101], num_gpus=2, and min_samples_per_gpu=2, GPU 0 would get samples of size 1 and 2 (total load 3), while GPU 1 would get samples of size 100 and 101 (total load 201).

While the current usage in this PR sets min_samples_per_gpu=0 (making this a non-issue for now), the function's default is 1, and it's a latent bug for other potential uses.

Here is a suggested fix that implements a proper round-robin assignment for Phase 1.

    # Assign minimum samples to each GPU
    # (round-robin with smallest samples first)
    if min_samples_per_gpu > 0:
        small_to_large_indices = torch.argsort(sizes, descending=False)

        unassigned_indices_iter = iter(idx.item() for idx in small_to_large_indices)

        for _ in range(min_samples_per_gpu):
            for gpu_id in range(num_gpus):
                try:
                    # Find the next available sample
                    idx = next(unassigned_indices_iter)

                    gpu_assignments[gpu_id].append(idx)
                    gpu_loads[gpu_id] += sizes[idx]
                    used_indices.add(idx)
                except StopIteration:
                    # Not enough samples to satisfy min_samples_per_gpu for all GPUs
                    break
            else:
                continue
            break

shard_size = self.output_sizes[loaded_shard_id]

param[shard_offset:shard_offset + shard_size] = loaded_weight
param.data[shard_offset:shard_offset + shard_size] = loaded_weight
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am getting this error if it is not assigning to the param.data directly

  File "/app/tritonmrope/dp-qwen2vl/vllm/model_executor/layers/linear.py", line 446, in weight_loader
    param[shard_offset:shard_offset + shard_size] = loaded_weight
RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao @mgoin any idea how this can happen?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like some params are created with required_grad=True incorrectly, but since other Linear layer's weights_loader all slice at param.data, I think this change is fine for this PR:

assert param.size() == loaded_weight.size(), (
f"Tried to load weights of size {loaded_weight.size()}"
f"to a parameter of size {param.size()}")
param.data.copy_(loaded_weight)

@tjtanaa
Copy link
Contributor Author

tjtanaa commented Aug 13, 2025

CC. @wuhuikx

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
@tjtanaa tjtanaa marked this pull request as ready for review August 13, 2025 09:49
@tjtanaa tjtanaa requested a review from sighingnow as a code owner August 13, 2025 09:49
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
@tjtanaa tjtanaa requested a review from ywang96 as a code owner August 13, 2025 15:47
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
auto-merge was automatically disabled August 19, 2025 12:56

Head branch was pushed to by a user without write access

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) August 19, 2025 13:19
@DarkLight1337 DarkLight1337 merged commit 1298c67 into vllm-project:main Aug 19, 2025
42 checks passed
princepride pushed a commit to princepride/vllm that referenced this pull request Aug 20, 2025
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
divakar-amd pushed a commit to divakar-amd/vllm_upstream that referenced this pull request Aug 20, 2025
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
cyang49 pushed a commit to cyang49/vllm that referenced this pull request Aug 20, 2025
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
djmmoss pushed a commit to djmmoss/vllm that referenced this pull request Aug 21, 2025
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Signed-off-by: Xiao Yu <xiao.yu@amd.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
mengxingkongzhouhan pushed a commit to mengxingkongzhouhan/vllm that referenced this pull request Aug 30, 2025
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Sep 3, 2025
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build multi-modality Related to multi-modality (#4194) performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants