Skip to content

Commit 0dfb10d

Browse files
committed
Simplifies MoE comm; removes unused MC2 params
Removes dead/commented paths in the MoE communication implementation and cleans up legacy chunking/gather remnants. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent 6607b2f commit 0dfb10d

File tree

2 files changed

+0
-33
lines changed

2 files changed

+0
-33
lines changed

vllm_ascend/distributed/moe_comm_method.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -305,23 +305,6 @@ def _pre_process(
305305
self.topk_weights = topk_weights.to(torch.float32)
306306
self.mc2_mask = get_forward_context().mc2_mask
307307

308-
# tp_size = get_tensor_model_parallel_world_size()
309-
# self.chunked_hidden_states = torch.tensor_split(hidden_states,
310-
# tp_size,
311-
# dim=0)
312-
# chunked_topk_ids = torch.tensor_split(self.topk_ids,
313-
# tp_size,
314-
# dim=0)
315-
# chunked_topk_weights = torch.tensor_split(self.topk_weights,
316-
# tp_size,
317-
# dim=0)
318-
# chunked_mc2_mask = torch.tensor_split(self.mc2_mask, tp_size, dim=0)
319-
# tp_rank = get_tensor_model_parallel_rank()
320-
# hidden_states = self.chunked_hidden_states[tp_rank]
321-
# self.topk_ids = chunked_topk_ids[tp_rank]
322-
# self.topk_weights = chunked_topk_weights[tp_rank]
323-
# self.mc2_mask = chunked_mc2_mask[tp_rank]
324-
325308
dispatch_kwargs = {
326309
"x": hidden_states,
327310
"expert_ids": self.topk_ids,
@@ -400,14 +383,6 @@ def _post_process(self, mlp_output: torch.Tensor,
400383

401384
hidden_states[:] = combine(**combine_kwargs)
402385

403-
# final_hidden_states = combine(**combine_kwargs)
404-
405-
# dist.all_gather(list(self.chunked_hidden_states), final_hidden_states, get_tp_group().device_group)
406-
407-
# final_hidden_states = torch.cat(self.chunked_hidden_states, dim=0)
408-
409-
# hidden_states[:] = final_hidden_states
410-
411386

412387
def moe_comm_pre_process(
413388
hidden_states: torch.Tensor,

vllm_ascend/ops/fused_moe.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,6 @@ def unified_fused_experts(
7777
moe_comm_method: Optional[MoECommMethod] = None,
7878
# For TorchAir graph
7979
is_torchair: bool = False,
80-
# For communication
81-
use_mc2: bool = False,
82-
moe_all_to_all_group_name: str = "",
83-
mc2_mask: Optional[torch.Tensor] = None,
8480
# For Cube/Vector parallel
8581
shared_experts: Optional[Any] = None,
8682
quantized_x_for_share: Optional[Any] = None,
@@ -104,9 +100,6 @@ def unified_fused_experts(
104100

105101
num_experts = w1.shape[0]
106102

107-
# permuted_hidden_states, expert_tokens, group_list_type = moe_comm_method._pre_process(
108-
# hidden_states, topk_ids, topk_weights, expert_map, num_experts
109-
# )
110103
permuted_hidden_states, expert_tokens, group_list_type = torch.ops.vllm.moe_comm_pre_process(
111104
hidden_states, topk_ids, topk_weights, expert_map, num_experts)
112105
mlp_output = apply_mlp(
@@ -116,7 +109,6 @@ def unified_fused_experts(
116109
expert_tokens,
117110
group_list_type=group_list_type,
118111
)
119-
# moe_comm_method._post_process(mlp_output, hidden_states)
120112
torch.ops.vllm.moe_comm_post_process(mlp_output, hidden_states)
121113

122114
return hidden_states

0 commit comments

Comments
 (0)