@@ -199,7 +199,7 @@ def __init__(
199199
200200 assert config .moe_num_experts [0 ] == config .moe_num_experts [1 ]
201201 self .e_score_correction_bias = nn .Parameter (
202- torch .empty (2 , config .moe_num_experts [0 ]))
202+ torch .empty (2 , config .moe_num_experts [0 ], dtype = torch . float32 ))
203203
204204 assert text_moe_layer_start_index <= text_moe_layer_end_index
205205
@@ -209,6 +209,7 @@ def __init__(
209209 config .hidden_size ,
210210 config .moe_num_experts [0 ],
211211 bias = False ,
212+ params_dtype = torch .float32 ,
212213 quant_config = quant_config ,
213214 prefix = f"{ prefix } .text_experts_gate" )
214215
@@ -238,6 +239,7 @@ def __init__(
238239 config .hidden_size ,
239240 config .moe_num_experts [1 ],
240241 bias = False ,
242+ params_dtype = torch .float32 ,
241243 quant_config = quant_config ,
242244 prefix = f"{ prefix } .vision_experts_gate" )
243245
@@ -288,7 +290,8 @@ def forward(
288290
289291 if visual_token_mask is not None and visual_token_mask .all ():
290292 # only vision modal input
291- router_logits , _ = self .vision_experts_gate (hidden_states )
293+ router_logits , _ = self .vision_experts_gate (
294+ hidden_states .to (dtype = torch .float32 ))
292295 final_hidden_states = self .vision_experts (
293296 hidden_states = hidden_states , router_logits = router_logits )
294297 elif visual_token_mask is not None and visual_token_mask .any ():
@@ -303,19 +306,21 @@ def forward(
303306 vision_hidden_states = hidden_states [visual_token_mask ].reshape (
304307 - 1 , self .hidden_size )
305308
306- text_router_logits , _ = self .text_experts_gate (text_hidden_states )
309+ text_router_logits , _ = self .text_experts_gate (
310+ text_hidden_states .to (dtype = torch .float32 ))
307311 final_hidden_states [text_token_mask ] = self .text_experts (
308312 hidden_states = text_hidden_states ,
309313 router_logits = text_router_logits ).flatten ()
310314
311315 vision_router_logits , _ = self .vision_experts_gate (
312- vision_hidden_states )
316+ vision_hidden_states . to ( dtype = torch . float32 ) )
313317 final_hidden_states [visual_token_mask ] = self .vision_experts (
314318 hidden_states = vision_hidden_states ,
315319 router_logits = vision_router_logits ).flatten ()
316320 else :
317321 # only text modal input
318- text_router_logits , _ = self .text_experts_gate (hidden_states )
322+ text_router_logits , _ = self .text_experts_gate (
323+ hidden_states .to (dtype = torch .float32 ))
319324
320325 final_hidden_states = self .text_experts (
321326 hidden_states = hidden_states , router_logits = text_router_logits )
0 commit comments