-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: Add assisted generation #22211
Conversation
The documentation is not available anymore as the PR was closed or merged. |
8ab1a4c
to
0b5a8ea
Compare
@amyeroberts @sgugger -- since this PR is a bit more complex than most, I've decided to request a review from you two 🤗 |
# may fix in the future: the following models fail to pass this test, and need model-specific fixes | ||
if any( | ||
model_name in model_class.__name__.lower() | ||
for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fix for these models is non-obvious, so I've decided to prioritize shipping the feature instead of aiming for 100% coverage :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gante, what issues did they have? (For example, the gptbigcode
model)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this! LGTM apart from the change of default of synced_gpus
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! 🚀
Just left some very small nits 🤏 Only thing I'd say is the generation method is pretty large, so might be good to split up - but the numbered comment sections help navigate it a lot 🔍
# 1.1. use the assistant model to obtain the next candidate logits | ||
if "assistant_past_key_values" in model_kwargs: | ||
prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[2] | ||
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we want to add an assert here to check this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather not add the assert -- the tests will fail if something is wrong here and it will cause slowdowns (which is very undesirable since this is a performace-oriented generation method)
@amyeroberts regarding splitting up, I totally agree! And not only on this method but on most parts of Before merging, I'm going to double-check that the current code keeps the performance numbers I got a few weeks ago. If everything goes well, it will be merged today 🙏 |
* working mvp * remove breakpoint * fix commit * standardize outputs * tmp commit * tests almost ready * tmp commit * skip a few models * Add streaming; Docs and examples * document limitations * PR commits * Amy PR comments
@gante Excellent work! I just dive into the code these days and found that the impl only support batchsize 1. Speculative Decoding have no relative to batchsize. I guess supporting bs >1 will be more hard to impl so you just support bs=1 firstly? Another question is about the decision of whether the candidate tokens generated by draft model be accepted or not. The process |
@zhaoyang-star thank you for the kind words :) Re batch size 1: it was a mix of implementation simplicity and diminishing returns. Since Re implementation differences: the two techniques were developed independently, despite relying on the same principle (saving GPU memory bandwidth with the aid of a smaller model). To put it plainly:
For the record, we will be adding the sampling trick to our implementation soon, so it will be the best of both worlds :) |
@gante Thanks for your reply.
How to get the conclusion that Speculative Decoding is better when sampling is active with temperatures above 0.3-0.4, and Assisted Generation is better in other scenarios? If the conclusion is right, is it better that we implement both the two methods and decide to execute it according to the vaule of temperature? BTW, Assisted Generation is much easier to understand than Speculative Decoding. So I perfer to use Assisted Generation. |
@zhaoyang-star The conclusion is empirical, with the After we merge the mathematical trick from speculative decoding, calling |
@gante Have you thought of any solution and approach to implement assisted generation on transformer-nueronx? |
Thanks @gante for the feature! I was trying out the following snippets and couldn't figure out which model-pairs are supported by the feature. And I've a couple of questions on how to use it.
I suppose the assumption are that
from transformers import AutoModelForCausalLM, AutoTokenizer
checkpoint = 'EleutherAI/pythia-1.4b-deduped'
assistant = 'EleutherAI/pythia-160m-deduped'
tokenizer = AutoTokenizer.from_pretrained(checkpoint) #, bos_token_id=101, eos_token_id=102)
model = AutoModelForCausalLM.from_pretrained(checkpoint) #, bos_token_id=101, eos_token_id=102)
assistant_model = AutoModelForCausalLM.from_pretrained(assistant)
tokenized_inputs = tokenizer("Alice and Bob", return_tensors="pt")
outputs = model.generate(**tokenized_inputs, assistant_model=assistant_model)
tokenizer.batch_decode(outputs, skip_special_tokens=True) What I've triedThis works:
These didn't:
|
It also works with encoder-decoder models, i.e. models supported by |
What does this PR do?
Here it is, the PR for assisted generation 🙌 In a nutshell, it uses an assistant model (which should be a smaller model with the same tokenizer) to speed up generation, taking advantage of the reduced need for memory transfers in the main model forward pass. It leverages the same property that makes batched inference faster per token.
Since it is meant to be a reference implementation, the code is meant to be clear and well-commented. If you come across any non-obvious steps, let me know so I can clarify them!
Follow-up steps after this PR:
sample
version of assisted generation (many cool apps rely on sampling, including chatbots/assistants)To process the potential speedup visually, consider the following script and the two videos. They correspond to greedy search using a 6.9B GPTNeoX model on an nvidia 3090 🚀
Script
(focus on the speed and the fact that the output is the same, not on the output itself)