2222from vllm .model_executor .layers .quantization import QuantizationConfig
2323from vllm .model_executor .layers .vocab_parallel_embedding import (
2424 VocabParallelEmbedding )
25- from vllm .model_executor .model_loader .weight_utils import default_weight_loader
2625from vllm .model_executor .pooling_metadata import PoolingMetadata
2726from vllm .sequence import IntermediateTensors , PoolerOutput
2827
2928from .interfaces import SupportsCrossEncoding , SupportsQuant , SupportsV0Only
30- from .utils import WeightsMapper , maybe_prefix
29+ from .utils import AutoWeightsLoader , WeightsMapper , maybe_prefix
3130
3231
3332class BertEmbedding (nn .Module ):
@@ -44,9 +43,11 @@ def __init__(self, config: BertConfig):
4443 config .type_vocab_size , config .hidden_size )
4544 self .LayerNorm = nn .LayerNorm (config .hidden_size ,
4645 eps = config .layer_norm_eps )
47- self .position_ids = nn .Parameter (
48- torch .empty ((1 , config .max_position_embeddings )), )
4946
47+ self .register_buffer (
48+ "position_ids" ,
49+ torch .arange (config .max_position_embeddings ).unsqueeze (0 ),
50+ )
5051 self .position_embedding_type = config .position_embedding_type
5152 if self .position_embedding_type != "absolute" :
5253 raise ValueError ("Only 'absolute' position_embedding_type" +
@@ -358,45 +359,45 @@ def load_weights(self, weights: Iterable[tuple[str,
358359 ("qkv_proj" , "value" , "v" ),
359360 ]
360361
362+ loaded_stacked_params = []
363+ other_weights = []
361364 params_dict = dict (self .named_parameters ())
362- loaded_params : set [str ] = set ()
363365 for name , loaded_weight in weights :
364- if self .pooler is None and "pooler" in name :
365- continue
366366 for (param_name , weight_name , shard_id ) in stacked_params_mapping :
367367 if weight_name not in name :
368368 continue
369+
369370 name = name .replace (weight_name , param_name )
370- # Skip loading extra bias for GPTQ models.
371- if name .endswith (".bias" ) and name not in params_dict :
371+ if name not in params_dict :
372372 continue
373373 param = params_dict [name ]
374374 weight_loader = param .weight_loader
375375 weight_loader (param , loaded_weight , shard_id )
376+ loaded_stacked_params .append (name )
376377 break
377378 else :
378- # Skip loading extra bias for GPTQ models.
379- if name .endswith (".bias" ) and name not in params_dict :
380- continue
381- param = params_dict [name ]
382- weight_loader = getattr (param , "weight_loader" ,
383- default_weight_loader )
384- weight_loader (param , loaded_weight )
385- loaded_params .add (name )
379+ if name in params_dict :
380+ other_weights .append ((name , loaded_weight ))
381+
382+ loader = AutoWeightsLoader (
383+ self ,
384+ skip_prefixes = (["pooler." ] if self .pooler is None else []),
385+ )
386+ loaded_params = loader .load_weights (other_weights )
387+ loaded_params .update (loaded_stacked_params )
386388 return loaded_params
387389
388390
389391class BertEmbeddingModel (nn .Module , SupportsV0Only , SupportsQuant ):
390392 """A model that uses Bert to provide embedding functionalities.
391393
392- This class encapsulates the BertModel and provides an interface for
393- embedding operations and customized pooling functions.
394+ This class encapsulates the BertModel and provides an interface for
395+ embedding operations and customized pooling functions.
394396
395- Attributes:
396- model: An instance of BertModel used for forward operations.
397- _pooler: An instance of Pooler used for pooling operations.
398- """
399- hf_to_vllm_mapper = WeightsMapper (orig_to_new_prefix = {"model." : "" })
397+ Attributes:
398+ model: An instance of BertModel used for forward operations.
399+ _pooler: An instance of Pooler used for pooling operations.
400+ """
400401
401402 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
402403 super ().__init__ ()
@@ -425,10 +426,15 @@ def pooler(
425426 return self ._pooler (hidden_states , pooling_metadata )
426427
427428 def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
428- weights = self .hf_to_vllm_mapper .apply (weights )
429- weights = ((name , data ) for name , data in weights
430- if not name .startswith ("lm_head." ))
431- self .model .load_weights (weights )
429+ weights_list = list (weights )
430+
431+ has_model_prefix = any (
432+ name .startswith ("model." ) for name , _ in weights_list )
433+ if not has_model_prefix :
434+ mapper = WeightsMapper (orig_to_new_prefix = {"" : "model." })
435+
436+ loader = AutoWeightsLoader (self , skip_prefixes = ["lm_head." ])
437+ return loader .load_weights (weights_list , mapper = mapper )
432438
433439 def _build_model (self ,
434440 vllm_config : VllmConfig ,
@@ -470,26 +476,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
470476 self .classifier , self .bert .pooler )
471477
472478 def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
473-
474- self_weights = []
475-
476- def weight_filter ():
477- for name , weight in weights :
478- if name .startswith ("bert." ):
479- yield (name [len ("bert." ):], weight )
480- else :
481- self_weights .append ((name , weight ))
482-
483- self .bert .load_weights (weight_filter ())
484-
485- params_dict = dict (self .named_parameters ())
486-
487- for name , loaded_weight in self_weights :
488- if name .startswith ("classifier" ):
489- param = params_dict [name ]
490- weight_loader = getattr (param , "weight_loader" ,
491- default_weight_loader )
492- weight_loader (param , loaded_weight )
479+ loader = AutoWeightsLoader (self )
480+ loaded_params = loader .load_weights (weights )
481+ return loaded_params
493482
494483 def pooler (
495484 self ,
0 commit comments