diff --git a/tests/models/test_utils.py b/tests/models/test_utils.py index d61c7d2d5000..a16384efe195 100644 --- a/tests/models/test_utils.py +++ b/tests/models/test_utils.py @@ -77,3 +77,73 @@ def weight_generator(): assert torch.all( new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1 + + +def test_module_skip_prefix(): + """Ensure the auto weight loader can skip prefix.""" + mod = ModuleWithNestedBatchNorm() + # Run some data through the module with batchnorm + mod(torch.Tensor([[1, 2], [3, 4]])) + + # Try to load the weights to a new instance + def weight_generator(): + # weights needed to be filtered out + redundant_weights = { + "prefix.bn.weight": torch.Tensor([1, 2]), + "prefix.bn.bias": torch.Tensor([3, 4]), + } + yield from (mod.state_dict() | redundant_weights).items() + + new_mod = ModuleWithNestedBatchNorm() + + assert not torch.all( + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + assert not torch.all( + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0 + + loader = AutoWeightsLoader(new_mod, skip_prefixes=["prefix."]) + loader.load_weights(weight_generator()) + + # Ensure the stats are updated + assert torch.all( + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + assert torch.all( + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1 + + +def test_module_skip_substr(): + """Ensure the auto weight loader can skip prefix.""" + mod = ModuleWithNestedBatchNorm() + # Run some data through the module with batchnorm + mod(torch.Tensor([[1, 2], [3, 4]])) + + # Try to load the weights to a new instance + def weight_generator(): + # weights needed to be filtered out + redundant_weights = { + "nested_mod.0.substr.weight": torch.Tensor([1, 2]), + "nested_mod.0.substr.bias": torch.Tensor([3, 4]), + "nested_mod.substr.weight": torch.Tensor([1, 2]), + "nested_mod.substr.bias": torch.Tensor([3, 4]), + } + yield from (mod.state_dict() | redundant_weights).items() + + new_mod = ModuleWithNestedBatchNorm() + + assert not torch.all( + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + assert not torch.all( + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0 + + loader = AutoWeightsLoader(new_mod, skip_substrs=["substr."]) + loader.load_weights(weight_generator()) + + # Ensure the stats are updated + assert torch.all( + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + assert torch.all( + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1 diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index eed0820a5779..c49db653f735 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -478,18 +478,14 @@ def make_empty_intermediate_tensors( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - skip_prefixes = [ - "rotary_emb.inv_freq", - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - "rotary_emb.cos_cached", - "rotary_emb.sin_cached", - ] # With tie_word_embeddings, we can skip lm_head.weight # The weight might appear unnecessarily in the files if the model is # processed with quantization, LoRA, fine-tuning, etc. - if self.config.tie_word_embeddings: - skip_prefixes.append("lm_head.weight") + skip_prefixes = (["lm_head."] + if self.config.tie_word_embeddings else None) - loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + loader = AutoWeightsLoader( + self, + skip_prefixes=skip_prefixes, + ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index 6d2d16d098d4..578d31a851a9 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -550,10 +550,12 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - skip_prefixes = ["rotary_emb.inv_freq"] # Skip lm_head when tie_word_embeddings is True - if self.config.tie_word_embeddings: - skip_prefixes.append("lm_head") + skip_prefixes = (["lm_head"] + if self.config.tie_word_embeddings else None) - loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + loader = AutoWeightsLoader( + self, + skip_prefixes=skip_prefixes, + ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 1968bf9e68af..4823808e8906 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -482,5 +482,5 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"]) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index b6a0c9ec6fc1..f096f6a7996d 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -447,8 +447,5 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["rotary_emb.inv_freq"]), - ) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 0b5a102ea1f2..c5c5155a2df5 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -502,14 +502,5 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=([ - "rotary_emb.inv_freq", - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - "rotary_emb.cos_cached", - "rotary_emb.sin_cached" - ]), - ) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 26ca770d8493..fcb7c619a102 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -382,19 +382,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=([ - "rotary_emb.inv_freq", - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - "rotary_emb.cos_cached", - "rotary_emb.sin_cached", - "lm_head.weight" - ] if self.config.tie_word_embeddings else [ - "rotary_emb.inv_freq", - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - "rotary_emb.cos_cached", - "rotary_emb.sin_cached" - ]), + skip_prefixes=(["lm_head.weight"] + if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index e4dc0e0cc411..0a1fb10c186e 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -403,19 +403,7 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader( self, - skip_prefixes=([ - "rotary_emb.inv_freq", - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - "rotary_emb.cos_cached", - "rotary_emb.sin_cached", - "lm_head.weight" - ] if self.config.tie_word_embeddings else [ - "rotary_emb.inv_freq", - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - "rotary_emb.cos_cached", - "rotary_emb.sin_cached" - ]), + skip_prefixes=(["lm_head.weight"] + if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 9a07f57fd999..6364b89fb837 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -442,8 +442,5 @@ def compute_logits(self, hidden_states: torch.Tensor, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=["rotary_emb.inv_freq"], - ) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 1ccd1fe1f741..da2a194e6bdf 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -344,14 +344,5 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=([ - "rotary_emb.inv_freq", - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - "rotary_emb.cos_cached", - "rotary_emb.sin_cached" - ]), - ) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index b7bb3c45c633..418ff900ffd5 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -1228,9 +1228,7 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None: - weights = ((name, data) for name, data in weights - if "lora" not in name) - loader = AutoWeightsLoader(self) + loader = AutoWeightsLoader(self, skip_substrs=["lora"]) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 7f2e9fdf7c4e..d9917c26d1b1 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -660,8 +660,5 @@ def compute_logits(self, hidden_states: torch.Tensor, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["rotary_emb.inv_freq"]), - ) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 7cf98dc7a4ea..143b9f98b029 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -535,8 +535,5 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["rotary_emb.inv_freq"]), - ) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index aae5401721df..8a4c2850dda3 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -530,8 +530,5 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["rotary_emb.inv_freq"]), - ) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 8c78c846302a..53e5274aa574 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -500,14 +500,5 @@ def compute_logits(self, hidden_states: torch.Tensor, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=([ - "rotary_emb.inv_freq", - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - "rotary_emb.cos_cached", - "rotary_emb.sin_cached" - ]), - ) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 8c2ad6f19251..86ce813ddf3d 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -338,13 +338,5 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - skip_prefixes=[ - "rotary_emb.inv_freq", "rotary_emb.cos_cached", - "rotary_emb.sin_cached" - ], - ) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 5927afa91f49..f4ba5a8030e5 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -349,8 +349,7 @@ def load_weights(self, weights: Iterable[tuple[str, self, # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. - skip_prefixes=([ - "rotary_emb.inv_freq", "lm_head.weight" - ] if self.config.tie_word_embeddings else ["rotary_emb.inv_freq"]), + skip_prefixes=(["lm_head.weight"] + if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 5cc501622891..027cd748e9de 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -80,18 +80,30 @@ class AutoWeightsLoader: environment variable ``VLLM_LOGGING_LEVEL=DEBUG``. """ + # Models trained using early version ColossalAI + # may include these tensors in checkpoint. Skip them. + ROTARY_EMBEDS_UNUSED_WEIGHTS = [ + "rotary_emb.inv_freq", + "rotary_emb.cos_cached", + "rotary_emb.sin_cached", + ] + def __init__( self, module: nn.Module, *, skip_prefixes: Optional[list[str]] = None, + skip_substrs: Optional[list[str]] = None, ignore_unexpected_prefixes: Optional[list[str]] = None, ) -> None: super().__init__() self.module = module self.skip_prefixes = skip_prefixes or [] + self.skip_substrs = skip_substrs or [] self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or [] + # update default skip_substrs + self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS def _groupby_prefix( self, @@ -119,7 +131,8 @@ def _get_qualname(self, prefix: str, rest: str) -> str: return ".".join((prefix, rest)) def _can_skip(self, qualname: str) -> bool: - return any(qualname.startswith(p) for p in self.skip_prefixes) + return (any(qualname.startswith(p) for p in self.skip_prefixes) + or any(substr in qualname for substr in self.skip_substrs)) def _can_ignore_unexpected(self, qualname: str) -> bool: return any( @@ -257,6 +270,9 @@ def load_weights( ) -> set[str]: if mapper is not None: weights = mapper.apply(weights) + # filter out weights with first-prefix/substr to skip in name + weights = ((name, weight) for name, weight in weights + if not self._can_skip(name)) autoloaded_weights = set(self._load_module("", self.module, weights)) return autoloaded_weights