@@ -755,61 +755,44 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
755755                layer_norm_eps = getattr (cfg , "layer_norm_eps" , 1e-12 ),
756756            )
757757
758-         weights_list  =  list (weights )
759-         loaded : set [str ] =  set ()
758+         def  _strip (name : str ) ->  str :
759+             for  p  in  ("model." , "bert." ):
760+                 if  name .startswith (p ):
761+                     name  =  name [len (p ) :]
762+             return  name 
760763
764+         weights_list  =  list (weights )
761765        model_side : list [tuple [str , torch .Tensor ]] =  []
766+         mlm_side : list [tuple [str , torch .Tensor ]] =  []
767+ 
762768        for  k , w  in  weights_list :
763-             if  k .startswith ("cls.predictions." ):
764-                 continue 
765-             name  =  k 
766-             if  name .startswith ("model." ):
767-                 name  =  name [len ("model." ) :]
768-             if  name .startswith ("bert." ):
769-                 name  =  name [len ("bert." ) :]
770-             model_side .append ((name , w ))
771- 
772-         other , stacked  =  self .model ._load_weights (model_side )
773-         loaded .update ({"model."  +  n  for  n  in  stacked })
774- 
775-         other_prefixed  =  [("model."  +  n , w ) for  (n , w ) in  other ]
776-         loader_top  =  AutoWeightsLoader (
777-             self , skip_prefixes = ["pooler." , "mlm_head." , "lm_head." ]
778-         )
779-         loaded_other  =  loader_top .load_weights (other_prefixed )
780-         loaded .update (loaded_other )
781- 
782-         name_map  =  {
783-             "cls.predictions.transform.dense.weight" : "mlm_head.dense.weight" ,
784-             "cls.predictions.transform.dense.bias" : "mlm_head.dense.bias" ,
785-             "cls.predictions.transform.LayerNorm.weight" : "mlm_head.layer_norm.weight" ,
786-             "cls.predictions.transform.LayerNorm.bias" : "mlm_head.layer_norm.bias" ,
787-             "cls.predictions.decoder.weight" : "mlm_head.decoder.weight" ,
788-             "cls.predictions.decoder.bias" : "mlm_head.decoder.bias" ,
789-         }
790-         extras : list [tuple [str , torch .Tensor ]] =  []
791-         for  k , w  in  weights_list :
792-             name  =  k 
793-             if  name .startswith ("model." ):
794-                 name  =  name [len ("model." ) :]
795-             if  name .startswith ("bert." ):
796-                 name  =  name [len ("bert." ) :]
797-             tgt  =  name_map .get (name )
798-             if  tgt  is  not None :
799-                 extras .append ((tgt , w ))
800- 
801-         if  extras :
802-             mlm_loader  =  AutoWeightsLoader (self )
803-             loaded_mlm  =  mlm_loader .load_weights (extras )
804-             loaded .update (loaded_mlm )
805- 
806-         try :
807-             emb_w  =  self .model .embeddings .word_embeddings .weight 
808-             dec_w  =  self .mlm_head .decoder .weight 
809-             if  dec_w .shape  ==  emb_w .shape  and  dec_w .data_ptr () !=  emb_w .data_ptr ():
810-                 self .mlm_head .decoder .weight  =  emb_w 
811-         except  Exception :
812-             pass 
769+             name  =  _strip (k )
770+             if  name .startswith ("cls.predictions." ):
771+                 mlm_side .append ((name , w ))
772+             else :
773+                 model_side .append ((name , w ))
774+ 
775+         loaded : set [str ] =  set ()
776+         loaded_model  =  self .model .load_weights (model_side )
777+         loaded .update ({"model."  +  n  for  n  in  loaded_model })
778+ 
779+         if  mlm_side :
780+             name_map  =  {
781+                 "cls.predictions.transform.dense.weight" : "mlm_head.dense.weight" ,
782+                 "cls.predictions.transform.dense.bias" : "mlm_head.dense.bias" ,
783+                 ("cls.predictions.transform.LayerNorm.weight" ): (
784+                     "mlm_head.layer_norm.weight" 
785+                 ),
786+                 ("cls.predictions.transform.LayerNorm.bias" ): (
787+                     "mlm_head.layer_norm.bias" 
788+                 ),
789+                 "cls.predictions.decoder.weight" : "mlm_head.decoder.weight" ,
790+                 "cls.predictions.decoder.bias" : "mlm_head.decoder.bias" ,
791+             }
792+             remapped  =  [(name_map [n ], w ) for  n , w  in  mlm_side  if  n  in  name_map ]
793+             if  remapped :
794+                 loaded_mlm  =  AutoWeightsLoader (self ).load_weights (remapped )
795+                 loaded .update (loaded_mlm )
813796
814797        return  loaded 
815798
0 commit comments