Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for incremental decoding #1141

Closed
njhill opened this issue Jan 4, 2023 · 16 comments
Closed

Support for incremental decoding #1141

njhill opened this issue Jan 4, 2023 · 16 comments

Comments

@njhill
Copy link

njhill commented Jan 4, 2023

I would like to be able to decode a sequence of token ids incrementally in a decoder-agnostic manner. I haven't found a straightforward way to do this with the current API - the first token is treated differently by some decoders which means that in general

decode([1,2,3]) != decode([1]) + decode([2]) + decode([3])

It would be really nice to have some kind of "continuation" flag to indicate that the result is intended to be be appended to an already-decoded prefix. So that you could have

decode([1,2,3]) == decode([1]) + decode'([2]) + decode'([3])

It would also be nice to have a variant of this that takes either a single u32 id or string token rather than a vec, for related reasons (latter could be used with id_to_token).

I'd love to know if there is another way to achieve this than my current ugly workaround :)

Current workaround
pub(crate) struct Decoder {
    pub(crate) tokenizer: Tokenizer,
    prefix_id: u32,
    prefix: String,
}

impl Decoder {
    pub(crate) fn new(tokenizer: Tokenizer) -> Decoder {
        let prefix_id = tokenizer.token_to_id("A").unwrap();
        Decoder {
            prefix_id,
            prefix: tokenizer.decode(vec![prefix_id], false).unwrap(),
            tokenizer,
        }
    }

    /// Decode continuation tokens to be added to some existing text
    pub(crate) fn decode_continuation(&self, mut ids: Vec<u32>) -> tokenizers::Result<String> {
        // How we handle this depends on the specific decoder's behaviour,
        // see each one's implementation of decode_chain in the tokenizers library.
        match self.tokenizer.get_decoder() {
            Some(ByteLevel(_)) => {
                // Lossless - call standard decode function
                self.tokenizer.decode(ids, true)
            },
            Some(Metaspace(_)) | Some(WordPiece(_)) | Some(BPE(_)) => {
                // For these, the first token in the sequence is treated differently,
                // so we add and then strip a placeholder token.
                ids.insert(0, self.prefix_id);
                let result = self.tokenizer.decode(ids, true)?;
                Ok(result.strip_prefix(&self.prefix).ok_or(DecodingError)?.to_string())
            },
            None => {
                // Just prepend a space
                Ok(format!(" {}", self.tokenizer.decode(ids, true)?))
            },
            _ => Err(UnsupportedTokenizerError.into())
        }
    }
}
@Narsil
Copy link
Collaborator

Narsil commented Jan 4, 2023

@njhill there' s no way to treat correctly what you are asking for,

By definition decoder is a one step thing. Decoders are allowed (and do) check previous and next tokens to check how to decode themselves. So making it iterative is not really feasible. Your code kind of works, but expect some weird things.

decoding is a best effort to go from tokens to string. On some tokenizers it's trivial (ByteLevel but on others there's more construction to it and the primary goal is to make the end result as humanly as readable as possible (and ideally such than encode(decode(ids) == ids but even that is not entirely guaranteed.

@njhill
Copy link
Author

njhill commented Jan 4, 2023

Thanks @Narsil, I understand what you mean w.r.t. this technically not being possible in the general case.

But in practice it appears this can be achieved easily for the most prominent tokenizers (at least all of the built-in ones excluding CTC) with information on whether some provided subsequence is at the start and/or end (or neither) of the full sequence.

More generally I can imagine an abstract decoder function that returns a stateful "incremental decoder" object. This in turn has a method like fn next(token_ids: Vec<u32>) -> String, where e.g. passing an empty vector could indicate that the sequence has finished.

The default implementation would just accumulate the raw token strings and return an empty string until the finished signal is sent at which point it calls decode_chain and returns the results.

But I think overrides for all of the standard tokenizers to be more "incremental" would be pretty straightforward.

In case it isn't obvious the motivation here is streaming text generation :) Maybe that's considered too niche to worry about.

@Narsil
Copy link
Collaborator

Narsil commented Jan 4, 2023

excluding CTC

That's a start of my worry.

Tbh I don't think it's worth it for now. Feel free to open a PR, track a branch. If there's interest within the broader community we're always up to add useful features.

But please bear in mind that maintaining them has a cost, that's why I'm not jumping on adding such new features (which indeed seem a bit niche, in addition to not necessarily be generally feasible).

@njhill
Copy link
Author

njhill commented Jan 6, 2023

Thanks @Narsil, completely understand about being selective and the additional maintenance cost. I may open a PR when I get a chance to work on it, or possibly propose a simpler change that would at least make it easier to implement externally.

@philpax
Copy link

philpax commented May 24, 2023

Hi there! I'm not sure if this is the same issue, but we're trying to integrate tokenizers in llm and have encountered an issue where our use of Tokenizer::decode on single tokens, as is produced by the LLM, results in the removal of spaces between tokens.

This can be seen here:

llm # cargo run --release llama infer -m models/llama/7B/WizardLM-7B-uncensored.ggml.qnt1.q4_0.bin -p "I am testing out prompt tokenization and" -r ehartford/WizardLM-7B-Uncensored

[...]

Iamtestingoutprompttokenizationandwouldliketoknowifyouhaveanysuggestionsonhowtoimproveit.

I believe that this is because decoding an individual token runs through the entire pipeline, including the following decoder from the config:

  "decoder": {
    "type": "Sequence",
    "decoders": [
      {
        "type": "Replace",
        "pattern": {
          "String": ""
        },
        "content": " "
      },
      {
        "type": "ByteFallback"
      },
      {
        "type": "Fuse"
      },
      {
        "type": "Strip",
        "content": " ",
        "start": 1,
        "stop": 0
      }
    ]
  },

The pseudocode for our logic looks something like this:

tokenizer = Tokenizer::from_pretrained(...);
encoded = tokenizer.encode(prompt);

inference.feed(encoded, |token_id| print(tokenizer.decode(token_id)))

loop:
  let token_id = inference.evaluate_and_sample()
  if token_id == eot:
    bail
  print(tokenizer.decode(token_id))

This issue seems to imply that we can't use HF Tokenizers like this for streaming output, which has me a little confused. I was hoping to find some answers in the smelte implementations of BERT/GPT-2, but it doesn't look like they attempt to retrieve token strings.

How do we correctly handle streaming inference with decoded output? We'd really prefer not to decode all of the tokens on each token inference - that would lead to O(N^2) behaviour at a glance.

@njhill
Copy link
Author

njhill commented May 25, 2023

@philpax I rolled my own incremental decoder implementation here, which takes into account differences between the different types of tokenizers, but I'm not sure it will work with arbitrary custom/Sequence type decoders.

@Narsil
Copy link
Collaborator

Narsil commented May 25, 2023

the tokens on each token inference - that would lead to O(N^2) behaviour at a glance.

It's O(N) but still not great indeed.

Core of the issue is that you're decoding tokens one by one, and llama (most tokenizers in general) have weird logic that forces decode to behave differently based on neighbouring tokens. For instance, most tokens will start with a prefix space (like " hello').
So it's normalized to "hello" when decoded on its own, since " hello" is weird (that prefix space is also added during the normalizing phase, so it needs not to be there when decoded). Which will indeed crash the decode one token logic to keep spaces.

In a gist what you can do is :

  • keep all the ids,
  • remember which is the last decoded ones (some tokens cannot be decoded on their own, because of byte_fallback, some tokens will be invalid utf-8 bytes).
  • Keep an extra token for the decode so you can keep the extra space in between if it's there.

Whenever the output of decode is valid (doesn't end with unknown utf-8 byte), then you can increment everything.

full code: https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/models/model.py#L51-L76

This works for llama tokenizer. This should work for all generative tokenizer. This may work for all tokenizers (fortunately the weirdest tokenizers are usually not for generative models).

For other readers, decoding tokens should really only be necessary for generative work. Please use offsets for anything non-generative, they work 100% of the case and don't have any flaws (compared to this).

In text-generation-inference. The last returned value by the client is a decode called on all ids, that's the only we 100% guarantee on all tokenizers is good.

@philpax
Copy link

philpax commented May 27, 2023

Thanks for clarifying! That's hugely helpful. It's unfortunate that's necessary, but it makes sense given what you've mentioned.

My O(N^2) fear was that you'd have to decode the entire string again with each incoming token, but the prefix/read offset solution should make that much less catastrophic.

We'll try to implement this soon and report back if there are any issues!

@philpax
Copy link

philpax commented May 29, 2023

We just merged in our integration of HF Tokenizers - thanks for the help, very much appreciated 🙏

@Narsil
Copy link
Collaborator

Narsil commented Jun 5, 2023

Glad it worked !

@philpax
Copy link

philpax commented Jun 28, 2023

Hi there! While debugging rustformers/llm#298, I noticed that the incremental decoding approach of removing previous bytes breaks down as the resulting bytes can actually change with additional tokens:

fn main() {
    let tokenizer = tokenizers::Tokenizer::from_pretrained("allenai/tulu-7b", None).unwrap();
    let tokens = tokenizer.encode("да гледам ќе идев", true).unwrap();
    let tokens = tokens.get_ids();
    for i in 0..tokens.len() {
        let slice = &tokens[0..i];
        let token = &tokenizer.decode(slice.to_vec(), true).unwrap();
        let bytes = token.as_bytes();
        println!("{bytes:?} ({slice:?})");
    }
}

results in

[] ([])
[] ([1])
[208, 180, 208, 176] ([1, 3574])
[208, 180, 208, 176, 32, 208, 179] ([1, 3574, 1214])
[208, 180, 208, 176, 32, 208, 179, 208, 187, 208, 181] ([1, 3574, 1214, 753])
[208, 180, 208, 176, 32, 208, 179, 208, 187, 208, 181, 208, 180, 208, 176] ([1, 3574, 1214, 753, 840])
[208, 180, 208, 176, 32, 208, 179, 208, 187, 208, 181, 208, 180, 208, 176, 208, 188] ([1, 3574, 1214, 753, 840, 29959])
[208, 180, 208, 176, 32, 208, 179, 208, 187, 208, 181, 208, 180, 208, 176, 208, 188, 32] ([1, 3574, 1214, 753, 840, 29959, 29871])
[208, 180, 208, 176, 32, 208, 179, 208, 187, 208, 181, 208, 180, 208, 176, 208, 188, 32, 239, 191, 189] ([1, 3574, 1214, 753, 840, 29959, 29871, 212])
[208, 180, 208, 176, 32, 208, 179, 208, 187, 208, 181, 208, 180, 208, 176, 208, 188, 32, 209, 156] ([1, 3574, 1214, 753, 840, 29959, 29871, 212, 159])
[208, 180, 208, 176, 32, 208, 179, 208, 187, 208, 181, 208, 180, 208, 176, 208, 188, 32, 209, 156, 208, 181] ([1, 3574, 1214, 753, 840, 29959, 29871, 212, 159, 29919])
[208, 180, 208, 176, 32, 208, 179, 208, 187, 208, 181, 208, 180, 208, 176, 208, 188, 32, 209, 156, 208, 181, 32, 208, 184, 208, 180, 208, 181] ([1, 3574, 1214, 753, 840, 29959, 29871, 212, 159, 29919, 28866])

or, with the relevant lines focused and previously decoded bytes removed:

[..., 32] ([1, 3574, 1214, 753, 840, 29959, 29871])
[..., 32, 239, 191, 189] ([1, 3574, 1214, 753, 840, 29959, 29871, 212])
[..., 32, 209, 156] ([1, 3574, 1214, 753, 840, 29959, 29871, 212, 159])
[..., 32, 209, 156, 208, 181] ([1, 3574, 1214, 753, 840, 29959, 29871, 212, 159, 29919])

Our current decoding logic will send the 32, then the 239, 191, 189, then fail because the next decoding switches to 209, 156 for the last few bytes, and the resulting array is shorter than the previously decoded bytes. You can hack around this to return an empty slice if the bounds are exceeded, but the result is already wrong as we've sent incorrect bytes.

Unfortunately, checking for UTF-8 correctness doesn't fix this; [..., 32, 239, 191, 189] is a valid UTF-8 string.

What would be the best way to proceed here? I'm thinking that we should potentially maintain a one-token buffer before reporting new incremental decodings (i.e. don't report new decodings until a common prefix has appeared in both the past and current decodings, then report the common prefix). I'm concerned that the same issue might occur with more tokens, however.

Is this something we can account for?

@Narsil
Copy link
Collaborator

Narsil commented Jun 29, 2023

This doesn't seem valid utf-8:
212 means byte 209, but the bytes you're getting is valid but the unknown utf-8 glyph:
https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=3e3695bb8b0f8d9318751691a6a2c85f

This is because tokenizers uses String:from_utf8_lossy(..) to prevent crashing within Python.

Maybe we need to expose something lower level which would make more sense for Rust users of tokenizers actually.

@philpax
Copy link

philpax commented Jun 29, 2023

Ah, that makes sense! I was wondering about tokenizers returning a String, actually - it makes sense that we're seeing the lossy conversion here.

I assume that we might need to wait until the next release cycle for a fix - is it safe to assume that the only time the lossy replacement character will occur is during partial decodes like this? If so, we can try ignoring the decoding until there are no replacement characters.

@Narsil
Copy link
Collaborator

Narsil commented Jun 29, 2023

If so, we can try ignoring the decoding until there are no replacement characters.

This is what is done in the Python version, you just need to check the last char for it.

is it safe to assume that the only time the lossy replacement character will occur is during partial decodes like this?

No, models will output whatever ID they want, in whatever order, if the output results in invalid utf-8, then you will get invalid utf-8 glyph from the current decode.

@philpax
Copy link

philpax commented Jun 29, 2023

Gotcha, thanks!

This is what is done in the Python version, you just need to check the last char for it.

I'll try this soon - hopefully it sorts this out for us :)

@philpax
Copy link

philpax commented Jun 29, 2023

Can confirm that works. Much obliged!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants