22
33import copy
44from functools import partial
5- from typing import Optional
5+ from typing import Optional , Union
66
77import numpy as np
88import pytest
9+ from mistral_common .protocol .instruct .messages import (ImageChunk , TextChunk ,
10+ UserMessage )
11+ from mistral_common .protocol .instruct .request import ChatCompletionRequest
912from PIL import Image
13+ from transformers import PreTrainedTokenizer , PreTrainedTokenizerFast
1014
1115from vllm .config import ModelConfig
1216from vllm .inputs import InputProcessingContext
13- from vllm .multimodal import MULTIMODAL_REGISTRY
14- from vllm .multimodal .processing import ProcessingCache
15- from vllm .transformers_utils .tokenizer import cached_tokenizer_from_config
17+ from vllm .multimodal import MULTIMODAL_REGISTRY , MultiModalDataDict
18+ from vllm .multimodal .inputs import MultiModalInputs
19+ from vllm .multimodal .processing import BaseMultiModalProcessor , ProcessingCache
20+ from vllm .transformers_utils .tokenizer import (MistralTokenizer ,
21+ cached_tokenizer_from_config )
1622
1723from ....multimodal .utils import random_audio , random_image , random_video
1824from ...registry import HF_EXAMPLE_MODELS
@@ -85,14 +91,6 @@ def _test_processing_correctness(
8591 partial (random_audio , rng , min_len = 512 , max_len = 1024 , sr = 16000 ),
8692 }
8793
88- tokenizer_encode_kwargs = {}
89- if model_config .hf_config .model_type in ("mllama" , "whisper" , "ultravox" ):
90- # For some multimodal models, tokenizer will always add bos_token
91- # at the beginning of prompt by default, causing hf_processor outputs
92- # incorrect token ids. So we need use `add_special_tokens=False` here
93- # to leave bos_token to be added by the processor.
94- tokenizer_encode_kwargs = {"add_special_tokens" : False }
95-
9694 for batch_idx in range (num_batches ):
9795 mm_data = {
9896 k :
@@ -115,43 +113,131 @@ def _test_processing_correctness(
115113 elif len (mm_data [k ]) == 1 :
116114 mm_data [k ] = mm_data [k ][0 ]
117115
118- baseline_result = baseline_processor .apply (
119- prompt ,
120- mm_data = mm_data ,
121- hf_processor_mm_kwargs = {},
122- )
123- cached_result = cached_processor .apply (
124- prompt ,
125- mm_data = mm_data ,
126- hf_processor_mm_kwargs = {},
127- )
128-
129- assert _drop_mm_kwargs_keys (
130- baseline_result , ignore_mm_keys ) == _drop_mm_kwargs_keys (
131- cached_result , ignore_mm_keys ), (
132- f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" )
133-
134- baseline_tokenized_result = baseline_processor .apply (
135- tokenizer .encode (prompt , ** tokenizer_encode_kwargs ),
136- mm_data = mm_data ,
137- hf_processor_mm_kwargs = {},
138- )
139-
140- assert _drop_mm_kwargs_keys (
141- baseline_result , ignore_mm_keys ) == _drop_mm_kwargs_keys (
142- baseline_tokenized_result , ignore_mm_keys ), (
143- f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" )
144-
145- cached_tokenized_result = cached_processor .apply (
146- tokenizer .encode (prompt , ** tokenizer_encode_kwargs ),
147- mm_data = mm_data ,
148- hf_processor_mm_kwargs = {},
149- )
150-
151- assert _drop_mm_kwargs_keys (
152- cached_result , ignore_mm_keys ) == _drop_mm_kwargs_keys (
153- cached_tokenized_result , ignore_mm_keys ), (
154- f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" )
116+ if isinstance (tokenizer , MistralTokenizer ):
117+ _test_processing_correctness_mistral (
118+ model_config ,
119+ tokenizer ,
120+ prompt ,
121+ mm_data ,
122+ baseline_processor ,
123+ cached_processor ,
124+ batch_idx ,
125+ ignore_mm_keys = ignore_mm_keys ,
126+ )
127+ else :
128+ _test_processing_correctness_hf (
129+ model_config ,
130+ tokenizer ,
131+ prompt ,
132+ mm_data ,
133+ baseline_processor ,
134+ cached_processor ,
135+ batch_idx ,
136+ ignore_mm_keys = ignore_mm_keys ,
137+ )
138+
139+
140+ def _test_processing_correctness_hf (
141+ model_config : ModelConfig ,
142+ tokenizer : Union [PreTrainedTokenizer , PreTrainedTokenizerFast ],
143+ prompt : str ,
144+ mm_data : MultiModalDataDict ,
145+ baseline_processor : BaseMultiModalProcessor ,
146+ cached_processor : BaseMultiModalProcessor ,
147+ batch_idx : int ,
148+ ignore_mm_keys : Optional [list [str ]] = None ,
149+ ):
150+ if model_config .hf_config .model_type in ("mllama" , "whisper" , "ultravox" ):
151+ # For some multimodal models, tokenizer will always add bos_token
152+ # at the beginning of prompt by default, causing hf_processor outputs
153+ # incorrect token ids. So we need use `add_special_tokens=False` here
154+ # to leave bos_token to be added by the processor.
155+ token_prompt = tokenizer .encode (prompt , add_special_tokens = False )
156+ else :
157+ token_prompt = tokenizer .encode (prompt )
158+
159+ baseline_result = baseline_processor .apply (
160+ prompt ,
161+ mm_data = mm_data ,
162+ hf_processor_mm_kwargs = {},
163+ )
164+ cached_result = cached_processor .apply (
165+ prompt ,
166+ mm_data = mm_data ,
167+ hf_processor_mm_kwargs = {},
168+ )
169+
170+ assert _inputs_equal (
171+ baseline_result ,
172+ cached_result ,
173+ ignore_mm_keys ,
174+ ), f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )"
175+
176+ baseline_tokenized_result = baseline_processor .apply (
177+ token_prompt ,
178+ mm_data = mm_data ,
179+ hf_processor_mm_kwargs = {},
180+ )
181+
182+ assert _inputs_equal (
183+ baseline_result ,
184+ baseline_tokenized_result ,
185+ ignore_mm_keys ,
186+ ), f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )"
187+
188+ cached_tokenized_result = cached_processor .apply (
189+ token_prompt ,
190+ mm_data = mm_data ,
191+ hf_processor_mm_kwargs = {},
192+ )
193+
194+ assert _inputs_equal (
195+ cached_result ,
196+ cached_tokenized_result ,
197+ ignore_mm_keys ,
198+ ), f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )"
199+
200+
201+ def _test_processing_correctness_mistral (
202+ model_config : ModelConfig ,
203+ tokenizer : MistralTokenizer ,
204+ prompt : str ,
205+ mm_data : MultiModalDataDict ,
206+ baseline_processor : BaseMultiModalProcessor ,
207+ cached_processor : BaseMultiModalProcessor ,
208+ batch_idx : int ,
209+ ignore_mm_keys : Optional [list [str ]] = None ,
210+ ):
211+ images = mm_data .get ("image" , [])
212+ if not isinstance (images , list ):
213+ images = [images ]
214+
215+ request = ChatCompletionRequest (messages = [
216+ UserMessage (content = [
217+ TextChunk (text = prompt ),
218+ * (ImageChunk (image = image ) for image in images ),
219+ ]),
220+ ])
221+ res = tokenizer .mistral .encode_chat_completion (request )
222+ token_prompt = res .tokens
223+
224+ # Mistral chat outputs tokens directly, rather than text prompts
225+ baseline_tokenized_result = baseline_processor .apply (
226+ token_prompt ,
227+ mm_data = mm_data ,
228+ hf_processor_mm_kwargs = {},
229+ )
230+ cached_tokenized_result = cached_processor .apply (
231+ token_prompt ,
232+ mm_data = mm_data ,
233+ hf_processor_mm_kwargs = {},
234+ )
235+
236+ assert _inputs_equal (
237+ baseline_tokenized_result ,
238+ cached_tokenized_result ,
239+ ignore_mm_keys ,
240+ ), f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )"
155241
156242
157243# yapf: disable
@@ -173,6 +259,7 @@ def _test_processing_correctness(
173259 "llava-hf/llava-onevision-qwen2-0.5b-ov-hf" ,
174260 "meta-llama/Llama-3.2-11B-Vision-Instruct" ,
175261 "TIGER-Lab/Mantis-8B-siglip-llama3" ,
262+ "mistralai/Pixtral-12B-2409" ,
176263 "mistral-community/pixtral-12b" ,
177264 "openbmb/MiniCPM-o-2_6" ,
178265 "openbmb/MiniCPM-V-2_6" ,
@@ -241,8 +328,19 @@ def test_processing_correctness_phi3v(
241328 )
242329
243330
244- def _drop_mm_kwargs_keys (result : dict ,
245- ignore_mm_keys : Optional [list [str ]] = None ) -> dict :
331+ def _inputs_equal (
332+ a : MultiModalInputs ,
333+ b : MultiModalInputs ,
334+ ignore_mm_keys : Optional [list [str ]] = None ,
335+ ):
336+ return _drop_mm_kwargs_keys (a , ignore_mm_keys ) == _drop_mm_kwargs_keys (
337+ b , ignore_mm_keys )
338+
339+
340+ def _drop_mm_kwargs_keys (
341+ result : MultiModalInputs ,
342+ ignore_mm_keys : Optional [list [str ]] = None ,
343+ ) -> MultiModalInputs :
246344 """Drop specified keys from result['mm_kwargs'].
247345
248346 This is mainly to avoid doing exact match of audio_features in ultravox.
0 commit comments