|  | 
| 20 | 20 | import torch | 
| 21 | 21 | import torch_npu | 
| 22 | 22 | from vllm.config import CompilationLevel, get_current_vllm_config | 
| 23 |  | -from vllm.distributed import get_dp_group, get_ep_group, get_tp_group | 
|  | 23 | +from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group, | 
|  | 24 | +                              tensor_model_parallel_all_reduce) | 
| 24 | 25 | from vllm.forward_context import get_forward_context | 
| 25 | 26 | from vllm.model_executor.layers.fused_moe.config import \ | 
| 26 | 27 |     FusedMoEParallelConfig  # isort: skip | 
| @@ -373,6 +374,21 @@ def __init__( | 
| 373 | 374 |                 self, method.__name__.lower(), | 
| 374 | 375 |                 method(moe_config=self.moe_config))  # type: ignore[abstract] | 
| 375 | 376 | 
 | 
|  | 377 | +    def maybe_all_reduce_tensor_model_parallel( | 
|  | 378 | +            self, final_hidden_states: torch.Tensor): | 
|  | 379 | +        """NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`, | 
|  | 380 | +        and `alltoallcommimpl`, we do not need to all-reduce the final outputs since | 
|  | 381 | +        the outputs are already aggregated across tensor parallel ranks in the | 
|  | 382 | +        `finalize` function. In `allgathercommimpl`, we still need to all-reduce the | 
|  | 383 | +        outputs since each rank only has partial outputs. | 
|  | 384 | +        """ | 
|  | 385 | +        forward_context = get_forward_context() | 
|  | 386 | +        moe_comm_method_name = forward_context.moe_comm_method_name | 
|  | 387 | +        if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}: | 
|  | 388 | +            return final_hidden_states | 
|  | 389 | +        else: | 
|  | 390 | +            return tensor_model_parallel_all_reduce(final_hidden_states) | 
|  | 391 | + | 
| 376 | 392 |     def forward_impl(self, hidden_states: torch.Tensor, | 
| 377 | 393 |                      router_logits: torch.Tensor): | 
| 378 | 394 |         assert self.quant_method is not None | 
| @@ -415,6 +431,38 @@ def forward_impl(self, hidden_states: torch.Tensor, | 
| 415 | 431 |         return final_hidden_states | 
| 416 | 432 | 
 | 
| 417 | 433 | 
 | 
|  | 434 | +class AscendSharedFusedMoE(AscendFusedMoE): | 
|  | 435 | + | 
|  | 436 | +    def __init__( | 
|  | 437 | +        self, | 
|  | 438 | +        shared_experts: torch.nn.Module, | 
|  | 439 | +        use_overlapped: bool = True, | 
|  | 440 | +        **kwargs, | 
|  | 441 | +    ): | 
|  | 442 | +        super().__init__(**kwargs) | 
|  | 443 | +        self._shared_experts = shared_experts | 
|  | 444 | +        self.use_overlapped = use_overlapped | 
|  | 445 | + | 
|  | 446 | +    def forward( | 
|  | 447 | +        self, | 
|  | 448 | +        hidden_states: torch.Tensor, | 
|  | 449 | +        router_logits: torch.Tensor, | 
|  | 450 | +    ) -> tuple[torch.Tensor, torch.Tensor]: | 
|  | 451 | +        shared_out = self._shared_experts(hidden_states) | 
|  | 452 | + | 
|  | 453 | +        # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` | 
|  | 454 | +        forward_context = get_forward_context() | 
|  | 455 | +        moe_comm_method_name = forward_context.moe_comm_method_name | 
|  | 456 | +        if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}: | 
|  | 457 | +            shared_out = tensor_model_parallel_all_reduce(shared_out) | 
|  | 458 | + | 
|  | 459 | +        fused_out = super().forward( | 
|  | 460 | +            hidden_states=hidden_states, | 
|  | 461 | +            router_logits=router_logits, | 
|  | 462 | +        ) | 
|  | 463 | +        return shared_out, fused_out | 
|  | 464 | + | 
|  | 465 | + | 
| 418 | 466 | UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func | 
| 419 | 467 | UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading | 
| 420 | 468 | 
 | 
|  | 
0 commit comments