@@ -160,14 +160,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
160160 hidden_states = hidden_states ,
161161 router_logits = router_logits ) * self .routed_scaling_factor
162162 else :
163- # This is a special case to avoid FP16 overflow
163+ # Fix FP16 overflow
164+ # See DeepseekV2DecoderLayer for more details.
164165 final_hidden_states = self .experts (hidden_states = hidden_states ,
165166 router_logits = router_logits )
166167 if shared_output is not None :
167168 if hidden_states .dtype != torch .float16 :
168169 final_hidden_states = final_hidden_states + shared_output
169170 else :
170- # This is a special case to avoid FP16 overflow
171+ # Fix FP16 overflow
172+ # See DeepseekV2DecoderLayer for more details.
171173 final_hidden_states = final_hidden_states + shared_output \
172174 * (1. / self .routed_scaling_factor )
173175 if self .tp_size > 1 :
@@ -499,6 +501,7 @@ def __init__(
499501 # DecoderLayers are created with `make_layers` which passes the prefix
500502 # with the layer's index.
501503 layer_idx = int (prefix .split (sep = '.' )[- 1 ])
504+ self .layer_idx = layer_idx
502505 if model_config .use_mla :
503506 attn_cls = DeepseekV2MLAAttention
504507 else :
@@ -561,19 +564,30 @@ def forward(
561564 hidden_states = hidden_states ,
562565 )
563566
564- # Fully Connected
565- if isinstance ( self . mlp , DeepseekV2MoE ) and \
566- hidden_states . dtype == torch . float16 :
567- # This is a special case to avoid FP16 overflow
567+ if hidden_states . dtype == torch . float16 :
568+ # Fix FP16 overflow
569+ # We scale both hidden_states and residual before
570+ # rmsnorm, and rmsnorm result would not affect by scale.
568571 hidden_states *= 1. / self .routed_scaling_factor
572+ if self .layer_idx == 0 :
573+ # The residual is shared by all layers, we only scale it on
574+ # first layer.
575+ residual *= 1. / self .routed_scaling_factor
576+
577+ # Fully Connected
569578 hidden_states , residual = self .post_attention_layernorm (
570579 hidden_states , residual )
571580 hidden_states = self .mlp (hidden_states )
572- if isinstance (self .mlp , DeepseekV2MLP ) and \
573- hidden_states .dtype == torch .float16 :
574- # This is a special case to avoid FP16 overflow
581+
582+ if isinstance (self .mlp ,
583+ DeepseekV2MLP ) and hidden_states .dtype == torch .float16 :
584+ # Fix FP16 overflow
585+ # Scaling the DeepseekV2MLP output, it is the input of
586+ # input_layernorm of next decoder layer.
587+ # The scaling of DeepseekV2MOE output would be done in the forward
588+ # of DeepseekV2MOE
575589 hidden_states *= 1. / self .routed_scaling_factor
576- residual *= 1. / self . routed_scaling_factor
590+
577591 return hidden_states , residual
578592
579593
0 commit comments