11from typing import cast
22
33import pytest
4- from transformers import BatchFeature
54
6- from vllm .multimodal .processing import (PromptReplacement , _PlaceholderInfo ,
7- find_text_matches , find_token_matches ,
8- iter_placeholders , iter_token_matches ,
5+ from vllm .multimodal .processing import (MultiModalDataItems , PromptReplacement ,
6+ _PlaceholderInfo , find_text_matches ,
7+ find_token_matches , iter_placeholders ,
8+ iter_token_matches ,
99 replace_text_matches ,
1010 replace_token_matches )
1111from vllm .transformers_utils .tokenizer import AnyTokenizer
1616@pytest .mark .parametrize (
1717 ("token_ids" , "match_ids" , "expected" ),
1818 [
19- ([], [], [{ "start_idx" : 0 , "end_idx" : 0 } ]),
19+ ([], [], []),
2020 ([], [32000 ], []),
2121 (
2222 [32000 , 32000 , 32000 ],
@@ -83,7 +83,7 @@ def test_iter_token_matches(token_ids, match_ids, expected):
8383 "pattern_2" : [32000 ],
8484 },
8585 {
86- "pattern_1" : [{ "start_idx" : 0 , "end_idx" : 0 } ],
86+ "pattern_1" : [],
8787 "pattern_2" : [],
8888 }
8989 ),
@@ -136,7 +136,7 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key):
136136 mock_tokenizer = cast (AnyTokenizer , object ())
137137
138138 prompt_repls = [
139- PromptReplacement (target , [], 0 ).bind (key , mock_tokenizer )
139+ PromptReplacement (key , target , []).bind (mock_tokenizer )
140140 for key , target in target_by_key .items ()
141141 ]
142142 result = find_token_matches (prompt , prompt_repls )
@@ -243,7 +243,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
243243 mock_tokenizer = cast (AnyTokenizer , object ())
244244
245245 prompt_repls = [
246- PromptReplacement (target , [], 0 ).bind (key , mock_tokenizer )
246+ PromptReplacement (key , target , []).bind (mock_tokenizer )
247247 for key , target in target_by_key .items ()
248248 ]
249249 result = find_text_matches (prompt , prompt_repls )
@@ -276,12 +276,12 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
276276 "pattern_3" : "!" ,
277277 },
278278 {
279- # Test whether target is confused with repl_unit
280- "pattern_1" : ( "<image><image>" , 1 ) ,
281- # Test empty repl_unit
282- "pattern_2" : ( "" , 1 ) ,
283- # Test multiple repl_count
284- "pattern_3" : ( "?" , 2 ) ,
279+ # Test whether target is confused with replacement
280+ "pattern_1" : "<image><image>" ,
281+ # Test empty replacement
282+ "pattern_2" : "" ,
283+ # Test dynamic replacement (beyond the form of `unit * count`)
284+ "pattern_3" : "?!?" ,
285285 },
286286 ),
287287 ]
@@ -290,8 +290,8 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
290290 ("mm_count" , "expected" ),
291291 [
292292 (0 , "Image:<image>Image:<image><image>!" ),
293- (1 , "<image><image>Image:<image><image>??" ),
294- (2 , "<image><image><image><image><image>??" ),
293+ (1 , "<image><image>Image:<image><image>?! ?" ),
294+ (2 , "<image><image><image><image><image>?! ?" ),
295295 ]
296296)
297297# yapf: enable
@@ -306,17 +306,16 @@ def test_find_replace_text(
306306 mock_tokenizer = cast (AnyTokenizer , object ())
307307
308308 prompt_repls = [
309- PromptReplacement (target , * repl_by_key [key ]).bind (key , mock_tokenizer )
309+ PromptReplacement (key , target , repl_by_key [key ]).bind (mock_tokenizer )
310310 for key , target in target_by_key .items ()
311311 ]
312312 matches = find_text_matches (prompt , prompt_repls )
313313
314314 result = replace_text_matches (
315315 prompt ,
316316 matches ,
317- {key : list (range (mm_count ))
318- for key in repl_by_key },
319- BatchFeature (),
317+ MultiModalDataItems ({key : [None ] * mm_count
318+ for key in repl_by_key }),
320319 )
321320
322321 # Only displayed on error
@@ -343,12 +342,12 @@ def test_find_replace_text(
343342 "pattern_3" : [918 ],
344343 },
345344 {
346- # Test whether target is confused with repl_unit
347- "pattern_1" : ( [32000 , 32000 ], 1 ) ,
348- # Test empty repl_unit
349- "pattern_2" : ([], 1 ) ,
350- # Test multiple repl_count
351- "pattern_3" : ( [1550 ], 2 ) ,
345+ # Test whether target is confused with replacement
346+ "pattern_1" : [32000 , 32000 ],
347+ # Test empty replacement
348+ "pattern_2" : [] ,
349+ # Test dynamic replacement (beyond the form of `unit * count`)
350+ "pattern_3" : [1550 , 918 , 1550 ] ,
352351 },
353352 ),
354353 ]
@@ -357,8 +356,8 @@ def test_find_replace_text(
357356 ("mm_count" , "expected" ),
358357 [
359358 (0 , [1 , 9833 , 28747 , 32000 , 9833 , 28747 , 32000 , 32000 , 918 ]),
360- (1 , [1 , 32000 , 32000 , 9833 , 28747 , 32000 , 32000 , 1550 , 1550 ]),
361- (2 , [1 , 32000 , 32000 , 32000 , 32000 , 32000 , 1550 , 1550 ]),
359+ (1 , [1 , 32000 , 32000 , 9833 , 28747 , 32000 , 32000 , 1550 , 918 , 1550 ]),
360+ (2 , [1 , 32000 , 32000 , 32000 , 32000 , 32000 , 1550 , 918 , 1550 ]),
362361 ]
363362)
364363# yapf: enable
@@ -373,17 +372,16 @@ def test_find_replace_tokens(
373372 mock_tokenizer = cast (AnyTokenizer , object ())
374373
375374 prompt_repls = [
376- PromptReplacement (target , * repl_by_key [key ]).bind (key , mock_tokenizer )
375+ PromptReplacement (key , target , repl_by_key [key ]).bind (mock_tokenizer )
377376 for key , target in target_by_key .items ()
378377 ]
379378 matches = find_token_matches (prompt , prompt_repls )
380379
381380 result = replace_token_matches (
382381 prompt ,
383382 matches ,
384- {key : list (range (mm_count ))
385- for key in repl_by_key },
386- BatchFeature (),
383+ MultiModalDataItems ({key : [None ] * mm_count
384+ for key in repl_by_key }),
387385 )
388386
389387 # Only displayed on error
@@ -399,9 +397,9 @@ def test_find_replace_tokens(
399397 "repl_by_key" ,
400398 [
401399 {
402- "pattern_1" : ( [32000 , 32000 ], 1 ) ,
403- "pattern_2" : ([], 1 ) ,
404- "pattern_3" : ( [1550 ], 2 ) ,
400+ "pattern_1" : [32000 , 32000 ],
401+ "pattern_2" : [] ,
402+ "pattern_3" : [1550 , 918 , 1550 ] ,
405403 },
406404 ],
407405)
@@ -414,48 +412,47 @@ def test_find_replace_tokens(
414412 _PlaceholderInfo (
415413 modality = "pattern_1" ,
416414 start_idx = 6 ,
417- unit = [32000 , 32000 ],
418- unit_count = 1 ,
415+ replacement = [32000 , 32000 ],
419416 ),
420417 ],
421418 ),
422419 (
423- [1 , 32000 , 32000 , 9833 , 28747 , 32000 , 32000 , 1550 , 1550 ],
420+ [1 , 32000 , 32000 , 9833 , 28747 , 32000 , 32000 , 1550 , 918 , 1550 ],
424421 [
425422 _PlaceholderInfo (
426423 modality = "pattern_1" ,
427424 start_idx = 1 ,
428- unit = [32000 , 32000 ],
429- unit_count = 1 ,
425+ replacement = [32000 , 32000 ],
430426 ),
431427 _PlaceholderInfo (
432428 modality = "pattern_1" ,
433429 start_idx = 5 ,
434- unit = [32000 , 32000 ],
435- unit_count = 1 ,
430+ replacement = [32000 , 32000 ],
436431 ),
437432 _PlaceholderInfo (
438433 modality = "pattern_3" ,
439434 start_idx = 7 ,
440- unit = [1550 ],
441- unit_count = 2 ,
435+ replacement = [1550 , 918 , 1550 ],
442436 ),
443437 ],
444438 ),
445439 (
446- [1 , 32000 , 32000 , 32000 , 32000 , 32000 , 1550 , 1550 ],
440+ [1 , 32000 , 32000 , 32000 , 32000 , 32000 , 1550 , 918 , 1550 ],
447441 [
448442 _PlaceholderInfo (
449443 modality = "pattern_1" ,
450444 start_idx = 1 ,
451- unit = [32000 , 32000 ],
452- unit_count = 2 ,
445+ replacement = [32000 , 32000 ],
446+ ),
447+ _PlaceholderInfo (
448+ modality = "pattern_1" ,
449+ start_idx = 3 ,
450+ replacement = [32000 , 32000 ],
453451 ),
454452 _PlaceholderInfo (
455453 modality = "pattern_3" ,
456454 start_idx = 6 ,
457- unit = [1550 ],
458- unit_count = 2 ,
455+ replacement = [1550 , 918 , 1550 ],
459456 ),
460457 ],
461458 ),
@@ -470,11 +467,17 @@ def test_iter_placeholders(
470467 mock_tokenizer = cast (AnyTokenizer , object ())
471468
472469 prompt_repls = [
473- PromptReplacement ([], * repl ).bind (key , mock_tokenizer )
470+ PromptReplacement (key , [], repl ).bind (mock_tokenizer )
474471 for key , repl in repl_by_key .items ()
475472 ]
476473
477- result = list (iter_placeholders (prompt_repls , prompt ))
474+ result = list (
475+ iter_placeholders (
476+ prompt_repls ,
477+ prompt ,
478+ # Effectively match all occurrences in the prompt
479+ MultiModalDataItems ({key : [None ] * 3 for key in repl_by_key }),
480+ ))
478481
479482 # Only displayed on error
480483 print ("result:" , result )
0 commit comments