2222from vllm .model_executor .layers .quantization import QuantizationConfig
2323from vllm .model_executor .layers .vocab_parallel_embedding import (
2424 VocabParallelEmbedding )
25- from vllm .model_executor .model_loader .weight_utils import default_weight_loader
2625from vllm .model_executor .pooling_metadata import PoolingMetadata
2726from vllm .sequence import IntermediateTensors , PoolerOutput
2827
@@ -44,9 +43,13 @@ def __init__(self, config: BertConfig):
4443 config .type_vocab_size , config .hidden_size )
4544 self .LayerNorm = nn .LayerNorm (config .hidden_size ,
4645 eps = config .layer_norm_eps )
47- self .register_buffer (
48- "position_ids" ,
49- torch .arange (config .max_position_embeddings ).expand ((1 , - 1 )))
46+
47+ # Use nn.Parameter with requires_grad=False to maintain compatibility
48+ # with existing HF checkpoints while ensuring position_ids are
49+ # non-trainable.
50+ self .position_ids = nn .Parameter (torch .empty (
51+ (1 , config .max_position_embeddings )),
52+ requires_grad = False )
5053
5154 self .position_embedding_type = config .position_embedding_type
5255 if self .position_embedding_type != "absolute" :
@@ -359,45 +362,44 @@ def load_weights(self, weights: Iterable[tuple[str,
359362 ("qkv_proj" , "value" , "v" ),
360363 ]
361364
365+ loaded_stacked_params = []
366+ other_weights = []
362367 params_dict = dict (self .named_parameters ())
363- loaded_params : set [str ] = set ()
364368 for name , loaded_weight in weights :
365- if self .pooler is None and "pooler" in name :
366- continue
367369 for (param_name , weight_name , shard_id ) in stacked_params_mapping :
368370 if weight_name not in name :
369371 continue
372+
370373 name = name .replace (weight_name , param_name )
371- # Skip loading extra bias for GPTQ models.
372- if name .endswith (".bias" ) and name not in params_dict :
374+ if name not in params_dict :
373375 continue
374376 param = params_dict [name ]
375377 weight_loader = param .weight_loader
376378 weight_loader (param , loaded_weight , shard_id )
379+ loaded_stacked_params .append (name )
377380 break
378381 else :
379- # Skip loading extra bias for GPTQ models.
380- if name . endswith ( ".bias" ) and name not in params_dict :
381- continue
382- param = params_dict [ name ]
383- weight_loader = getattr ( param , "weight_loader" ,
384- default_weight_loader )
385- weight_loader ( param , loaded_weight )
386- loaded_params .add ( name )
382+ other_weights . append (( name , loaded_weight ))
383+
384+ loader = AutoWeightsLoader (
385+ self ,
386+ skip_prefixes = ([ "pooler." ] if self . pooler is None else []) ,
387+ )
388+ loaded_params = loader . load_weights ( other_weights )
389+ loaded_params .update ( loaded_stacked_params )
387390 return loaded_params
388391
389392
390393class BertEmbeddingModel (nn .Module , SupportsV0Only , SupportsQuant ):
391394 """A model that uses Bert to provide embedding functionalities.
392395
393- This class encapsulates the BertModel and provides an interface for
394- embedding operations and customized pooling functions.
396+ This class encapsulates the BertModel and provides an interface for
397+ embedding operations and customized pooling functions.
395398
396- Attributes:
397- model: An instance of BertModel used for forward operations.
398- _pooler: An instance of Pooler used for pooling operations.
399- """
400- hf_to_vllm_mapper = WeightsMapper (orig_to_new_prefix = {"model." : "" })
399+ Attributes:
400+ model: An instance of BertModel used for forward operations.
401+ _pooler: An instance of Pooler used for pooling operations.
402+ """
401403
402404 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
403405 super ().__init__ ()
@@ -426,10 +428,15 @@ def pooler(
426428 return self ._pooler (hidden_states , pooling_metadata )
427429
428430 def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
429- weights = self .hf_to_vllm_mapper .apply (weights )
430- weights = ((name , data ) for name , data in weights
431- if not name .startswith ("lm_head." ))
432- self .model .load_weights (weights )
431+ weights_list = list (weights )
432+
433+ has_model_prefix = any (
434+ name .startswith ("model." ) for name , _ in weights_list )
435+ if not has_model_prefix :
436+ mapper = WeightsMapper (orig_to_new_prefix = {"" : "model." })
437+
438+ loader = AutoWeightsLoader (self , skip_prefixes = ["lm_head." ])
439+ return loader .load_weights (weights_list , mapper = mapper )
433440
434441 def _build_model (self ,
435442 vllm_config : VllmConfig ,
@@ -471,27 +478,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
471478 self .classifier , self .bert .pooler )
472479
473480 def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
474- bert_weights = []
475- classifier_weights = []
476-
477- for name , weight in weights :
478- if name .startswith ("bert." ):
479- bert_weights .append ((name , weight ))
480- else :
481- classifier_weights .append ((name , weight ))
482-
483481 loader = AutoWeightsLoader (self )
484- loaded_params = loader .load_weights (bert_weights )
485-
486- params_dict = dict (self .named_parameters ())
487- for name , loaded_weight in classifier_weights :
488- if name in params_dict :
489- param = params_dict [name ]
490- weight_loader = getattr (param , "weight_loader" ,
491- default_weight_loader )
492- weight_loader (param , loaded_weight )
493- loaded_params .add (name )
494-
482+ loaded_params = loader .load_weights (weights )
495483 return loaded_params
496484
497485 def pooler (
0 commit comments