Skip to content

Commit 706a735

Browse files
committed
feat(bert): simplify load_weights
Signed-off-by: gjgjos <gjgjos@naver.com>
1 parent b220766 commit 706a735

File tree

1 file changed

+40
-58
lines changed

1 file changed

+40
-58
lines changed

vllm/model_executor/models/bert.py

Lines changed: 40 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -640,13 +640,12 @@ def forward(
640640
max_len: int = int(lens_tensor.max().item())
641641

642642
if isinstance(hidden_states, list):
643-
hidden_states = torch.cat(hidden_states, dim=0)
644-
assert isinstance(hidden_states, torch.Tensor) and hidden_states.dim() == 2
645-
646-
device = hidden_states.device
647-
H = int(self.mlm_head.dense.in_features)
643+
hs_list = hidden_states
644+
else:
645+
hs_list = torch.split(hidden_states, lens, dim=0)
648646

649-
hs_list = torch.split(hidden_states, lens, dim=0)
647+
device = hs_list[0].device
648+
H = hs_list[0].size(-1)
650649

651650
padded = hidden_states.new_zeros((B, max_len, H))
652651
valid_mask = torch.zeros((B, max_len), dtype=torch.bool, device=device)
@@ -755,61 +754,44 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
755754
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
756755
)
757756

758-
weights_list = list(weights)
759-
loaded: set[str] = set()
757+
def _strip(name: str) -> str:
758+
for p in ("model.", "bert."):
759+
if name.startswith(p):
760+
name = name[len(p) :]
761+
return name
760762

763+
weights_list = list(weights)
761764
model_side: list[tuple[str, torch.Tensor]] = []
765+
mlm_side: list[tuple[str, torch.Tensor]] = []
766+
762767
for k, w in weights_list:
763-
if k.startswith("cls.predictions."):
764-
continue
765-
name = k
766-
if name.startswith("model."):
767-
name = name[len("model.") :]
768-
if name.startswith("bert."):
769-
name = name[len("bert.") :]
770-
model_side.append((name, w))
771-
772-
other, stacked = self.model._load_weights(model_side)
773-
loaded.update({"model." + n for n in stacked})
774-
775-
other_prefixed = [("model." + n, w) for (n, w) in other]
776-
loader_top = AutoWeightsLoader(
777-
self, skip_prefixes=["pooler.", "mlm_head.", "lm_head."]
778-
)
779-
loaded_other = loader_top.load_weights(other_prefixed)
780-
loaded.update(loaded_other)
781-
782-
name_map = {
783-
"cls.predictions.transform.dense.weight": "mlm_head.dense.weight",
784-
"cls.predictions.transform.dense.bias": "mlm_head.dense.bias",
785-
"cls.predictions.transform.LayerNorm.weight": "mlm_head.layer_norm.weight",
786-
"cls.predictions.transform.LayerNorm.bias": "mlm_head.layer_norm.bias",
787-
"cls.predictions.decoder.weight": "mlm_head.decoder.weight",
788-
"cls.predictions.decoder.bias": "mlm_head.decoder.bias",
789-
}
790-
extras: list[tuple[str, torch.Tensor]] = []
791-
for k, w in weights_list:
792-
name = k
793-
if name.startswith("model."):
794-
name = name[len("model.") :]
795-
if name.startswith("bert."):
796-
name = name[len("bert.") :]
797-
tgt = name_map.get(name)
798-
if tgt is not None:
799-
extras.append((tgt, w))
800-
801-
if extras:
802-
mlm_loader = AutoWeightsLoader(self)
803-
loaded_mlm = mlm_loader.load_weights(extras)
804-
loaded.update(loaded_mlm)
805-
806-
try:
807-
emb_w = self.model.embeddings.word_embeddings.weight
808-
dec_w = self.mlm_head.decoder.weight
809-
if dec_w.shape == emb_w.shape and dec_w.data_ptr() != emb_w.data_ptr():
810-
self.mlm_head.decoder.weight = emb_w
811-
except Exception:
812-
pass
768+
name = _strip(k)
769+
if name.startswith("cls.predictions."):
770+
mlm_side.append((name, w))
771+
else:
772+
model_side.append((name, w))
773+
774+
loaded: set[str] = set()
775+
loaded_model = self.model.load_weights(model_side)
776+
loaded.update({"model." + n for n in loaded_model})
777+
778+
if mlm_side:
779+
name_map = {
780+
"cls.predictions.transform.dense.weight": "mlm_head.dense.weight",
781+
"cls.predictions.transform.dense.bias": "mlm_head.dense.bias",
782+
("cls.predictions.transform.LayerNorm.weight"): (
783+
"mlm_head.layer_norm.weight"
784+
),
785+
("cls.predictions.transform.LayerNorm.bias"): (
786+
"mlm_head.layer_norm.bias"
787+
),
788+
"cls.predictions.decoder.weight": "mlm_head.decoder.weight",
789+
"cls.predictions.decoder.bias": "mlm_head.decoder.bias",
790+
}
791+
remapped = [(name_map[n], w) for n, w in mlm_side if n in name_map]
792+
if remapped:
793+
loaded_mlm = AutoWeightsLoader(self).load_weights(remapped)
794+
loaded.update(loaded_mlm)
813795

814796
return loaded
815797

0 commit comments

Comments
 (0)