Skip to content

Commit d9d94d2

Browse files
yewentao256rtourgeman
authored andcommitted
[Bug] Fix DeepEP low latency assert self.batched_router_logits.size(-1) == full_router_logits.size(-1) Bug (vllm-project#27682)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent 8f50c0c commit d9d94d2

File tree

1 file changed

+3
-3
lines changed
  • vllm/model_executor/layers/fused_moe

1 file changed

+3
-3
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,6 +1135,7 @@ def __init__(
11351135
)
11361136

11371137
self.global_num_experts = num_experts + num_redundant_experts
1138+
self.logical_num_experts = num_experts
11381139
self.zero_expert_num = zero_expert_num
11391140
self.zero_expert_type = zero_expert_type
11401141

@@ -1998,13 +1999,12 @@ def ensure_dp_chunking_init(self):
19981999

19992000
moe = self.moe_config
20002001

2001-
# Note here we use `num_experts` which is logical expert count
20022002
if self.vllm_config.parallel_config.enable_dbo:
20032003
states_shape = (2, moe.max_num_tokens, self.hidden_size)
2004-
logits_shape = (2, moe.max_num_tokens, moe.num_experts)
2004+
logits_shape = (2, moe.max_num_tokens, self.logical_num_experts)
20052005
else:
20062006
states_shape = (moe.max_num_tokens, self.hidden_size)
2007-
logits_shape = (moe.max_num_tokens, moe.num_experts)
2007+
logits_shape = (moe.max_num_tokens, self.logical_num_experts)
20082008

20092009
self.batched_hidden_states = torch.zeros(
20102010
states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()

0 commit comments

Comments
 (0)