Skip to content

Conversation

@kunpengW-code
Copy link
Contributor

@kunpengW-code kunpengW-code commented Aug 4, 2025

What this PR does / why we need it?

1.Shared Expert Sharding Strategy Update: Switched from TP-aligned to pure DP for shared experts, enabling more efficient execution.
2.O_Proj AllReduce → ReduceScatter: Reduced communication overhead by using ReduceScatter, made possible by pure DP sharding.
3.AllGather Postponed: Delayed to after QKV down projection to reduce synchronization impact during prefill.

Does this PR introduce any user-facing change?

How was this patch tested?

Adding ut case in tests/ut/attention/test_mla_v1.py

How to run

use parameter --additional_config='{"enable_shared_expert_dp": true}'

a.How to run eager mode

eg:
python -m vllm.entrypoints.openai.api_server --model=/model_path --trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002 --max-model-len 5120 --max-num-batched-tokens 16384 --enforce-eager --disable-log-requests --additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp": true,"chunked_prefill_for_mla":true}'

b.How to run graph mode

eg:
python -m vllm.entrypoints.openai.api_server --model=/model_path --trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002 --max-model-len 5120 --max-num-batched-tokens 16384 --disable-log-requests --additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp": true,"chunked_prefill_for_mla":true,"torchair_graph_config":{"enabled":true}}'

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
@github-actions
Copy link

github-actions bot commented Aug 4, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
if self.enable_shared_expert_dp and self.debug_layer_idx > 3 and self.debug_layer_idx < 61:
hidden_states_or_q_c = get_tp_group().all_gather(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of severe load imbalance between DPs, the DP domain may become stuck?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the DP domain will not block, this is all_gather of all TPs within the DP domain, and different DP domains will not affect each other.

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Aug 4, 2025
@codecov
Copy link

codecov bot commented Aug 4, 2025

Codecov Report

❌ Patch coverage is 53.00000% with 47 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.82%. Comparing base (875a86c) to head (0cf9e1d).
⚠️ Report is 28 commits behind head on main.

Files with missing lines Patch % Lines
vllm_ascend/models/deepseek_v2.py 27.90% 31 Missing ⚠️
vllm_ascend/attention/mla_v1.py 52.17% 11 Missing ⚠️
vllm_ascend/ops/fused_moe.py 55.55% 4 Missing ⚠️
vllm_ascend/ascend_config.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2198      +/-   ##
==========================================
+ Coverage   76.70%   76.82%   +0.11%     
==========================================
  Files         113      113              
  Lines       12944    13051     +107     
==========================================
+ Hits         9929    10026      +97     
- Misses       3015     3025      +10     
Flag Coverage Δ
unittests 76.82% <53.00%> (+0.11%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
| `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. |
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
| `enable_shared_expert_dp` | bool | `True` | When the shared expert in DP, it has better performance but consumes more memory. When the memory is sensitive, this switch can be turned off manually. |
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add the note,please take a look, thanks. @wangxiyuan

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
@jianzs
Copy link
Collaborator

jianzs commented Aug 5, 2025

Refine your title and description to better reflect your PR.

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
@kunpengW-code kunpengW-code changed the title [main][Prefill Perf] Parallel Strategy Optimizations [main][prefill optimization] Optimize parallel strategies to reduce communication overhead Aug 5, 2025
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance and not self.use_aclgraph:
if enable_force_load_balance:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line change appears to be unrelated to the purpose of this PR.

Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
@SlightwindSec SlightwindSec force-pushed the upstream_main_shared_expert_dp branch from bd22a0d to a4117fb Compare August 6, 2025 11:39
@wangxiyuan
Copy link
Collaborator

@jianzs please take a look again.

@jianzs
Copy link
Collaborator

jianzs commented Aug 11, 2025

Are there any prerequisite features needed for this to function correctly? If yes, I recommend adding configuration validation to catch any missing dependencies early.

@SlightwindSec
Copy link
Contributor

Are there any prerequisite features needed for this to function correctly? If yes, I recommend adding configuration validation to catch any missing dependencies early.

Thanks for your suggestion! This feature requires Torchair to be disabled and Expert Parallelism (EP) to be enabled.

self.enable_shared_expert_dp = additional_config.get(
            "enable_shared_expert_dp", True
        ) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel

I've already implemented a check for this in the code, so enable_shared_expert_dp will only be activated if these conditions are met. I've also added a warning that notifies users and disables the feature automatically if they try to use it in an unsupported configuration. @jianzs

@wangxiyuan wangxiyuan merged commit dc585f1 into vllm-project:main Aug 12, 2025
25 checks passed
Csrayz added a commit to Csrayz/vllm-ascend that referenced this pull request Aug 13, 2025
* enable mm allreduce test (vllm-project#2192)

### What this PR does / why we need it?
This PR is to add e2e test for using npu_mm_all_reduce_base fusion
kernel.
### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
not involved

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@5d5d419

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>

* [main] remove torch.cat and replace it by List[0] (vllm-project#2153)

### What this PR does / why we need it?
torch_npu.npu_grouped_matmul:

https://www.hiascend.com/document/detail/zh/Pytorch/710/apiref/torchnpuCustomsapi/context/torch_npu-npu_grouped_matmul.md

According to the document, when `split_item` is 2 or 3,
`torch_npu.npu_grouped_matmul` will return a list which has one element.
Therefore, the `torch.cat` after `torch_npu.npu_grouped_matmul` is
unnecessary.

### Does this PR introduce _any_ user-facing change?
not involved

### How was this patch tested?
ut and e2e covered: `tests/ut/ops/test_fused_ops.py`,
`tests/e2e/singlecard/ops/test_fused_moe.py`

**performance**:
(qwen3 30B, 2k->20k)

base:
Total Token throughput (tok/s):          667.76 

remove cat:
Total Token throughput (tok/s):          680.82 


- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@fa00c5d

Signed-off-by: huangxialu <huangxialu1@huawei.com>

* [CI][Quickfix] Fix AscendFusedMoE init error (vllm-project#2268)

### What this PR does / why we need it?
Fix AscendFusedMoE init error. Use `super().__init__()` instead of
`super(FusedMoE, self).__init__()` to ensure the member variables in
base class could be called by the children class

### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
CI passed with new existing test.


- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@766bc81

---------

Signed-off-by: MengqingCao <cmq0113@163.com>

* Fix accuracy test config and add DeepSeek-V2-Lite test (vllm-project#2261)

### What this PR does / why we need it?
This PR fix accuracy test related to
vllm-project#2073, users can now
perform accuracy tests on multiple models simultaneously and generate
different report files by running:

```bash
cd ~/vllm-ascend
pytest -sv ./tests/e2e/models/test_lm_eval_correctness.py \
          --config-list-file ./tests/e2e/models/configs/accuracy.txt
```

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
<img width="1648" height="511" alt="image"
src="https://github.com/user-attachments/assets/1757e3b8-a6b7-44e5-b701-80940dc756cd"
/>


- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@766bc81

---------

Signed-off-by: Icey <1790571317@qq.com>

* Fix accuracy test create PR (vllm-project#2274)

### What this PR does / why we need it?

Fix create PR of accuracy test 

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
Local testing: nv-action/vllm-benchmarks#87

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@099c046

---------

Signed-off-by: Icey <1790571317@qq.com>

* Add ut for test_communicator.py (vllm-project#2293)

### What this PR does / why we need it?

Add ut for test_communicator.py 

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@e5ebeeb

Signed-off-by: yangqinghao-cmss <yangqinghao_yewu@cmss.chinamobile.com>

* [CI] Fix broken CI (vllm-project#2302)

1. disable test_eagle_ccorrectness test, we'll reopen it once oom error
fixed.
2. drop transformers version limit for main, since vLLM rely on
>=4.55.0, see:
vllm-project/vllm@65552b4
3. fix kv_connector_output bug, see:
vllm-project/vllm@796bae0

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@d1af8b7

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>

* [2/N][Refactor] torchair model runner refactor (vllm-project#2204)

There is lot of torchair code in model runner leading the code hard for
maintenance. We'll create new torchair_model_runner to split torchair
related logic. Following the workflow vllm-project#2203

What's this PR do:

move `torchair` related logic into `_get_forward_metadata_across_dp` and
override it in torchair model runner


- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@1b99028

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>

* [core] Support capture custom ops into aclgraph (vllm-project#2113)

### What this PR does / why we need it?
Thanks to the PR vllm-project#426
make vllm-ascend support the aclgraph inference to reduce the host
overhead. However, the capability of aclgraph strongly relies on the
functionality provided by `torch.compile`, which is the key feature
supported in torch 2.x . Therefore, capture custom op into aclgraph is
only possible when it can be recognize and captured by `torch.compile`.

In this PR, we register the meta implementation of current custom ops to
enable the fx graph capture. And by doing that, insert those custom ops
into aclgraph become a natural thing to the ascend runtime.

### Does this PR introduce _any_ user-facing change?
No user face change.

### How was this patch tested?
Tested in unittest, we will integrate the `rotary_embedding` op into a
small custom model and use `torch.compile` and aclgraph to capture and
replay it to verify its functionality.

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@1b99028

---------

Signed-off-by: ganyi <pleaplusone.gy@gmail.com>

* Bump actions/download-artifact from 4 to 5 (vllm-project#2311)

Bumps
[actions/download-artifact](https://github.com/actions/download-artifact)
from 4 to 5.

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@ebf7605

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* [Perf][MTP] Optimize reject sampler in greedy situation. (vllm-project#2137)

This PR port optimization in PR vllm-project#2002 to main and makes it cleaner.

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@afa5b7c

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>

* [3/N][Refactor] torchair model runner refactor  (vllm-project#2207)

There is lot of torchair code in model runner leading the code hard for
maintenance. We'll create new torchair_model_runner to split torchair
related logic. Following the workflow vllm-project#2203, this is the first PR.

What's this PR do:

create common function `_build_attention_metadata` and
`_generate_dummy_run_hidden_states` for dummy_run

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@ebf7605

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>

* [Feat] chunkprefill mla support torchair graph (vllm-project#1772)

chunkprefill mla only support eager mode now,we want to optimaze it by
support torchair graph, the idea is simple, when all the request is
running in decode, use torchair graph to deal with it, else when
chunkprefill or prefill only, use the eager mode

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@ebf7605

Signed-off-by: haojiangzheng <justineric096@gmail.com>
Co-authored-by: haojiangzheng <justineric096@gmail.com>

* [4/N][Refactor] torchair model runner refactor (vllm-project#2208)

There is lot of torchair code in model runner leading the code hard for
maintenance. We'll create new torchair_model_runner to split torchair
related logic. Following the workflow vllm-project#2203, this is the first PR.

What's this PR do:

create common function `_convert_torch_foramt`  for initialize_kv_cache


- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@14a5d90

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>

* Configure Gemini (vllm-project#2298)

### What this PR does / why we need it?
This PR requests Gemini AI to review PRs.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
NA

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@14a5d90

Signed-off-by: QwertyJack <7554089+QwertyJack@users.noreply.github.com>

* ut: add ci guard for ut coverage (vllm-project#2317)

### What this PR does / why we need it?
add ci guard for ut coverage, if ut coverage of patch pr is below 80%,
the ci will failed/

### Does this PR introduce _any_ user-facing change?
not involved

### How was this patch tested?
not involved

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@458e74e

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>

* [main][prefill optimization] Optimize parallel strategies to reduce communication overhead (vllm-project#2198)

### What this PR does / why we need it?
1.Shared Expert Sharding Strategy Update: Switched from TP-aligned to
pure DP for shared experts, enabling more efficient execution.
2.O_Proj AllReduce → ReduceScatter: Reduced communication overhead by
using ReduceScatter, made possible by pure DP sharding.
3.AllGather Postponed: Delayed to after QKV down projection to reduce
synchronization impact during prefill.

### How was this patch tested?
Adding ut case in `tests/ut/attention/test_mla_v1.py`

#### How to run

use parameter `--additional_config='{"enable_shared_expert_dp": true}'`

##### a.How to run eager mode

eg:
python -m vllm.entrypoints.openai.api_server --model=/model_path
--trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002
--max-model-len 5120 --max-num-batched-tokens 16384 --enforce-eager
--disable-log-requests
--additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp":
true,"chunked_prefill_for_mla":true}'

##### b.How to run graph mode

eg:
python -m vllm.entrypoints.openai.api_server --model=/model_path
--trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002
--max-model-len 5120 --max-num-batched-tokens 16384
--disable-log-requests
--additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp":
true,"chunked_prefill_for_mla":true,"torchair_graph_config":{"enabled":true}}'


- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@9edd1db

---------

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: SlightwindSec <slightwindsec@gmail.com>

* [Doc] Update faq (vllm-project#2334)

### What this PR does / why we need it?
  - update determinitic calculation
  - update support device

### Does this PR introduce _any_ user-facing change?
- Users should update ray and protobuf when using ray as distributed
backend
- Users should change to use `export HCCL_DETERMINISTIC=true` when
enabling determinitic calculation

### How was this patch tested?
N/A

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@ea1292a

Signed-off-by: MengqingCao <cmq0113@163.com>

* [5/N][Refactor] torchair model runner refactor (vllm-project#2216)

There is lot of torchair code in model runner leading the code hard for
maintenance. We'll create new torchair_model_runner to split torchair
related logic. Following the workflow vllm-project#2203

What's this PR do:

create common function `_capture_model` for capture_model

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@1891a26

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>

* [1/N][Feat] Support MoE models with ACL Graph and refactor MoE communication logic (vllm-project#2125)

### What this PR does / why we need it?
This PR refactors the MoE (Mixture of Experts) communication logic by
introducing a strategy pattern. It defines an abstract base class,
`MoECommMethod`, which encapsulates different communication strategies
for MoE layers. By decoupling the MoE implementation from any single
communication method, this change makes it simpler to add, replace, or
optimize communication strategies in the future.

Plan / Roadmap

1. Introduce `MoECommMethod`, implement `AllGatherImpl`, and adapt ACL
Graph handling to cover all scenarios (this PR).
2. Implement `MC2CommImpl` and `AllToAllCommImpl` to optimize
performance in specific scenarios.
3. Enable W8A8 / Int8 models to use `unified_fused_experts`.

Other notes

* Data-parallel (DP) communication currently does not work with vLLM's
dispatch/combine mechanisms; an alternative approach is required to
resolve this incompatibility.

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@f7ad6a1

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>

* [Doc] Add container image save/load FAQ for offline environments (vllm-project#2347)

### What this PR does / why we need it?

Add Docker export/import guide for air-gapped environments

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?

NA

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@d16aa3d

Signed-off-by: QwertyJack <7554089+QwertyJack@users.noreply.github.com>

* [Bugfix] fix the oom when chunkprefill with long context like 64k (vllm-project#2319)

The attn mask was declared in the mla.py,we don't need the splitfuse
mask when mla chunkprefill, and this mask will cause memory problem when
long context like 64k or 128k

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@14a5d90

---------

Signed-off-by: haojiangzheng <justineric096@gmail.com>

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
Signed-off-by: huangxialu <huangxialu1@huawei.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: Icey <1790571317@qq.com>
Signed-off-by: yangqinghao-cmss <yangqinghao_yewu@cmss.chinamobile.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
Signed-off-by: dependabot[bot] <support@github.com>
Signed-off-by: whx-sjtu <2952154980@qq.com>
Signed-off-by: haojiangzheng <justineric096@gmail.com>
Signed-off-by: QwertyJack <7554089+QwertyJack@users.noreply.github.com>
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Ronald1995 <ronaldautomobile@163.com>
Co-authored-by: huangxialu <huangxialu1@huawei.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
Co-authored-by: Icey <1790571317@qq.com>
Co-authored-by: yangqinghao-cmss <yangqinghao_yewu@cmss.chinamobile.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: Pleaplusone <pleaplusone.gy@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: whx <56632993+whx-sjtu@users.noreply.github.com>
Co-authored-by: zhenghaojiang <zhjoneson@163.com>
Co-authored-by: haojiangzheng <justineric096@gmail.com>
Co-authored-by: jack <QwertyJack@users.noreply.github.com>
Co-authored-by: Wang Kunpeng <1289706727@qq.com>
Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: yiz-liu <136800916+yiz-liu@users.noreply.github.com>
chopper0126 pushed a commit to chopper0126/vllm-ascend that referenced this pull request Sep 26, 2025
…ommunication overhead (vllm-project#2198)

1.Shared Expert Sharding Strategy Update: Switched from TP-aligned to
pure DP for shared experts, enabling more efficient execution.
2.O_Proj AllReduce → ReduceScatter: Reduced communication overhead by
using ReduceScatter, made possible by pure DP sharding.
3.AllGather Postponed: Delayed to after QKV down projection to reduce
synchronization impact during prefill.

Adding ut case in `tests/ut/attention/test_mla_v1.py`

use parameter `--additional_config='{"enable_shared_expert_dp": true}'`

eg:
python -m vllm.entrypoints.openai.api_server --model=/model_path
--trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002
--max-model-len 5120 --max-num-batched-tokens 16384 --enforce-eager
--disable-log-requests
--additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp":
true,"chunked_prefill_for_mla":true}'

eg:
python -m vllm.entrypoints.openai.api_server --model=/model_path
--trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002
--max-model-len 5120 --max-num-batched-tokens 16384
--disable-log-requests
--additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp":
true,"chunked_prefill_for_mla":true,"torchair_graph_config":{"enabled":true}}'

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@9edd1db

---------

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
@SlightwindSec SlightwindSec deleted the upstream_main_shared_expert_dp branch October 13, 2025 01:42
Angazenn pushed a commit to Angazenn/vllm-ascend that referenced this pull request Oct 21, 2025
…ommunication overhead (vllm-project#2198)

### What this PR does / why we need it?
1.Shared Expert Sharding Strategy Update: Switched from TP-aligned to
pure DP for shared experts, enabling more efficient execution.
2.O_Proj AllReduce → ReduceScatter: Reduced communication overhead by
using ReduceScatter, made possible by pure DP sharding.
3.AllGather Postponed: Delayed to after QKV down projection to reduce
synchronization impact during prefill.

### How was this patch tested?
Adding ut case in `tests/ut/attention/test_mla_v1.py`

#### How to run

use parameter `--additional_config='{"enable_shared_expert_dp": true}'`

##### a.How to run eager mode

eg:
python -m vllm.entrypoints.openai.api_server --model=/model_path
--trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002
--max-model-len 5120 --max-num-batched-tokens 16384 --enforce-eager
--disable-log-requests
--additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp":
true,"chunked_prefill_for_mla":true}'

##### b.How to run graph mode

eg:
python -m vllm.entrypoints.openai.api_server --model=/model_path
--trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002
--max-model-len 5120 --max-num-batched-tokens 16384
--disable-log-requests
--additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp":
true,"chunked_prefill_for_mla":true,"torchair_graph_config":{"enabled":true}}'


- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@9edd1db

---------

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation module:core module:ops module:tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants