Skip to content

Conversation

Xarbirus
Copy link
Contributor

@Xarbirus Xarbirus commented Apr 29, 2025

I encountered a problem while using benchmark_serving.py with --dataset-name random. The problem was that when encoding a prompt on the server side, its length in tokens could be longer than required in the benchmark (depending on the tokenizer).

For example: running the benchmark with the following parameters

--backend vllm --model meta-llama/Meta-Llama-3.1-8B --dataset-name random --random-prefix-len 0 --random-input-len 4000 --random-output-len 16 --num-prompts 100

and running vllm with --max-model-len 4096 resulted in a following error:

ERROR 04-29 12:49:57 [serving_completion.py:116] ValueError: This model's maximum context length is 4096 tokens. However, you requested 4122 tokens (4106 in the messages, 16 in the completion). Please reduce the length of the messages or completion.

Short repro:

# Original token IDs
original_token_ids = [6880, 6881]

# Step 1: Decode the token IDs into text
decoded_text = tokenizer.decode(original_token_ids) # callshere

# Step 2: Re-encode the text to token IDs
re_encoded_ids = tokenizer.encode(decoded_text, add_special_tokens=False) # re_encoded_ids = [1650, 939, 486]

# Step 3: Check output
print("Original token IDs:", original_token_ids)
print("Re-encoded token IDs:", re_encoded_ids)
# original_token_ids != re_encoded_ids

Signed-off-by: Mikhail Podvitskii <podvitskiymichael@gmail.com>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@wubai
Copy link

wubai commented Apr 30, 2025

I have the same problem. Encode the sequence again is a good idea!

@DarkLight1337 DarkLight1337 requested a review from comaniac April 30, 2025 16:51
Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for finding out this interesting bug! Left a comment but otherwise LGTM

Xarbirus added 2 commits May 5, 2025 12:47
Signed-off-by: Mikhail Podvitskii <podvitskiymichael@gmail.com>
Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ywang96 ywang96 added the ready ONLY add when PR is ready to merge/full CI is needed label May 6, 2025
@ywang96 ywang96 enabled auto-merge (squash) May 6, 2025 05:30
@ywang96 ywang96 merged commit dc47ba3 into vllm-project:main May 6, 2025
49 of 51 checks passed
robertgshaw2-redhat added a commit to neuralmagic/vllm that referenced this pull request May 6, 2025
* [Model] Add GraniteMoeHybrid 4.0 model (vllm-project#17497)

Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Stanislaw Wozniak <stw@zurich.ibm.com>
Co-authored-by: Thomas Ortner <boh@zurich.ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>

* [easy] Fix logspam on PiecewiseBackend errors (vllm-project#17138)

Signed-off-by: rzou <zou3519@gmail.com>

* [Bugfix] Fixed prompt length for random dataset (vllm-project#17408)

Signed-off-by: Mikhail Podvitskii <podvitskiymichael@gmail.com>

* [Doc] Update notes for H2O-VL and Gemma3 (vllm-project#17219)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [Misc] Fix ScalarType float4 naming  (vllm-project#17690)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>

* Fix `dockerfilegraph` pre-commit hook (vllm-project#17698)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [Bugfix] Fix triton import with local TritonPlaceholder (vllm-project#17446)

Signed-off-by: Mengqing Cao <cmq0113@163.com>

* [V1] Enable TPU V1 backend by default (vllm-project#17673)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [V1][PP] Support PP for MultiprocExecutor (vllm-project#14219)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: jiang.li <jiang1.li@intel.com>

* [v1] AttentionMetadata for each layer (vllm-project#17394)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>

* [Feat] Add deprecated=True to CLI args (vllm-project#17426)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>

* [Docs] Use gh-file to add links to tool_calling.md (vllm-project#17709)

Signed-off-by: windsonsea <haifeng.yao@daocloud.io>

* [v1] Introduce KVCacheBlocks as interface between Scheduler and KVCacheManager (vllm-project#17479)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>

* [doc] Add RAG Integration example (vllm-project#17692)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* [Bugfix] Fix modality limits in vision language example (vllm-project#17721)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* Make right sidebar more readable in "Supported Models" (vllm-project#17723)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [TPU] Increase block size and reset block shapes (vllm-project#16458)

* [Misc] Add Next Edit Prediction (NEP) datasets support in `benchmark_serving.py` (vllm-project#16839)

Signed-off-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>
Signed-off-by: dtransposed <>
Co-authored-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>

* [Bugfix] Fix for the condition to accept empty encoder inputs for mllama (vllm-project#17732)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>

* [Kernel] Unified Triton kernel that doesn't distinguish between prefill + decode (vllm-project#16828)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>

---------

Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Stanislaw Wozniak <stw@zurich.ibm.com>
Signed-off-by: rzou <zou3519@gmail.com>
Signed-off-by: Mikhail Podvitskii <podvitskiymichael@gmail.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Mengqing Cao <cmq0113@163.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: jiang.li <jiang1.li@intel.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Signed-off-by: windsonsea <haifeng.yao@daocloud.io>
Signed-off-by: reidliu41 <reid201711@gmail.com>
Signed-off-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>
Signed-off-by: dtransposed <>
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
Co-authored-by: Stan Wozniak <77159600+s3woz@users.noreply.github.com>
Co-authored-by: Thomas Ortner <boh@zurich.ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
Co-authored-by: Richard Zou <zou3519@users.noreply.github.com>
Co-authored-by: Mikhail Podvitskii <podvitskiymichael@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Li, Jiang <jiang1.li@intel.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Aaron Pham <contact@aarnphm.xyz>
Co-authored-by: Michael Yao <haifeng.yao@daocloud.io>
Co-authored-by: Reid <61492567+reidliu41@users.noreply.github.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: Jevin Jiang <jevin0change@gmail.com>
Co-authored-by: d.transposed <damian.bogunowicz@gmail.com>
Co-authored-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>
Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
@Xarbirus Xarbirus deleted the random-dataset-fix branch May 7, 2025 07:30
@ksmusz
Copy link

ksmusz commented May 7, 2025

Hi @ywang96 @Xarbirus

I've also encountered the issue mentioned here. After seeing this resolution, I've tested it and saw that there are some cases, where it still behaves unexpectedly, making the benchmark a bit less stable.

Here's the repro:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", trust_remote_code=True)

# Tokens below were from some of the random token generation and decode/encode/decode process
tokens_1 = [107191, 107192, 107193, 5809, 107195, 107196, 107197]

# Adding token 107198 to the end of the list, as it has influence on the tokenization process
tokens_2 = tokens_1 + [107198]

# Decode the tokens
prompt_tokens_1 = tokenizer.decode(tokens_1)
prompt_tokens_2 = tokenizer.decode(tokens_2)

# Encode the tokens again
encoded_tokens_1 = tokenizer.encode(prompt_tokens_1, add_special_tokens=False)
encoded_tokens_2 = tokenizer.encode(prompt_tokens_2, add_special_tokens=False)

# Results
print(f"Original tokens_1: {tokens_1}")
print(f"Original tokens_1 length: {len(tokens_1)}")
print(f"Original tokens_1 as strings: {[tokenizer.decode([x]) for x in tokens_1]}\n")
print(f"Original tokens_2: {tokens_2}")
print(f"Original tokens_2 length: {len(tokens_2)}")
print(f"Original tokens_2 as strings: {[tokenizer.decode([x]) for x in tokens_2]}\n\n")

print(f"Encoded tokens_1, after one decode/encode process: {encoded_tokens_1}")
print(f"Encoded tokens_1 length: {len(encoded_tokens_1)}")
print(f"Encoded tokens_1 as strings: {[tokenizer.decode([x]) for x in encoded_tokens_1]}")
print(f"The encoded tokens_1 are the same as tokens_1: {encoded_tokens_1 == tokens_1}\n\n")

print(f"Encoded tokens_2, after one decode/encode process: {encoded_tokens_2}")
print(f"Encoded tokens_2 length: {len(encoded_tokens_2)}")
print(f"Encoded tokens_2 as strings: {[tokenizer.decode([x]) for x in encoded_tokens_2]}")
print(f"The encoded tokens_2 are the same as tokens_2: {encoded_tokens_2 == tokens_2}\n\n")

Running this code results in following output:
image

As far as I done some experiments, the issue happens, because the tokenizer generates the tokens differently in relation to the neighboring words to each token, even if these words are not part of the token. In the case above, if follows the використов, the latter word either will or will not be split to separate tokens.

When we truncate the encoded tokens to input length, we change the whole sentence and the tokenizer could then encode it differently the second time. If we try to truncate or add a padding to the word again, we will once more modify the sentence, hence we could end up in an infinite loop of modifying the prompts to suite the exact length we expect.

When running with default input length (1024) and 256 iterations of the randomized tokens (requests), the results were as following:
image
In this example of 10 executions in a loop, the maximum number of mismatches in a single 256 batch were 5, whereas the number of tokens generated was sometimes even more than 15% lower than expected (e.g. 849 out of expected 1024).

I'm not yet sure how to overcome this issue if we want to have the actual text passed as prompt to SampleRequest. The easiest and most robust solution would be to not pass text there, but only the randomly generated tokens. The only drawback would be that this way we would not be simulating the exact e2e scenario of the output generation, as the decode phase for the input would be skipped.
If we allow to skip the decode phase we could pass the generated tokens this way to SampleRequest, instead of all the things done in this PR:
image

Not sure on the right approach here, hence letting you know on my concerns.

RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: Mikhail Podvitskii <podvitskiymichael@gmail.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request May 14, 2025
Signed-off-by: Mikhail Podvitskii <podvitskiymichael@gmail.com>
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
Signed-off-by: Mikhail Podvitskii <podvitskiymichael@gmail.com>
Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants