Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 0 additions & 41 deletions vllm_ascend/models/deepseek_dbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,47 +154,6 @@ def __init__(
CustomDeepseekDBOMoE.top_k = config.num_experts_per_tok
self.config = config

def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
forward_context = get_forward_context()
if attn_metadata is None:
attn_metadata = forward_context.attn_metadata

# when profile runs, force experts to load balanced tokens
# to avoid high memory consumption on a single rank.
enable_force_load_balance = forward_context.in_profile_run

is_prefill = forward_context.with_prefill
# If this node is kv_consumer, we force the moe always runs in decode path to make sure
# the behaviour aligned between dummy_run and normal model_execute.
if self.kv_consumer:
is_prefill = False

# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)

experts_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=CustomDeepseekDBOMoE.top_k,
enable_force_load_balance=enable_force_load_balance,
shared_experts=self.shared_experts)

shared_experts_hidden = experts_hidden_states[1]
if not (self.shared_experts.down_proj.reduce_results
and self.shared_experts.down_proj.tp_size > 1):
shared_experts_hidden = tensor_model_parallel_all_reduce(
shared_experts_hidden)

hidden_states = (
experts_hidden_states[0] * self.routed_scaling_factor +
shared_experts_hidden)

return hidden_states

# ----------------------------------------- TBO-related --------------------------------------------
def _forward_ms_op_shared_expert(
self,
Expand Down
3 changes: 3 additions & 0 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,9 @@ def forward(
if shared_experts:
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
shared_hidden_states = shared_experts(hidden_states)
if not shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1:
shared_hidden_states = tensor_model_parallel_all_reduce(
shared_hidden_states)

mc2_mask = forward_context.mc2_mask
tp_size = get_tensor_model_parallel_world_size()
Expand Down