Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Refactor composite weight loading logic #8656

Merged
merged 9 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import itertools
import re
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
Expand Down Expand Up @@ -33,8 +32,8 @@
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches)
from .interfaces import SupportsMultiModal
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)
from .utils import (flatten_bn, group_weights_with_prefix,
init_vllm_registered_model, merge_multimodal_embeddings)

IMG_START = '<img>'
IMG_END = '</img>'
Expand Down Expand Up @@ -518,21 +517,18 @@ def sample(

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
weights_group = group_weights_with_prefix(weights)

# load vision encoder
vit_weights = filter_weights(vit_weights, "vision_model")
self.vision_model.load_weights(vit_weights)
self.vision_model.load_weights(weights_group["vision_model"])

# load mlp projector
mlp_weights = filter_weights(mlp_weights, "mlp1")
mlp_params_dict = dict(self.mlp1.named_parameters())
for name, loaded_weight in mlp_weights:
for name, loaded_weight in weights_group["mlp1"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
self.language_model.load_weights(weights_group["language_model"])
16 changes: 6 additions & 10 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)

Expand Down Expand Up @@ -26,8 +25,8 @@
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip)
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)
from .utils import (flatten_bn, group_weights_with_prefix,
init_vllm_registered_model, merge_multimodal_embeddings)


class LlavaImagePixelInputs(TypedDict):
Expand Down Expand Up @@ -393,21 +392,18 @@ def sample(

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
weights_group = group_weights_with_prefix(weights)

# load vision encoder
vit_weights = filter_weights(vit_weights, "vision_tower")
self.vision_tower.load_weights(vit_weights)
self.vision_tower.load_weights(weights_group["vision_tower"])

# load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights:
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
self.language_model.load_weights(weights_group["language_model"])
20 changes: 7 additions & 13 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)

Expand Down Expand Up @@ -30,8 +29,8 @@
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)
from .utils import (flatten_bn, group_weights_with_prefix,
init_vllm_registered_model, merge_multimodal_embeddings)

logger = init_logger(__name__)

Expand Down Expand Up @@ -637,31 +636,26 @@ def sample(

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
weights, 4)
weights_group = group_weights_with_prefix(weights)

# load vision encoder
vit_weights = filter_weights(vit_weights, "vision_tower")
self.vision_tower.load_weights(vit_weights)
self.vision_tower.load_weights(weights_group["vision_tower"])

# load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights:
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# load newline
newline_weights = filter_weights(newline_weights, "image_newline")
for name, loaded_weight in newline_weights:
for name, loaded_weight in weights_group["image_newline"]:
assert name == ""
param = self.image_newline
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
self.language_model.load_weights(weights_group["language_model"])
17 changes: 6 additions & 11 deletions vllm/model_executor/models/llava_next_video.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
import math
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
Expand Down Expand Up @@ -30,7 +29,7 @@
from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model,
from .utils import (group_weights_with_prefix, init_vllm_registered_model,
merge_multimodal_embeddings)

logger = init_logger(__name__)
Expand Down Expand Up @@ -449,23 +448,19 @@ def sample(
return self.language_model.sample(logits, sampling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators
vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
weights, 4)
# prepare weight iterators for components
weights_group = group_weights_with_prefix(weights)

# load vision encoder
vit_weights = filter_weights(vit_weights, "vision_tower")
self.vision_tower.load_weights(vit_weights)
self.vision_tower.load_weights(weights_group["vision_tower"])

# load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights:
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
self.language_model.load_weights(weights_group["language_model"])
14 changes: 5 additions & 9 deletions vllm/model_executor/models/paligemma.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)

Expand All @@ -23,7 +22,7 @@
from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import filter_weights, merge_multimodal_embeddings
from .utils import group_weights_with_prefix, merge_multimodal_embeddings

logger = init_logger(__name__)

Expand Down Expand Up @@ -286,21 +285,18 @@ def sample(

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
weights_group = group_weights_with_prefix(weights)

# load vision tower
vit_weights = filter_weights(vit_weights, "vision_tower")
self.vision_tower.load_weights(vit_weights)
self.vision_tower.load_weights(weights_group["vision_tower"])

# load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights:
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
self.language_model.load_weights(weights_group["language_model"])
12 changes: 5 additions & 7 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""

import itertools
import math
from array import array
from functools import lru_cache
Expand Down Expand Up @@ -29,7 +28,8 @@
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.utils import (filter_weights, flatten_bn,
from vllm.model_executor.models.utils import (flatten_bn,
group_weights_with_prefix,
init_vllm_registered_model,
merge_multimodal_embeddings)
from vllm.model_executor.sampling_metadata import SamplingMetadata
Expand Down Expand Up @@ -467,11 +467,10 @@ def sample(

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
projector_weights, llm_weights = itertools.tee(weights, 2)
weights_group = group_weights_with_prefix(weights)

# load projector weights
projector_weights = filter_weights(projector_weights,
"multi_modal_projector")
projector_weights = weights_group["multi_modal_projector"]
projector_params_dict = dict(
self.multi_modal_projector.named_parameters())
for name, loaded_weight in projector_weights:
Expand All @@ -481,5 +480,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader(param, loaded_weight)

# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
self.language_model.load_weights(weights_group["language_model"])
36 changes: 35 additions & 1 deletion vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import itertools
from collections import UserDict
from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
Union, overload)

Expand All @@ -16,7 +18,23 @@
from vllm.utils import is_pin_memory_available


def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
class WeightsGroup(UserDict):
"""
Wraps grouped weights dictionary for a more informative error message
when attempting to access a weight component that does not exist.
"""

def __getitem__(self, key: str) -> int:
try:
return super().__getitem__(key)
except KeyError as exc:
msg = (f"There is no weights named with the prefix: {key}. "
f"Available prefix: {set(self.keys())}")
raise KeyError(msg) from exc


def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]],
prefix: str) -> Iterable[Tuple[str, torch.Tensor]]:
"""
Helper function to load weights for inner vLLM models.

Expand All @@ -30,6 +48,22 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
yield name, loaded_weight


def group_weights_with_prefix(
weights: Iterable[Tuple[str, torch.Tensor]]
) -> Dict[str, Iterable[Tuple[str, torch.Tensor]]]:
"""
Helper function to group weights with prefix
"""
init_weights, repeated_weights = itertools.tee(weights, 2)
weights_prefix = {name.split(".")[0] for name, _ in init_weights}
repeated_weights = itertools.tee(repeated_weights, len(weights_prefix))

return WeightsGroup({
prefix: filter_weights(component, prefix)
for component, prefix in zip(repeated_weights, weights_prefix)
})


def init_vllm_registered_model(
hf_config: PretrainedConfig,
cache_config: Optional[CacheConfig],
Expand Down
Loading