Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"/v1/chat/completions" tokenization issue #2012

Closed
SaulLu opened this issue Dec 11, 2023 · 7 comments
Closed

"/v1/chat/completions" tokenization issue #2012

SaulLu opened this issue Dec 11, 2023 · 7 comments

Comments

@SaulLu
Copy link

SaulLu commented Dec 11, 2023

Context

The "/v1/chat/completions" endpoint uses the apply_chat_template method of the HF tokenizers. It seems to us that these templates take care of adding special tokens (cf. this line from Llama's default template). However, tokenization in vLLM also seems to add special token(s) if this is the tokenizer's default behavior - in particular, the Llama tokenizer adds a BOS token at the start of its tokenization.

There are therefore configurations in which the final tokenization will contain more special tokens than necessary.

Repro

In a terminal, launch a vLLM server. For example:

python -m vllm.entrypoints.openai.api_server --model TheBloke/Llama-2-7B-Chat-AWQ

In another terminal, request this server:

from openai import OpenAI

# Set OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "None"
openai_api_base = f"http://{FILL_ME}/v1"

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

chat_response = client.chat.completions.create(
    model="TheBloke/Llama-2-7B-Chat-AWQ",
    messages=[
        {"role": "user", "content": "Tell me a joke."},
        {
            "role": "assistant",
            "content": " Ah, a moment of levity you seek! Very well. Pray, allow me to regale you with this humorous anecdote:\n\nWhy don't historians play cricket?\n\nBecause they prefer to leave their past in the archives!\n\nAnd now, if you'll excuse me, I must return to my scholarly pursuits. Although, I must admit, it is rather refreshing to engage in such frivolous banter from time to time.",
        },
        {"role": "user", "content": "Another one."},
    ],
)
print("Chat response:", chat_response)

Output:

async_llm_engine.py:379] 
Received request cmpl-cca85113d5af4178b3c93fb2c2b72578: 
prompt: "<s>[INST] Tell me a joke. [/INST] Ah, a moment of levity you seek! Very well. Pray, allow me to regale you with this humorous anecdote:\n\nWhy don't historians play cricket?\n\nBecause they prefer to leave their past in the archives!\n\nAnd now, if you'll excuse me, I must return to my scholarly pursuits. Although, I must admit, it is rather refreshing to engage in such frivolous banter from time to time. </s><s>[INST] Another one. [/INST]", 
sampling params: SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.7, top_p=1.0, top_k=-1, min_p=0.0, use_beam_search=False, length_penalty=1.0, early_stopping=False, stop=[], stop_token_ids=[], ignore_eos=False, max_tokens=3959, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True), 
prompt token ids: [1, 1, 518, 25580, 29962, 24948, 592, 263, 2958, 446, 29889, 518, 29914, 25580, 29962, 9070, 29892, 263, 3256, 310, 14453, 537, 366, 16508, 29991, 18064, 1532, 29889, 349, 764, 29892, 2758, 592, 304, 1072, 744, 366, 411, 445, 3165, 20657, 385, 687, 29881, 866, 29901, 13, 13, 11008, 1016, 29915, 29873, 3603, 5834, 1708, 2181, 8522, 29973, 13, 13, 29933, 5658, 896, 5821, 304, 5967, 1009, 4940, 297, 278, 3190, 3145, 29991, 13, 13, 2855, 1286, 29892, 565, 366, 29915, 645, 5566, 1509, 592, 29892, 306, 1818, 736, 304, 590, 21344, 368, 12359, 19544, 29889, 8512, 29892, 306, 1818, 20000, 29892, 372, 338, 3265, 2143, 690, 2790, 304, 3033, 482, 297, 1316, 285, 1150, 324, 681, 9892, 357, 515, 931, 304, 931, 29889, 29871, 2, 1, 518, 25580, 29962, 7280, 697, 29889, 518, 29914, 25580, 29962].

We can see that the prompt token ids start with two 1s instead of one.

This issue also impacts the new mistralai/Mixtral-8x7B-Instruct-v0.1 model added in the PR #2011

@Tostino
Copy link
Contributor

Tostino commented Dec 11, 2023

I did document this issue with the Mistral template in one of the PRs: #1493 (comment)

@SemMulder
Copy link

The issue seems to be that add_special_token=False is not passed below, since add_special_tokens=True is the default.

prompt_token_ids = self.tokenizer.encode(prompt)

However, like mentioned above, the apply_chat_template call below already inserts those tokens:

prompt = tokenizer.apply_chat_template(

@SemMulder
Copy link

I don't see any use in having add_special_token=True there, or am I missing something?

@SemMulder
Copy link

Note that this not only impacts the "/v1/chat/completions" endpoint, but also when using an embedded LLMEngine which gets prompts passed in which have been formatted using apply_chat_template.

@Tostino
Copy link
Contributor

Tostino commented Dec 13, 2023

I just want to note that the special tokens being in the template or not is totally upto the author of the model/template. Most templates don't add special tokens that i've seen. But maybe that is a mistake, and the templates really should be forced to handle the special tokens. I wonder what the best way forward is...

That obviously conflicts with the HF default of adding the tokens though.

@SemMulder
Copy link

I just want to note that the special tokens being in the template or not is totally upto the author of the model/template.

True! For the model I was working with (OpenChat 3.5), they had add_bos_token = True but also had the bos_token in the prompt template explicitly. Same with mistralai/Mistral-7B-Instruct-v0.1, and a few others I checked. So this is most likely also an error on their side.

However, the reason I think "it just works" with vanilla HF is because they explicitly set add_special_tokens=False when tokenizing using apply_chat_template:

https://github.com/huggingface/transformers/blob/17506d1256c1780efc9e2a5898a828c10ad4ea69/src/transformers/tokenization_utils_base.py#L1750

So to be compatible with HF we would have to do something similar...

@freckletonj
Copy link

It appears that the completion endpoint, v1/completions also adds BOS and EOS. I'm applying a chat template to an Instruct-tuned model myself, and using this endpoint, and I was getting fishy results. The reason I found was that this endpoint automatically adds <s> ... </s>.

The remedy for me was to use the tokenize endpoint myself and pass the tokens directly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants