2727from vllm .sequence import IntermediateTensors , PoolerOutput
2828
2929from .interfaces import SupportsCrossEncoding , SupportsQuant , SupportsV0Only
30- from .utils import WeightsMapper , maybe_prefix
30+ from .utils import AutoWeightsLoader , WeightsMapper , maybe_prefix
3131
3232
3333class BertEmbedding (nn .Module ):
@@ -44,8 +44,9 @@ def __init__(self, config: BertConfig):
4444 config .type_vocab_size , config .hidden_size )
4545 self .LayerNorm = nn .LayerNorm (config .hidden_size ,
4646 eps = config .layer_norm_eps )
47- self .position_ids = nn .Parameter (
48- torch .empty ((1 , config .max_position_embeddings )), )
47+ self .register_buffer (
48+ "position_ids" ,
49+ torch .arange (config .max_position_embeddings ).expand ((1 , - 1 )))
4950
5051 self .position_embedding_type = config .position_embedding_type
5152 if self .position_embedding_type != "absolute" :
@@ -470,26 +471,28 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
470471 self .classifier , self .bert .pooler )
471472
472473 def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
474+ bert_weights = []
475+ classifier_weights = []
473476
474- self_weights = []
475-
476- def weight_filter ():
477- for name , weight in weights :
478- if name .startswith ("bert." ):
479- yield (name [len ("bert." ):], weight )
480- else :
481- self_weights .append ((name , weight ))
477+ for name , weight in weights :
478+ if name .startswith ("bert." ):
479+ bert_weights .append ((name , weight ))
480+ else :
481+ classifier_weights .append ((name , weight ))
482482
483- self .bert .load_weights (weight_filter ())
483+ loader = AutoWeightsLoader (self )
484+ loaded_params = loader .load_weights (bert_weights )
484485
485486 params_dict = dict (self .named_parameters ())
486-
487- for name , loaded_weight in self_weights :
488- if name .startswith ("classifier" ):
487+ for name , loaded_weight in classifier_weights :
488+ if name in params_dict :
489489 param = params_dict [name ]
490490 weight_loader = getattr (param , "weight_loader" ,
491491 default_weight_loader )
492492 weight_loader (param , loaded_weight )
493+ loaded_params .add (name )
494+
495+ return loaded_params
493496
494497 def pooler (
495498 self ,
0 commit comments