Skip to content

Commit 23f922b

Browse files
Isotr0pyYuqi Zhang
authored andcommitted
[Misc] Allow AutoWeightsLoader to skip loading weights with specific substr in name (vllm-project#18358)
Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
1 parent 50bcde5 commit 23f922b

File tree

18 files changed

+116
-109
lines changed

18 files changed

+116
-109
lines changed

tests/models/test_utils.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,73 @@ def weight_generator():
7777
assert torch.all(
7878
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
7979
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
80+
81+
82+
def test_module_skip_prefix():
83+
"""Ensure the auto weight loader can skip prefix."""
84+
mod = ModuleWithNestedBatchNorm()
85+
# Run some data through the module with batchnorm
86+
mod(torch.Tensor([[1, 2], [3, 4]]))
87+
88+
# Try to load the weights to a new instance
89+
def weight_generator():
90+
# weights needed to be filtered out
91+
redundant_weights = {
92+
"prefix.bn.weight": torch.Tensor([1, 2]),
93+
"prefix.bn.bias": torch.Tensor([3, 4]),
94+
}
95+
yield from (mod.state_dict() | redundant_weights).items()
96+
97+
new_mod = ModuleWithNestedBatchNorm()
98+
99+
assert not torch.all(
100+
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
101+
assert not torch.all(
102+
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
103+
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
104+
105+
loader = AutoWeightsLoader(new_mod, skip_prefixes=["prefix."])
106+
loader.load_weights(weight_generator())
107+
108+
# Ensure the stats are updated
109+
assert torch.all(
110+
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
111+
assert torch.all(
112+
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
113+
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
114+
115+
116+
def test_module_skip_substr():
117+
"""Ensure the auto weight loader can skip prefix."""
118+
mod = ModuleWithNestedBatchNorm()
119+
# Run some data through the module with batchnorm
120+
mod(torch.Tensor([[1, 2], [3, 4]]))
121+
122+
# Try to load the weights to a new instance
123+
def weight_generator():
124+
# weights needed to be filtered out
125+
redundant_weights = {
126+
"nested_mod.0.substr.weight": torch.Tensor([1, 2]),
127+
"nested_mod.0.substr.bias": torch.Tensor([3, 4]),
128+
"nested_mod.substr.weight": torch.Tensor([1, 2]),
129+
"nested_mod.substr.bias": torch.Tensor([3, 4]),
130+
}
131+
yield from (mod.state_dict() | redundant_weights).items()
132+
133+
new_mod = ModuleWithNestedBatchNorm()
134+
135+
assert not torch.all(
136+
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
137+
assert not torch.all(
138+
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
139+
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
140+
141+
loader = AutoWeightsLoader(new_mod, skip_substrs=["substr."])
142+
loader.load_weights(weight_generator())
143+
144+
# Ensure the stats are updated
145+
assert torch.all(
146+
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
147+
assert torch.all(
148+
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
149+
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1

vllm/model_executor/models/granite.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -478,18 +478,14 @@ def make_empty_intermediate_tensors(
478478

479479
def load_weights(self, weights: Iterable[tuple[str,
480480
torch.Tensor]]) -> set[str]:
481-
skip_prefixes = [
482-
"rotary_emb.inv_freq",
483-
# Models trained using ColossalAI may include these tensors in
484-
# the checkpoint. Skip them.
485-
"rotary_emb.cos_cached",
486-
"rotary_emb.sin_cached",
487-
]
488481
# With tie_word_embeddings, we can skip lm_head.weight
489482
# The weight might appear unnecessarily in the files if the model is
490483
# processed with quantization, LoRA, fine-tuning, etc.
491-
if self.config.tie_word_embeddings:
492-
skip_prefixes.append("lm_head.weight")
484+
skip_prefixes = (["lm_head."]
485+
if self.config.tie_word_embeddings else None)
493486

494-
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
487+
loader = AutoWeightsLoader(
488+
self,
489+
skip_prefixes=skip_prefixes,
490+
)
495491
return loader.load_weights(weights)

vllm/model_executor/models/grok1.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -550,10 +550,12 @@ def compute_logits(
550550

551551
def load_weights(self, weights: Iterable[tuple[str,
552552
torch.Tensor]]) -> set[str]:
553-
skip_prefixes = ["rotary_emb.inv_freq"]
554553
# Skip lm_head when tie_word_embeddings is True
555-
if self.config.tie_word_embeddings:
556-
skip_prefixes.append("lm_head")
554+
skip_prefixes = (["lm_head"]
555+
if self.config.tie_word_embeddings else None)
557556

558-
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
557+
loader = AutoWeightsLoader(
558+
self,
559+
skip_prefixes=skip_prefixes,
560+
)
559561
return loader.load_weights(weights)

vllm/model_executor/models/mixtral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,5 +482,5 @@ def compute_logits(
482482

483483
def load_weights(self, weights: Iterable[tuple[str,
484484
torch.Tensor]]) -> set[str]:
485-
loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"])
485+
loader = AutoWeightsLoader(self)
486486
return loader.load_weights(weights)

vllm/model_executor/models/mixtral_quant.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -447,8 +447,5 @@ def compute_logits(
447447

448448
def load_weights(self, weights: Iterable[tuple[str,
449449
torch.Tensor]]) -> set[str]:
450-
loader = AutoWeightsLoader(
451-
self,
452-
skip_prefixes=(["rotary_emb.inv_freq"]),
453-
)
450+
loader = AutoWeightsLoader(self)
454451
return loader.load_weights(weights)

vllm/model_executor/models/nemotron.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -502,14 +502,5 @@ def compute_logits(
502502

503503
def load_weights(self, weights: Iterable[tuple[str,
504504
torch.Tensor]]) -> set[str]:
505-
loader = AutoWeightsLoader(
506-
self,
507-
skip_prefixes=([
508-
"rotary_emb.inv_freq",
509-
# Models trained using ColossalAI may include these tensors in
510-
# the checkpoint. Skip them.
511-
"rotary_emb.cos_cached",
512-
"rotary_emb.sin_cached"
513-
]),
514-
)
505+
loader = AutoWeightsLoader(self)
515506
return loader.load_weights(weights)

vllm/model_executor/models/olmo.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -382,19 +382,7 @@ def load_weights(self, weights: Iterable[tuple[str,
382382
torch.Tensor]]) -> set[str]:
383383
loader = AutoWeightsLoader(
384384
self,
385-
skip_prefixes=([
386-
"rotary_emb.inv_freq",
387-
# Models trained using ColossalAI may include these tensors in
388-
# the checkpoint. Skip them.
389-
"rotary_emb.cos_cached",
390-
"rotary_emb.sin_cached",
391-
"lm_head.weight"
392-
] if self.config.tie_word_embeddings else [
393-
"rotary_emb.inv_freq",
394-
# Models trained using ColossalAI may include these tensors in
395-
# the checkpoint. Skip them.
396-
"rotary_emb.cos_cached",
397-
"rotary_emb.sin_cached"
398-
]),
385+
skip_prefixes=(["lm_head.weight"]
386+
if self.config.tie_word_embeddings else None),
399387
)
400388
return loader.load_weights(weights)

vllm/model_executor/models/olmo2.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -403,19 +403,7 @@ def compute_logits(
403403
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
404404
loader = AutoWeightsLoader(
405405
self,
406-
skip_prefixes=([
407-
"rotary_emb.inv_freq",
408-
# Models trained using ColossalAI may include these tensors in
409-
# the checkpoint. Skip them.
410-
"rotary_emb.cos_cached",
411-
"rotary_emb.sin_cached",
412-
"lm_head.weight"
413-
] if self.config.tie_word_embeddings else [
414-
"rotary_emb.inv_freq",
415-
# Models trained using ColossalAI may include these tensors in
416-
# the checkpoint. Skip them.
417-
"rotary_emb.cos_cached",
418-
"rotary_emb.sin_cached"
419-
]),
406+
skip_prefixes=(["lm_head.weight"]
407+
if self.config.tie_word_embeddings else None),
420408
)
421409
return loader.load_weights(weights)

vllm/model_executor/models/olmoe.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -442,8 +442,5 @@ def compute_logits(self, hidden_states: torch.Tensor,
442442

443443
def load_weights(self, weights: Iterable[tuple[str,
444444
torch.Tensor]]) -> set[str]:
445-
loader = AutoWeightsLoader(
446-
self,
447-
skip_prefixes=["rotary_emb.inv_freq"],
448-
)
445+
loader = AutoWeightsLoader(self)
449446
return loader.load_weights(weights)

vllm/model_executor/models/orion.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -344,14 +344,5 @@ def compute_logits(
344344

345345
def load_weights(self, weights: Iterable[tuple[str,
346346
torch.Tensor]]) -> set[str]:
347-
loader = AutoWeightsLoader(
348-
self,
349-
skip_prefixes=([
350-
"rotary_emb.inv_freq",
351-
# Models trained using ColossalAI may include these tensors in
352-
# the checkpoint. Skip them.
353-
"rotary_emb.cos_cached",
354-
"rotary_emb.sin_cached"
355-
]),
356-
)
347+
loader = AutoWeightsLoader(self)
357348
return loader.load_weights(weights)

0 commit comments

Comments
 (0)