@@ -199,7 +199,7 @@ def __init__(
199199
200200        assert  config .moe_num_experts [0 ] ==  config .moe_num_experts [1 ]
201201        self .e_score_correction_bias  =  nn .Parameter (
202-             torch .empty (2 , config .moe_num_experts [0 ]))
202+             torch .empty (2 , config .moe_num_experts [0 ],  dtype = torch . float32 ))
203203
204204        assert  text_moe_layer_start_index  <=  text_moe_layer_end_index 
205205
@@ -209,6 +209,7 @@ def __init__(
209209                config .hidden_size ,
210210                config .moe_num_experts [0 ],
211211                bias = False ,
212+                 params_dtype = torch .float32 ,
212213                quant_config = quant_config ,
213214                prefix = f"{ prefix }  )
214215
@@ -238,6 +239,7 @@ def __init__(
238239                config .hidden_size ,
239240                config .moe_num_experts [1 ],
240241                bias = False ,
242+                 params_dtype = torch .float32 ,
241243                quant_config = quant_config ,
242244                prefix = f"{ prefix }  )
243245
@@ -288,7 +290,8 @@ def forward(
288290
289291        if  visual_token_mask  is  not None  and  visual_token_mask .all ():
290292            # only vision modal input 
291-             router_logits , _  =  self .vision_experts_gate (hidden_states )
293+             router_logits , _  =  self .vision_experts_gate (
294+                 hidden_states .to (dtype = torch .float32 ))
292295            final_hidden_states  =  self .vision_experts (
293296                hidden_states = hidden_states , router_logits = router_logits )
294297        elif  visual_token_mask  is  not None  and  visual_token_mask .any ():
@@ -303,19 +306,21 @@ def forward(
303306            vision_hidden_states  =  hidden_states [visual_token_mask ].reshape (
304307                - 1 , self .hidden_size )
305308
306-             text_router_logits , _  =  self .text_experts_gate (text_hidden_states )
309+             text_router_logits , _  =  self .text_experts_gate (
310+                 text_hidden_states .to (dtype = torch .float32 ))
307311            final_hidden_states [text_token_mask ] =  self .text_experts (
308312                hidden_states = text_hidden_states ,
309313                router_logits = text_router_logits ).flatten ()
310314
311315            vision_router_logits , _  =  self .vision_experts_gate (
312-                 vision_hidden_states )
316+                 vision_hidden_states . to ( dtype = torch . float32 ) )
313317            final_hidden_states [visual_token_mask ] =  self .vision_experts (
314318                hidden_states = vision_hidden_states ,
315319                router_logits = vision_router_logits ).flatten ()
316320        else :
317321            # only text modal input 
318-             text_router_logits , _  =  self .text_experts_gate (hidden_states )
322+             text_router_logits , _  =  self .text_experts_gate (
323+                 hidden_states .to (dtype = torch .float32 ))
319324
320325            final_hidden_states  =  self .text_experts (
321326                hidden_states = hidden_states , router_logits = text_router_logits )
0 commit comments