Skip to content

Commit 61c6a5a

Browse files
[VLM] Merged multi-modal processor for Pixtral (#12211)
Signed-off-by: remi <remi@mistral.ai> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 74bc397 commit 61c6a5a

File tree

9 files changed

+622
-360
lines changed

9 files changed

+622
-360
lines changed

examples/offline_inference/pixtral.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,18 @@
4343
# python demo.py advanced
4444

4545

46-
def run_simple_demo():
46+
def run_simple_demo(args: argparse.Namespace):
4747
model_name = "mistralai/Pixtral-12B-2409"
4848
sampling_params = SamplingParams(max_tokens=8192)
4949

50-
# Lower max_num_seqs or max_model_len on low-VRAM GPUs.
51-
llm = LLM(model=model_name, tokenizer_mode="mistral")
50+
# Lower max_model_len and/or max_num_seqs on low-VRAM GPUs.
51+
llm = LLM(
52+
model=model_name,
53+
tokenizer_mode="mistral",
54+
max_model_len=4096,
55+
max_num_seqs=2,
56+
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
57+
)
5258

5359
prompt = "Describe this image in one sentence."
5460
image_url = "https://picsum.photos/id/237/200/300"
@@ -76,7 +82,7 @@ def run_simple_demo():
7682
print(outputs[0].outputs[0].text)
7783

7884

79-
def run_advanced_demo():
85+
def run_advanced_demo(args: argparse.Namespace):
8086
model_name = "mistralai/Pixtral-12B-2409"
8187
max_img_per_msg = 5
8288
max_tokens_per_img = 4096
@@ -87,6 +93,7 @@ def run_advanced_demo():
8793
tokenizer_mode="mistral",
8894
limit_mm_per_prompt={"image": max_img_per_msg},
8995
max_model_len=max_img_per_msg * max_tokens_per_img,
96+
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
9097
)
9198

9299
prompt = "Describe the following image."
@@ -153,14 +160,19 @@ def main():
153160
help="Specify the demo mode: 'simple' or 'advanced'",
154161
)
155162

163+
parser.add_argument(
164+
'--disable-mm-preprocessor-cache',
165+
action='store_true',
166+
help='If True, disables caching of multi-modal preprocessor/mapper.')
167+
156168
args = parser.parse_args()
157169

158170
if args.mode == "simple":
159171
print("Running simple demo...")
160-
run_simple_demo()
172+
run_simple_demo(args)
161173
elif args.mode == "advanced":
162174
print("Running advanced demo...")
163-
run_advanced_demo()
175+
run_advanced_demo(args)
164176

165177

166178
if __name__ == "__main__":

tests/models/multimodal/processing/test_common.py

Lines changed: 149 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,23 @@
22

33
import copy
44
from functools import partial
5-
from typing import Optional
5+
from typing import Optional, Union
66

77
import numpy as np
88
import pytest
9+
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
10+
UserMessage)
11+
from mistral_common.protocol.instruct.request import ChatCompletionRequest
912
from PIL import Image
13+
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
1014

1115
from vllm.config import ModelConfig
1216
from vllm.inputs import InputProcessingContext
13-
from vllm.multimodal import MULTIMODAL_REGISTRY
14-
from vllm.multimodal.processing import ProcessingCache
15-
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
17+
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
18+
from vllm.multimodal.inputs import MultiModalInputs
19+
from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache
20+
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
21+
cached_tokenizer_from_config)
1622

1723
from ....multimodal.utils import random_audio, random_image, random_video
1824
from ...registry import HF_EXAMPLE_MODELS
@@ -85,14 +91,6 @@ def _test_processing_correctness(
8591
partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
8692
}
8793

88-
tokenizer_encode_kwargs = {}
89-
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
90-
# For some multimodal models, tokenizer will always add bos_token
91-
# at the beginning of prompt by default, causing hf_processor outputs
92-
# incorrect token ids. So we need use `add_special_tokens=False` here
93-
# to leave bos_token to be added by the processor.
94-
tokenizer_encode_kwargs = {"add_special_tokens": False}
95-
9694
for batch_idx in range(num_batches):
9795
mm_data = {
9896
k:
@@ -115,43 +113,131 @@ def _test_processing_correctness(
115113
elif len(mm_data[k]) == 1:
116114
mm_data[k] = mm_data[k][0]
117115

118-
baseline_result = baseline_processor.apply(
119-
prompt,
120-
mm_data=mm_data,
121-
hf_processor_mm_kwargs={},
122-
)
123-
cached_result = cached_processor.apply(
124-
prompt,
125-
mm_data=mm_data,
126-
hf_processor_mm_kwargs={},
127-
)
128-
129-
assert _drop_mm_kwargs_keys(
130-
baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
131-
cached_result, ignore_mm_keys), (
132-
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
133-
134-
baseline_tokenized_result = baseline_processor.apply(
135-
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
136-
mm_data=mm_data,
137-
hf_processor_mm_kwargs={},
138-
)
139-
140-
assert _drop_mm_kwargs_keys(
141-
baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
142-
baseline_tokenized_result, ignore_mm_keys), (
143-
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
144-
145-
cached_tokenized_result = cached_processor.apply(
146-
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
147-
mm_data=mm_data,
148-
hf_processor_mm_kwargs={},
149-
)
150-
151-
assert _drop_mm_kwargs_keys(
152-
cached_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
153-
cached_tokenized_result, ignore_mm_keys), (
154-
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
116+
if isinstance(tokenizer, MistralTokenizer):
117+
_test_processing_correctness_mistral(
118+
model_config,
119+
tokenizer,
120+
prompt,
121+
mm_data,
122+
baseline_processor,
123+
cached_processor,
124+
batch_idx,
125+
ignore_mm_keys=ignore_mm_keys,
126+
)
127+
else:
128+
_test_processing_correctness_hf(
129+
model_config,
130+
tokenizer,
131+
prompt,
132+
mm_data,
133+
baseline_processor,
134+
cached_processor,
135+
batch_idx,
136+
ignore_mm_keys=ignore_mm_keys,
137+
)
138+
139+
140+
def _test_processing_correctness_hf(
141+
model_config: ModelConfig,
142+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
143+
prompt: str,
144+
mm_data: MultiModalDataDict,
145+
baseline_processor: BaseMultiModalProcessor,
146+
cached_processor: BaseMultiModalProcessor,
147+
batch_idx: int,
148+
ignore_mm_keys: Optional[list[str]] = None,
149+
):
150+
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
151+
# For some multimodal models, tokenizer will always add bos_token
152+
# at the beginning of prompt by default, causing hf_processor outputs
153+
# incorrect token ids. So we need use `add_special_tokens=False` here
154+
# to leave bos_token to be added by the processor.
155+
token_prompt = tokenizer.encode(prompt, add_special_tokens=False)
156+
else:
157+
token_prompt = tokenizer.encode(prompt)
158+
159+
baseline_result = baseline_processor.apply(
160+
prompt,
161+
mm_data=mm_data,
162+
hf_processor_mm_kwargs={},
163+
)
164+
cached_result = cached_processor.apply(
165+
prompt,
166+
mm_data=mm_data,
167+
hf_processor_mm_kwargs={},
168+
)
169+
170+
assert _inputs_equal(
171+
baseline_result,
172+
cached_result,
173+
ignore_mm_keys,
174+
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
175+
176+
baseline_tokenized_result = baseline_processor.apply(
177+
token_prompt,
178+
mm_data=mm_data,
179+
hf_processor_mm_kwargs={},
180+
)
181+
182+
assert _inputs_equal(
183+
baseline_result,
184+
baseline_tokenized_result,
185+
ignore_mm_keys,
186+
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
187+
188+
cached_tokenized_result = cached_processor.apply(
189+
token_prompt,
190+
mm_data=mm_data,
191+
hf_processor_mm_kwargs={},
192+
)
193+
194+
assert _inputs_equal(
195+
cached_result,
196+
cached_tokenized_result,
197+
ignore_mm_keys,
198+
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
199+
200+
201+
def _test_processing_correctness_mistral(
202+
model_config: ModelConfig,
203+
tokenizer: MistralTokenizer,
204+
prompt: str,
205+
mm_data: MultiModalDataDict,
206+
baseline_processor: BaseMultiModalProcessor,
207+
cached_processor: BaseMultiModalProcessor,
208+
batch_idx: int,
209+
ignore_mm_keys: Optional[list[str]] = None,
210+
):
211+
images = mm_data.get("image", [])
212+
if not isinstance(images, list):
213+
images = [images]
214+
215+
request = ChatCompletionRequest(messages=[
216+
UserMessage(content=[
217+
TextChunk(text=prompt),
218+
*(ImageChunk(image=image) for image in images),
219+
]),
220+
])
221+
res = tokenizer.mistral.encode_chat_completion(request)
222+
token_prompt = res.tokens
223+
224+
# Mistral chat outputs tokens directly, rather than text prompts
225+
baseline_tokenized_result = baseline_processor.apply(
226+
token_prompt,
227+
mm_data=mm_data,
228+
hf_processor_mm_kwargs={},
229+
)
230+
cached_tokenized_result = cached_processor.apply(
231+
token_prompt,
232+
mm_data=mm_data,
233+
hf_processor_mm_kwargs={},
234+
)
235+
236+
assert _inputs_equal(
237+
baseline_tokenized_result,
238+
cached_tokenized_result,
239+
ignore_mm_keys,
240+
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
155241

156242

157243
# yapf: disable
@@ -173,6 +259,7 @@ def _test_processing_correctness(
173259
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
174260
"meta-llama/Llama-3.2-11B-Vision-Instruct",
175261
"TIGER-Lab/Mantis-8B-siglip-llama3",
262+
"mistralai/Pixtral-12B-2409",
176263
"mistral-community/pixtral-12b",
177264
"openbmb/MiniCPM-o-2_6",
178265
"openbmb/MiniCPM-V-2_6",
@@ -241,8 +328,19 @@ def test_processing_correctness_phi3v(
241328
)
242329

243330

244-
def _drop_mm_kwargs_keys(result: dict,
245-
ignore_mm_keys: Optional[list[str]] = None) -> dict:
331+
def _inputs_equal(
332+
a: MultiModalInputs,
333+
b: MultiModalInputs,
334+
ignore_mm_keys: Optional[list[str]] = None,
335+
):
336+
return _drop_mm_kwargs_keys(a, ignore_mm_keys) == _drop_mm_kwargs_keys(
337+
b, ignore_mm_keys)
338+
339+
340+
def _drop_mm_kwargs_keys(
341+
result: MultiModalInputs,
342+
ignore_mm_keys: Optional[list[str]] = None,
343+
) -> MultiModalInputs:
246344
"""Drop specified keys from result['mm_kwargs'].
247345
248346
This is mainly to avoid doing exact match of audio_features in ultravox.

0 commit comments

Comments
 (0)