From b9a5c7fdc46b45a30ee0ae285eeb51321f72c865 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 31 Mar 2025 10:53:14 +0000 Subject: [PATCH 1/2] Add weightloader tests for bn Signed-off-by: Alex-Brooks --- tests/models/test_utils.py | 79 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 tests/models/test_utils.py diff --git a/tests/models/test_utils.py b/tests/models/test_utils.py new file mode 100644 index 000000000000..d61c7d2d5000 --- /dev/null +++ b/tests/models/test_utils.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from vllm.model_executor.models.utils import AutoWeightsLoader + + +class ModuleWithBatchNorm(torch.nn.Module): + + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm1d(2) + + def forward(self, x): + return self.bn(x) + + +class ModuleWithNestedBatchNorm(torch.nn.Module): + + def __init__(self): + super().__init__() + self.nested_mod = ModuleWithBatchNorm() + + def forward(self, x): + return self.nested_mod(x) + + +def test_module_with_batchnorm_can_load(): + """Ensure the auto weight loader can load batchnorm stats.""" + mod = ModuleWithBatchNorm() + # Run some data through the module with batchnorm + mod(torch.Tensor([[1, 2], [3, 4]])) + + # Try to load the weights to a new instance + def weight_generator(): + yield from mod.state_dict().items() + + new_mod = ModuleWithBatchNorm() + + assert not torch.all(new_mod.bn.running_mean == mod.bn.running_mean) + assert not torch.all(new_mod.bn.running_var == mod.bn.running_var) + assert new_mod.bn.num_batches_tracked.item() == 0 + + loader = AutoWeightsLoader(new_mod) + loader.load_weights(weight_generator()) + + # Ensure the stats are updated + assert torch.all(new_mod.bn.running_mean == mod.bn.running_mean) + assert torch.all(new_mod.bn.running_var == mod.bn.running_var) + assert new_mod.bn.num_batches_tracked.item() == 1 + + +def test_module_with_child_containing_batchnorm_can_autoload(): + """Ensure the auto weight loader can load nested modules batchnorm stats.""" + mod = ModuleWithNestedBatchNorm() + # Run some data through the module with batchnorm + mod(torch.Tensor([[1, 2], [3, 4]])) + + # Try to load the weights to a new instance + def weight_generator(): + yield from mod.state_dict().items() + + new_mod = ModuleWithNestedBatchNorm() + + assert not torch.all( + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + assert not torch.all( + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0 + + loader = AutoWeightsLoader(new_mod) + loader.load_weights(weight_generator()) + + # Ensure the stats are updated + assert torch.all( + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + assert torch.all( + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1 From f8523b5801b078d034c2bc294600462476ca6f1f Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 31 Mar 2025 10:55:13 +0000 Subject: [PATCH 2/2] Fix batchnorm stat loading Signed-off-by: Alex-Brooks --- vllm/model_executor/models/utils.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 1e3d78c7f6fd..0fce807d4d25 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -156,6 +156,26 @@ def _load_param( yield weight_qualname + def _add_loadable_non_param_tensors(self, module: nn.Module, + child_params: Dict[str, torch.Tensor]): + """ + Add tensor names that are not in the model params that may be in the + safetensors, e.g., batch normalization stats. + """ + if isinstance(module, ( + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.LazyBatchNorm1d, + nn.LazyBatchNorm2d, + nn.LazyBatchNorm3d, + nn.SyncBatchNorm, + )): + module_state_dict = module.state_dict() + for stat_name in ("running_mean", "running_var", + "num_batches_tracked"): + child_params[stat_name] = module_state_dict[stat_name] + def _load_module( self, base_prefix: str, @@ -184,6 +204,10 @@ def _load_module( child_modules = dict(module.named_children()) child_params = dict(module.named_parameters(recurse=False)) + # Add missing tensors the weight loader needs to be able to load + # that aren't registered as params, e.g., batchnorm statistics. + self._add_loadable_non_param_tensors(module, child_params) + for child_prefix, child_weights in self._groupby_prefix(weights): prefix = self._get_qualname(base_prefix, child_prefix)