Skip to content

Commit f4e4088

Browse files
weireweireywang96
andauthored
Fix random dataset mismatched token length with config. (#24937)
Signed-off-by: Weiliang Liu <weiliangl@nvidia.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io>
1 parent 0efd540 commit f4e4088

File tree

1 file changed

+118
-27
lines changed

1 file changed

+118
-27
lines changed

vllm/benchmarks/datasets.py

Lines changed: 118 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -366,11 +366,67 @@ def process_video(video: Any) -> Mapping[str, Any]:
366366
f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501
367367
)
368368

369+
370+
def gen_prompt_decode_to_target_len(
371+
tokenizer: PreTrainedTokenizerBase,
372+
token_sequence: list[int],
373+
target_token_len: int,
374+
max_retry: int = 10,
375+
add_special_tokens: bool = False,
376+
rng: Optional[np.random.Generator] = None,
377+
) -> tuple[str, list[int]]:
378+
"""
379+
Ensure decoded-then-encoded prompt length matches the target token length.
380+
381+
This function decodes an initial token sequence to text and re-encodes it
382+
, iteratively adjusting the token sequence length to match a target.
383+
This is necessary because some tokenizers do not guarantee a 1:1 mapping
384+
between consecutive tokens and the decoded-then-encoded sequence length.
385+
For example, for GPT2Tokenizer:
386+
[6880, 6881] -> ['Ġcalls', 'here'] ->
387+
[1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
388+
389+
Returns a tuple of the final prompt string and the adjusted token sequence.
390+
"""
391+
remain_num_try = max_retry
392+
token_mismatch = 0
393+
while True:
394+
prompt = tokenizer.decode(token_sequence)
395+
token_sequence = tokenizer.encode(
396+
prompt, add_special_tokens=add_special_tokens
397+
)
398+
if remain_num_try <= 0:
399+
if len(token_sequence) != target_token_len:
400+
token_mismatch = len(token_sequence) - target_token_len
401+
break
402+
403+
if len(token_sequence) == target_token_len:
404+
break
405+
elif len(token_sequence) < target_token_len:
406+
if rng is not None:
407+
extra_tokens = rng.integers(
408+
0,
409+
tokenizer.vocab_size,
410+
size=target_token_len - len(token_sequence),
411+
).tolist()
412+
else:
413+
extra_tokens = np.random.randint(
414+
0,
415+
tokenizer.vocab_size,
416+
size=target_token_len - len(token_sequence),
417+
).tolist()
418+
token_sequence.extend(extra_tokens)
419+
elif len(token_sequence) > target_token_len:
420+
token_sequence = token_sequence[:target_token_len]
421+
422+
remain_num_try -= 1
423+
424+
return prompt, token_sequence, token_mismatch
425+
369426
# -----------------------------------------------------------------------------
370427
# Random Dataset Implementation (Synthetic Data)
371428
# -----------------------------------------------------------------------------
372429

373-
374430
class RandomDataset(BenchmarkDataset):
375431
"""
376432
Synthetic text-only dataset for serving/throughput benchmarks.
@@ -420,8 +476,9 @@ def sample(
420476
vocab_size = tokenizer.vocab_size
421477

422478
requests = []
479+
token_mismatch_total = 0
423480
for i in range(num_requests):
424-
prompt, total_input_len = self.generate_token_sequence(
481+
prompt, total_input_len, token_mismatch = self.generate_token_sequence( # noqa: E501
425482
tokenizer=tokenizer,
426483
prefix_token_ids=prefix_token_ids,
427484
prefix_len=prefix_len,
@@ -430,6 +487,7 @@ def sample(
430487
offset=int(offsets[i]),
431488
index=i,
432489
)
490+
token_mismatch_total += token_mismatch
433491
requests.append(
434492
SampleRequest(
435493
prompt=prompt,
@@ -453,6 +511,18 @@ def sample(
453511
)
454512
)
455513
requests = batch_requests
514+
515+
if token_mismatch_total != 0:
516+
sign = "more" if token_mismatch_total > 0 else "fewer"
517+
logger.warning(
518+
"Across all generated prompts, there were %d %s tokens "
519+
"than expected after decoding and re-encoding. This is "
520+
"expected due to the imperfect nature of the sampling "
521+
"procedure.",
522+
abs(token_mismatch_total),
523+
sign,
524+
)
525+
456526
return requests
457527

458528
def get_prefix(
@@ -530,7 +600,7 @@ def generate_token_sequence(
530600
input_len: int,
531601
offset: int,
532602
index: int,
533-
) -> tuple[str, int]:
603+
) -> tuple[str, int, int]:
534604
"""
535605
Returns (prompt, total_input_len).
536606
@@ -549,15 +619,16 @@ def generate_token_sequence(
549619
token_sequence = prefix_token_ids + inner_seq
550620

551621
# Decode, then re-encode and truncate to preserve token count invariants
552-
prompt = tokenizer.decode(token_sequence)
553622
total_input_len = prefix_len + int(input_len)
554-
555-
re_encoded_sequence = tokenizer.encode(
556-
prompt, add_special_tokens=False)[:total_input_len]
557-
prompt = tokenizer.decode(re_encoded_sequence)
558-
total_input_len = len(re_encoded_sequence)
559-
560-
return prompt, total_input_len
623+
prompt, adjusted_token_sequence, token_mismatch = gen_prompt_decode_to_target_len( # noqa: E501
624+
tokenizer=tokenizer,
625+
token_sequence=token_sequence,
626+
target_token_len=total_input_len,
627+
add_special_tokens=False,
628+
rng=self._rng,
629+
)
630+
total_input_len = len(adjusted_token_sequence)
631+
return prompt, total_input_len, token_mismatch
561632

562633

563634
# -----------------------------------------------------------------------------
@@ -873,8 +944,9 @@ def sample(
873944
vocab_size = tokenizer.vocab_size
874945
# Add synthetic multimodal items to each request
875946
mm_requests = []
947+
token_mismatch_total = 0
876948
for i in range(num_requests):
877-
prompt, total_input_len = self.generate_token_sequence(
949+
prompt, total_input_len, token_mismatch = self.generate_token_sequence( # noqa: E501
878950
tokenizer=tokenizer,
879951
prefix_token_ids=prefix_token_ids,
880952
prefix_len=prefix_len,
@@ -883,6 +955,7 @@ def sample(
883955
offset=int(offsets[i]),
884956
index=i,
885957
)
958+
token_mismatch_total += token_mismatch
886959
# Get multimodal item iterator for a given request
887960
mm_item_iterator = self.get_mm_item_iterator(
888961
min_num_mm_items,
@@ -918,6 +991,18 @@ def sample(
918991
request_id=request_id_prefix + str(i),
919992
)
920993
mm_requests.append(sample_request)
994+
995+
if token_mismatch_total != 0:
996+
sign = "more" if token_mismatch_total > 0 else "fewer"
997+
logger.warning(
998+
"Across all generated prompts, there were %d %s tokens "
999+
"than expected after decoding and re-encoding. This is "
1000+
"expected due to the imperfect nature of the sampling "
1001+
"procedure.",
1002+
abs(token_mismatch_total),
1003+
sign,
1004+
)
1005+
9211006
return mm_requests
9221007

9231008
# -----------------------------------------------------------------------------
@@ -2694,27 +2779,23 @@ def _generate_exact_length_tokens(target_length: int) -> list[int]:
26942779
# Generate random tokens
26952780
tokens = np.random.randint(
26962781
0, vocab_size, size=target_length).tolist()
2697-
text = tokenizer.decode(tokens)
2698-
re_encoded = tokenizer.encode(text, add_special_tokens=False)
2699-
2700-
if len(re_encoded) == target_length:
2701-
return re_encoded
2702-
elif len(re_encoded) < target_length:
2703-
# Recursively generate additional consistent tokens
2704-
needed = target_length - len(re_encoded)
2705-
extra_tokens = _generate_exact_length_tokens(needed)
2706-
return re_encoded + extra_tokens
2707-
else:
2708-
# Truncate to target length
2709-
return re_encoded[:target_length]
2782+
2783+
_, adjusted_tokens, token_mismatch = gen_prompt_decode_to_target_len( # noqa: E501
2784+
tokenizer=tokenizer,
2785+
token_sequence=tokens,
2786+
target_token_len=target_length,
2787+
add_special_tokens=False,
2788+
)
2789+
return adjusted_tokens, token_mismatch
27102790

27112791
requests = []
2792+
token_mismatch_total = 0
27122793
for _ in range(num_prefixes):
27132794
prefix_tokens = _generate_exact_length_tokens(prefix_len)
27142795

27152796
for _ in range(prompts_per_prefix):
2716-
suffix_tokens = _generate_exact_length_tokens(suffix_len)
2717-
2797+
suffix_tokens, token_mistmatch = _generate_exact_length_tokens(suffix_len) # noqa: E501
2798+
token_mismatch_total += token_mistmatch
27182799
combined_tokens = prefix_tokens + suffix_tokens
27192800
prompt = tokenizer.decode(combined_tokens)
27202801
prompt_len = len(combined_tokens)
@@ -2726,6 +2807,16 @@ def _generate_exact_length_tokens(target_length: int) -> list[int]:
27262807
)
27272808
)
27282809

2810+
if token_mismatch_total != 0:
2811+
sign = "more" if token_mismatch_total > 0 else "fewer"
2812+
logger.warning(
2813+
"Across all generated prompts, there were %d %s tokens "
2814+
"than expected after decoding and re-encoding. This is "
2815+
"expected due to the imperfect nature of the sampling "
2816+
"procedure.",
2817+
abs(token_mismatch_total),
2818+
sign,
2819+
)
27292820
random.shuffle(requests)
27302821
return requests
27312822

0 commit comments

Comments
 (0)