Skip to content

Commit e43027c

Browse files
CSWYF3634076albertoperdomo2
authored andcommitted
[Model][Bugfix] fix ernie45 vl run failed from shared experts optimization (vllm-project#26885)
Signed-off-by: wangyafeng <wangyafeng@baidu.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
1 parent 3bdd404 commit e43027c

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

vllm/model_executor/models/ernie45_vl_moe.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,10 @@ def forward(
341341
# text and vision modals input
342342
visual_token_mask = visual_token_mask.repeat(1, self.hidden_size).bool()
343343
text_token_mask = ~visual_token_mask
344-
final_hidden_states = torch.zeros_like(hidden_states)
344+
final_experts_hidden_states = torch.zeros_like(hidden_states)
345+
final_shared_ouput = (
346+
torch.zeros_like(hidden_states) if self.has_shared_experts else None
347+
)
345348

346349
text_hidden_states = hidden_states[text_token_mask].reshape(
347350
-1, self.hidden_size
@@ -353,16 +356,26 @@ def forward(
353356
text_router_logits, _ = self.text_experts_gate(
354357
text_hidden_states.to(dtype=torch.float32)
355358
)
356-
final_hidden_states[text_token_mask] = self.text_experts(
359+
text_shared_ouput, text_experts_output = self.text_experts(
357360
hidden_states=text_hidden_states, router_logits=text_router_logits
358-
).flatten()
361+
)
362+
final_experts_hidden_states[text_token_mask] = text_experts_output.flatten()
363+
if self.has_shared_experts:
364+
final_shared_ouput[text_token_mask] = text_shared_ouput.flatten()
359365

360366
vision_router_logits, _ = self.vision_experts_gate(
361367
vision_hidden_states.to(dtype=torch.float32)
362368
)
363-
final_hidden_states[visual_token_mask] = self.vision_experts(
369+
vision_shared_ouput, vision_experts_output = self.vision_experts(
364370
hidden_states=vision_hidden_states, router_logits=vision_router_logits
365-
).flatten()
371+
)
372+
final_experts_hidden_states[visual_token_mask] = (
373+
vision_experts_output.flatten()
374+
)
375+
if self.has_shared_experts:
376+
final_shared_ouput[visual_token_mask] = vision_shared_ouput.flatten()
377+
378+
final_hidden_states = (final_shared_ouput, final_experts_hidden_states)
366379
else:
367380
# only text modal input
368381
text_router_logits, _ = self.text_experts_gate(
@@ -374,7 +387,11 @@ def forward(
374387
)
375388

376389
if self.has_shared_experts:
390+
# for shared_experts model
377391
final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
392+
else:
393+
# for not shared_experts model
394+
final_hidden_states = final_hidden_states[1]
378395

379396
if self.tp_size > 1:
380397
final_hidden_states = (

0 commit comments

Comments
 (0)