@@ -755,61 +755,40 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
755755 layer_norm_eps = getattr (cfg , "layer_norm_eps" , 1e-12 ),
756756 )
757757
758- weights_list = list (weights )
759- loaded : set [str ] = set ()
758+ def _strip (name : str ) -> str :
759+ for p in ("model." , "bert." ):
760+ if name .startswith (p ):
761+ name = name [len (p ) :]
762+ return name
760763
764+ weights_list = list (weights )
761765 model_side : list [tuple [str , torch .Tensor ]] = []
766+ mlm_side : list [tuple [str , torch .Tensor ]] = []
767+
762768 for k , w in weights_list :
763- if k .startswith ("cls.predictions." ):
764- continue
765- name = k
766- if name .startswith ("model." ):
767- name = name [len ("model." ) :]
768- if name .startswith ("bert." ):
769- name = name [len ("bert." ) :]
770- model_side .append ((name , w ))
771-
772- other , stacked = self .model ._load_weights (model_side )
773- loaded .update ({"model." + n for n in stacked })
774-
775- other_prefixed = [("model." + n , w ) for (n , w ) in other ]
776- loader_top = AutoWeightsLoader (
777- self , skip_prefixes = ["pooler." , "mlm_head." , "lm_head." ]
778- )
779- loaded_other = loader_top .load_weights (other_prefixed )
780- loaded .update (loaded_other )
781-
782- name_map = {
783- "cls.predictions.transform.dense.weight" : "mlm_head.dense.weight" ,
784- "cls.predictions.transform.dense.bias" : "mlm_head.dense.bias" ,
785- "cls.predictions.transform.LayerNorm.weight" : "mlm_head.layer_norm.weight" ,
786- "cls.predictions.transform.LayerNorm.bias" : "mlm_head.layer_norm.bias" ,
787- "cls.predictions.decoder.weight" : "mlm_head.decoder.weight" ,
788- "cls.predictions.decoder.bias" : "mlm_head.decoder.bias" ,
789- }
790- extras : list [tuple [str , torch .Tensor ]] = []
791- for k , w in weights_list :
792- name = k
793- if name .startswith ("model." ):
794- name = name [len ("model." ) :]
795- if name .startswith ("bert." ):
796- name = name [len ("bert." ) :]
797- tgt = name_map .get (name )
798- if tgt is not None :
799- extras .append ((tgt , w ))
800-
801- if extras :
802- mlm_loader = AutoWeightsLoader (self )
803- loaded_mlm = mlm_loader .load_weights (extras )
804- loaded .update (loaded_mlm )
805-
806- try :
807- emb_w = self .model .embeddings .word_embeddings .weight
808- dec_w = self .mlm_head .decoder .weight
809- if dec_w .shape == emb_w .shape and dec_w .data_ptr () != emb_w .data_ptr ():
810- self .mlm_head .decoder .weight = emb_w
811- except Exception :
812- pass
769+ name = _strip (k )
770+ if name .startswith ("cls.predictions." ):
771+ mlm_side .append ((name , w ))
772+ else :
773+ model_side .append ((name , w ))
774+
775+ loaded : set [str ] = set ()
776+ loaded_model = self .model .load_weights (model_side )
777+ loaded .update ({"model." + n for n in loaded_model })
778+
779+ if mlm_side :
780+ name_map = {
781+ "cls.predictions.transform.dense.weight" : "mlm_head.dense.weight" ,
782+ "cls.predictions.transform.dense.bias" : "mlm_head.dense.bias" ,
783+ "cls.predictions.transform.LayerNorm.weight" : "mlm_head.layer_norm.weight" ,
784+ "cls.predictions.transform.LayerNorm.bias" : "mlm_head.layer_norm.bias" ,
785+ "cls.predictions.decoder.weight" : "mlm_head.decoder.weight" ,
786+ "cls.predictions.decoder.bias" : "mlm_head.decoder.bias" ,
787+ }
788+ remapped = [(name_map [n ], w ) for n , w in mlm_side if n in name_map ]
789+ if remapped :
790+ loaded_mlm = AutoWeightsLoader (self ).load_weights (remapped )
791+ loaded .update (loaded_mlm )
813792
814793 return loaded
815794
0 commit comments