diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index 20dafdf7ac..a33a69b80a 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -154,47 +154,6 @@ def __init__( CustomDeepseekDBOMoE.top_k = config.num_experts_per_tok self.config = config - def forward( - self, - hidden_states: torch.Tensor, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - forward_context = get_forward_context() - if attn_metadata is None: - attn_metadata = forward_context.attn_metadata - - # when profile runs, force experts to load balanced tokens - # to avoid high memory consumption on a single rank. - enable_force_load_balance = forward_context.in_profile_run - - is_prefill = forward_context.with_prefill - # If this node is kv_consumer, we force the moe always runs in decode path to make sure - # the behaviour aligned between dummy_run and normal model_execute. - if self.kv_consumer: - is_prefill = False - - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - - experts_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits, - is_prefill=is_prefill, - top_k=CustomDeepseekDBOMoE.top_k, - enable_force_load_balance=enable_force_load_balance, - shared_experts=self.shared_experts) - - shared_experts_hidden = experts_hidden_states[1] - if not (self.shared_experts.down_proj.reduce_results - and self.shared_experts.down_proj.tp_size > 1): - shared_experts_hidden = tensor_model_parallel_all_reduce( - shared_experts_hidden) - - hidden_states = ( - experts_hidden_states[0] * self.routed_scaling_factor + - shared_experts_hidden) - - return hidden_states - # ----------------------------------------- TBO-related -------------------------------------------- def _forward_ms_op_shared_expert( self, diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 3fa9c8be74..d2ce4114f4 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -1266,6 +1266,9 @@ def forward( if shared_experts: if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: shared_hidden_states = shared_experts(hidden_states) + if not shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1: + shared_hidden_states = tensor_model_parallel_all_reduce( + shared_hidden_states) mc2_mask = forward_context.mc2_mask tp_size = get_tensor_model_parallel_world_size()