99                                                       UserMessage )
1010from  mistral_common .protocol .instruct .request  import  ChatCompletionRequest 
1111from  PIL  import  Image 
12- from  transformers  import  PreTrainedTokenizer , PreTrainedTokenizerFast 
1312
1413from  vllm .config  import  ModelConfig 
1514from  vllm .inputs  import  InputProcessingContext 
1615from  vllm .multimodal  import  MULTIMODAL_REGISTRY , MultiModalDataDict 
1716from  vllm .multimodal .inputs  import  MultiModalInputs 
1817from  vllm .multimodal .processing  import  BaseMultiModalProcessor , ProcessingCache 
19- from  vllm .transformers_utils .tokenizer  import  (MistralTokenizer ,
20-                                                cached_tokenizer_from_config )
18+ from  vllm .transformers_utils .tokenizer  import  (AnyTokenizer , MistralTokenizer ,
19+                                                cached_tokenizer_from_config ,
20+                                                encode_tokens )
2121
2222from  ....multimodal .utils  import  random_audio , random_image , random_video 
2323from  ...registry  import  HF_EXAMPLE_MODELS 
@@ -28,7 +28,6 @@ def _test_processing_correctness(
2828    hit_rate : float ,
2929    num_batches : int ,
3030    simplify_rate : float ,
31-     ignore_mm_keys : Optional [set [str ]] =  None ,
3231):
3332    model_info  =  HF_EXAMPLE_MODELS .find_hf_info (model_id )
3433    model_info .check_available_online (on_fail = "skip" )
@@ -99,10 +98,23 @@ def _test_processing_correctness(
9998        }
10099
101100        mm_counts  =  {k : len (vs ) for  k , vs  in  mm_data .items ()}
102-         prompt  =  dummy_inputs .get_dummy_processor_inputs (
103-             model_config .max_model_len ,
104-             mm_counts ,
105-         ).prompt_text 
101+ 
102+         # Mistral chat outputs tokens directly, rather than text prompts 
103+         if  isinstance (tokenizer , MistralTokenizer ):
104+             images  =  mm_data .get ("image" , [])
105+             request  =  ChatCompletionRequest (messages = [
106+                 UserMessage (content = [
107+                     TextChunk (text = "" ),
108+                     * (ImageChunk (image = image ) for  image  in  images ),
109+                 ]),
110+             ])
111+             res  =  tokenizer .mistral .encode_chat_completion (request )
112+             prompt  =  res .tokens 
113+         else :
114+             prompt  =  dummy_inputs .get_dummy_processor_inputs (
115+                 model_config .max_model_len ,
116+                 mm_counts ,
117+             ).prompt 
106118
107119        # Drop unnecessary keys and test single -> multi conversion 
108120        if  rng .rand () <  simplify_rate :
@@ -112,124 +124,66 @@ def _test_processing_correctness(
112124                elif  len (mm_data [k ]) ==  1 :
113125                    mm_data [k ] =  mm_data [k ][0 ]
114126
115-         if  isinstance (tokenizer , MistralTokenizer ):
116-             _test_processing_correctness_mistral (
117-                 model_config ,
118-                 tokenizer ,
119-                 prompt ,
120-                 mm_data ,
121-                 baseline_processor ,
122-                 cached_processor ,
123-                 batch_idx ,
124-                 ignore_mm_keys = ignore_mm_keys ,
125-             )
126-         else :
127-             _test_processing_correctness_hf (
128-                 model_config ,
129-                 tokenizer ,
130-                 prompt ,
131-                 mm_data ,
132-                 baseline_processor ,
133-                 cached_processor ,
134-                 batch_idx ,
135-                 ignore_mm_keys = ignore_mm_keys ,
136-             )
137- 
138- 
139- def  _test_processing_correctness_hf (
127+         _test_processing_correctness_one (
128+             model_config ,
129+             tokenizer ,
130+             prompt ,
131+             mm_data ,
132+             baseline_processor ,
133+             cached_processor ,
134+             batch_idx ,
135+         )
136+ 
137+ 
138+ # For some multimodal models, tokenizer will always add bos_token 
139+ # at the beginning of prompt by default, causing hf_processor outputs 
140+ # incorrect token ids. So we need use `add_special_tokens=False` here 
141+ # to leave bos_token to be added by the processor. 
142+ _ADD_SPECIAL_TOKENS_OVERRIDES  =  {
143+     "mllama" : False ,
144+     "ovis" : False ,
145+     "ultravox" : False ,
146+     "whisper" : False ,
147+ }
148+ 
149+ _IGNORE_MM_KEYS  =  {
150+     # In Ultravox, the audio_features can be different depending on padding 
151+     # The slight difference should not be a problem though, since 
152+     # attention_mask lets us ignore the difference. 
153+     "ultravox" : {"audio_features" },
154+ }
155+ 
156+ 
157+ def  _test_processing_correctness_one (
140158    model_config : ModelConfig ,
141-     tokenizer : Union [ PreTrainedTokenizer ,  PreTrainedTokenizerFast ] ,
142-     prompt : str ,
159+     tokenizer : AnyTokenizer ,
160+     prompt : Union [ str ,  list [ int ]] ,
143161    mm_data : MultiModalDataDict ,
144162    baseline_processor : BaseMultiModalProcessor ,
145163    cached_processor : BaseMultiModalProcessor ,
146164    batch_idx : int ,
147-     ignore_mm_keys : Optional [set [str ]] =  None ,
148165):
149-     if  model_config .hf_config .model_type  in  ("mllama" , "ovis" , "ultravox" ,
150-                                              "whisper" ):
151-         # For some multimodal models, tokenizer will always add bos_token 
152-         # at the beginning of prompt by default, causing hf_processor outputs 
153-         # incorrect token ids. So we need use `add_special_tokens=False` here 
154-         # to leave bos_token to be added by the processor. 
155-         token_prompt  =  tokenizer .encode (prompt , add_special_tokens = False )
166+     model_type  =  model_config .hf_config .model_type 
167+     ignore_mm_keys  =  _IGNORE_MM_KEYS .get (model_type , set [str ]())
168+ 
169+     if  isinstance (prompt , str ):
170+         text_prompt  =  prompt 
171+         token_prompt  =  encode_tokens (
172+             tokenizer ,
173+             prompt ,
174+             add_special_tokens = _ADD_SPECIAL_TOKENS_OVERRIDES .get (model_type ),
175+         )
156176    else :
157-         token_prompt  =  tokenizer .encode (prompt )
158- 
159-     baseline_result  =  baseline_processor .apply (
160-         prompt ,
161-         mm_data = mm_data ,
162-         hf_processor_mm_kwargs = {},
163-     )
164-     cached_result  =  cached_processor .apply (
165-         prompt ,
166-         mm_data = mm_data ,
167-         hf_processor_mm_kwargs = {},
168-     )
169- 
170-     _assert_inputs_equal (
171-         baseline_result ,
172-         cached_result ,
173-         ignore_mm_keys = ignore_mm_keys ,
174-         msg = f"Failed ({ batch_idx = } { prompt = } { mm_data = }  ,
175-     )
177+         # Mistral does not support decode_tokens with skip_special_tokens=False 
178+         text_prompt  =  None 
179+         token_prompt  =  prompt 
176180
177181    baseline_tokenized_result  =  baseline_processor .apply (
178182        token_prompt ,
179183        mm_data = mm_data ,
180184        hf_processor_mm_kwargs = {},
181185    )
182186
183-     _assert_inputs_equal (
184-         baseline_result ,
185-         baseline_tokenized_result ,
186-         ignore_mm_keys = ignore_mm_keys ,
187-         msg = f"Failed ({ batch_idx = } { prompt = } { mm_data = }  ,
188-     )
189- 
190-     cached_tokenized_result  =  cached_processor .apply (
191-         token_prompt ,
192-         mm_data = mm_data ,
193-         hf_processor_mm_kwargs = {},
194-     )
195- 
196-     _assert_inputs_equal (
197-         cached_result ,
198-         cached_tokenized_result ,
199-         ignore_mm_keys = ignore_mm_keys ,
200-         msg = f"Failed ({ batch_idx = } { prompt = } { mm_data = }  ,
201-     )
202- 
203- 
204- def  _test_processing_correctness_mistral (
205-     model_config : ModelConfig ,
206-     tokenizer : MistralTokenizer ,
207-     prompt : str ,
208-     mm_data : MultiModalDataDict ,
209-     baseline_processor : BaseMultiModalProcessor ,
210-     cached_processor : BaseMultiModalProcessor ,
211-     batch_idx : int ,
212-     ignore_mm_keys : Optional [set [str ]] =  None ,
213- ):
214-     images  =  mm_data .get ("image" , [])
215-     if  not  isinstance (images , list ):
216-         images  =  [images ]
217- 
218-     request  =  ChatCompletionRequest (messages = [
219-         UserMessage (content = [
220-             TextChunk (text = prompt ),
221-             * (ImageChunk (image = image ) for  image  in  images ),
222-         ]),
223-     ])
224-     res  =  tokenizer .mistral .encode_chat_completion (request )
225-     token_prompt  =  res .tokens 
226- 
227-     # Mistral chat outputs tokens directly, rather than text prompts 
228-     baseline_tokenized_result  =  baseline_processor .apply (
229-         token_prompt ,
230-         mm_data = mm_data ,
231-         hf_processor_mm_kwargs = {},
232-     )
233187    cached_tokenized_result  =  cached_processor .apply (
234188        token_prompt ,
235189        mm_data = mm_data ,
@@ -240,9 +194,44 @@ def _test_processing_correctness_mistral(
240194        baseline_tokenized_result ,
241195        cached_tokenized_result ,
242196        ignore_mm_keys = ignore_mm_keys ,
243-         msg = f"Failed ({ batch_idx = } { prompt = } { mm_data = }  ,
197+         msg = f"Failed ({ batch_idx = } { token_prompt = } { mm_data = }  ,
244198    )
245199
200+     if  text_prompt  is  not None :
201+         baseline_text_result  =  baseline_processor .apply (
202+             text_prompt ,
203+             mm_data = mm_data ,
204+             hf_processor_mm_kwargs = {},
205+         )
206+         cached_text_result  =  cached_processor .apply (
207+             text_prompt ,
208+             mm_data = mm_data ,
209+             hf_processor_mm_kwargs = {},
210+         )
211+ 
212+         _assert_inputs_equal (
213+             baseline_text_result ,
214+             cached_text_result ,
215+             ignore_mm_keys = ignore_mm_keys ,
216+             msg = f"Failed ({ batch_idx = } { text_prompt = } { mm_data = }  ,
217+         )
218+ 
219+         _assert_inputs_equal (
220+             baseline_text_result ,
221+             baseline_tokenized_result ,
222+             ignore_mm_keys = ignore_mm_keys ,
223+             msg = f"Failed ({ batch_idx = } { text_prompt = }  
224+             f"{ token_prompt = } { mm_data = }  ,
225+         )
226+ 
227+         _assert_inputs_equal (
228+             cached_text_result ,
229+             cached_tokenized_result ,
230+             ignore_mm_keys = ignore_mm_keys ,
231+             msg = f"Failed ({ batch_idx = } { text_prompt = }  
232+             f"{ token_prompt = } { mm_data = }  ,
233+         )
234+ 
246235
247236# yapf: disable 
248237@pytest .mark .parametrize ("model_id" , [ 
@@ -281,6 +270,7 @@ def _test_processing_correctness_mistral(
281270    "AIDC-AI/Ovis2-1B" , 
282271    "google/paligemma-3b-mix-224" , 
283272    "google/paligemma2-3b-ft-docci-448" , 
273+     "microsoft/Phi-3.5-vision-instruct" , 
284274    "microsoft/Phi-4-multimodal-instruct" , 
285275    "mistralai/Pixtral-12B-2409" , 
286276    "mistral-community/pixtral-12b" , 
@@ -303,41 +293,6 @@ def test_processing_correctness(
303293    num_batches : int ,
304294    simplify_rate : float ,
305295):
306-     ignore_mm_keys  =  None 
307-     if  'ultravox'  in  model_id :
308-         # In Ultravox, the audio_features can be different depending on padding 
309-         # The slight difference should not be a problem though, since 
310-         # attention_mask lets us ignore the difference. 
311-         ignore_mm_keys  =  {"audio_features" }
312- 
313-     _test_processing_correctness (
314-         model_id ,
315-         hit_rate = hit_rate ,
316-         num_batches = num_batches ,
317-         simplify_rate = simplify_rate ,
318-         ignore_mm_keys = ignore_mm_keys ,
319-     )
320- 
321- 
322- # yapf: disable 
323- @pytest .mark .parametrize ("model_id" , ["microsoft/Phi-3.5-vision-instruct" ]) 
324- @pytest .mark .parametrize ("hit_rate" , [0.3 , 0.5 , 1.0 ]) 
325- @pytest .mark .parametrize ("num_batches" , [32 ]) 
326- @pytest .mark .parametrize ("simplify_rate" , [1.0 ]) 
327- # yapf: enable 
328- def  test_processing_correctness_phi3v (
329-     model_id : str ,
330-     hit_rate : float ,
331-     num_batches : int ,
332-     simplify_rate : float ,
333- ):
334-     # HACK - this is an attempted workaround for the following bug 
335-     # https://github.com/huggingface/transformers/issues/34307 
336-     from  transformers  import  AutoImageProcessor   # noqa: F401 
337-     from  transformers  import  AutoProcessor   # noqa: F401 
338- 
339-     AutoImageProcessor .from_pretrained (model_id , trust_remote_code = True )
340- 
341296    _test_processing_correctness (
342297        model_id ,
343298        hit_rate = hit_rate ,
@@ -356,16 +311,10 @@ def _assert_inputs_equal(
356311    if  ignore_mm_keys  is  None :
357312        ignore_mm_keys  =  set ()
358313
359-     if  msg  is  None :
360-         assert  "mm_kwargs"  in  a  and  "mm_kwargs"  in  b 
361-     else :
362-         assert  "mm_kwargs"  in  a  and  "mm_kwargs"  in  b , msg 
314+     assert  "mm_kwargs"  in  a  and  "mm_kwargs"  in  b , msg 
363315
364316    for  key  in  ignore_mm_keys :
365317        a ["mm_kwargs" ].pop (key , None )
366318        b ["mm_kwargs" ].pop (key , None )
367319
368-     if  msg  is  None :
369-         assert  a  ==  b 
370-     else :
371-         assert  a  ==  b , msg 
320+     assert  a  ==  b , msg 
0 commit comments