-
Notifications
You must be signed in to change notification settings - Fork 826
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for incremental decoding #1141
Comments
@njhill there' s no way to treat correctly what you are asking for, By definition
|
Thanks @Narsil, I understand what you mean w.r.t. this technically not being possible in the general case. But in practice it appears this can be achieved easily for the most prominent tokenizers (at least all of the built-in ones excluding CTC) with information on whether some provided subsequence is at the start and/or end (or neither) of the full sequence. More generally I can imagine an abstract decoder function that returns a stateful "incremental decoder" object. This in turn has a method like The default implementation would just accumulate the raw token strings and return an empty string until the finished signal is sent at which point it calls But I think overrides for all of the standard tokenizers to be more "incremental" would be pretty straightforward. In case it isn't obvious the motivation here is streaming text generation :) Maybe that's considered too niche to worry about. |
That's a start of my worry. Tbh I don't think it's worth it for now. Feel free to open a PR, track a branch. If there's interest within the broader community we're always up to add useful features. But please bear in mind that maintaining them has a cost, that's why I'm not jumping on adding such new features (which indeed seem a bit niche, in addition to not necessarily be generally feasible). |
Thanks @Narsil, completely understand about being selective and the additional maintenance cost. I may open a PR when I get a chance to work on it, or possibly propose a simpler change that would at least make it easier to implement externally. |
Hi there! I'm not sure if this is the same issue, but we're trying to integrate This can be seen here:
I believe that this is because decoding an individual token runs through the entire pipeline, including the following "decoder": {
"type": "Sequence",
"decoders": [
{
"type": "Replace",
"pattern": {
"String": "▁"
},
"content": " "
},
{
"type": "ByteFallback"
},
{
"type": "Fuse"
},
{
"type": "Strip",
"content": " ",
"start": 1,
"stop": 0
}
]
}, The pseudocode for our logic looks something like this:
This issue seems to imply that we can't use HF Tokenizers like this for streaming output, which has me a little confused. I was hoping to find some answers in the smelte implementations of BERT/GPT-2, but it doesn't look like they attempt to retrieve token strings. How do we correctly handle streaming inference with decoded output? We'd really prefer not to decode all of the tokens on each token inference - that would lead to O(N^2) behaviour at a glance. |
It's O(N) but still not great indeed. Core of the issue is that you're decoding tokens one by one, and llama (most tokenizers in general) have weird logic that forces decode to behave differently based on neighbouring tokens. For instance, most tokens will start with a prefix space (like In a gist what you can do is :
Whenever the output of decode is valid (doesn't end with unknown utf-8 byte), then you can increment everything. This works for llama tokenizer. This should work for all generative tokenizer. This may work for all tokenizers (fortunately the weirdest tokenizers are usually not for generative models). For other readers, decoding tokens should really only be necessary for In |
Thanks for clarifying! That's hugely helpful. It's unfortunate that's necessary, but it makes sense given what you've mentioned. My O(N^2) fear was that you'd have to decode the entire string again with each incoming token, but the prefix/read offset solution should make that much less catastrophic. We'll try to implement this soon and report back if there are any issues! |
We just merged in our integration of HF Tokenizers - thanks for the help, very much appreciated 🙏 |
Glad it worked ! |
Hi there! While debugging rustformers/llm#298, I noticed that the incremental decoding approach of removing previous bytes breaks down as the resulting bytes can actually change with additional tokens: fn main() {
let tokenizer = tokenizers::Tokenizer::from_pretrained("allenai/tulu-7b", None).unwrap();
let tokens = tokenizer.encode("да гледам ќе идев", true).unwrap();
let tokens = tokens.get_ids();
for i in 0..tokens.len() {
let slice = &tokens[0..i];
let token = &tokenizer.decode(slice.to_vec(), true).unwrap();
let bytes = token.as_bytes();
println!("{bytes:?} ({slice:?})");
}
} results in
or, with the relevant lines focused and previously decoded bytes removed:
Our current decoding logic will send the Unfortunately, checking for UTF-8 correctness doesn't fix this; What would be the best way to proceed here? I'm thinking that we should potentially maintain a one-token buffer before reporting new incremental decodings (i.e. don't report new decodings until a common prefix has appeared in both the past and current decodings, then report the common prefix). I'm concerned that the same issue might occur with more tokens, however. Is this something we can account for? |
This doesn't seem valid utf-8: This is because Maybe we need to expose something lower level which would make more sense for Rust users of |
Ah, that makes sense! I was wondering about I assume that we might need to wait until the next release cycle for a fix - is it safe to assume that the only time the lossy replacement character will occur is during partial decodes like this? If so, we can try ignoring the decoding until there are no replacement characters. |
This is what is done in the Python version, you just need to check the last
No, models will output whatever ID they want, in whatever order, if the output results in invalid utf-8, then you will get invalid utf-8 glyph from the current |
Gotcha, thanks!
I'll try this soon - hopefully it sorts this out for us :) |
Can confirm that works. Much obliged! |
I would like to be able to decode a sequence of token ids incrementally in a decoder-agnostic manner. I haven't found a straightforward way to do this with the current API - the first token is treated differently by some decoders which means that in general
It would be really nice to have some kind of "continuation" flag to indicate that the result is intended to be be appended to an already-decoded prefix. So that you could have
It would also be nice to have a variant of this that takes either a single
u32
id or string token rather than a vec, for related reasons (latter could be used withid_to_token
).I'd love to know if there is another way to achieve this than my current ugly workaround :)
Current workaround
The text was updated successfully, but these errors were encountered: