Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pixtral] Improve loading #11040

Merged
merged 2 commits into from
Dec 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 25 additions & 31 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from dataclasses import dataclass, fields
from functools import cached_property
from itertools import tee
from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union

import numpy
Expand Down Expand Up @@ -359,38 +358,33 @@ def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]):
def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]):
return weight[0].startswith("vision_language_adapter")

def is_vision_weights(weight: Tuple[str, torch.Tensor]):
return is_vision_encoder_weights(
weight) or is_vision_lang_adapter_weights(weight)

llm_weights, vision_encoder_weights, vision_lang_adapter_weights = tee(
weights, 3)

# llm
llm_weights = filter(lambda x: not is_vision_weights(x), llm_weights)
self.language_model.load_weights(llm_weights)

# vision encoder
vision_encoder_weights = filter(is_vision_encoder_weights,
vision_encoder_weights)
# Get references to parameters for direct loading
vision_encoder_dict = dict(self.vision_encoder.named_parameters())
for name, loaded_weight in vision_encoder_weights:
# cut 'vision_encoder.'
name = '.'.join(name.split(".")[1:])
param = vision_encoder_dict[name]

default_weight_loader(param, loaded_weight)

# adapter
vision_lang_adapter_weights = filter(is_vision_lang_adapter_weights,
vision_lang_adapter_weights)
vision_lang_adpter_dict = dict(
vision_lang_adapter_dict = dict(
self.vision_language_adapter.named_parameters())
for name, loaded_weight in vision_lang_adapter_weights:
# cut 'vision_language_adapter.'
name = '.'.join(name.split(".")[1:])
param = vision_lang_adpter_dict[name]
default_weight_loader(param, loaded_weight)

def llm_weights_generator():
# Single pass over weights
for name, w in weights:
if is_vision_encoder_weights((name, w)):
# Load vision encoder weights directly
trimmed_name = '.'.join(name.split(".")[1:])
param = vision_encoder_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
elif is_vision_lang_adapter_weights((name, w)):
# Load vision-language adapter weights directly
trimmed_name = '.'.join(name.split(".")[1:])
param = vision_lang_adapter_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
else:
# LLM weights: yield them to be loaded
# by language_model.load_weights
yield (name, w)

# Now we call the language model load with the generator
self.language_model.load_weights(llm_weights_generator())


# Vision encoder
Expand Down
Loading