@@ -640,13 +640,12 @@ def forward(
640640 max_len : int = int (lens_tensor .max ().item ())
641641
642642 if isinstance (hidden_states , list ):
643- hidden_states = torch .cat (hidden_states , dim = 0 )
644- assert isinstance (hidden_states , torch .Tensor ) and hidden_states .dim () == 2
645-
646- device = hidden_states .device
647- H = int (self .mlm_head .dense .in_features )
643+ hs_list = hidden_states
644+ else :
645+ hs_list = torch .split (hidden_states , lens , dim = 0 )
648646
649- hs_list = torch .split (hidden_states , lens , dim = 0 )
647+ device = hs_list [0 ].device
648+ H = hs_list [0 ].size (- 1 )
650649
651650 padded = hidden_states .new_zeros ((B , max_len , H ))
652651 valid_mask = torch .zeros ((B , max_len ), dtype = torch .bool , device = device )
@@ -755,61 +754,44 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
755754 layer_norm_eps = getattr (cfg , "layer_norm_eps" , 1e-12 ),
756755 )
757756
758- weights_list = list (weights )
759- loaded : set [str ] = set ()
757+ def _strip (name : str ) -> str :
758+ for p in ("model." , "bert." ):
759+ if name .startswith (p ):
760+ name = name [len (p ) :]
761+ return name
760762
763+ weights_list = list (weights )
761764 model_side : list [tuple [str , torch .Tensor ]] = []
765+ mlm_side : list [tuple [str , torch .Tensor ]] = []
766+
762767 for k , w in weights_list :
763- if k .startswith ("cls.predictions." ):
764- continue
765- name = k
766- if name .startswith ("model." ):
767- name = name [len ("model." ) :]
768- if name .startswith ("bert." ):
769- name = name [len ("bert." ) :]
770- model_side .append ((name , w ))
771-
772- other , stacked = self .model ._load_weights (model_side )
773- loaded .update ({"model." + n for n in stacked })
774-
775- other_prefixed = [("model." + n , w ) for (n , w ) in other ]
776- loader_top = AutoWeightsLoader (
777- self , skip_prefixes = ["pooler." , "mlm_head." , "lm_head." ]
778- )
779- loaded_other = loader_top .load_weights (other_prefixed )
780- loaded .update (loaded_other )
781-
782- name_map = {
783- "cls.predictions.transform.dense.weight" : "mlm_head.dense.weight" ,
784- "cls.predictions.transform.dense.bias" : "mlm_head.dense.bias" ,
785- "cls.predictions.transform.LayerNorm.weight" : "mlm_head.layer_norm.weight" ,
786- "cls.predictions.transform.LayerNorm.bias" : "mlm_head.layer_norm.bias" ,
787- "cls.predictions.decoder.weight" : "mlm_head.decoder.weight" ,
788- "cls.predictions.decoder.bias" : "mlm_head.decoder.bias" ,
789- }
790- extras : list [tuple [str , torch .Tensor ]] = []
791- for k , w in weights_list :
792- name = k
793- if name .startswith ("model." ):
794- name = name [len ("model." ) :]
795- if name .startswith ("bert." ):
796- name = name [len ("bert." ) :]
797- tgt = name_map .get (name )
798- if tgt is not None :
799- extras .append ((tgt , w ))
800-
801- if extras :
802- mlm_loader = AutoWeightsLoader (self )
803- loaded_mlm = mlm_loader .load_weights (extras )
804- loaded .update (loaded_mlm )
805-
806- try :
807- emb_w = self .model .embeddings .word_embeddings .weight
808- dec_w = self .mlm_head .decoder .weight
809- if dec_w .shape == emb_w .shape and dec_w .data_ptr () != emb_w .data_ptr ():
810- self .mlm_head .decoder .weight = emb_w
811- except Exception :
812- pass
768+ name = _strip (k )
769+ if name .startswith ("cls.predictions." ):
770+ mlm_side .append ((name , w ))
771+ else :
772+ model_side .append ((name , w ))
773+
774+ loaded : set [str ] = set ()
775+ loaded_model = self .model .load_weights (model_side )
776+ loaded .update ({"model." + n for n in loaded_model })
777+
778+ if mlm_side :
779+ name_map = {
780+ "cls.predictions.transform.dense.weight" : "mlm_head.dense.weight" ,
781+ "cls.predictions.transform.dense.bias" : "mlm_head.dense.bias" ,
782+ ("cls.predictions.transform.LayerNorm.weight" ): (
783+ "mlm_head.layer_norm.weight"
784+ ),
785+ ("cls.predictions.transform.LayerNorm.bias" ): (
786+ "mlm_head.layer_norm.bias"
787+ ),
788+ "cls.predictions.decoder.weight" : "mlm_head.decoder.weight" ,
789+ "cls.predictions.decoder.bias" : "mlm_head.decoder.bias" ,
790+ }
791+ remapped = [(name_map [n ], w ) for n , w in mlm_side if n in name_map ]
792+ if remapped :
793+ loaded_mlm = AutoWeightsLoader (self ).load_weights (remapped )
794+ loaded .update (loaded_mlm )
813795
814796 return loaded
815797
0 commit comments