Skip to content

Commit 48fd2a1

Browse files
committed
[bugfix] Pass in mc2 param according to soc_version and is_torchair_graph_mode
Signed-off-by: zzzzwwjj <1183291235@qq.com>
1 parent 0c7375f commit 48fd2a1

File tree

8 files changed

+138
-83
lines changed

8 files changed

+138
-83
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,27 @@
11
from contextlib import contextmanager
2+
from enum import Enum
23
from typing import Any, Optional
34

45
import torch
56
from vllm.config import VllmConfig
67
from vllm.forward_context import get_forward_context, set_forward_context
78

8-
from vllm_ascend.utils import get_fused_moe_state
9+
10+
class FusedMoEState(Enum):
11+
AllGather = 0
12+
All2All = 1
13+
MC2 = 2
14+
15+
16+
# TODO(zzzzwwjj): add soc_version to choose branch
17+
def get_fused_moe_state(ep_size: int, with_prefill: bool):
18+
if ep_size == 1:
19+
return FusedMoEState.AllGather
20+
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
21+
elif ep_size < 16 or with_prefill:
22+
return FusedMoEState.All2All
23+
else:
24+
return FusedMoEState.MC2
925

1026

1127
@contextmanager
@@ -31,11 +47,16 @@ def set_ascend_forward_context(
3147

3248
ep_size = torch.distributed.get_world_size(
3349
) if vllm_config.parallel_config.enable_expert_parallel else 1
50+
3451
fused_moe_state = get_fused_moe_state(ep_size, with_prefill)
52+
3553
forward_context.fused_moe_state = fused_moe_state
3654

3755
forward_context.in_profile_run = in_profile_run
3856

57+
forward_context.max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item(
58+
)
59+
3960
try:
4061
yield
4162
finally:

vllm_ascend/models/deepseek_dbo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,7 @@ def _forward_ms_layer(
627627
if self.dp_size > 1:
628628
if attn_metadata[i] is not None:
629629
max_num_tokens_across_dp = get_forward_context(
630-
).dp_metadata.max_tokens_across_dp_cpu
630+
).max_tokens_across_dp
631631
if num_tokens[i] < max_num_tokens_across_dp:
632632
hidden_states[i] = nn.functional.pad(
633633
hidden_states[i],

vllm_ascend/models/deepseek_v2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from vllm.attention import Attention, AttentionMetadata
3636
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
3737
get_current_vllm_config)
38-
from vllm.distributed import (get_dp_group, get_ep_group, get_pp_group,
38+
from vllm.distributed import (get_dp_group, get_pp_group,
3939
get_tensor_model_parallel_world_size,
4040
get_tp_group)
4141
from vllm.forward_context import get_forward_context
@@ -284,7 +284,6 @@ def __init__(
284284

285285
self.tp_group = get_tp_group().device_group
286286
self.tp_rank = get_tp_group().rank_in_group
287-
self.ep_group = get_ep_group()
288287
self.kv_consumer = None
289288
transfer_config = get_current_vllm_config().kv_transfer_config
290289
if transfer_config is not None:

vllm_ascend/ops/fused_moe.py

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# This file is a part of the vllm-ascend project.
1616
# Adapted from vllm/tests/kernels/test_moe.py
1717

18+
import math
1819
import os
1920
from typing import Any, Callable, List, Optional, Tuple, Union
2021

@@ -37,9 +38,11 @@
3738

3839
import vllm_ascend.envs as envs_ascend
3940
from vllm_ascend.ascend_config import get_ascend_config
41+
from vllm_ascend.ascend_forward_context import FusedMoEState
4042
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
41-
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
42-
npu_stream_switch, npu_wait_tensor)
43+
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
44+
get_ascend_soc_version, npu_stream_switch,
45+
npu_wait_tensor)
4346

4447
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
4548

@@ -117,9 +120,24 @@ def fused_experts_with_mc2(
117120
top_k: int,
118121
expert_map: torch.Tensor = None,
119122
moe_all_to_all_group_name: Optional[str] = None,
120-
shared_experts: Optional[Any] = None
123+
shared_experts: Optional[Any] = None,
124+
is_torchair: bool = False,
121125
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
122-
global_bs = 0
126+
quant_mode = 0
127+
ep_group = get_ep_group()
128+
ep_rank_id = ep_group.rank_in_group
129+
ep_world_size = ep_group.world_size
130+
tp_world_size = get_tp_group().world_size
131+
132+
# NOTE: `global_bs` should be equal to `max_num_tokens_across_dp` * `ep_world_size`,
133+
# and `max_num_tokens_across_dp` has been split into `tp_world_size` parts before.
134+
global_bs = math.ceil(get_forward_context().max_tokens_across_dp /
135+
tp_world_size) * ep_world_size
136+
137+
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
138+
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
139+
or is_torchair)
140+
123141
moe_expert_num = len(expert_map)
124142
kwargs_mc2 = {
125143
"x": hidden_states,
@@ -130,23 +148,20 @@ def fused_experts_with_mc2(
130148
"global_bs": global_bs,
131149
}
132150

133-
quant_mode = 0
134-
ep_group = get_ep_group().device_group
135-
assert torch.distributed.get_world_size() == ep_group.world_size
136-
local_rank = torch.distributed.get_rank(group=ep_group)
137-
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
138-
139151
stage1_kwargs = {
140152
"scales": None,
141153
"quant_mode": quant_mode,
142154
"group_ep": moe_all_to_all_group_name,
143-
"ep_world_size": all_to_all_group_size,
144-
"ep_rank_id": local_rank,
145-
# "group_tp": self.moe_rs_group_name,
146-
"group_tp": moe_all_to_all_group_name,
147-
"tp_world_size": 1,
148-
"tp_rank_id": 0,
155+
"ep_world_size": ep_world_size,
156+
"ep_rank_id": ep_rank_id,
149157
}
158+
if need_extra_args:
159+
stage1_kwargs.update({
160+
"group_tp": moe_all_to_all_group_name,
161+
"tp_world_size": 1,
162+
"tp_rank_id": 0,
163+
})
164+
150165
kwargs_mc2.update(stage1_kwargs)
151166

152167
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
@@ -205,14 +220,16 @@ def fused_experts_with_mc2(
205220
stage3_kwargs = {
206221
"ep_send_counts": ep_recv_counts,
207222
"group_ep": moe_all_to_all_group_name,
208-
"ep_world_size": all_to_all_group_size,
209-
"ep_rank_id": local_rank,
210-
"tp_send_counts": tp_recv_counts,
211-
# "group_tp": self.moe_rs_group_name,
212-
"group_tp": moe_all_to_all_group_name,
213-
"tp_world_size": 1,
214-
"tp_rank_id": 0,
223+
"ep_world_size": ep_world_size,
224+
"ep_rank_id": ep_rank_id,
215225
}
226+
if need_extra_args:
227+
stage3_kwargs.update({
228+
"tp_send_counts": tp_recv_counts,
229+
"group_tp": moe_all_to_all_group_name,
230+
"tp_world_size": 1,
231+
"tp_rank_id": 0,
232+
})
216233
kwargs_mc2.update(stage3_kwargs)
217234

218235
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
@@ -842,17 +859,14 @@ def __init__(self, moe: MoEConfig = None):
842859
super().__init__(moe=moe)
843860
vllm_config = get_current_vllm_config()
844861

845-
self.ep_group = get_ep_group()
846-
self.ep_size = self.ep_group.world_size
847862
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
848-
self.local_batch_size = self.global_batch_size // self.ep_size
849863
self.max_model_len = vllm_config.model_config.max_model_len
850864

851865
ascend_config = get_ascend_config()
852866
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
853867

854868
try:
855-
device_group = self.ep_group.device_group
869+
device_group = get_ep_group().device_group
856870
# TODO: Try local_rank = ep_group.rank_in_group
857871
local_rank = torch.distributed.get_rank(group=device_group)
858872
backend = device_group._get_backend(torch.device("npu"))
@@ -939,7 +953,8 @@ def apply(
939953
top_k=top_k,
940954
expert_map=expert_map,
941955
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
942-
shared_experts=shared_experts)
956+
shared_experts=shared_experts,
957+
is_torchair=self.torchair_graph_enabled)
943958
elif fused_moe_state == FusedMoEState.AllGather:
944959
return fused_experts(hidden_states=x,
945960
w1=layer.w13_weight,
@@ -1049,17 +1064,15 @@ def __init__(
10491064
self.local_num_experts, self.expert_map = \
10501065
expert_load_balancer.get_rank_placement_map(
10511066
self.moe_instance_id,
1052-
get_ep_group().rank_in_group)
1067+
self.ep_rank)
10531068
self.log2phy = expert_load_balancer.get_rank_log2phy_map(
1054-
self.moe_instance_id,
1055-
get_ep_group().rank_in_group)
1069+
self.moe_instance_id, self.ep_rank)
10561070
self.global_redundant_expert_num = \
10571071
expert_load_balancer.get_global_redundant_expert_num()
10581072
else:
10591073
# Create a tensor of size num_experts filled with -1
10601074
self.local_num_experts, self.expert_map = determine_expert_map(
1061-
self.ep_size,
1062-
get_ep_group().rank_in_group, self.global_num_experts)
1075+
self.ep_size, self.ep_rank, self.global_num_experts)
10631076

10641077
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
10651078
self.enable_multistream_moe = \
@@ -1102,7 +1115,6 @@ def __init__(
11021115
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
11031116
moe_quant_params["intermediate_size_full"] = intermediate_size
11041117

1105-
self.ep_group = get_ep_group()
11061118
# NOTE: self.tp_group is not expert_tp_group
11071119
self.tp_group = get_tp_group().device_group
11081120
self.quant_method.create_weights(layer=self, **moe_quant_params)
@@ -1148,7 +1160,7 @@ def forward(self,
11481160
# NOTE: When in torchair graph, it has been padded in model_runner_v1
11491161
if not self.torchair_graph_enabled or is_prefill:
11501162
max_num_tokens_across_dp = get_forward_context(
1151-
).dp_metadata.max_tokens_across_dp_cpu
1163+
).max_tokens_across_dp
11521164
if num_tokens < max_num_tokens_across_dp:
11531165
hidden_states = nn.functional.pad(
11541166
hidden_states,

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,21 @@
1515
# limitations under the License.
1616
#
1717

18+
import math
1819
from typing import Any, Callable, Dict, Optional, Tuple, Union
1920

2021
import torch
2122
import torch.distributed as dist
2223
import torch_npu
23-
from vllm.distributed import GroupCoordinator, get_ep_group
24+
from vllm.distributed import GroupCoordinator, get_ep_group, get_tp_group
2425
from vllm.forward_context import get_forward_context
2526

2627
from vllm_ascend.ascend_config import get_ascend_config
28+
from vllm_ascend.ascend_forward_context import FusedMoEState
2729
from vllm_ascend.ops.fused_moe import select_experts
28-
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
29-
npu_stream_switch, npu_wait_tensor)
30+
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
31+
get_ascend_soc_version, npu_stream_switch,
32+
npu_wait_tensor)
3033

3134

3235
def apply_mlp(hidden_states: torch.Tensor,
@@ -116,10 +119,25 @@ def fused_experts_with_mc2(
116119
log2phy: torch.Tensor = None,
117120
global_redundant_expert_num: int = 0,
118121
shared_experts: Optional[Any] = None,
122+
is_torchair: bool = False,
119123
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
120124
if log2phy:
121125
topk_ids = log2phy[topk_ids]
122-
global_bs = 0
126+
quant_mode = 2
127+
ep_group = get_ep_group()
128+
ep_rank_id = ep_group.rank_in_group
129+
ep_world_size = ep_group.world_size
130+
tp_world_size = get_tp_group().world_size
131+
132+
# NOTE: `global_bs` should be equal to `max_num_tokens_across_dp` * `ep_world_size`,
133+
# and `max_num_tokens_across_dp` has been split into `tp_world_size` parts before.
134+
global_bs = math.ceil(get_forward_context().max_tokens_across_dp /
135+
tp_world_size) * ep_world_size
136+
137+
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
138+
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
139+
or is_torchair)
140+
123141
if (expert_map is not None):
124142
moe_expert_num = len(expert_map) + global_redundant_expert_num
125143
else:
@@ -134,28 +152,19 @@ def fused_experts_with_mc2(
134152
"global_bs": global_bs,
135153
}
136154

137-
rank = torch.distributed.get_rank()
138-
139-
quant_mode = 2
140-
ep_group = get_ep_group().device_group
141-
local_rank = torch.distributed.get_rank(group=ep_group)
142-
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
143-
144-
world_szie = torch.distributed.get_world_size()
145-
tp_size = world_szie // all_to_all_group_size
146-
tp_rank = rank % tp_size
147-
148155
stage1_kwargs = {
149156
"scales": None,
150157
"quant_mode": quant_mode,
151158
"group_ep": moe_all_to_all_group_name,
152-
"ep_world_size": all_to_all_group_size,
153-
"ep_rank_id": local_rank,
154-
# "group_tp": self.moe_rs_group_name,
155-
"group_tp": moe_all_to_all_group_name,
156-
"tp_world_size": tp_size,
157-
"tp_rank_id": tp_rank,
159+
"ep_world_size": ep_world_size,
160+
"ep_rank_id": ep_rank_id,
158161
}
162+
if need_extra_args:
163+
stage1_kwargs.update({
164+
"group_tp": moe_all_to_all_group_name,
165+
"tp_world_size": 1,
166+
"tp_rank_id": 0,
167+
})
159168
kwargs_mc2.update(stage1_kwargs)
160169

161170
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
@@ -196,14 +205,16 @@ def fused_experts_with_mc2(
196205
stage3_kwargs = {
197206
"ep_send_counts": ep_recv_counts,
198207
"group_ep": moe_all_to_all_group_name,
199-
"ep_world_size": all_to_all_group_size,
200-
"ep_rank_id": local_rank,
201-
"tp_send_counts": tp_recv_counts,
202-
# "group_tp": self.moe_rs_group_name,
203-
"group_tp": moe_all_to_all_group_name,
204-
"tp_world_size": tp_size,
205-
"tp_rank_id": tp_rank,
208+
"ep_world_size": ep_world_size,
209+
"ep_rank_id": ep_rank_id,
206210
}
211+
if need_extra_args:
212+
stage3_kwargs.update({
213+
"tp_send_counts": tp_recv_counts,
214+
"group_tp": moe_all_to_all_group_name,
215+
"tp_world_size": 1,
216+
"tp_rank_id": 0,
217+
})
207218
kwargs_mc2.update(stage3_kwargs)
208219

209220
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
@@ -679,7 +690,8 @@ def apply(
679690
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
680691
log2phy=log2phy,
681692
global_redundant_expert_num=global_redundant_expert_num,
682-
shared_experts=shared_experts)
693+
shared_experts=shared_experts,
694+
is_torchair=self.torchair_graph_enabled)
683695
elif fused_moe_state == FusedMoEState.AllGather:
684696
return fused_experts(hidden_states=x,
685697
w1=layer.w13_weight,

0 commit comments

Comments
 (0)