Skip to content

Conversation

Isotr0py
Copy link
Member

@Isotr0py Isotr0py commented Aug 16, 2025

Purpose

Test Plan

python examples/offline_inference/basic/generate.py --model RedHatAI/DeepSeek-V2.5-1210-FP8 -tp 4 --max-model-len 4096 --enforce-eager

Test Result

  • FP8 deepseek-v3 model can still work with fused_qkv_a_proj after reverting MergedReplicatedParallelLinear back to MergedColumnParallelLinear

(Optional) Documentation Update


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the deepseek Related to DeepSeek models label Aug 16, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a disable_tp flag to parallel linear layers, allowing them to fall back to a replicated mode. This is a valuable feature for models that need to conditionally disable tensor parallelism, for instance when data parallelism is enabled. The implementation across the various linear layer classes in vllm/model_executor/layers/linear.py is clean and consistent. The refactoring in other model files to leverage this new flag simplifies the code and demonstrates its utility. I have identified a critical typo and a potential prefixing issue in vllm/model_executor/models/step3_vl.py that should be addressed to ensure correctness, particularly for quantized models.

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
@Isotr0py Isotr0py changed the title [WIP] Allow disabling TP sharding for parallel Linear layer [Core] Allow disabling TP sharding for parallel Linear layer Aug 19, 2025
@Isotr0py Isotr0py marked this pull request as ready for review August 19, 2025 13:54
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @mgoin

Copy link

mergify bot commented Aug 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Isotr0py.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 19, 2025
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
@mergify mergify bot removed the needs-rebase label Aug 19, 2025
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
@Isotr0py Isotr0py requested a review from sighingnow as a code owner August 19, 2025 16:22
@mergify mergify bot added the qwen Related to Qwen models label Aug 19, 2025
@Isotr0py Isotr0py mentioned this pull request Aug 22, 2025
4 tasks
@Isotr0py
Copy link
Member Author

@mgoin Can you please take a look to this PR? It could help us simplify the work to finish #22743. Thanks!

@Isotr0py Isotr0py requested a review from 22quinn as a code owner September 6, 2025 03:12
@simon-mo simon-mo merged commit 53b19cc into vllm-project:main Sep 6, 2025
40 of 43 checks passed
@Isotr0py Isotr0py deleted the disable-linear-tp branch September 6, 2025 06:07
minosfuture added a commit to minosfuture/vllm that referenced this pull request Sep 6, 2025
@minosfuture
Copy link
Contributor

I think this PR broke TP on trunk. I see errors like

(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588] Traceback (most recent call last):
(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588]   File "/data/users/yming/gitrepos/vllm/vllm/v1/executor/multiproc_executor.py", line 562, in worker_main
(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588]     worker = WorkerProc(*args, **kwargs)
(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588]   File "/data/users/yming/gitrepos/vllm/vllm/v1/executor/multiproc_executor.py", line 431, in __init__
(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588]     self.worker.load_model()
(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588]   File "/data/users/yming/gitrepos/vllm/vllm/v1/worker/gpu_worker.py", line 213, in load_model
(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588]     self.model_runner.load_model(eep_scale_up=eep_scale_up)
(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588]   File "/data/users/yming/gitrepos/vllm/vllm/v1/worker/gpu_model_runner.py", line 2131, in load_model
(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588]     self.model = model_loader.load_model(
(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588]                  ^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588]   File "/data/users/yming/gitrepos/vllm/vllm/model_executor/model_loader/base_loader.py", line 49, in load_model
(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588]     self.load_weights(model, model_config)
(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588]   File "/data/users/yming/gitrepos/vllm/vllm/model_executor/model_loader/default_loader.py", line 264, in load_weights
(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588]     loaded_weights = model.load_weights(
(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588]                      ^^^^^^^^^^^^^^^^^^^
(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588]   File "/data/users/yming/gitrepos/vllm/vllm/model_executor/models/deepseek_v2.py", line 906, in load_weights
(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588]     weight_loader(param, loaded_weight, shard_id)
(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588]   File "/data/users/yming/gitrepos/vllm/vllm/model_executor/layers/linear.py", line 807, in weight_loader_v2
(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588]     param.load_merged_column_weight(loaded_weight=loaded_weight,
(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588]   File "/data/users/yming/gitrepos/vllm/vllm/model_executor/parameter.py", line 145, in load_merged_column_weight
(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588]     loaded_weight = loaded_weight.narrow(self.output_dim,
(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker TP5 pid=969158) ERROR 09-06 00:42:40 [multiproc_executor.py:588] IndexError: start out of range (expected to be in range of [-576, 576], but got 2880)

reverting this fixed it. Lemme know if you need reproducing steps. I think kimi TP16 or deepseek DP2TP8EP would reproduce it.

@Isotr0py
Copy link
Member Author

Isotr0py commented Sep 6, 2025

I think kimi TP16 or deepseek DP2TP8EP would reproduce it.

Ooops, can you provide the code for reproduction?

@minosfuture
Copy link
Contributor

#node 0
vllm serve deepseek-ai/DeepSeek-V3-0324   --tensor-parallel-size 8   --enable-expert-parallel   --data-parallel-size 2   --data-parallel-size-local 1   --data-parallel-address $MASTER_IP  --data-parallel-rpc-port 13345

#node 1
vllm serve deepseek-ai/DeepSeek-V3-0324   --tensor-parallel-size 8   --enable-expert-parallel   --data-parallel-size 2   --data-parallel-size-local 1   --data-parallel-start-rank 1   --data-parallel-address  $MASTER_IP   --data-parallel-rpc-port 13345   --headless

I think single node test of a smaller model should also be able to reproduce it but I haven't tried.

@Isotr0py
Copy link
Member Author

Isotr0py commented Sep 6, 2025

Hi @minosfuture, I think #24367 should fix this issue. Can you have a look? Thanks!

tlrmchlsmth added a commit to tlrmchlsmth/vllm that referenced this pull request Sep 7, 2025
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
eicherseiji pushed a commit to eicherseiji/vllm that referenced this pull request Sep 9, 2025
…oject#23024)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
MengqingCao pushed a commit to vllm-project/vllm-ascend that referenced this pull request Sep 10, 2025
### What this PR does / why we need it?
1. Initial support disable tp for integrating with
[vllm-commit](vllm-project/vllm#23024)
2. [vllm@commit](vllm-project/vllm#23673) now
use `bytes` to save the `BlockHash` to reduce GC overhead, this pr add
the integration

- vLLM version: main
- vLLM main:
vllm-project/vllm@e408272

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
skyloevil pushed a commit to skyloevil/vllm that referenced this pull request Sep 13, 2025
…oject#23024)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
offline893 pushed a commit to offline893/vllm-ascend that referenced this pull request Sep 16, 2025
### What this PR does / why we need it?
1. Initial support disable tp for integrating with
[vllm-commit](vllm-project/vllm#23024)
2. [vllm@commit](vllm-project/vllm#23673) now
use `bytes` to save the `BlockHash` to reduce GC overhead, this pr add
the integration

- vLLM version: main
- vLLM main:
vllm-project/vllm@e408272

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: offline0806 <z00858301@china.huawei.com>
wangxiaoteng888 pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Sep 25, 2025
### What this PR does / why we need it?
1. Initial support disable tp for integrating with
[vllm-commit](vllm-project/vllm#23024)
2. [vllm@commit](vllm-project/vllm#23673) now
use `bytes` to save the `BlockHash` to reduce GC overhead, this pr add
the integration

- vLLM version: main
- vLLM main:
vllm-project/vllm@e408272

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…oject#23024)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
chopper0126 pushed a commit to chopper0126/vllm-ascend that referenced this pull request Sep 26, 2025
### What this PR does / why we need it?
1. Initial support disable tp for integrating with
[vllm-commit](vllm-project/vllm#23024)
2. [vllm@commit](vllm-project/vllm#23673) now
use `bytes` to save the `BlockHash` to reduce GC overhead, this pr add
the integration

- vLLM version: main
- vLLM main:
vllm-project/vllm@e408272

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…oject#23024)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants