Skip to content

Commit 804c357

Browse files
committed
run _select_moe_comm_method only in MoE model
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
1 parent ae36ada commit 804c357

File tree

3 files changed

+9
-2
lines changed

3 files changed

+9
-2
lines changed

tests/ut/worker/test_model_runner_v1.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def test_select_moe_comm_method(soc_version, enable_expert_parallel,
6868
with patch('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version',
6969
return_value=soc_version), \
7070
patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank',
71+
return_value=True), \
72+
patch('vllm_ascend.worker.model_runner_v1.is_moe_model',
7173
return_value=True):
7274

7375
# Bind the real method to the mock object
@@ -102,6 +104,8 @@ def test_select_moe_comm_method_unsupported_soc():
102104
return_value=unsupported_soc), \
103105
patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank',
104106
return_value=True), \
107+
patch('vllm_ascend.worker.model_runner_v1.is_moe_model',
108+
return_value=True), \
105109
pytest.raises(ValueError, match=f"Unsupported soc_version: {unsupported_soc}"):
106110

107111
NPUModelRunner._select_moe_comm_method(mock_runner, 100, False)

vllm_ascend/ops/moe/moe_comm_method.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
def get_moe_comm_method(
4141
moe_comm_type: Optional[MoECommType]) -> Optional[MoECommMethod]:
42-
return _MoECommMethods.get(moe_comm_type)
42+
return _MoECommMethods.get(moe_comm_type, None)
4343

4444

4545
def setup_moe_comm_method(moe_config):

vllm_ascend/worker/model_runner_v1.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1841,7 +1841,7 @@ def _pool(
18411841
)
18421842

18431843
def _select_moe_comm_method(self, num_tokens: int,
1844-
with_prefill: bool) -> MoECommType:
1844+
with_prefill: bool) -> Optional[MoECommType]:
18451845
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
18461846
are designed for expert parallelism.
18471847
2. If expert parallel is enabled, we need to consider the soc version and the
@@ -1864,6 +1864,9 @@ def _select_moe_comm_method(self, num_tokens: int,
18641864
Returns:
18651865
MoECommType: The selected MoE communication method.
18661866
"""
1867+
if not is_moe_model(self.vllm_config):
1868+
return None
1869+
18671870
soc_version = get_ascend_soc_version()
18681871
quant_type = getattr(self.vllm_config.model_config.hf_config,
18691872
'moe_quantize', None)

0 commit comments

Comments
 (0)