Skip to content

Commit c0d7622

Browse files
hmellorcharlifu
authored andcommitted
Improve weight loading for encoder models in Transformers backend (vllm-project#25289)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: charlifu <charlifu@amd.com>
1 parent b8eefb7 commit c0d7622

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

vllm/model_executor/models/transformers.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -702,21 +702,45 @@ def load_weights(self, weights: Iterable[tuple[str,
702702
class TransformersModel(TransformersBase):
703703
hf_to_vllm_mapper = WeightsMapper(
704704
orig_to_new_prefix={
705+
# Handle BERT-like models
706+
"bert": "model",
705707
# Add `model.` prefix for base model checkpoints
706708
"": "model.",
707-
# Remove `model.` from places it should not be
709+
# Remove `model.` prefix if it was already there
708710
"model.model.": "model.",
711+
# Pooling adapters will be adjacent to `model`
712+
"model.pooler": "pooler",
709713
"model.score": "score",
714+
# Classifier adapter's classifier layer is renamed to score
715+
"model.classifier": "score",
716+
},
717+
orig_to_new_suffix={
718+
# Replace legacy suffixes used for norms
719+
".gamma": ".weight",
720+
".beta": ".bias",
710721
})
711722

712723
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
713724
super().__init__(vllm_config=vllm_config, prefix=prefix)
714725

715-
# Some encoder models have the position_ids buffer in the checkpoint
726+
# After creating a pooling model, `pooler` will be duplicated.
727+
# The one inside `model` comes from the Transformers modelling code.
728+
# The one after `model` is an adapter from vLLM.
729+
# We want to use the adapter so we nullify the original pooler.
730+
if getattr(self.model, "pooler", None) is not None:
731+
self.skip_prefixes.append("pooler.")
732+
self.model.pooler = torch.nn.Identity()
733+
734+
# Some encoder models have the position_ids buffer in the checkpoint.
716735
# vLLM will always pass position_ids as an argument, so we skip loading
717736
# the buffer if it exists
718737
self.skip_substrs.append("position_ids")
719738

739+
# Some encoder models have the bias of the final classifier layer
740+
# in the checkpoint. vLLM does not use this bias, so we skip loading
741+
# it if it exists
742+
self.skip_substrs.append("score.bias")
743+
720744
def create_attention_instances(
721745
self, attn_type: AttentionType = AttentionType.DECODER):
722746
# TODO(hmellor): Better way to detect encoder models

0 commit comments

Comments
 (0)