@@ -366,11 +366,67 @@ def process_video(video: Any) -> Mapping[str, Any]:
366366 f"Invalid video input { video } . Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501
367367 )
368368
369+
370+ def gen_prompt_decode_to_target_len (
371+ tokenizer : PreTrainedTokenizerBase ,
372+ token_sequence : list [int ],
373+ target_token_len : int ,
374+ max_retry : int = 10 ,
375+ add_special_tokens : bool = False ,
376+ rng : Optional [np .random .Generator ] = None ,
377+ ) -> tuple [str , list [int ]]:
378+ """
379+ Ensure decoded-then-encoded prompt length matches the target token length.
380+
381+ This function decodes an initial token sequence to text and re-encodes it
382+ , iteratively adjusting the token sequence length to match a target.
383+ This is necessary because some tokenizers do not guarantee a 1:1 mapping
384+ between consecutive tokens and the decoded-then-encoded sequence length.
385+ For example, for GPT2Tokenizer:
386+ [6880, 6881] -> ['Ġcalls', 'here'] ->
387+ [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
388+
389+ Returns a tuple of the final prompt string and the adjusted token sequence.
390+ """
391+ remain_num_try = max_retry
392+ token_mismatch = 0
393+ while True :
394+ prompt = tokenizer .decode (token_sequence )
395+ token_sequence = tokenizer .encode (
396+ prompt , add_special_tokens = add_special_tokens
397+ )
398+ if remain_num_try <= 0 :
399+ if len (token_sequence ) != target_token_len :
400+ token_mismatch = len (token_sequence ) - target_token_len
401+ break
402+
403+ if len (token_sequence ) == target_token_len :
404+ break
405+ elif len (token_sequence ) < target_token_len :
406+ if rng is not None :
407+ extra_tokens = rng .integers (
408+ 0 ,
409+ tokenizer .vocab_size ,
410+ size = target_token_len - len (token_sequence ),
411+ ).tolist ()
412+ else :
413+ extra_tokens = np .random .randint (
414+ 0 ,
415+ tokenizer .vocab_size ,
416+ size = target_token_len - len (token_sequence ),
417+ ).tolist ()
418+ token_sequence .extend (extra_tokens )
419+ elif len (token_sequence ) > target_token_len :
420+ token_sequence = token_sequence [:target_token_len ]
421+
422+ remain_num_try -= 1
423+
424+ return prompt , token_sequence , token_mismatch
425+
369426# -----------------------------------------------------------------------------
370427# Random Dataset Implementation (Synthetic Data)
371428# -----------------------------------------------------------------------------
372429
373-
374430class RandomDataset (BenchmarkDataset ):
375431 """
376432 Synthetic text-only dataset for serving/throughput benchmarks.
@@ -420,8 +476,9 @@ def sample(
420476 vocab_size = tokenizer .vocab_size
421477
422478 requests = []
479+ token_mismatch_total = 0
423480 for i in range (num_requests ):
424- prompt , total_input_len = self .generate_token_sequence (
481+ prompt , total_input_len , token_mismatch = self .generate_token_sequence ( # noqa: E501
425482 tokenizer = tokenizer ,
426483 prefix_token_ids = prefix_token_ids ,
427484 prefix_len = prefix_len ,
@@ -430,6 +487,7 @@ def sample(
430487 offset = int (offsets [i ]),
431488 index = i ,
432489 )
490+ token_mismatch_total += token_mismatch
433491 requests .append (
434492 SampleRequest (
435493 prompt = prompt ,
@@ -453,6 +511,18 @@ def sample(
453511 )
454512 )
455513 requests = batch_requests
514+
515+ if token_mismatch_total != 0 :
516+ sign = "more" if token_mismatch_total > 0 else "fewer"
517+ logger .warning (
518+ "Across all generated prompts, there were %d %s tokens "
519+ "than expected after decoding and re-encoding. This is "
520+ "expected due to the imperfect nature of the sampling "
521+ "procedure." ,
522+ abs (token_mismatch_total ),
523+ sign ,
524+ )
525+
456526 return requests
457527
458528 def get_prefix (
@@ -530,7 +600,7 @@ def generate_token_sequence(
530600 input_len : int ,
531601 offset : int ,
532602 index : int ,
533- ) -> tuple [str , int ]:
603+ ) -> tuple [str , int , int ]:
534604 """
535605 Returns (prompt, total_input_len).
536606
@@ -549,15 +619,16 @@ def generate_token_sequence(
549619 token_sequence = prefix_token_ids + inner_seq
550620
551621 # Decode, then re-encode and truncate to preserve token count invariants
552- prompt = tokenizer .decode (token_sequence )
553622 total_input_len = prefix_len + int (input_len )
554-
555- re_encoded_sequence = tokenizer .encode (
556- prompt , add_special_tokens = False )[:total_input_len ]
557- prompt = tokenizer .decode (re_encoded_sequence )
558- total_input_len = len (re_encoded_sequence )
559-
560- return prompt , total_input_len
623+ prompt , adjusted_token_sequence , token_mismatch = gen_prompt_decode_to_target_len ( # noqa: E501
624+ tokenizer = tokenizer ,
625+ token_sequence = token_sequence ,
626+ target_token_len = total_input_len ,
627+ add_special_tokens = False ,
628+ rng = self ._rng ,
629+ )
630+ total_input_len = len (adjusted_token_sequence )
631+ return prompt , total_input_len , token_mismatch
561632
562633
563634# -----------------------------------------------------------------------------
@@ -873,8 +944,9 @@ def sample(
873944 vocab_size = tokenizer .vocab_size
874945 # Add synthetic multimodal items to each request
875946 mm_requests = []
947+ token_mismatch_total = 0
876948 for i in range (num_requests ):
877- prompt , total_input_len = self .generate_token_sequence (
949+ prompt , total_input_len , token_mismatch = self .generate_token_sequence ( # noqa: E501
878950 tokenizer = tokenizer ,
879951 prefix_token_ids = prefix_token_ids ,
880952 prefix_len = prefix_len ,
@@ -883,6 +955,7 @@ def sample(
883955 offset = int (offsets [i ]),
884956 index = i ,
885957 )
958+ token_mismatch_total += token_mismatch
886959 # Get multimodal item iterator for a given request
887960 mm_item_iterator = self .get_mm_item_iterator (
888961 min_num_mm_items ,
@@ -918,6 +991,18 @@ def sample(
918991 request_id = request_id_prefix + str (i ),
919992 )
920993 mm_requests .append (sample_request )
994+
995+ if token_mismatch_total != 0 :
996+ sign = "more" if token_mismatch_total > 0 else "fewer"
997+ logger .warning (
998+ "Across all generated prompts, there were %d %s tokens "
999+ "than expected after decoding and re-encoding. This is "
1000+ "expected due to the imperfect nature of the sampling "
1001+ "procedure." ,
1002+ abs (token_mismatch_total ),
1003+ sign ,
1004+ )
1005+
9211006 return mm_requests
9221007
9231008# -----------------------------------------------------------------------------
@@ -2694,27 +2779,23 @@ def _generate_exact_length_tokens(target_length: int) -> list[int]:
26942779 # Generate random tokens
26952780 tokens = np .random .randint (
26962781 0 , vocab_size , size = target_length ).tolist ()
2697- text = tokenizer .decode (tokens )
2698- re_encoded = tokenizer .encode (text , add_special_tokens = False )
2699-
2700- if len (re_encoded ) == target_length :
2701- return re_encoded
2702- elif len (re_encoded ) < target_length :
2703- # Recursively generate additional consistent tokens
2704- needed = target_length - len (re_encoded )
2705- extra_tokens = _generate_exact_length_tokens (needed )
2706- return re_encoded + extra_tokens
2707- else :
2708- # Truncate to target length
2709- return re_encoded [:target_length ]
2782+
2783+ _ , adjusted_tokens , token_mismatch = gen_prompt_decode_to_target_len ( # noqa: E501
2784+ tokenizer = tokenizer ,
2785+ token_sequence = tokens ,
2786+ target_token_len = target_length ,
2787+ add_special_tokens = False ,
2788+ )
2789+ return adjusted_tokens , token_mismatch
27102790
27112791 requests = []
2792+ token_mismatch_total = 0
27122793 for _ in range (num_prefixes ):
27132794 prefix_tokens = _generate_exact_length_tokens (prefix_len )
27142795
27152796 for _ in range (prompts_per_prefix ):
2716- suffix_tokens = _generate_exact_length_tokens (suffix_len )
2717-
2797+ suffix_tokens , token_mistmatch = _generate_exact_length_tokens (suffix_len ) # noqa: E501
2798+ token_mismatch_total += token_mistmatch
27182799 combined_tokens = prefix_tokens + suffix_tokens
27192800 prompt = tokenizer .decode (combined_tokens )
27202801 prompt_len = len (combined_tokens )
@@ -2726,6 +2807,16 @@ def _generate_exact_length_tokens(target_length: int) -> list[int]:
27262807 )
27272808 )
27282809
2810+ if token_mismatch_total != 0 :
2811+ sign = "more" if token_mismatch_total > 0 else "fewer"
2812+ logger .warning (
2813+ "Across all generated prompts, there were %d %s tokens "
2814+ "than expected after decoding and re-encoding. This is "
2815+ "expected due to the imperfect nature of the sampling "
2816+ "procedure." ,
2817+ abs (token_mismatch_total ),
2818+ sign ,
2819+ )
27292820 random .shuffle (requests )
27302821 return requests
27312822
0 commit comments