Skip to content

Commit 383b307

Browse files
committed
feat(moe): Add All-to-All communication method
This method leverages an `all-to-all` collective communication pattern, which is more efficient than the existing `all-gather` strategy for large token counts on newer hardware. The model runner now dynamically selects the optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) based on the token count and the underlying Ascend SoC version. But note that all-gather has not supported quantized models. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent 4812f00 commit 383b307

File tree

4 files changed

+163
-55
lines changed

4 files changed

+163
-55
lines changed

vllm_ascend/distributed/moe_comm_method.py

Lines changed: 92 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from vllm_ascend.distributed.communication_op import \
1515
data_parallel_reduce_scatter
1616
from vllm_ascend.distributed.parallel_state import get_mc2_group
17+
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
18+
get_token_dispatcher
1719
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
1820

1921

@@ -55,7 +57,7 @@ def permute(
5557
expert_map: torch.Tensor,
5658
num_experts: int,
5759
use_a8: bool,
58-
) -> tuple[torch.Tensor, torch.Tensor, int]:
60+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
5961
"""Pre-process before MLP.
6062
6163
Args:
@@ -161,7 +163,7 @@ def permute(
161163
expert_map: torch.Tensor, # noqa: F841
162164
num_experts: int,
163165
use_a8: bool,
164-
) -> tuple[torch.Tensor, torch.Tensor, int]:
166+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
165167
num_tokens = hidden_states.shape[0]
166168

167169
self.topk_weights = topk_weights
@@ -222,7 +224,7 @@ def permute(
222224
expert_map: torch.Tensor,
223225
num_experts: int,
224226
use_a8: bool,
225-
) -> tuple[torch.Tensor, torch.Tensor, int]:
227+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
226228
num_tokens = hidden_states.shape[0]
227229

228230
# Generate token indices and flatten
@@ -379,7 +381,7 @@ def permute(
379381
expert_map: torch.Tensor,
380382
num_experts: int,
381383
use_a8: bool,
382-
) -> tuple[torch.Tensor, torch.Tensor, int]:
384+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
383385
# Store tensors needed for post_process
384386
self.topk_ids = topk_ids
385387
self.topk_weights = topk_weights.to(torch.float32)
@@ -461,3 +463,89 @@ def unpermute(self, mlp_output: torch.Tensor,
461463
combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine
462464

463465
hidden_states[:] = combine(**combine_kwargs)
466+
467+
468+
class AlltoAllCommImpl(MoECommMethod):
469+
"""This implementation is for the scenarios listed below:
470+
1. `enable_expert_parallel=True`.
471+
2. `npu_grouped_matmul` is available.
472+
473+
This implementation uses all-to-all communication to exchange tokens
474+
between data parallel ranks before and after the MLP computation. It should
475+
have better performance than AllGatherCommImpl when DP size > 1.
476+
"""
477+
478+
def __init__(self, moe_config: Optional[FusedMoEConfig]):
479+
super().__init__(moe_config)
480+
self.token_dispatcher = get_token_dispatcher(
481+
"TokenDispatcherWithAll2AllV")
482+
self._restore_tp_across_dp()
483+
484+
def _restore_tp_across_dp(self):
485+
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
486+
# tp_size and tp_rank.
487+
self.tp_size = get_tensor_model_parallel_world_size()
488+
self.tp_rank = get_tensor_model_parallel_rank()
489+
490+
def prepare(
491+
self, hidden_states: torch.Tensor,
492+
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
493+
self.num_tokens, _ = hidden_states.shape
494+
pad_size = self.tp_size - self.num_tokens
495+
496+
if pad_size > 0:
497+
hidden_states = nn.functional.pad(hidden_states,
498+
(0, 0, 0, pad_size))
499+
router_logits = nn.functional.pad(router_logits,
500+
(0, 0, 0, pad_size))
501+
502+
if self.tp_size > 1:
503+
split_hidden_states = torch.tensor_split(hidden_states,
504+
self.tp_size,
505+
dim=0)
506+
split_router_logits = torch.tensor_split(router_logits,
507+
self.tp_size,
508+
dim=0)
509+
self.split_hidden_states = split_hidden_states
510+
511+
hidden_states = split_hidden_states[self.tp_rank]
512+
router_logits = split_router_logits[self.tp_rank]
513+
514+
return hidden_states, router_logits
515+
516+
def finalize(self, hidden_states: torch.Tensor,
517+
reduce_results: bool) -> torch.Tensor:
518+
"""If TP size > 1, all-gather the hidden states to get the final output.
519+
520+
Also, unpad the hidden states if needed.
521+
"""
522+
if self.tp_size > 1:
523+
dist.all_gather(list(self.split_hidden_states), hidden_states,
524+
self.moe_config.tp_group.device_group)
525+
hidden_states = torch.cat(self.split_hidden_states, dim=0)
526+
527+
if self.num_tokens < hidden_states.shape[0]:
528+
hidden_states = hidden_states[:self.num_tokens]
529+
530+
return hidden_states
531+
532+
def permute(
533+
self,
534+
hidden_states: torch.Tensor,
535+
topk_ids: torch.Tensor,
536+
topk_weights: torch.Tensor,
537+
expert_map: torch.Tensor,
538+
num_experts: int,
539+
use_a8: bool,
540+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
541+
results = self.token_dispatcher.token_dispatch(hidden_states,
542+
topk_weights,
543+
topk_ids,
544+
None,
545+
log2phy=None)
546+
return results["hidden_states"], results["group_list"], results[
547+
"dynamic_scale"], results["group_list_type"]
548+
549+
def unpermute(self, mlp_output: torch.Tensor,
550+
hidden_states: torch.Tensor) -> None:
551+
hidden_states[:] = self.token_dispatcher.token_combine(mlp_output)

vllm_ascend/ops/common_fused_moe.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,20 @@
1919

2020
import torch
2121
import torch_npu
22-
from vllm.config import CompilationLevel, get_current_vllm_config
2322
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
2423
from vllm.forward_context import get_forward_context
2524
from vllm.model_executor.layers.fused_moe.layer import (
2625
FusedMoE, UnquantizedFusedMoEMethod)
2726

28-
from vllm_ascend.ascend_config import get_ascend_config
2927
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
30-
MC2CommImpl,
31-
MoECommMethod)
28+
AlltoAllCommImpl,
29+
MC2CommImpl)
3230
from vllm_ascend.distributed.parallel_state import get_mc2_group
33-
from vllm_ascend.ops.fused_moe import apply_mlp, fused_experts_moge
31+
from vllm_ascend.ops.fused_moe import fused_experts_moge
3432
from vllm_ascend.ops.layers.experts_selector import select_experts
35-
from vllm_ascend.utils import is_310p, ACL_FORMAT_FRACTAL_NZ
33+
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
34+
setup_token_dispatchers
35+
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
3636

3737
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
3838

@@ -66,26 +66,32 @@ def fused_experts(
6666
# Check constraints
6767
assert hidden_states.shape[1] == w1.shape[1], (
6868
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[1]}")
69-
7069
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
7170
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
7271
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
7372
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
7473
assert hidden_states.dtype in [
7574
torch.float32, torch.float16, torch.bfloat16
7675
]
76+
if (use_int8_w8a8 or use_int4_w4a8):
77+
assert w1_scale is not None and w2_scale is not None, \
78+
"INT8 quantization requires weight scales."
79+
80+
w1_scale = w1_scale.to(torch.float32)
81+
down_scale = [w2_scale]
82+
down_output_dtype = w2_scale.dtype
83+
else:
84+
down_scale = None
85+
down_output_dtype = None
7786

7887
moe_comm_method = get_forward_context().moe_comm_method
7988
assert moe_comm_method is not None, "Missing communication context"
8089

8190
num_experts = w1.shape[0]
8291

8392
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type = moe_comm_method.permute(
84-
hidden_states, topk_ids, topk_weights, expert_map, num_experts, use_int8_w8a8 or use_int4_w4a8)
85-
86-
if (use_int8_w8a8 or use_int4_w4a8) and dynamic_scale is None:
87-
permuted_hidden_states, dynamic_scale = torch_npu.npu_dynamic_quant(
88-
permuted_hidden_states)
93+
hidden_states, topk_ids, topk_weights, expert_map, num_experts,
94+
use_int8_w8a8 or use_int4_w4a8)
8995

9096
gate_up_output = torch_npu.npu_grouped_matmul(
9197
x=[permuted_hidden_states],
@@ -97,10 +103,10 @@ def fused_experts(
97103
output_dtype=torch.int32 if use_int8_w8a8 else None,
98104
)[0]
99105

100-
if use_int8_w8a8:
106+
if (use_int8_w8a8 or use_int4_w4a8):
101107
activated_output, activated_output_scale = torch_npu.npu_dequant_swiglu_quant(
102108
x=gate_up_output,
103-
weight_scale=w1_scale.to(torch.float32),
109+
weight_scale=w1_scale,
104110
activation_scale=dynamic_scale,
105111
bias=None,
106112
quant_scale=None,
@@ -109,42 +115,28 @@ def fused_experts(
109115
activate_left=True,
110116
quant_mode=1,
111117
)
118+
activated_output_scale = [activated_output_scale]
112119
else:
113120
activated_output = torch_npu.npu_swiglu(gate_up_output)
114121
activated_output_scale = None
115122

116123
down_output = torch_npu.npu_grouped_matmul(
117124
x=[activated_output],
118125
weight=[w2],
119-
scale=[w2_scale] if use_int8_w8a8 else None,
120-
per_token_scale=[activated_output_scale] if use_int8_w8a8 else None,
126+
scale=down_scale,
127+
per_token_scale=activated_output_scale,
121128
split_item=2,
122129
group_list_type=group_list_type,
123130
group_type=0,
124131
group_list=expert_tokens,
125-
output_dtype=w2_scale.dtype if use_int8_w8a8 else None,
132+
output_dtype=down_output_dtype,
126133
)[0]
127134

128135
moe_comm_method.unpermute(down_output, hidden_states)
129136

130137
return hidden_states
131138

132139

133-
def unquantized_fused_moe_init_func(self, *args, **kwargs):
134-
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
135-
vllm_config = get_current_vllm_config()
136-
self.max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
137-
138-
ascend_config = get_ascend_config()
139-
140-
if ascend_config.torchair_graph_config.enabled:
141-
self.use_aclgraph = False
142-
else:
143-
self.use_aclgraph = (vllm_config.compilation_config.level
144-
== CompilationLevel.PIECEWISE
145-
and not vllm_config.model_config.enforce_eager)
146-
147-
148140
def forward_oot(
149141
self,
150142
layer: torch.nn.Module,
@@ -276,12 +268,19 @@ def __init__(
276268
has_bias,
277269
)
278270

271+
with_quant = quant_config is not None
272+
setup_token_dispatchers(self.moe_config.ep_size,
273+
top_k=self.top_k,
274+
num_experts=self.global_num_experts,
275+
num_local_experts=self.local_num_experts,
276+
with_quant=with_quant)
277+
279278
self.moe_config.tp_group = get_tp_group()
280279
self.moe_config.dp_group = get_dp_group()
281280
self.moe_config.ep_group = get_ep_group()
282281
self.moe_config.mc2_group = get_mc2_group()
283282

284-
for method in {AllGatherCommImpl, MC2CommImpl}:
283+
for method in {AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl}:
285284
setattr(
286285
self, method.__name__.lower(),
287286
method(moe_config=self.moe_config)) # type: ignore[abstract]
@@ -332,6 +331,5 @@ def forward_impl(self, hidden_states: torch.Tensor,
332331
return final_hidden_states
333332

334333

335-
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
336334
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading
337335
UnquantizedFusedMoEMethod.forward_oot = forward_oot

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,17 @@
1919

2020
import torch
2121
import torch_npu
22+
from vllm.config import CompilationLevel, get_current_vllm_config
2223
from vllm.distributed import get_ep_group
2324
from vllm.forward_context import get_forward_context
2425

2526
import vllm_ascend.envs as envs_ascend
27+
from vllm_ascend.ascend_config import get_ascend_config
2628
from vllm_ascend.ascend_forward_context import FusedMoEState
2729
from vllm_ascend.distributed.parallel_state import get_mc2_group
28-
from vllm_ascend.ops.fused_moe import unified_fused_experts_eager
2930
from vllm_ascend.ops.common_fused_moe import \
3031
fused_experts as unified_fused_experts
32+
from vllm_ascend.ops.fused_moe import unified_fused_experts_eager
3133
from vllm_ascend.ops.layers.experts_selector import select_experts
3234
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, dispose_tensor
3335

@@ -285,6 +287,13 @@ def __init__(self):
285287

286288
self.ep_group = get_ep_group()
287289

290+
vllm_config = get_current_vllm_config()
291+
ascend_config = get_ascend_config()
292+
self.use_aclgraph = (
293+
vllm_config.compilation_config.level == CompilationLevel.PIECEWISE
294+
and not vllm_config.model_config.enforce_eager
295+
and not ascend_config.torchair_graph_config.enabled)
296+
288297
try:
289298
device_group = get_mc2_group().device_group
290299
# TODO: Try local_rank = ep_group.rank_in_group
@@ -377,19 +386,18 @@ def apply(
377386
e_score_correction_bias=e_score_correction_bias,
378387
global_num_experts=global_num_experts)
379388

380-
moe_comm_method = get_forward_context().moe_comm_method
381-
382-
return unified_fused_experts(
383-
hidden_states=x,
384-
w1=layer.w13_weight,
385-
w2=layer.w2_weight,
386-
topk_weights=topk_weights,
387-
topk_ids=topk_ids,
388-
use_int8_w8a8=True,
389-
w1_scale=layer.w13_weight_scale,
390-
w2_scale=layer.w2_weight_scale,
391-
expert_map=expert_map,
392-
)
389+
if self.use_aclgraph:
390+
return unified_fused_experts(
391+
hidden_states=x,
392+
w1=layer.w13_weight,
393+
w2=layer.w2_weight,
394+
topk_weights=topk_weights,
395+
topk_ids=topk_ids,
396+
use_int8_w8a8=True,
397+
w1_scale=layer.w13_weight_scale,
398+
w2_scale=layer.w2_weight_scale,
399+
expert_map=expert_map,
400+
)
393401

394402
fused_moe_state = get_forward_context().fused_moe_state
395403
shared_gate_up, shared_dequant_scale = None, None

vllm_ascend/worker/model_runner_v1.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@
8989
from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata
9090
from vllm_ascend.torchair.torchair_mla import AscendMLATorchairMetadata
9191
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
92-
ProfileExecuteDuration, is_310p,
92+
AscendSocVersion, ProfileExecuteDuration,
93+
get_ascend_soc_version, is_310p,
9394
vllm_version_is)
9495
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
9596
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
@@ -1614,8 +1615,21 @@ def _pool(
16141615
)
16151616

16161617
def _select_moe_comm_method(self, num_tokens: int) -> str:
1617-
return ("mc2"
1618-
if num_tokens <= self.mc2_tokens_capacity else "allgather")
1618+
soc_version = get_ascend_soc_version()
1619+
1620+
if num_tokens <= self.mc2_tokens_capacity:
1621+
moe_comm_method = "mc2"
1622+
elif soc_version in {AscendSocVersion.A2}:
1623+
moe_comm_method = "allgather"
1624+
elif soc_version in {AscendSocVersion.A3}:
1625+
moe_comm_method = "alltoall"
1626+
else:
1627+
raise ValueError(f"Unsupported soc_version: {soc_version}")
1628+
1629+
logger.debug(f"num_tokens: {num_tokens}, "
1630+
f"moe_comm_method: {moe_comm_method}")
1631+
1632+
return moe_comm_method
16191633

16201634
@torch.inference_mode()
16211635
def execute_model(

0 commit comments

Comments
 (0)