Skip to content

Commit 3921a2f

Browse files
authored
[Model] Support Pixtral models in the HF Transformers format (#9036)
1 parent 67a7e5e commit 3921a2f

File tree

7 files changed

+503
-12
lines changed

7 files changed

+503
-12
lines changed

docs/source/models/supported_models.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ Text Generation
437437
* - :code:`PixtralForConditionalGeneration`
438438
- Pixtral
439439
- T + I\ :sup:`+`
440-
- :code:`mistralai/Pixtral-12B-2409`
440+
- :code:`mistralai/Pixtral-12B-2409`, :code:`mistral-community/pixtral-12b` etc.
441441
-
442442
- ✅︎
443443
* - :code:`QWenLMHeadModel`

examples/offline_inference_vision_language.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,22 @@ def run_qwen2_vl(question: str, modality: str):
277277
return llm, prompt, stop_token_ids
278278

279279

280+
# Pixtral HF-format
281+
def run_pixtral_hf(question: str, modality: str):
282+
assert modality == "image"
283+
284+
model_name = "mistral-community/pixtral-12b"
285+
286+
llm = LLM(
287+
model=model_name,
288+
max_model_len=8192,
289+
)
290+
291+
prompt = f"<s>[INST]{question}\n[IMG][/INST]"
292+
stop_token_ids = None
293+
return llm, prompt, stop_token_ids
294+
295+
280296
# LLama 3.2
281297
def run_mllama(question: str, modality: str):
282298
assert modality == "image"
@@ -347,6 +363,7 @@ def run_glm4v(question: str, modality: str):
347363
"NVLM_D": run_nvlm_d,
348364
"qwen_vl": run_qwen_vl,
349365
"qwen2_vl": run_qwen2_vl,
366+
"pixtral_hf": run_pixtral_hf,
350367
"mllama": run_mllama,
351368
"molmo": run_molmo,
352369
"glm4v": run_glm4v,

vllm/model_executor/layers/activation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
264264
lambda: nn.ReLU(),
265265
"relu2":
266266
lambda: ReLUSquaredActivation(),
267+
"silu":
268+
lambda: nn.SiLU(),
267269
"quick_gelu":
268270
lambda: QuickGELU(),
269271
})

vllm/model_executor/models/llava.py

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import torch
66
import torch.nn as nn
77
from PIL import Image
8-
from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
8+
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
9+
SiglipVisionConfig)
910

1011
from vllm.attention import AttentionMetadata
1112
from vllm.config import CacheConfig, MultiModalConfig
@@ -22,6 +23,10 @@
2223
dummy_seq_data_for_clip, get_max_clip_image_tokens,
2324
input_processor_for_clip)
2425
from .interfaces import SupportsMultiModal, SupportsPP
26+
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
27+
dummy_seq_data_for_pixtral_hf,
28+
get_max_pixtral_hf_image_tokens,
29+
input_processor_for_pixtral_hf)
2530
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
2631
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
2732
input_processor_for_siglip)
@@ -31,8 +36,13 @@
3136

3237
class LlavaImagePixelInputs(TypedDict):
3338
type: Literal["pixel_values"]
34-
data: torch.Tensor
35-
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
39+
data: Union[torch.Tensor, List[torch.Tensor]]
40+
"""
41+
Shape: `(batch_size * num_images, num_channels, height, width)`
42+
43+
Note that `height` or `width` may be different per batch and image,
44+
in which case the data is passed as a list instead of a batched tensor.
45+
"""
3646

3747

3848
class LlavaImageEmbeddingInputs(TypedDict):
@@ -77,6 +87,8 @@ def get_max_llava_image_tokens(ctx: InputContext):
7787
num_image_tokens = get_max_clip_image_tokens(vision_config)
7888
elif isinstance(vision_config, SiglipVisionConfig):
7989
num_image_tokens = get_max_siglip_image_tokens(vision_config)
90+
elif isinstance(vision_config, PixtralVisionConfig):
91+
num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
8092
else:
8193
msg = f"Unsupported vision config: {type(vision_config)}"
8294
raise NotImplementedError(msg)
@@ -120,6 +132,17 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
120132

121133
mm_data = dummy_image_for_siglip(vision_config, num_images)
122134
return seq_data, mm_data
135+
elif isinstance(vision_config, PixtralVisionConfig):
136+
seq_data = dummy_seq_data_for_pixtral_hf(
137+
vision_config,
138+
seq_len,
139+
num_images,
140+
image_token_id=hf_config.image_token_index,
141+
image_feature_size_override=image_feature_size,
142+
)
143+
144+
mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
145+
return seq_data, mm_data
123146

124147
msg = f"Unsupported vision config: {type(vision_config)}"
125148
raise NotImplementedError(msg)
@@ -163,6 +186,15 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
163186
image_token_id=hf_config.image_token_index,
164187
image_feature_size_override=image_feature_size,
165188
)
189+
elif isinstance(vision_config, PixtralVisionConfig):
190+
# We ignore image_feature_size_override since we have non-uniform
191+
# image sizes for Pixtral
192+
return input_processor_for_pixtral_hf(
193+
model_config,
194+
vision_config,
195+
inputs,
196+
image_token_id=hf_config.image_token_index,
197+
)
166198

167199
msg = f"Unsupported vision config: {type(vision_config)}"
168200
raise NotImplementedError(msg)
@@ -189,6 +221,9 @@ def _init_vision_tower(hf_config: LlavaConfig):
189221
vision_config,
190222
num_hidden_layers_override=num_hidden_layers,
191223
)
224+
elif isinstance(vision_config, PixtralVisionConfig):
225+
# TODO: allow layer override?
226+
return PixtralHFVisionModel(vision_config)
192227

193228
msg = f"Unsupported vision config: {type(vision_config)}"
194229
raise NotImplementedError(msg)
@@ -210,6 +245,15 @@ def __init__(self,
210245
self.config = config
211246
self.multimodal_config = multimodal_config
212247

248+
# NOTE: These are special cases for Pixtral-12B in the HF-format
249+
# https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa
250+
if (config.text_config.architectures is None
251+
and config.text_config.model_type == "mistral"):
252+
config.text_config.architectures = ["MistralForCausalLM"]
253+
if (config.projector_hidden_act is None
254+
and config.vision_config.hidden_act == "gelu"):
255+
config.projector_hidden_act = "gelu"
256+
213257
# TODO: Optionally initializes this for supporting embeddings.
214258
self.vision_tower = _init_vision_tower(config)
215259
self.multi_modal_projector = LlavaMultiModalProjector(
@@ -246,6 +290,7 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
246290
def _parse_and_validate_image_input(
247291
self, **kwargs: object) -> Optional[LlavaImageInputs]:
248292
pixel_values = kwargs.pop("pixel_values", None)
293+
image_sizes = kwargs.pop("image_sizes", None)
249294
image_embeds = kwargs.pop("image_embeds", None)
250295

251296
if pixel_values is None and image_embeds is None:
@@ -256,6 +301,26 @@ def _parse_and_validate_image_input(
256301
raise ValueError("Incorrect type of pixel values. "
257302
f"Got type: {type(pixel_values)}")
258303

304+
# Case for models like PixtralHF that have dynamic image sizes
305+
# so we need to produce a list of tensors
306+
if image_sizes is not None:
307+
images = pixel_values
308+
if isinstance(images, torch.Tensor):
309+
# if passed as batch take all images
310+
NN, N, B, C, W, H = images.shape
311+
images = images.reshape(NN * N * B, C, W, H)
312+
images = [images[i] for i in range(images.size(0))]
313+
elif isinstance(images, list):
314+
# if passed as list flatten lists of tensors
315+
while isinstance(images, list) and len(images) == 1:
316+
images = images[0]
317+
318+
# TODO: Add validation based on image_sizes
319+
return LlavaImagePixelInputs(
320+
type="pixel_values",
321+
data=images,
322+
)
323+
259324
return LlavaImagePixelInputs(
260325
type="pixel_values",
261326
data=self._validate_pixel_values(
@@ -286,7 +351,8 @@ def _select_image_features(self, image_features: torch.Tensor, *,
286351

287352
def _image_pixels_to_features(
288353
self,
289-
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
354+
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
355+
PixtralHFVisionModel],
290356
pixel_values: torch.Tensor,
291357
) -> torch.Tensor:
292358

0 commit comments

Comments
 (0)