@@ -702,21 +702,45 @@ def load_weights(self, weights: Iterable[tuple[str,
702702class TransformersModel (TransformersBase ):
703703 hf_to_vllm_mapper = WeightsMapper (
704704 orig_to_new_prefix = {
705+ # Handle BERT-like models
706+ "bert" : "model" ,
705707 # Add `model.` prefix for base model checkpoints
706708 "" : "model." ,
707- # Remove `model.` from places it should not be
709+ # Remove `model.` prefix if it was already there
708710 "model.model." : "model." ,
711+ # Pooling adapters will be adjacent to `model`
712+ "model.pooler" : "pooler" ,
709713 "model.score" : "score" ,
714+ # Classifier adapter's classifier layer is renamed to score
715+ "model.classifier" : "score" ,
716+ },
717+ orig_to_new_suffix = {
718+ # Replace legacy suffixes used for norms
719+ ".gamma" : ".weight" ,
720+ ".beta" : ".bias" ,
710721 })
711722
712723 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
713724 super ().__init__ (vllm_config = vllm_config , prefix = prefix )
714725
715- # Some encoder models have the position_ids buffer in the checkpoint
726+ # After creating a pooling model, `pooler` will be duplicated.
727+ # The one inside `model` comes from the Transformers modelling code.
728+ # The one after `model` is an adapter from vLLM.
729+ # We want to use the adapter so we nullify the original pooler.
730+ if getattr (self .model , "pooler" , None ) is not None :
731+ self .skip_prefixes .append ("pooler." )
732+ self .model .pooler = torch .nn .Identity ()
733+
734+ # Some encoder models have the position_ids buffer in the checkpoint.
716735 # vLLM will always pass position_ids as an argument, so we skip loading
717736 # the buffer if it exists
718737 self .skip_substrs .append ("position_ids" )
719738
739+ # Some encoder models have the bias of the final classifier layer
740+ # in the checkpoint. vLLM does not use this bias, so we skip loading
741+ # it if it exists
742+ self .skip_substrs .append ("score.bias" )
743+
720744 def create_attention_instances (
721745 self , attn_type : AttentionType = AttentionType .DECODER ):
722746 # TODO(hmellor): Better way to detect encoder models
0 commit comments