11# SPDX-License-Identifier: Apache-2.0
22
3- import copy
43from functools import partial
54from typing import Optional , Union
65
@@ -29,7 +28,7 @@ def _test_processing_correctness(
2928 hit_rate : float ,
3029 num_batches : int ,
3130 simplify_rate : float ,
32- ignore_mm_keys : Optional [list [str ]] = None ,
31+ ignore_mm_keys : Optional [set [str ]] = None ,
3332):
3433 model_info = HF_EXAMPLE_MODELS .find_hf_info (model_id )
3534 model_info .check_available_online (on_fail = "skip" )
@@ -145,7 +144,7 @@ def _test_processing_correctness_hf(
145144 baseline_processor : BaseMultiModalProcessor ,
146145 cached_processor : BaseMultiModalProcessor ,
147146 batch_idx : int ,
148- ignore_mm_keys : Optional [list [str ]] = None ,
147+ ignore_mm_keys : Optional [set [str ]] = None ,
149148):
150149 if model_config .hf_config .model_type in ("mllama" , "whisper" , "ultravox" ):
151150 # For some multimodal models, tokenizer will always add bos_token
@@ -167,35 +166,38 @@ def _test_processing_correctness_hf(
167166 hf_processor_mm_kwargs = {},
168167 )
169168
170- assert _inputs_equal (
169+ _assert_inputs_equal (
171170 baseline_result ,
172171 cached_result ,
173- ignore_mm_keys ,
174- ), f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )"
172+ ignore_mm_keys = ignore_mm_keys ,
173+ msg = f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" ,
174+ )
175175
176176 baseline_tokenized_result = baseline_processor .apply (
177177 token_prompt ,
178178 mm_data = mm_data ,
179179 hf_processor_mm_kwargs = {},
180180 )
181181
182- assert _inputs_equal (
182+ _assert_inputs_equal (
183183 baseline_result ,
184184 baseline_tokenized_result ,
185- ignore_mm_keys ,
186- ), f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )"
185+ ignore_mm_keys = ignore_mm_keys ,
186+ msg = f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" ,
187+ )
187188
188189 cached_tokenized_result = cached_processor .apply (
189190 token_prompt ,
190191 mm_data = mm_data ,
191192 hf_processor_mm_kwargs = {},
192193 )
193194
194- assert _inputs_equal (
195+ _assert_inputs_equal (
195196 cached_result ,
196197 cached_tokenized_result ,
197- ignore_mm_keys ,
198- ), f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )"
198+ ignore_mm_keys = ignore_mm_keys ,
199+ msg = f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" ,
200+ )
199201
200202
201203def _test_processing_correctness_mistral (
@@ -206,7 +208,7 @@ def _test_processing_correctness_mistral(
206208 baseline_processor : BaseMultiModalProcessor ,
207209 cached_processor : BaseMultiModalProcessor ,
208210 batch_idx : int ,
209- ignore_mm_keys : Optional [list [str ]] = None ,
211+ ignore_mm_keys : Optional [set [str ]] = None ,
210212):
211213 images = mm_data .get ("image" , [])
212214 if not isinstance (images , list ):
@@ -233,11 +235,12 @@ def _test_processing_correctness_mistral(
233235 hf_processor_mm_kwargs = {},
234236 )
235237
236- assert _inputs_equal (
238+ _assert_inputs_equal (
237239 baseline_tokenized_result ,
238240 cached_tokenized_result ,
239- ignore_mm_keys ,
240- ), f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )"
241+ ignore_mm_keys = ignore_mm_keys ,
242+ msg = f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" ,
243+ )
241244
242245
243246# yapf: disable
@@ -261,6 +264,7 @@ def _test_processing_correctness_mistral(
261264 "TIGER-Lab/Mantis-8B-siglip-llama3" ,
262265 "mistralai/Pixtral-12B-2409" ,
263266 "mistral-community/pixtral-12b" ,
267+ "openbmb/MiniCPM-Llama3-V-2_5" ,
264268 "openbmb/MiniCPM-o-2_6" ,
265269 "openbmb/MiniCPM-V-2_6" ,
266270 "allenai/Molmo-7B-D-0924" ,
@@ -290,7 +294,7 @@ def test_processing_correctness(
290294 # In Ultravox, the audio_features can be different depending on padding
291295 # The slight difference should not be a problem though, since
292296 # attention_mask lets us ignore the difference.
293- ignore_mm_keys = [ ' audio_features' ]
297+ ignore_mm_keys = { " audio_features" }
294298
295299 _test_processing_correctness (
296300 model_id ,
@@ -328,38 +332,26 @@ def test_processing_correctness_phi3v(
328332 )
329333
330334
331- def _inputs_equal (
335+ def _assert_inputs_equal (
332336 a : MultiModalInputs ,
333337 b : MultiModalInputs ,
334- ignore_mm_keys : Optional [list [str ]] = None ,
338+ * ,
339+ ignore_mm_keys : Optional [set [str ]] = None ,
340+ msg : str = "" ,
335341):
336- return _drop_mm_kwargs_keys (a , ignore_mm_keys ) == _drop_mm_kwargs_keys (
337- b , ignore_mm_keys )
338-
339-
340- def _drop_mm_kwargs_keys (
341- result : MultiModalInputs ,
342- ignore_mm_keys : Optional [list [str ]] = None ,
343- ) -> MultiModalInputs :
344- """Drop specified keys from result['mm_kwargs'].
345-
346- This is mainly to avoid doing exact match of audio_features in ultravox.
347-
348- Args:
349- result: Result to drop keys from
350- ignore_mm_keys: List of keys to ignore, e.g. ['audio_features']
351- """
352- if not ignore_mm_keys :
353- return result
354-
355- if 'mm_kwargs' in result :
356- result = copy .deepcopy (result )
357- mm_kwargs = result ['mm_kwargs' ]
358- for key in ignore_mm_keys :
359- mm_kwargs .pop (key , None )
360- for items in mm_kwargs ._items_by_modality .values ():
361- for item in items :
362- for key in ignore_mm_keys :
363- item .pop (key , None )
364-
365- return result
342+ if ignore_mm_keys is None :
343+ ignore_mm_keys = set ()
344+
345+ if msg is None :
346+ assert "mm_kwargs" in a and "mm_kwargs" in b
347+ else :
348+ assert "mm_kwargs" in a and "mm_kwargs" in b , msg
349+
350+ for key in ignore_mm_keys :
351+ a ["mm_kwargs" ].pop (key , None )
352+ b ["mm_kwargs" ].pop (key , None )
353+
354+ if msg is None :
355+ assert a == b
356+ else :
357+ assert a == b , msg
0 commit comments