Skip to content

Commit e559e13

Browse files
yiz-liuAngazenn
authored andcommitted
[3/N][Feat][Graph] Support all-to-all and quantized models with ACL Graph (vllm-project#2614)
### What this PR does / why we need it? * **Unify execution paths:** Consolidates the quantized and non-quantized execution paths into a single `fused_experts` function, removing duplicated logic and making the control flow clearer and easier to maintain. * **W8A8 dynamic quantization:** Adds support for W8A8 dynamic quantization inside the unified MoE kernel. Communication routines are updated to correctly handle dynamic quantization scales for activations. * **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight matrices (as implemented in PR vllm-project#2025) so that quantized and non-quantized models follow the same code path for the MoE gating, up-projection, and down-projection operations. * **All-to-all communication:** Adds an `all-to-all` collective communication pattern. For large token counts on modern hardware, `all-to-all` is more efficient than the previous `all-gather` strategy. However, `all-to-all` is not really captured and replayed due to multiple D2H operations which will trigger synchronization, and thus raise error when capture graphs. We only use `all-to-all` when fallback to `compiled_graph_for_general_shape`. * **Dynamic communication selection:** The model runner now selects the optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at runtime based on token count and the Ascend SoC version. * **Limitation:** `all-gather` is not yet supported for quantized models, which means there is still something left to do on A2. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? No further test cases needed. - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@d660c98 --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent e40e66c commit e559e13

File tree

7 files changed

+248
-41
lines changed

7 files changed

+248
-41
lines changed

tests/e2e/multicard/moe/test_moe_comm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
@pytest.mark.parametrize("top_k_num", [2, 4])
3434
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
3535
@pytest.mark.parametrize("ep_rank", [0, 1])
36+
@pytest.mark.parametrize("apply_a8_quantization", [False])
3637
def test_all_gather_comm_impl(
3738
num_tokens,
3839
hidden_size,
@@ -41,6 +42,7 @@ def test_all_gather_comm_impl(
4142
top_k_num,
4243
dtype,
4344
ep_rank,
45+
apply_a8_quantization,
4446
mocker,
4547
):
4648
"""
@@ -118,8 +120,9 @@ def test_all_gather_comm_impl(
118120
native_permuted_hidden,
119121
native_expert_tokens,
120122
_,
123+
_,
121124
) = native_impl.permute(hidden_states, topk_ids, topk_weights, expert_map,
122-
num_experts)
125+
num_experts, apply_a8_quantization)
123126
# Simulate MLP output
124127
native_mlp_output = torch.randn_like(native_permuted_hidden)
125128
native_impl.unpermute(native_mlp_output, native_hidden_states_out)
@@ -130,8 +133,9 @@ def test_all_gather_comm_impl(
130133
all_gather_permuted_hidden,
131134
all_gather_expert_tokens,
132135
_,
136+
_,
133137
) = all_gather_impl.permute(hidden_states, topk_ids, topk_weights,
134-
expert_map, num_experts)
138+
expert_map, num_experts, apply_a8_quantization)
135139

136140
# Use the same simulated MLP output for a fair comparison
137141
all_gather_mlp_output = native_mlp_output.clone()

tests/e2e/multicard/test_qwen3_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,4 +107,4 @@ def test_models_distributed_Qwen3_MOE_TP2_WITH_ACLGRAPH():
107107
tensor_parallel_size=2,
108108
enforce_eager=False,
109109
) as vllm_model:
110-
vllm_model.generate_greedy(example_prompts, max_tokens)
110+
vllm_model.generate_greedy(example_prompts, max_tokens)

vllm_ascend/distributed/moe_comm_method.py

Lines changed: 106 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def permute(
5454
topk_weights: torch.Tensor,
5555
expert_map: torch.Tensor,
5656
num_experts: int,
57-
) -> tuple[torch.Tensor, torch.Tensor, int]:
57+
apply_a8_quantization: bool,
58+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
5859
"""Pre-process before MLP.
5960
6061
Args:
@@ -64,6 +65,7 @@ def permute(
6465
expert_map (torch.Tensor): Tensor of shape (global_num_experts, )
6566
Mapping from global expert IDs to local expert IDs.
6667
num_experts (int): Number of local experts (experts on this device).
68+
apply_a8_quantization (bool): Whether to apply A8 quantization (W4A8 and W8A8).
6769
6870
Returns:
6971
tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing:
@@ -72,6 +74,8 @@ def permute(
7274
hidden_states based on topk_ids.
7375
- expert_tokens (torch.Tensor): Tensor of shape (num_experts, )
7476
Number of tokens assigned to each expert.
77+
- dynamic_scale (torch.Tensor, optional): Tensor of shape (num_experts, )
78+
Dynamic scale for each expert, used for quantization.
7579
- group_list_type (int): Type of group list, 0 for `cumsum`
7680
and 1 for `count`. This is mainly for `npu_grouped_matmul`
7781
to determine how to handle the output.
@@ -159,7 +163,8 @@ def permute(
159163
topk_weights: torch.Tensor,
160164
expert_map: torch.Tensor, # noqa: F841
161165
num_experts: int,
162-
) -> tuple[torch.Tensor, torch.Tensor, int]:
166+
apply_a8_quantization: bool,
167+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
163168
num_tokens = hidden_states.shape[0]
164169

165170
self.topk_weights = topk_weights
@@ -194,7 +199,7 @@ def permute(
194199

195200
group_list_type = 1 # `count` mode
196201

197-
return permuted_hidden_states, expert_tokens, group_list_type
202+
return permuted_hidden_states, expert_tokens, None, group_list_type
198203

199204
def unpermute(self, mlp_output: torch.Tensor,
200205
hidden_states: torch.Tensor) -> None:
@@ -219,7 +224,8 @@ def permute(
219224
topk_weights: torch.Tensor,
220225
expert_map: torch.Tensor,
221226
num_experts: int,
222-
) -> tuple[torch.Tensor, torch.Tensor, int]:
227+
apply_a8_quantization: bool,
228+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
223229
num_tokens = hidden_states.shape[0]
224230

225231
# Generate token indices and flatten
@@ -269,7 +275,7 @@ def permute(
269275

270276
group_list_type = 1 # `count` mode
271277

272-
return permuted_hidden_states, expert_tokens, group_list_type
278+
return permuted_hidden_states, expert_tokens, None, group_list_type
273279

274280
def unpermute(self, mlp_output: torch.Tensor,
275281
hidden_states: torch.Tensor) -> None:
@@ -375,7 +381,8 @@ def permute(
375381
topk_weights: torch.Tensor,
376382
expert_map: torch.Tensor,
377383
num_experts: int,
378-
) -> tuple[torch.Tensor, torch.Tensor, int]:
384+
apply_a8_quantization: bool,
385+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
379386
# Store tensors needed for post_process
380387
self.topk_ids = topk_ids
381388
self.topk_weights = topk_weights.to(torch.float32)
@@ -388,7 +395,7 @@ def permute(
388395
"moe_expert_num": self.moe_config.num_experts,
389396
"global_bs": 0,
390397
"scales": None,
391-
"quant_mode": 0,
398+
"quant_mode": 2 if apply_a8_quantization else 0,
392399
"group_ep": self.mc2_comm_name,
393400
"ep_world_size": self.moe_config.ep_size,
394401
"ep_rank_id": self.moe_config.ep_rank,
@@ -409,7 +416,7 @@ def permute(
409416

410417
(
411418
permuted_hidden_states,
412-
_, # dynamic_scale is not used
419+
dynamic_scale,
413420
self.assist_info_for_combine,
414421
expert_tokens,
415422
self.ep_recv_counts,
@@ -418,7 +425,7 @@ def permute(
418425

419426
group_list_type = 1
420427

421-
return permuted_hidden_states, expert_tokens, group_list_type
428+
return permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type
422429

423430
def unpermute(self, mlp_output: torch.Tensor,
424431
hidden_states: torch.Tensor) -> None:
@@ -457,3 +464,93 @@ def unpermute(self, mlp_output: torch.Tensor,
457464
combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine
458465

459466
hidden_states[:] = combine(**combine_kwargs)
467+
468+
469+
class AlltoAllCommImpl(MoECommMethod):
470+
"""This implementation is for the scenarios listed below:
471+
1. `enable_expert_parallel=True`.
472+
2. `npu_grouped_matmul` is available.
473+
474+
This implementation uses all-to-all communication to exchange tokens
475+
between data parallel ranks before and after the MLP computation. It should
476+
have better performance than AllGatherCommImpl when DP size > 1.
477+
"""
478+
479+
def __init__(self, moe_config: Optional[FusedMoEConfig]):
480+
super().__init__(moe_config)
481+
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
482+
get_token_dispatcher
483+
self.token_dispatcher = get_token_dispatcher(
484+
"TokenDispatcherWithAll2AllV")
485+
self._restore_tp_across_dp()
486+
487+
def _restore_tp_across_dp(self):
488+
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
489+
# tp_size and tp_rank.
490+
self.tp_size = get_tensor_model_parallel_world_size()
491+
self.tp_rank = get_tensor_model_parallel_rank()
492+
493+
def prepare(
494+
self, hidden_states: torch.Tensor,
495+
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
496+
self.num_tokens, _ = hidden_states.shape
497+
pad_size = self.tp_size - self.num_tokens
498+
499+
if pad_size > 0:
500+
hidden_states = nn.functional.pad(hidden_states,
501+
(0, 0, 0, pad_size))
502+
router_logits = nn.functional.pad(router_logits,
503+
(0, 0, 0, pad_size))
504+
505+
if self.tp_size > 1:
506+
split_hidden_states = torch.tensor_split(hidden_states,
507+
self.tp_size,
508+
dim=0)
509+
split_router_logits = torch.tensor_split(router_logits,
510+
self.tp_size,
511+
dim=0)
512+
self.split_hidden_states = split_hidden_states
513+
514+
hidden_states = split_hidden_states[self.tp_rank]
515+
router_logits = split_router_logits[self.tp_rank]
516+
517+
return hidden_states, router_logits
518+
519+
def finalize(self, hidden_states: torch.Tensor,
520+
reduce_results: bool) -> torch.Tensor:
521+
"""If TP size > 1, all-gather the hidden states to get the final output.
522+
523+
Also, unpad the hidden states if needed.
524+
"""
525+
if self.tp_size > 1:
526+
dist.all_gather(list(self.split_hidden_states), hidden_states,
527+
self.moe_config.tp_group.device_group)
528+
hidden_states = torch.cat(self.split_hidden_states, dim=0)
529+
530+
if self.num_tokens < hidden_states.shape[0]:
531+
hidden_states = hidden_states[:self.num_tokens]
532+
533+
return hidden_states
534+
535+
def permute(
536+
self,
537+
hidden_states: torch.Tensor,
538+
topk_ids: torch.Tensor,
539+
topk_weights: torch.Tensor,
540+
expert_map: torch.Tensor,
541+
num_experts: int,
542+
apply_a8_quantization: bool,
543+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
544+
results = self.token_dispatcher.token_dispatch(
545+
hidden_states,
546+
topk_weights,
547+
topk_ids,
548+
None,
549+
log2phy=None,
550+
with_quant=apply_a8_quantization)
551+
return results["hidden_states"], results["group_list"], results[
552+
"dynamic_scale"], results["group_list_type"]
553+
554+
def unpermute(self, mlp_output: torch.Tensor,
555+
hidden_states: torch.Tensor) -> None:
556+
hidden_states[:] = self.token_dispatcher.token_combine(mlp_output)

0 commit comments

Comments
 (0)