|
2 | 2 |
|
3 | 3 | import torch |
4 | 4 | import torch_npu |
| 5 | +from vllm.distributed.parallel_state import get_tp_group |
5 | 6 | from vllm.forward_context import ForwardContext, get_forward_context |
6 | 7 | from vllm.utils import direct_register_custom_op |
7 | 8 |
|
| 9 | +from vllm_ascend.distributed.parallel_state import get_mc2_group |
| 10 | +from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version |
| 11 | + |
8 | 12 |
|
9 | 13 | class MoECommMethod(ABC): |
10 | 14 | """Base class for MoE communication methods.""" |
@@ -76,6 +80,7 @@ def _pre_process( |
76 | 80 | expert_map: torch.Tensor, |
77 | 81 | num_experts: int, |
78 | 82 | ) -> tuple[torch.Tensor, torch.Tensor, int]: |
| 83 | + print("Using AllGatherCommImpl for MoE communication.") |
79 | 84 | num_tokens = hidden_states.shape[0] |
80 | 85 |
|
81 | 86 | # Generate token indices and flatten |
@@ -164,6 +169,7 @@ def _pre_process( |
164 | 169 | expert_map: torch.Tensor, # noqa: F841 |
165 | 170 | num_experts: int, |
166 | 171 | ) -> tuple[torch.Tensor, torch.Tensor, int]: |
| 172 | + print("Using AllReduceCommImpl for MoE communication.") |
167 | 173 | num_tokens = hidden_states.shape[0] |
168 | 174 |
|
169 | 175 | self.topk_weights = topk_weights |
@@ -229,6 +235,145 @@ def _post_process(self, mlp_output: torch.Tensor, |
229 | 235 | ) |
230 | 236 |
|
231 | 237 |
|
| 238 | +class MC2CommImpl(MoECommMethod): |
| 239 | + """This implementation is for the scenarios listed below: |
| 240 | + 1. `enable_expert_parallel=True`. |
| 241 | + 2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available. |
| 242 | + 3. `enable_expert_parallel=False` is not supported. |
| 243 | + |
| 244 | + This implementation uses the MC2 communication method, which is optimized for |
| 245 | + Communication and Computation parallelism on Ascend devices. |
| 246 | + """ |
| 247 | + |
| 248 | + def __init__( |
| 249 | + self, |
| 250 | + device: torch.device, |
| 251 | + dtype: torch.dtype, |
| 252 | + top_k_num: int, |
| 253 | + global_num_experts: int, |
| 254 | + ): |
| 255 | + super().__init__(device, dtype, top_k_num, global_num_experts) |
| 256 | + |
| 257 | + # Shared communication configurations |
| 258 | + ep_group = get_mc2_group() |
| 259 | + self.ep_rank_id = ep_group.rank_in_group |
| 260 | + self.ep_world_size = ep_group.world_size |
| 261 | + self.tp_world_size = get_tp_group().world_size |
| 262 | + |
| 263 | + device_group = ep_group.device_group |
| 264 | + local_rank = torch.distributed.get_rank(group=device_group) |
| 265 | + backend = device_group._get_backend(torch.device("npu")) |
| 266 | + self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank) |
| 267 | + |
| 268 | + # Feature flags |
| 269 | + self.enable_dispatch_v2 = hasattr(torch_npu, |
| 270 | + "npu_moe_distribute_dispatch_v2") |
| 271 | + self.is_ascend_a3 = get_ascend_soc_version() == AscendSocVersion.A3 |
| 272 | + self.need_extra_args = self.is_ascend_a3 # or is_torchair |
| 273 | + |
| 274 | + # Intermediate tensors to be passed from pre_process to post_process |
| 275 | + self.topk_ids = None |
| 276 | + self.topk_weights = None |
| 277 | + self.mc2_mask = None |
| 278 | + self.assist_info_for_combine = None |
| 279 | + self.ep_recv_counts = None |
| 280 | + self.tp_recv_counts = None |
| 281 | + |
| 282 | + def _pre_process( |
| 283 | + self, |
| 284 | + hidden_states: torch.Tensor, |
| 285 | + topk_ids: torch.Tensor, |
| 286 | + topk_weights: torch.Tensor, |
| 287 | + expert_map: torch.Tensor, |
| 288 | + num_experts: int, |
| 289 | + ) -> tuple[torch.Tensor, torch.Tensor, int]: |
| 290 | + # Store tensors needed for post_process |
| 291 | + self.topk_ids = topk_ids.clone() |
| 292 | + self.topk_weights = topk_weights |
| 293 | + self.mc2_mask = get_forward_context().mc2_mask |
| 294 | + |
| 295 | + dispatch_kwargs = { |
| 296 | + "x": hidden_states, |
| 297 | + "expert_ids": self.topk_ids, |
| 298 | + "expert_shard_type": 0, |
| 299 | + "shared_expert_rank_num": 0, |
| 300 | + "moe_expert_num": self.global_num_experts, |
| 301 | + "global_bs": 0, |
| 302 | + "scales": None, |
| 303 | + "quant_mode": 0, |
| 304 | + "group_ep": self.moe_all_to_all_group_name, |
| 305 | + "ep_world_size": self.ep_world_size, |
| 306 | + "ep_rank_id": self.ep_rank_id, |
| 307 | + } |
| 308 | + |
| 309 | + if self.need_extra_args: |
| 310 | + dispatch_kwargs.update({ |
| 311 | + "group_tp": self.moe_all_to_all_group_name, |
| 312 | + "tp_world_size": 1, |
| 313 | + "tp_rank_id": 0, |
| 314 | + }) |
| 315 | + if self.is_ascend_a3 and self.enable_dispatch_v2: |
| 316 | + dispatch_kwargs.update({ |
| 317 | + "x_active_mask": self.mc2_mask, |
| 318 | + }) |
| 319 | + |
| 320 | + dispatch = torch_npu.npu_moe_distribute_dispatch_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch |
| 321 | + |
| 322 | + ( |
| 323 | + permuted_hidden_states, |
| 324 | + _, # dynamic_scale is not used |
| 325 | + self.assist_info_for_combine, |
| 326 | + expert_tokens, |
| 327 | + self.ep_recv_counts, |
| 328 | + self.tp_recv_counts, |
| 329 | + ) = torch_npu.npu_moe_distribute_dispatch_v2(**dispatch_kwargs)[:6] |
| 330 | + |
| 331 | + group_list_type = 1 |
| 332 | + |
| 333 | + return permuted_hidden_states, expert_tokens, group_list_type |
| 334 | + |
| 335 | + def _post_process(self, mlp_output: torch.Tensor, |
| 336 | + hidden_states: torch.Tensor) -> None: |
| 337 | + combine_kwargs = { |
| 338 | + "expand_x": mlp_output, |
| 339 | + "expert_ids": self.topk_ids, |
| 340 | + "expert_scales": self.topk_weights.to(torch.float32), |
| 341 | + "expert_shard_type": 0, |
| 342 | + "shared_expert_rank_num": 0, |
| 343 | + "moe_expert_num": self.global_num_experts, |
| 344 | + "global_bs": 0, |
| 345 | + "ep_send_counts": self.ep_recv_counts, |
| 346 | + "group_ep": self.moe_all_to_all_group_name, |
| 347 | + "ep_world_size": self.ep_world_size, |
| 348 | + "ep_rank_id": self.ep_rank_id, |
| 349 | + } |
| 350 | + |
| 351 | + if self.enable_dispatch_v2: |
| 352 | + combine_kwargs[ |
| 353 | + "assist_info_for_combine"] = self.assist_info_for_combine |
| 354 | + else: |
| 355 | + combine_kwargs["expand_idx"] = self.assist_info_for_combine |
| 356 | + |
| 357 | + if self.need_extra_args: |
| 358 | + combine_kwargs.update({ |
| 359 | + "tp_send_counts": self.tp_recv_counts, |
| 360 | + "group_tp": self.moe_all_to_all_group_name, |
| 361 | + "tp_world_size": 1, |
| 362 | + "tp_rank_id": 0, |
| 363 | + }) |
| 364 | + if self.is_ascend_a3 and self.enable_dispatch_v2: |
| 365 | + combine_kwargs.update({ |
| 366 | + "x_active_mask": self.mc2_mask, |
| 367 | + }) |
| 368 | + |
| 369 | + if self.enable_dispatch_v2: |
| 370 | + hidden_states[:] = torch_npu.npu_moe_distribute_combine_v2( |
| 371 | + **combine_kwargs) |
| 372 | + else: |
| 373 | + hidden_states[:] = torch_npu.npu_moe_distribute_combine( |
| 374 | + **combine_kwargs) |
| 375 | + |
| 376 | + |
232 | 377 | def moe_comm_pre_process( |
233 | 378 | hidden_states: torch.Tensor, |
234 | 379 | topk_ids: torch.Tensor, |
|
0 commit comments