From b4f050f118c6b6037b26cf155fbbba99794e75b6 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 20 May 2025 11:20:12 +0800 Subject: [PATCH] [Misc] Allow `AutoWeightsLoader` to skip loading weights with specific substr in name (#18358) Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: kwisniewski98 --- tests/models/test_utils.py | 70 ++++++++++++++++++++++++ vllm/model_executor/models/granite.py | 20 +++---- vllm/model_executor/models/grok1.py | 14 +++-- vllm/model_executor/models/olmoe.py | 9 +-- vllm/model_executor/models/orion.py | 15 +---- vllm/model_executor/models/phi4mm.py | 4 +- vllm/model_executor/models/phimoe.py | 9 +-- vllm/model_executor/models/qwen2_moe.py | 9 +-- vllm/model_executor/models/qwen3_moe.py | 9 +-- vllm/model_executor/models/stablelm.py | 14 +---- vllm/model_executor/models/starcoder2.py | 5 +- vllm/model_executor/models/utils.py | 22 +++++++- 12 files changed, 126 insertions(+), 74 deletions(-) diff --git a/tests/models/test_utils.py b/tests/models/test_utils.py index d61c7d2d5000..a16384efe195 100644 --- a/tests/models/test_utils.py +++ b/tests/models/test_utils.py @@ -77,3 +77,73 @@ def weight_generator(): assert torch.all( new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1 + + +def test_module_skip_prefix(): + """Ensure the auto weight loader can skip prefix.""" + mod = ModuleWithNestedBatchNorm() + # Run some data through the module with batchnorm + mod(torch.Tensor([[1, 2], [3, 4]])) + + # Try to load the weights to a new instance + def weight_generator(): + # weights needed to be filtered out + redundant_weights = { + "prefix.bn.weight": torch.Tensor([1, 2]), + "prefix.bn.bias": torch.Tensor([3, 4]), + } + yield from (mod.state_dict() | redundant_weights).items() + + new_mod = ModuleWithNestedBatchNorm() + + assert not torch.all( + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + assert not torch.all( + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0 + + loader = AutoWeightsLoader(new_mod, skip_prefixes=["prefix."]) + loader.load_weights(weight_generator()) + + # Ensure the stats are updated + assert torch.all( + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + assert torch.all( + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1 + + +def test_module_skip_substr(): + """Ensure the auto weight loader can skip prefix.""" + mod = ModuleWithNestedBatchNorm() + # Run some data through the module with batchnorm + mod(torch.Tensor([[1, 2], [3, 4]])) + + # Try to load the weights to a new instance + def weight_generator(): + # weights needed to be filtered out + redundant_weights = { + "nested_mod.0.substr.weight": torch.Tensor([1, 2]), + "nested_mod.0.substr.bias": torch.Tensor([3, 4]), + "nested_mod.substr.weight": torch.Tensor([1, 2]), + "nested_mod.substr.bias": torch.Tensor([3, 4]), + } + yield from (mod.state_dict() | redundant_weights).items() + + new_mod = ModuleWithNestedBatchNorm() + + assert not torch.all( + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + assert not torch.all( + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0 + + loader = AutoWeightsLoader(new_mod, skip_substrs=["substr."]) + loader.load_weights(weight_generator()) + + # Ensure the stats are updated + assert torch.all( + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + assert torch.all( + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1 diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 1ec6ffa66215..6625afa99e00 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -509,20 +509,16 @@ def make_empty_intermediate_tensors( device=device), }) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - skip_prefixes = [ - "rotary_emb.inv_freq", - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - "rotary_emb.cos_cached", - "rotary_emb.sin_cached", - ] + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: # With tie_word_embeddings, we can skip lm_head.weight # The weight might appear unnecessarily in the files if the model is # processed with quantization, LoRA, fine-tuning, etc. - if self.config.tie_word_embeddings: - skip_prefixes.append("lm_head.weight") + skip_prefixes = (["lm_head."] + if self.config.tie_word_embeddings else None) - loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + loader = AutoWeightsLoader( + self, + skip_prefixes=skip_prefixes, + ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index c48cb157084d..bcff09109c25 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -549,12 +549,14 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - skip_prefixes = ["rotary_emb.inv_freq"] + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: # Skip lm_head when tie_word_embeddings is True - if self.config.tie_word_embeddings: - skip_prefixes.append("lm_head") + skip_prefixes = (["lm_head"] + if self.config.tie_word_embeddings else None) - loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + loader = AutoWeightsLoader( + self, + skip_prefixes=skip_prefixes, + ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index e6925e125690..c11df60e5441 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -439,10 +439,7 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=["rotary_emb.inv_freq"], - ) + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 8d9c000750d7..3c2e85e6c608 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -341,16 +341,7 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=([ - "rotary_emb.inv_freq", - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - "rotary_emb.cos_cached", - "rotary_emb.sin_cached" - ]), - ) + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 6035994f4336..55c36100c723 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -1244,9 +1244,7 @@ def compute_logits( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None: - weights = ((name, data) for name, data in weights - if "lora" not in name) - loader = AutoWeightsLoader(self) + loader = AutoWeightsLoader(self, skip_substrs=["lora"]) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 2dc55e4c352e..d46a9ba77b13 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -657,10 +657,7 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["rotary_emb.inv_freq"]), - ) + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 47d90919ed8f..17bb6bd3ed7d 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -521,10 +521,7 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["rotary_emb.inv_freq"]), - ) + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 97acbaa2ac34..706abcb8477e 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -518,10 +518,7 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["rotary_emb.inv_freq"]), - ) + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 1cbda7267e4c..aadce04f647b 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -335,15 +335,7 @@ def compute_logits( sampling_metadata) return logits - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - loader = AutoWeightsLoader( - self, - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - skip_prefixes=[ - "rotary_emb.inv_freq", "rotary_emb.cos_cached", - "rotary_emb.sin_cached" - ], - ) + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 6eebe4c4d614..1c2a1abc1101 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -348,8 +348,7 @@ def load_weights(self, weights: Iterable[Tuple[str, self, # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. - skip_prefixes=([ - "rotary_emb.inv_freq", "lm_head.weight" - ] if self.config.tie_word_embeddings else ["rotary_emb.inv_freq"]), + skip_prefixes=(["lm_head.weight"] + if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 0e97e93a3e0f..5ff61767dbd6 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -82,18 +82,30 @@ class AutoWeightsLoader: environment variable ``VLLM_LOGGING_LEVEL=DEBUG``. """ + # Models trained using early version ColossalAI + # may include these tensors in checkpoint. Skip them. + ROTARY_EMBEDS_UNUSED_WEIGHTS = [ + "rotary_emb.inv_freq", + "rotary_emb.cos_cached", + "rotary_emb.sin_cached", + ] + def __init__( self, module: nn.Module, *, - skip_prefixes: Optional[List[str]] = None, - ignore_unexpected_prefixes: Optional[List[str]] = None, + skip_prefixes: Optional[list[str]] = None, + skip_substrs: Optional[list[str]] = None, + ignore_unexpected_prefixes: Optional[list[str]] = None, ) -> None: super().__init__() self.module = module self.skip_prefixes = skip_prefixes or [] + self.skip_substrs = skip_substrs or [] self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or [] + # update default skip_substrs + self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS def _groupby_prefix( self, @@ -121,7 +133,8 @@ def _get_qualname(self, prefix: str, rest: str) -> str: return ".".join((prefix, rest)) def _can_skip(self, qualname: str) -> bool: - return any(qualname.startswith(p) for p in self.skip_prefixes) + return (any(qualname.startswith(p) for p in self.skip_prefixes) + or any(substr in qualname for substr in self.skip_substrs)) def _can_ignore_unexpected(self, qualname: str) -> bool: return any( @@ -259,6 +272,9 @@ def load_weights( ) -> Set[str]: if mapper is not None: weights = mapper.apply(weights) + # filter out weights with first-prefix/substr to skip in name + weights = ((name, weight) for name, weight in weights + if not self._can_skip(name)) autoloaded_weights = set(self._load_module("", self.module, weights)) return autoloaded_weights