Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] Make space insertion optional in SentencePiece tokenizer #3664

Closed
shibe2 opened this issue Oct 18, 2023 · 29 comments
Closed

[feature] Make space insertion optional in SentencePiece tokenizer #3664

shibe2 opened this issue Oct 18, 2023 · 29 comments
Labels
enhancement New feature or request stale

Comments

@shibe2
Copy link
Contributor

shibe2 commented Oct 18, 2023

Current Behavior

Since #2810, space is inserted into any non-empty text. This breaks multiple use cases:

  • Infill Incorrect Tokenization #3503 (comment)
  • Prompt size control: splitting long text into pieces and adding these pieces to the prompt until certain token limit is reached. It is desirable to know precise token count of each piece, but space insertion gets in the way here.
  • Prompt formats that use added tokens. Whether a space should be inserted after a special token depends on particular model used (how it was trained/finetuned and perhaps, alignment of stars). The decision should be left to the code that composes/formats prompts.
  • Yi tokenizer does not insert space.

Recently, space insertion was disabled in the case when text representation of special tokens is recognized:

auto raw_text = (special ? "" : " ") + fragment.raw_text.substr(fragment.offset, fragment.length);
This works for toying with main example when escape processing is enabled, but leaves other scenarios broken. In particular, when special token identifiers are added to the prompt by client and passed to server (which is a proper way to handle added tokens), it inserts space into each piece of text between special tokens.

Space insertion was made to match original Python implementation, but that behavior is itself not optimal, as evidenced by people having to hack around it.

Proposals

Option 1

Insert space only when BOS token is also inserted. There is a clear intersection between cases where BOS and space need to be inserted: when the whole prompt is one chunk of text. All broken cases that I listed here involve splitting and recombining the prompt.

The argument bos can be optionally renamed to something like full_prompt, or inverse partial, with meaning that would encompass the behaviors controlled by it.

Option 2

Add a separate argument that controls insertion of space. It would be used by /tokenize in server and in other places.

Option 3

Add another tokenization function with options that would control many aspects of the process. Suggested by @staviq.

@ggerganov
Copy link
Owner

Option 1 might actually be a good idea

@staviq
Copy link
Contributor

staviq commented Oct 18, 2023

In particular, when special token identifiers are added to the prompt by client and passed to server (which is a proper way to handle added tokens), it inserts space into each piece of text between special tokens.

Unless there's a bug ( and I don't think there is, I tested it extensively, though it's not impossible ), it's the other way around.

Tokenizing with "special=false" will only add space at the begining of the entire input string, and special tokens will be naively tokenized as plaintext. That is in line with old tokenizer behaviour.

Tokenizing with "special=true" will not add any spaces, neither at the begining of the input string nor to the substrings between special tokens.

EDIT:

Option 1 would leave you with stray space if special tokens are enabled and add_bos is enabled.

Example:
<|im_start|>user\nHi.<|im_end|>\n

Current behavior with special=false, same as old behavior:
bos, <,|,im,_,start,|,>,user,endl,Hi,.,<,|,im,_,end,|,>,endl

Current (new) behaviour with special=true:
bos,<|im_start|>,user,endl,Hi,.,<|im_end|>,endl

Tying space prefix to add bos when special=true would leave you with:
bos, ,<|im_start|>,user,endl,Hi,.,<|im_end|>,endl

@shibe2
Copy link
Contributor Author

shibe2 commented Oct 18, 2023

@staviq I just tested it.

request: {"prompt":[1,32001,"name\nmessage",32002,"\n",32001,"name\nmessage",32002,"\n",32001,"name\n"]}
tokenized: [1,32001,1024,13,4906,32002,29871,13,32001,1024,13,4906,32002,29871,13,32001,1024,13]
decoded: "<s><|im_start|> name\nmessage<|im_end|> \n<|im_start|> name\nmessage<|im_end|> \n<|im_start|> name\n"

As you can see, there are no spaces at all in the request, but tokenizer uses tokens that start with space. That is, it inserts space into each text fragment.

@staviq
Copy link
Contributor

staviq commented Oct 18, 2023

@staviq I just tested it.

request: {"prompt":[1,32001,"name\nmessage",32002,"\n",32001,"name\nmessage",32002,"\n",32001,"name\n"]} tokenized: [1,32001,1024,13,4906,32002,29871,13,32001,1024,13,4906,32002,29871,13,32001,1024,13] decoded: "<s><|im_start|> name\nmessage<|im_end|> \n<|im_start|> name\nmessage<|im_end|> \n<|im_start|> name\n"

As you can see, there are no spaces at all in the request, but tokenizer uses tokens that start with space. That is, it inserts space into each text fragment.

Oh, I see, server has it's own implementation of tokenize:

std::vector<llama_token> tokenize(const json & json_prompt, bool add_bos) const
{
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
// or the first element of the json_prompt array is a string.
std::vector<llama_token> prompt_tokens;
if (json_prompt.is_array())
{
bool first = true;
for (const auto& p : json_prompt)

I thought you mean new special token handling broke this, but you meant it didn't fix an already existing problem in server, correct ?

@shibe2
Copy link
Contributor Author

shibe2 commented Oct 18, 2023

Yes, the compulsory space is affecting server's behavior ever since it was introduced in #2810. Maybe we can call it a problem, but I see it as a missing feature. In many cases, inserting space is desirable, but it takes control over it away from the client.

If my option 1 is chosen, it may be that nothing needs to be changed in the server code. But it should be documented in which cases the client needs to add spaces.

If my option 2 is chosen, the server can be made to behave like in option 1 by default and have optional parameters that control insertion of spaces. Again, with appropriate documentation.

@grencez
Copy link
Contributor

grencez commented Oct 25, 2023

Spaces aren't opportunistically inserted within the text, so doing it only when add_bos==true would make sense to me. (Option 1)

A little extra info: The OP links to a hack that prepends a "🙂". Pretty smart! I've been prepending a newline, but am not sure if it will reliably yield a unique newline token for all models in all cases. If tokenizers are really unpredictable, clients that edit the context should be retokenizing at arbitrary boundaries (I think this is called "token healing"), which is also complicated by space insertion.

@shibe2
Copy link
Contributor Author

shibe2 commented Oct 26, 2023

@grencez I am also using newline (LF) for working around the space insertion. I'm curious to see a model that uses SentencePiece tokenizer and has multiple tokens that contain newline, which could make this usage of newline problematic. One of things that I also do is splitting text on token boundaries, so this will be important even when space insertion will be controllable.

@grencez
Copy link
Contributor

grencez commented Oct 26, 2023

I'm curious to see a model that uses SentencePiece tokenizer and has multiple tokens that contain newline

@shibe2 The CausalLM GGUF models on HuggingFace do this sometimes. " \n\n" becomes 2 tokens (" \n" and "\n"). Those quantizations are still experimental and seem buggy (e.g., llama_token_nl() differs from how "\n" is tokenized), so don't jump to any conclusions just yet.

@shibe2
Copy link
Contributor Author

shibe2 commented Nov 17, 2023

Given that, BOS and initial space no longer seem to be nicely coupled. Therefore, I prefer my option 2 – adding a new argument to tokenization functions.

@KerfuffleV2
Copy link
Collaborator

Something definitely should happen with this. The current behavior destroys the quality of some models. See #4081 (comment)

The way we token also clearly differs from the official Python tokenizing stuff. (Example in the comment I linked.)

@staviq
Copy link
Contributor

staviq commented Nov 18, 2023

It seems the discussion is happening in parallel, here and in #4081

I looks to me like a single llama_tokenize function might not be enough to solve this, because it seems like for every possible solution so far, there is a mutually exclusive use case.

So my proposal is to add llama_tokenize_advanced ( or something similar ), and I think we might want to use single bitfiled parameter for "settings" ( ORable enum )

This way, we won't break llama.cpp bindings when adding function arguments ( we/I did accidentally break llama-cpp-python by adding special before ), and we would be able to modify and add functionality to the tokenizer, without breaking compatibility in the future.

@KerfuffleV2
Copy link
Collaborator

So my proposal is to add llama_tokenize_advanced

How about doing that but make it take a struct with options? Sort of like how the quantize stuff has struct llama_model_quantize_params. Then you can basically add an unlimited amount of options to it without making the parameter passing insane or breaking the API whenever a new option gets added.

@staviq
Copy link
Contributor

staviq commented Nov 18, 2023

I was about to say that, but I thought it would be to radical of a change

But this is a good idea, Vulkan for example uses that to solve it's problem of ungodly amount of parameters in function calls.

@shibe2
Copy link
Contributor Author

shibe2 commented Nov 18, 2023

So my proposal is to add llama_tokenize_advanced ( or something similar )

I added it as option 3. I think, after adding that function we would still need to decide what to do with the existing function, because its API is not in a good shape given current needs.

@KerfuffleV2
Copy link
Collaborator

we would still need to decide what to do with the existing function

It could call the advanced function with the options for special and add_bos set appropriately. Possibly also deprecate it if it really just doesn't have a use case anymore. I guess if we were going to say that then we should probably just break the API and replace the existing function with the proposed advanced one.

@cebtenzzre
Copy link
Collaborator

cebtenzzre commented Nov 21, 2023

Once HF supports add_prefix_space for SPM (#3538 (comment)), YiTokenizer should no longer be needed upstream (AFAIK) and we can store the value of add_prefix_space in the GGUF for both SPM and GPT2 tokenizers.

@crasm
Copy link
Contributor

crasm commented Jan 15, 2024

Would anybody be averse to me implementing Option 1?

My code does incremental tokenization/decoding, and I'm getting weird spaces where they should not be appearing. I'd rather fix this now than work around it.

@ggerganov
Copy link
Owner

I believe the original problem discussed in this issue has been resolved:

llama.cpp/llama.cpp

Lines 7096 to 7100 in ddb008d

auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
if (&fragment == &fragment_buffer.front()) {
raw_text = " " + raw_text; // prefix with space if the first token is not special
}

Can you give a specific example and how you use llama.cpp tokenization that does not produce the expected results?

@crasm
Copy link
Contributor

crasm commented Jan 15, 2024

@ggerganov

I saw becomes:

INFO   : 106070: LlamaCppService: Adding text: ```I saw```
FINE   : 106071: ContextTokens: Added BOS token to context
INFO   : 106071: LlamaCppService: Added 3 tokens: [
   0:     1 = <s>
   1:   306 = ▁I
   2:  4446 = ▁saw
]

Changing this to I saw Sam becomes:

INFO   : 163750: LlamaCppService: Adding text: ``` Sam```
INFO   : 163750: LlamaCppService: Added 2 tokens: [
   0: 29871 = ▁
   1:  3685 = ▁Sam
]

I expect Sam to become just 3685 = ▁Sam.

I could trim prefix spaces from the added partial text, or throw away the first 29871 token from the result. However, this behavior is surprising and complicates the client.

I think it would improve the API to not add a space in this case.

@ggerganov
Copy link
Owner

This tokenization matches what I get when I use the SentencePiece implementation via Python and LLaMA tokenizer:

from sentencepiece import SentencePieceProcessor
tokenizer = SentencePieceProcessor(dir_tokenizer + '/tokenizer.model')

print(tokenizer.encode('I saw', add_bos=True))
# [1, 306, 4446]
print(tokenizer.encode(' Sam'))
# [29871, 3685]

I think it is important to match the reference implementation API so the leading 29871 whitespace token is correct and expected in this case.

For example, in this Python example, how would you use the SentencePiece tokenizer to achieve your goal? If you can provide an example to do that, we could decide on approach to extend llama.cpp tokenization API in a similar way so that it is consistent. But simply removing the whitespace is not an option because this simple example will no longer work correctly.

@grencez
Copy link
Contributor

grencez commented Jan 17, 2024

For example, in this Python example, how would you use the SentencePiece tokenizer to achieve your goal?

You have to set NormalizerSpec's add_dummy_prefix field as false. Conveniently, a recent SentencePiece commit added this in a test (https://github.com/google/sentencepiece/blob/de1747bbd4b4f35c1f0432851a9fcd4def61c0b0/python/test/sentencepiece_test.py#L851-L866). This syntax seems too cumbersome for llama.cpp's API.

I do think the feature is important though, especially when dealing with prompt formats where there's a special token followed by non-space characters. For example, a user's input in the ChatML format starts with a <|im_start|> token, then a string "user" with no preceding space, then a newline, then the user's input text which also doesn't begin with a space. To me, the cleanest/fastest way to append the user's input text would involve a tokenize function that doesn't add a space prefix.

@crasm
Copy link
Contributor

crasm commented Jan 21, 2024

Here's a relevant issue on SentencePiece: google/sentencepiece#282.

Copy link
Contributor

github-actions bot commented Apr 4, 2024

This issue was closed because it has been inactive for 14 days since being marked as stale.

@github-actions github-actions bot closed this as completed Apr 4, 2024
@cebtenzzre cebtenzzre removed the stale label Apr 4, 2024
@cebtenzzre
Copy link
Collaborator

Not stale - spaces are still inserted by default with special=False, and this is undesirable when joining multiple tokenized texts.

@cebtenzzre cebtenzzre reopened this Apr 4, 2024
@github-actions github-actions bot added the stale label May 5, 2024
Copy link
Contributor

This issue was closed because it has been inactive for 14 days since being marked as stale.

Copy link
Contributor

github-actions bot commented Jul 4, 2024

This issue was closed because it has been inactive for 14 days since being marked as stale.

@github-actions github-actions bot closed this as completed Jul 4, 2024
@shibe2
Copy link
Contributor Author

shibe2 commented Jul 4, 2024

I believe, this is still an issue. Although for my builds, I disable this space insertion altogether.

@shibe2 shibe2 reopened this Jul 4, 2024
@github-actions github-actions bot removed the stale label Jul 5, 2024
@prashanthsadasivan
Copy link

I recently ran into this as well. I'm working on a library to make integrating a local llama file into an ios app easier, and one feature is the idea that you can steer the assistant by force choosing an option, based on what it said so far. the way it works is that sample grabs the token, converts it to a string using llama_token_to_piece, and then calls a callback that a developer can use to return what the next string of words that the assistant should say next - often just returning the sampled value. But I chose to have the developer return it as a string, as opposed to as a llama_token.

Since the whitespace was already unescaped from llama_token_to_piece, i'd like to be able to pass the string value of back to llama_tokenize in a way that won't add another preceding space.

There are tokens in the vocab that shouldn't have a space prepended to it. for example, the word "I'm" gets tokenized like so:

from sentencepiece import SentencePieceProcessor
tokenizer = SentencePieceProcessor(dir_tokenizer + '/tokenizer.model')
print(tokenizer.encode("I'm"))
# [306, 29915, 29885]

but if you tokenize each one character by character

print(tokenizer.encode("I"))
# [306]
print(tokenizer.encode("'"))
# [525]
print(tokenizer.encode("m"))
# [286]

you get different values. If you decode those tokens together, you get "I ' m" instead of "I'm".

print(tokenizer.decode([306, 525, 286]))
"I ' m"

Basically, it would be nice to be able to take the input string "m" and be able to choose 29885 (which does not have prepended whitespace), instead of 286 (which has the U+2581 space prepended).

Screenshot 2024-07-04 at 11 59 44 PM Screenshot 2024-07-05 at 12 00 16 AM

For my issue, I can sorta work around it for sampled tokens, and possibly design the library around it, but I was hoping to not need to expose the library users to the intricacies of tokenization (part of the inspiration for the library :P ). Does that make sense for why an option to not prepend a space?

Copy link
Contributor

github-actions bot commented Sep 2, 2024

This issue was closed because it has been inactive for 14 days since being marked as stale.

@github-actions github-actions bot closed this as completed Sep 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request stale
Projects
None yet
Development

No branches or pull requests

8 participants