Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions tests/models/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# SPDX-License-Identifier: Apache-2.0

import torch

from vllm.model_executor.models.utils import AutoWeightsLoader


class ModuleWithBatchNorm(torch.nn.Module):

def __init__(self):
super().__init__()
self.bn = torch.nn.BatchNorm1d(2)

def forward(self, x):
return self.bn(x)


class ModuleWithNestedBatchNorm(torch.nn.Module):

def __init__(self):
super().__init__()
self.nested_mod = ModuleWithBatchNorm()

def forward(self, x):
return self.nested_mod(x)


def test_module_with_batchnorm_can_load():
"""Ensure the auto weight loader can load batchnorm stats."""
mod = ModuleWithBatchNorm()
# Run some data through the module with batchnorm
mod(torch.Tensor([[1, 2], [3, 4]]))

# Try to load the weights to a new instance
def weight_generator():
yield from mod.state_dict().items()

new_mod = ModuleWithBatchNorm()

assert not torch.all(new_mod.bn.running_mean == mod.bn.running_mean)
assert not torch.all(new_mod.bn.running_var == mod.bn.running_var)
assert new_mod.bn.num_batches_tracked.item() == 0

loader = AutoWeightsLoader(new_mod)
loader.load_weights(weight_generator())

# Ensure the stats are updated
assert torch.all(new_mod.bn.running_mean == mod.bn.running_mean)
assert torch.all(new_mod.bn.running_var == mod.bn.running_var)
assert new_mod.bn.num_batches_tracked.item() == 1


def test_module_with_child_containing_batchnorm_can_autoload():
"""Ensure the auto weight loader can load nested modules batchnorm stats."""
mod = ModuleWithNestedBatchNorm()
# Run some data through the module with batchnorm
mod(torch.Tensor([[1, 2], [3, 4]]))

# Try to load the weights to a new instance
def weight_generator():
yield from mod.state_dict().items()

new_mod = ModuleWithNestedBatchNorm()

assert not torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
assert not torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0

loader = AutoWeightsLoader(new_mod)
loader.load_weights(weight_generator())

# Ensure the stats are updated
assert torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
assert torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
24 changes: 24 additions & 0 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,26 @@ def _load_param(

yield weight_qualname

def _add_loadable_non_param_tensors(self, module: nn.Module,
child_params: Dict[str, torch.Tensor]):
"""
Add tensor names that are not in the model params that may be in the
safetensors, e.g., batch normalization stats.
"""
if isinstance(module, (
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.LazyBatchNorm1d,
nn.LazyBatchNorm2d,
nn.LazyBatchNorm3d,
nn.SyncBatchNorm,
)):
module_state_dict = module.state_dict()
for stat_name in ("running_mean", "running_var",
"num_batches_tracked"):
child_params[stat_name] = module_state_dict[stat_name]

def _load_module(
self,
base_prefix: str,
Expand Down Expand Up @@ -184,6 +204,10 @@ def _load_module(
child_modules = dict(module.named_children())
child_params = dict(module.named_parameters(recurse=False))

# Add missing tensors the weight loader needs to be able to load
# that aren't registered as params, e.g., batchnorm statistics.
self._add_loadable_non_param_tensors(module, child_params)

for child_prefix, child_weights in self._groupby_prefix(weights):
prefix = self._get_qualname(base_prefix, child_prefix)

Expand Down