Skip to content

Commit

Permalink
Fix weight loading for tied word embedding when TP > 1 (sgl-project#2009
Browse files Browse the repository at this point in the history
)
  • Loading branch information
merrymercy authored Nov 12, 2024
1 parent befc6be commit 530ae1b
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions python/sglang/srt/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
]
params_dict = dict(self.named_parameters())

load_tie_word_embeddings = (
hasattr(self.config, "tie_word_embeddings")
and self.config.tie_word_embeddings
and "lm_head.weight" in params_dict
)

for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name:
continue
Expand Down Expand Up @@ -412,15 +418,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)

if (
hasattr(self.config, "tie_word_embeddings")
and self.config.tie_word_embeddings
and "lm_head.weight" in params_dict
):
if load_tie_word_embeddings and name == "model.embed_tokens.weight":
embed_tokens_weight = loaded_weight

if load_tie_word_embeddings:
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, self.model.embed_tokens.weight)
weight_loader(param, embed_tokens_weight)

apply_torchao_config_(self, params_dict, set(["proj.weight"]))

Expand Down

0 comments on commit 530ae1b

Please sign in to comment.