@@ -3241,11 +3241,11 @@ def _get_resized_embeddings(
32413241
32423242 with deepspeed .zero .GatheredParameters ([old_embeddings .weight ], modifier_rank = None ):
32433243 self ._init_added_embeddings_weights_with_mean (
3244- old_embeddings , new_embeddings , old_embedding_dim , old_num_tokens , added_num_tokens
3244+ old_embeddings , new_embeddings , old_num_tokens , added_num_tokens
32453245 )
32463246 else :
32473247 self ._init_added_embeddings_weights_with_mean (
3248- old_embeddings , new_embeddings , old_embedding_dim , old_num_tokens , added_num_tokens
3248+ old_embeddings , new_embeddings , old_num_tokens , added_num_tokens
32493249 )
32503250
32513251 # Copy token embeddings from the previous weights
@@ -3415,7 +3415,7 @@ def _get_resized_lm_head(
34153415 return new_lm_head
34163416
34173417 def _init_added_embeddings_weights_with_mean (
3418- self , old_embeddings , new_embeddings , old_embedding_dim , old_num_tokens , added_num_tokens
3418+ self , old_embeddings , new_embeddings , old_num_tokens , added_num_tokens
34193419 ):
34203420 old_embeddings_weight = old_embeddings .weight .data .to (torch .float32 )
34213421 mean_embeddings = torch .mean (old_embeddings_weight , axis = 0 )
@@ -3454,9 +3454,7 @@ def _init_added_lm_head_weights_with_mean(
34543454 old_lm_head .weight .data = old_lm_head .weight .data .T
34553455
34563456 # The same initialization logic as Embeddings.
3457- self ._init_added_embeddings_weights_with_mean (
3458- old_lm_head , new_lm_head , old_lm_head_dim , old_num_tokens , added_num_tokens
3459- )
3457+ self ._init_added_embeddings_weights_with_mean (old_lm_head , new_lm_head , old_num_tokens , added_num_tokens )
34603458
34613459 if transposed :
34623460 # Transpose again to the correct shape.
0 commit comments