Skip to content

Commit 83c2b7d

Browse files
committed
feat(bert): simplify load_weights
Signed-off-by: gjgjos <gjgjos@naver.com>
1 parent b220766 commit 83c2b7d

File tree

1 file changed

+35
-52
lines changed

1 file changed

+35
-52
lines changed

vllm/model_executor/models/bert.py

Lines changed: 35 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -755,61 +755,44 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
755755
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
756756
)
757757

758-
weights_list = list(weights)
759-
loaded: set[str] = set()
758+
def _strip(name: str) -> str:
759+
for p in ("model.", "bert."):
760+
if name.startswith(p):
761+
name = name[len(p) :]
762+
return name
760763

764+
weights_list = list(weights)
761765
model_side: list[tuple[str, torch.Tensor]] = []
766+
mlm_side: list[tuple[str, torch.Tensor]] = []
767+
762768
for k, w in weights_list:
763-
if k.startswith("cls.predictions."):
764-
continue
765-
name = k
766-
if name.startswith("model."):
767-
name = name[len("model.") :]
768-
if name.startswith("bert."):
769-
name = name[len("bert.") :]
770-
model_side.append((name, w))
771-
772-
other, stacked = self.model._load_weights(model_side)
773-
loaded.update({"model." + n for n in stacked})
774-
775-
other_prefixed = [("model." + n, w) for (n, w) in other]
776-
loader_top = AutoWeightsLoader(
777-
self, skip_prefixes=["pooler.", "mlm_head.", "lm_head."]
778-
)
779-
loaded_other = loader_top.load_weights(other_prefixed)
780-
loaded.update(loaded_other)
781-
782-
name_map = {
783-
"cls.predictions.transform.dense.weight": "mlm_head.dense.weight",
784-
"cls.predictions.transform.dense.bias": "mlm_head.dense.bias",
785-
"cls.predictions.transform.LayerNorm.weight": "mlm_head.layer_norm.weight",
786-
"cls.predictions.transform.LayerNorm.bias": "mlm_head.layer_norm.bias",
787-
"cls.predictions.decoder.weight": "mlm_head.decoder.weight",
788-
"cls.predictions.decoder.bias": "mlm_head.decoder.bias",
789-
}
790-
extras: list[tuple[str, torch.Tensor]] = []
791-
for k, w in weights_list:
792-
name = k
793-
if name.startswith("model."):
794-
name = name[len("model.") :]
795-
if name.startswith("bert."):
796-
name = name[len("bert.") :]
797-
tgt = name_map.get(name)
798-
if tgt is not None:
799-
extras.append((tgt, w))
800-
801-
if extras:
802-
mlm_loader = AutoWeightsLoader(self)
803-
loaded_mlm = mlm_loader.load_weights(extras)
804-
loaded.update(loaded_mlm)
805-
806-
try:
807-
emb_w = self.model.embeddings.word_embeddings.weight
808-
dec_w = self.mlm_head.decoder.weight
809-
if dec_w.shape == emb_w.shape and dec_w.data_ptr() != emb_w.data_ptr():
810-
self.mlm_head.decoder.weight = emb_w
811-
except Exception:
812-
pass
769+
name = _strip(k)
770+
if name.startswith("cls.predictions."):
771+
mlm_side.append((name, w))
772+
else:
773+
model_side.append((name, w))
774+
775+
loaded: set[str] = set()
776+
loaded_model = self.model.load_weights(model_side)
777+
loaded.update({"model." + n for n in loaded_model})
778+
779+
if mlm_side:
780+
name_map = {
781+
"cls.predictions.transform.dense.weight": "mlm_head.dense.weight",
782+
"cls.predictions.transform.dense.bias": "mlm_head.dense.bias",
783+
("cls.predictions.transform.LayerNorm.weight"): (
784+
"mlm_head.layer_norm.weight"
785+
),
786+
("cls.predictions.transform.LayerNorm.bias"): (
787+
"mlm_head.layer_norm.bias"
788+
),
789+
"cls.predictions.decoder.weight": "mlm_head.decoder.weight",
790+
"cls.predictions.decoder.bias": "mlm_head.decoder.bias",
791+
}
792+
remapped = [(name_map[n], w) for n, w in mlm_side if n in name_map]
793+
if remapped:
794+
loaded_mlm = AutoWeightsLoader(self).load_weights(remapped)
795+
loaded.update(loaded_mlm)
813796

814797
return loaded
815798

0 commit comments

Comments
 (0)