Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

proposal: Move token decoding and stopping evaluation to router #138

Closed
wants to merge 4 commits into from

Conversation

njhill
Copy link
Contributor

@njhill njhill commented Mar 26, 2023

Benefits:

  • Centralizes this logic that's on the critical inference loop path and does it in rust instead of python
  • Simplifies python side of the code, decoupling next-token generation and batch pruning operations
  • Allows for cancelling requests when client closes connection/stream
  • Allows time-limit based stopping criteria (which is otherwise tricky in multi-shard case)

I've included implementation of IncrementalDecoder logic that addresses the various problems associated with incremental decoding when streaming for the various kinds of tokenizers. Specifically:

The main change is removal of the stopping criteria logic from the python model generate_token methods, and adding a separate prune method to the batch classes that prunes the batch based on a list of completed request ids. These lists are passed in via a second parameter added to the internal Decode rpc.

A new optional time_limit_ms parameter is added to the external API.

Given that this changes the internal python APIs a bit, the python-based tests need some adjustments. I've only partially done that, and set most of them to be skipped for now.

@OlivierDehaene
Copy link
Member

Hello! Thanks for the PR.

First a disclaimer: I only skimmed through the PR but I want to get this discussion started.

Benefits:

  • Centralizes this logic that's on the critical inference loop path and does it in rust instead of python

I would argue that the decoding logic is already "in rust" since right now text-generation-inference only supports models with fast tokenizers aka rust tokenizers.

  • Simplifies python side of the code, decoupling next-token generation and batch pruning operations

I'm not sure I see what you mean here. From what I saw, you moved part of the decode_tokenlogic to the Batch class and the pruning logic execution from the end of the Decode method to the beginning. The decoding is removed and moved to Rust but now more complex.

  • Allows for cancelling requests when client closes connection/stream
  • Allows time-limit based stopping criteria (which is otherwise tricky in multi-shard case)

True. But this could have been implemented with the current system.

  1. Add next_batch_keep_indices to the batch metadata
  2. move the pruning from the end to the beggining of decode
  3. Add drop_indices to the Batch proto
  4. Prune using both next_batch_keep_indices and drop_indices

Do you have an example use case for timi-limit based stopping criteria?

I've included implementation of IncrementalDecoder logic that addresses the various problems associated with incremental decoding when streaming for the various kinds of tokenizers. Specifically:

Shoudn't this be added directly to tokenizers?

The main claim of this PR is that moving the code from Python to Rust is faster.
Here are my thoughts on this subject:

  1. I need to see benchmarks in different scenarios: is the end-to-end latency over a single request lower? What about batching?
  2. I like having the tokenizers in the Python code. It allows for greater flexibility. A lot of new models don't have fast tokenizers yet (llama, ChatGLM, even OPT). I'm not sure if rellying more on tokenizers is the correct way to go. I was about to make the validation/truncation optional if no fast tokenizer is found and rellying entirely on Python to support this kind of models.

@njhill
Copy link
Contributor Author

njhill commented Mar 26, 2023

Thanks @OlivierDehaene, and sorry for the PR being quite large. For context - I have been making many changes/additions on an internal fork for some time now with the intention of contributing (or at least offering) most of them back, but you've also been making changes quickly including implementing some of the same things before I had a chance to port PRs over!

This could probably be separated into two logical parts - moving the stop criteria handling could be done while leaving the detokenization on the python side, so perhaps it's worth discussing them separately.

I would argue that the decoding logic is already "in rust" since right now text-generation-inference only supports models with fast tokenizers aka rust tokenizers.

I know what you mean, but the lookup of individual tokens is probably not very significant either way (just a hashtable index). I was thinking more about how that can be inlined with surrounding conditional logic, including that related to incremental decoding. It seemed nice for the python shards to just stream token ids back and not worry about the text manipulation side of things.

I'm not sure I see what you mean here. From what I saw, you moved part of the decode_token logic to the Batch class and the pruning logic execution from the end of the Decode method to the beginning.

Yes, exactly. I did this originally in support of moving the stopping criteria logic out but then found that this decomposition of generate_token into simpler operations ends up cleaner (imho). Most executions of Decode don't involve pruning, so it's nice to have it separate. Then the contract of these operations just becomes:

  • generate_token - generate the next token for every member of the batch (updates the batch in-place)
  • prune - remove specified subset of requests from a batch

and the resulting simpler code for each of these should be easier to maintain or implement for new kinds of models (again just imho)

But this could have been implemented with the current system.

  1. Add next_batch_keep_indices to the batch metadata
  2. move the pruning from the end to the beggining of decode
  3. Add drop_indices to the Batch proto
  4. Prune using both next_batch_keep_indices and drop_indices

Sure, but this seems a bit more complicated - the stopping decision logic is split between two different places, and each of the separate model subclasses have to be involved. Performance/simplicity-wise what's the downside of having the stop criteria logic done in one place on the rust side? It also means the stop sequence evaluation can be more efficient... it's all on the infer loop critical path.

Do you have an example use case for timi-limit based stopping criteria?

Depending on the model/hardware, the generation might not be super fast. So it can be useful to say for example give me as much as you can in 2 seconds rather than guessing a max token limit.

Shoudn't this be added directly to tokenizers?

See the discussion with @Narsil here. I agree with you that it would be nice for it to be added there, but it sounds like streaming functionality may be too niche. In any case I thought perhaps this project could be a good place to incubate it given it's the main application for it. Kind of like how you've included various custom/specialized changes to transformers.

The main claim of this PR is that moving the code from Python to Rust is faster.

I wouldn't say that's the main claim, more like just one of the potential benefits. I think more significant is the separation/movement of concerns.

  1. I need to see benchmarks in different scenarios: is the end-to-end latency over a single request lower? What about batching?

We are working on benchmarking and so I hope can get some data on this soon. There are other variations too, like lots of stop sequences being used, etc.

  1. I like having the tokenizers in the Python code. It allows for greater flexibility. A lot of new models don't have fast tokenizers yet (llama, ChatGLM, even OPT). I'm not sure if rellying more on tokenizers is the correct way to go. I was about to make the validation/truncation optional if no fast tokenizer is found and rellying entirely on Python to support this kind of models.

I thought I saw that @Narsil was working hard on filling these gaps :)

Also:
- Fixes incremental decoding inconsistency issues
- Adds time limit based stopping option
@njhill
Copy link
Contributor Author

njhill commented Mar 31, 2023

@OlivierDehaene I was going to rebase this but then realized your benchmark stuff talks directly to the internal proto interface and so would need to be adjusted too.

@OlivierDehaene
Copy link
Member

@njhill,
I don't think this will be merged anytime soon: it is too big of a change.
I love the upstream PR though. I think its a very important feedback mechanism.

@drbh
Copy link
Collaborator

drbh commented Jan 29, 2024

closing in favor of more recent work that implements large parts of this PR #202

@drbh drbh closed this Jan 29, 2024
@Narsil
Copy link
Collaborator

Narsil commented Jan 29, 2024

Wrong link ?

@drbh
Copy link
Collaborator

drbh commented Jan 29, 2024

woops sorry about that bad copy/paste 😅 I've updated the comment above

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

How to process Unicode such as 🤗 😄?
4 participants