@@ -341,7 +341,10 @@ def forward(
341341 # text and vision modals input
342342 visual_token_mask = visual_token_mask .repeat (1 , self .hidden_size ).bool ()
343343 text_token_mask = ~ visual_token_mask
344- final_hidden_states = torch .zeros_like (hidden_states )
344+ final_experts_hidden_states = torch .zeros_like (hidden_states )
345+ final_shared_ouput = (
346+ torch .zeros_like (hidden_states ) if self .has_shared_experts else None
347+ )
345348
346349 text_hidden_states = hidden_states [text_token_mask ].reshape (
347350 - 1 , self .hidden_size
@@ -353,16 +356,26 @@ def forward(
353356 text_router_logits , _ = self .text_experts_gate (
354357 text_hidden_states .to (dtype = torch .float32 )
355358 )
356- final_hidden_states [ text_token_mask ] = self .text_experts (
359+ text_shared_ouput , text_experts_output = self .text_experts (
357360 hidden_states = text_hidden_states , router_logits = text_router_logits
358- ).flatten ()
361+ )
362+ final_experts_hidden_states [text_token_mask ] = text_experts_output .flatten ()
363+ if self .has_shared_experts :
364+ final_shared_ouput [text_token_mask ] = text_shared_ouput .flatten ()
359365
360366 vision_router_logits , _ = self .vision_experts_gate (
361367 vision_hidden_states .to (dtype = torch .float32 )
362368 )
363- final_hidden_states [ visual_token_mask ] = self .vision_experts (
369+ vision_shared_ouput , vision_experts_output = self .vision_experts (
364370 hidden_states = vision_hidden_states , router_logits = vision_router_logits
365- ).flatten ()
371+ )
372+ final_experts_hidden_states [visual_token_mask ] = (
373+ vision_experts_output .flatten ()
374+ )
375+ if self .has_shared_experts :
376+ final_shared_ouput [visual_token_mask ] = vision_shared_ouput .flatten ()
377+
378+ final_hidden_states = (final_shared_ouput , final_experts_hidden_states )
366379 else :
367380 # only text modal input
368381 text_router_logits , _ = self .text_experts_gate (
@@ -374,7 +387,11 @@ def forward(
374387 )
375388
376389 if self .has_shared_experts :
390+ # for shared_experts model
377391 final_hidden_states = final_hidden_states [0 ] + final_hidden_states [1 ]
392+ else :
393+ # for not shared_experts model
394+ final_hidden_states = final_hidden_states [1 ]
378395
379396 if self .tp_size > 1 :
380397 final_hidden_states = (
0 commit comments