Skip to content

Commit 4969046

Browse files
committed
Refactor MoE parameter initialization for flexibility
Pass the Hugging Face configuration object directly to the MoE communication method constructor. This allows the method to handle different attribute names for MoE parameters, such as `num_experts` and `n_routed_experts`. This change improves robustness and makes the implementation more compatible with various MoE model configurations. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent fc3899e commit 4969046

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

vllm_ascend/distributed/moe_comm_method.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
import torch_npu
5+
from transformers.configuration_utils import PretrainedConfig
56
from vllm.distributed.parallel_state import get_ep_group, get_tp_group
67
from vllm.forward_context import ForwardContext, get_forward_context
78
from vllm.utils import direct_register_custom_op
@@ -17,13 +18,19 @@ def __init__(
1718
self,
1819
device: torch.device,
1920
dtype: torch.dtype,
20-
top_k_num: int,
21-
global_num_experts: int,
21+
hf_config: PretrainedConfig,
2222
):
2323
self.device = device
2424
self.dtype = dtype
25-
self.top_k_num = top_k_num
26-
self.global_num_experts = global_num_experts
25+
self.top_k_num = getattr(hf_config, "num_experts_per_tok", 0)
26+
# global_num_experts may be called num_experts or n_routed_experts in different models.
27+
possible_keys = ["num_experts", "n_routed_experts"]
28+
for key in possible_keys:
29+
if hasattr(hf_config, key):
30+
self.global_num_experts = getattr(hf_config, key)
31+
break
32+
else:
33+
self.global_num_experts = 0
2734

2835
@abstractmethod
2936
def _pre_process(
@@ -232,10 +239,9 @@ def __init__(
232239
self,
233240
device: torch.device,
234241
dtype: torch.dtype,
235-
top_k_num: int,
236-
global_num_experts: int,
242+
hf_config: PretrainedConfig,
237243
):
238-
super().__init__(device, dtype, top_k_num, global_num_experts)
244+
super().__init__(device, dtype, hf_config)
239245

240246
# Shared communication configurations
241247
ep_group = get_mc2_group()

vllm_ascend/worker/model_runner_v1.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,10 +1324,8 @@ def _process_reqs(
13241324
num_tokens_across_dp=num_tokens_across_dp,
13251325
with_prefill=with_prefill,
13261326
reserved_mc2_mask=self.reserved_mc2_mask,
1327-
moe_comm_method=moe_comm_method(
1328-
self.device, self.dtype,
1329-
self.model_config.hf_config.num_experts_per_tok,
1330-
self.model_config.hf_config.num_experts),
1327+
moe_comm_method=moe_comm_method(self.device, self.dtype,
1328+
self.model_config.hf_config),
13311329
num_actual_tokens=total_num_scheduled_tokens):
13321330
with ProfileExecuteDuration().capture_async("forward"):
13331331
self.maybe_setup_kv_connector(scheduler_output)
@@ -1990,9 +1988,7 @@ def _dummy_run(
19901988
in_profile_run=self.in_profile_run,
19911989
reserved_mc2_mask=self.reserved_mc2_mask,
19921990
moe_comm_method=moe_comm_method(
1993-
self.device, self.dtype,
1994-
self.model_config.hf_config.num_experts_per_tok,
1995-
self.model_config.hf_config.num_experts),
1991+
self.device, self.dtype, self.model_config.hf_config),
19961992
num_actual_tokens=0,
19971993
):
19981994
model_kwargs = {}

0 commit comments

Comments
 (0)