Skip to content

Conversation

@yiz-liu
Copy link
Collaborator

@yiz-liu yiz-liu commented Aug 28, 2025

What this PR does / why we need it?

  • Unify execution paths: Consolidates the quantized and non-quantized execution paths into a single fused_experts function, removing duplicated logic and making the control flow clearer and easier to maintain.
  • W8A8 dynamic quantization: Adds support for W8A8 dynamic quantization inside the unified MoE kernel. Communication routines are updated to correctly handle dynamic quantization scales for activations.
  • Weight pre-processing: Prae-transpose the w13 and w2 weight matrices (as implemented in PR [Refactor] Pre-transpose MoE weights for improved performance #2025) so that quantized and non-quantized models follow the same code path for the MoE gating, up-projection, and down-projection operations.
  • All-to-all communication: Adds an all-to-all collective communication pattern. For large token counts on modern hardware, all-to-all is more efficient than the previous all-gather strategy. However, all-to-all is not really captured and replayed due to multiple D2H operations which will trigger synchronization, and thus raise error when capture graphs. We only use all-to-all when fallback to compiled_graph_for_general_shape.
  • Dynamic communication selection: The model runner now selects the optimal MoE communication method (mc2, allgather, or alltoall) at runtime based on token count and the Ascend SoC version.
  • Limitation: all-gather is not yet supported for quantized models, which means there is still something left to do on A2.

Does this PR introduce any user-facing change?

None.

How was this patch tested?

No further test cases needed.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for all-to-all communication in MoE layers and unifies the execution paths for quantized and non-quantized models. The changes are well-structured and align with the description. I've found one potential issue with the padding logic in the new AlltoAllCommImpl which could be not robust for all input sizes. My suggestion aims to make the padding logic more general and correct.

self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
self.num_tokens, _ = hidden_states.shape
pad_size = self.tp_size - self.num_tokens
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current padding logic pad_size = self.tp_size - self.num_tokens is not robust. It only handles cases where num_tokens < tp_size and does not pad when num_tokens > tp_size and is not a multiple of tp_size. This can lead to uneven tensor splits, which might be suboptimal or cause issues with certain collective operations. A more robust approach is to pad the number of tokens to the next multiple of tp_size.

Suggested change
pad_size = self.tp_size - self.num_tokens
pad_size = (-self.num_tokens) % self.tp_size

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@codecov
Copy link

codecov bot commented Aug 28, 2025

Codecov Report

❌ Patch coverage is 18.47826% with 75 lines in your changes missing coverage. Please review.
✅ Project coverage is 72.41%. Comparing base (600b08f) to head (1002f4b).

Files with missing lines Patch % Lines
vllm_ascend/distributed/moe_comm_method.py 18.91% 30 Missing ⚠️
vllm_ascend/ops/common_fused_moe.py 19.44% 29 Missing ⚠️
vllm_ascend/worker/model_runner_v1.py 0.00% 11 Missing ⚠️
vllm_ascend/quantization/w8a8_dynamic.py 37.50% 5 Missing ⚠️

❌ Your patch check has failed because the patch coverage (18.47%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2614      +/-   ##
==========================================
- Coverage   72.61%   72.41%   -0.20%     
==========================================
  Files         147      147              
  Lines       21805    21883      +78     
==========================================
+ Hits        15833    15846      +13     
- Misses       5972     6037      +65     
Flag Coverage Δ
unittests 72.41% <18.47%> (-0.20%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@yiz-liu yiz-liu force-pushed the support-all-to-all branch from 383b307 to ff45e0e Compare August 28, 2025 15:45
@yiz-liu yiz-liu force-pushed the support-all-to-all branch from ff45e0e to fcc3b69 Compare August 28, 2025 16:07
@yiz-liu yiz-liu changed the title [Feat][Graph] Support all-to-all and quantized models with ACL Graph [3/N][Feat][Graph] Support all-to-all and quantized models with ACL Graph Aug 29, 2025
@github-actions
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Refactors the Fused MoE implementation by unifying the quantized and non-quantized execution paths into a single `fused_experts` function. This simplifies the codebase and centralizes MoE logic.

Adds support for W8A8 dynamic quantization within the unified MoE kernel. Communication methods are updated to handle dynamic scales for quantized activations.

Additionally, this change introduces a weight pre-processing step that transposes and converts weights to the `NZ` format, optimizing `matmul` performance on NPU hardware.

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This method leverages an `all-to-all` collective communication pattern, which is more efficient than the existing `all-gather` strategy for large token counts on newer hardware.

The model runner now dynamically selects the optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) based on the token count and the underlying Ascend SoC version.

But note that all-gather has not supported quantized models.

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
- Moves the token dispatcher import into the `AlltoAllCommImpl` constructor to enable lazy loading.
- Restricts MoE communication method logging to the global first rank to reduce log verbosity.
- Updates MoE communication tests to accommodate a new parameter in the `permute` function.

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
@yiz-liu yiz-liu force-pushed the support-all-to-all branch 3 times, most recently from a3c69a9 to 9048a87 Compare August 30, 2025 01:54
Removes unnecessary weight transpose operations within the fused MoE expert function to improve performance.

Refactors how quantization flags are passed for MoE communication primitives.

Skips a W8A8 MoE test, as the required All-Gather communication operation does not yet support this quantization mode.

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
@yiz-liu yiz-liu force-pushed the support-all-to-all branch from 9048a87 to d3f1a46 Compare August 30, 2025 02:47
@wangxiyuan wangxiyuan merged commit d3c93fb into vllm-project:main Aug 30, 2025
23 of 24 checks passed
@yiz-liu yiz-liu deleted the support-all-to-all branch August 30, 2025 03:00
845473182 pushed a commit to raindaywhu/vllm-ascend that referenced this pull request Sep 1, 2025
…into main_829

* 'main_829' of https://github.com/raindaywhu/vllm-ascend:
  [torchair]remove aicpu op (vllm-project#2640)
  bugfix for torchair graph (vllm-project#2639)
  [CI] fix UT error. (vllm-project#2644)
  [3/N][Feat][Graph] Support `all-to-all` and quantized models with ACL Graph (vllm-project#2614)
  [Bugfix] Fix mc2 operator error in aclgraph + ep<16 scenario (vllm-project#2609)
wenba0 pushed a commit to wenba0/vllm-ascend that referenced this pull request Sep 5, 2025
… Graph (vllm-project#2614)

### What this PR does / why we need it?
* **Unify execution paths:** Consolidates the quantized and
non-quantized execution paths into a single `fused_experts` function,
removing duplicated logic and making the control flow clearer and easier
to maintain.
* **W8A8 dynamic quantization:** Adds support for W8A8 dynamic
quantization inside the unified MoE kernel. Communication routines are
updated to correctly handle dynamic quantization scales for activations.
* **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight
matrices (as implemented in PR vllm-project#2025) so that quantized and
non-quantized models follow the same code path for the MoE gating,
up-projection, and down-projection operations.
* **All-to-all communication:** Adds an `all-to-all` collective
communication pattern. For large token counts on modern hardware,
`all-to-all` is more efficient than the previous `all-gather` strategy.
However, `all-to-all` is not really captured and replayed due to
multiple D2H operations which will trigger synchronization, and thus
raise error when capture graphs. We only use `all-to-all` when fallback
to `compiled_graph_for_general_shape`.
* **Dynamic communication selection:** The model runner now selects the
optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at
runtime based on token count and the Ascend SoC version.
* **Limitation:** `all-gather` is not yet supported for quantized
models, which means there is still something left to do on A2.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
No further test cases needed.

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@d660c98

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Signed-off-by: lijiaojiao <lijiaojiao990304@163.com>

return hidden_states

def permute(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is the CPU logic in alltoall kept within the computation graph here? I only see the registration of fake functions, but no explicit mechanism to detach the graph.

wangxiaoteng888 pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Sep 25, 2025
… Graph (vllm-project#2614)

### What this PR does / why we need it?
* **Unify execution paths:** Consolidates the quantized and
non-quantized execution paths into a single `fused_experts` function,
removing duplicated logic and making the control flow clearer and easier
to maintain.
* **W8A8 dynamic quantization:** Adds support for W8A8 dynamic
quantization inside the unified MoE kernel. Communication routines are
updated to correctly handle dynamic quantization scales for activations.
* **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight
matrices (as implemented in PR vllm-project#2025) so that quantized and
non-quantized models follow the same code path for the MoE gating,
up-projection, and down-projection operations.
* **All-to-all communication:** Adds an `all-to-all` collective
communication pattern. For large token counts on modern hardware,
`all-to-all` is more efficient than the previous `all-gather` strategy.
However, `all-to-all` is not really captured and replayed due to
multiple D2H operations which will trigger synchronization, and thus
raise error when capture graphs. We only use `all-to-all` when fallback
to `compiled_graph_for_general_shape`.
* **Dynamic communication selection:** The model runner now selects the
optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at
runtime based on token count and the Ascend SoC version.
* **Limitation:** `all-gather` is not yet supported for quantized
models, which means there is still something left to do on A2.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
No further test cases needed.

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@d660c98

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
chopper0126 pushed a commit to chopper0126/vllm-ascend that referenced this pull request Sep 26, 2025
… Graph (vllm-project#2614)

### What this PR does / why we need it?
* **Unify execution paths:** Consolidates the quantized and
non-quantized execution paths into a single `fused_experts` function,
removing duplicated logic and making the control flow clearer and easier
to maintain.
* **W8A8 dynamic quantization:** Adds support for W8A8 dynamic
quantization inside the unified MoE kernel. Communication routines are
updated to correctly handle dynamic quantization scales for activations.
* **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight
matrices (as implemented in PR vllm-project#2025) so that quantized and
non-quantized models follow the same code path for the MoE gating,
up-projection, and down-projection operations.
* **All-to-all communication:** Adds an `all-to-all` collective
communication pattern. For large token counts on modern hardware,
`all-to-all` is more efficient than the previous `all-gather` strategy.
However, `all-to-all` is not really captured and replayed due to
multiple D2H operations which will trigger synchronization, and thus
raise error when capture graphs. We only use `all-to-all` when fallback
to `compiled_graph_for_general_shape`.
* **Dynamic communication selection:** The model runner now selects the
optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at
runtime based on token count and the Ascend SoC version.
* **Limitation:** `all-gather` is not yet supported for quantized
models, which means there is still something left to do on A2.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
No further test cases needed.

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@d660c98

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
wangxiyuan pushed a commit that referenced this pull request Oct 16, 2025
…10.0rc1 (#3455)

Pin version that can stable running 310I Duo to vllm-ascend v0.10.0rc1.

### What this PR does / why we need it?
Since PR #2614 310I Duo been broken. Although we are currently working
on fixing the issue with the 310I Duo being broken, there is no
confirmed timeline for a fix in the short term. To allow users to
quickly find a working version instead of going back and forth on trial
and error, this PR fixes the version in the 310I Duo guide.

### Does this PR introduce _any_ user-facing change?
NA

### How was this patch tested?
NA

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
MrZ20 pushed a commit to MrZ20/vllm-ascend that referenced this pull request Oct 17, 2025
…10.0rc1 (vllm-project#3455)

Pin version that can stable running 310I Duo to vllm-ascend v0.10.0rc1.

### What this PR does / why we need it?
Since PR vllm-project#2614 310I Duo been broken. Although we are currently working
on fixing the issue with the 310I Duo being broken, there is no
confirmed timeline for a fix in the short term. To allow users to
quickly find a working version instead of going back and forth on trial
and error, this PR fixes the version in the 310I Duo guide.

### Does this PR introduce _any_ user-facing change?
NA

### How was this patch tested?
NA

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
Signed-off-by: MrZ20 <2609716663@qq.com>
MrZ20 pushed a commit to MrZ20/vllm-ascend that referenced this pull request Oct 17, 2025
…10.0rc1 (vllm-project#3455)

Pin version that can stable running 310I Duo to vllm-ascend v0.10.0rc1.

### What this PR does / why we need it?
Since PR vllm-project#2614 310I Duo been broken. Although we are currently working
on fixing the issue with the 310I Duo being broken, there is no
confirmed timeline for a fix in the short term. To allow users to
quickly find a working version instead of going back and forth on trial
and error, this PR fixes the version in the 310I Duo guide.

### Does this PR introduce _any_ user-facing change?
NA

### How was this patch tested?
NA

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
Signed-off-by: MrZ20 <2609716663@qq.com>
Angazenn pushed a commit to Angazenn/vllm-ascend that referenced this pull request Oct 21, 2025
… Graph (vllm-project#2614)

### What this PR does / why we need it?
* **Unify execution paths:** Consolidates the quantized and
non-quantized execution paths into a single `fused_experts` function,
removing duplicated logic and making the control flow clearer and easier
to maintain.
* **W8A8 dynamic quantization:** Adds support for W8A8 dynamic
quantization inside the unified MoE kernel. Communication routines are
updated to correctly handle dynamic quantization scales for activations.
* **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight
matrices (as implemented in PR vllm-project#2025) so that quantized and
non-quantized models follow the same code path for the MoE gating,
up-projection, and down-projection operations.
* **All-to-all communication:** Adds an `all-to-all` collective
communication pattern. For large token counts on modern hardware,
`all-to-all` is more efficient than the previous `all-gather` strategy.
However, `all-to-all` is not really captured and replayed due to
multiple D2H operations which will trigger synchronization, and thus
raise error when capture graphs. We only use `all-to-all` when fallback
to `compiled_graph_for_general_shape`.
* **Dynamic communication selection:** The model runner now selects the
optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at
runtime based on token count and the Ascend SoC version.
* **Limitation:** `all-gather` is not yet supported for quantized
models, which means there is still something left to do on A2.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
No further test cases needed.

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@d660c98

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants