Skip to content

Commit de9d711

Browse files
committed
fix(moe): fix moe_comm_method test case
- Moves the token dispatcher import into the `AlltoAllCommImpl` constructor to enable lazy loading. - Restricts MoE communication method logging to the global first rank to reduce log verbosity. - Updates MoE communication tests to accommodate a new parameter in the `permute` function. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent d46e1f5 commit de9d711

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

tests/e2e/multicard/moe/test_moe_comm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
@pytest.mark.parametrize("top_k_num", [2, 4])
3434
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
3535
@pytest.mark.parametrize("ep_rank", [0, 1])
36+
@pytest.mark.parametrize("use_a8", [False])
3637
def test_all_gather_comm_impl(
3738
num_tokens,
3839
hidden_size,
@@ -41,6 +42,7 @@ def test_all_gather_comm_impl(
4142
top_k_num,
4243
dtype,
4344
ep_rank,
45+
use_a8,
4446
mocker,
4547
):
4648
"""
@@ -118,8 +120,9 @@ def test_all_gather_comm_impl(
118120
native_permuted_hidden,
119121
native_expert_tokens,
120122
_,
123+
_,
121124
) = native_impl.permute(hidden_states, topk_ids, topk_weights, expert_map,
122-
num_experts)
125+
num_experts, use_a8)
123126
# Simulate MLP output
124127
native_mlp_output = torch.randn_like(native_permuted_hidden)
125128
native_impl.unpermute(native_mlp_output, native_hidden_states_out)
@@ -130,8 +133,9 @@ def test_all_gather_comm_impl(
130133
all_gather_permuted_hidden,
131134
all_gather_expert_tokens,
132135
_,
136+
_,
133137
) = all_gather_impl.permute(hidden_states, topk_ids, topk_weights,
134-
expert_map, num_experts)
138+
expert_map, num_experts, use_a8)
135139

136140
# Use the same simulated MLP output for a fair comparison
137141
all_gather_mlp_output = native_mlp_output.clone()

vllm_ascend/distributed/moe_comm_method.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
from vllm_ascend.distributed.communication_op import \
1515
data_parallel_reduce_scatter
1616
from vllm_ascend.distributed.parallel_state import get_mc2_group
17-
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
18-
get_token_dispatcher
1917
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
2018

2119

@@ -477,6 +475,8 @@ class AlltoAllCommImpl(MoECommMethod):
477475

478476
def __init__(self, moe_config: Optional[FusedMoEConfig]):
479477
super().__init__(moe_config)
478+
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
479+
get_token_dispatcher
480480
self.token_dispatcher = get_token_dispatcher(
481481
"TokenDispatcherWithAll2AllV")
482482
self._restore_tp_across_dp()

vllm_ascend/worker/model_runner_v1.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,8 +1632,9 @@ def _select_moe_comm_method(self, num_tokens: int) -> str:
16321632
else:
16331633
raise ValueError(f"Unsupported soc_version: {soc_version}")
16341634

1635-
logger.debug(f"num_tokens: {num_tokens}, "
1636-
f"moe_comm_method: {moe_comm_method}")
1635+
if is_global_first_rank():
1636+
logger.debug(f"num_tokens: {num_tokens}, "
1637+
f"moe_comm_method: {moe_comm_method}")
16371638

16381639
return moe_comm_method
16391640

0 commit comments

Comments
 (0)