Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: Add assisted generation #22211

Merged
merged 13 commits into from
Apr 18, 2023
Merged

Conversation

gante
Copy link
Member

@gante gante commented Mar 16, 2023

What does this PR do?

Here it is, the PR for assisted generation 🙌 In a nutshell, it uses an assistant model (which should be a smaller model with the same tokenizer) to speed up generation, taking advantage of the reduced need for memory transfers in the main model forward pass. It leverages the same property that makes batched inference faster per token.

Since it is meant to be a reference implementation, the code is meant to be clear and well-commented. If you come across any non-obvious steps, let me know so I can clarify them!

Follow-up steps after this PR:

  1. Add support for a sample version of assisted generation (many cool apps rely on sampling, including chatbots/assistants)
  2. Write a blog post a prepare strong communications about the feature

To process the potential speedup visually, consider the following script and the two videos. They correspond to greedy search using a 6.9B GPTNeoX model on an nvidia 3090 🚀

Script
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
import torch
import time

model_id = "EleutherAI/pythia-6.9b-deduped"
assistant_id = "EleutherAI/pythia-160m-deduped"

tokenizer = AutoTokenizer.from_pretrained(model_id)

assistant_model = AutoModelForCausalLM.from_pretrained(assistant_id)
assistant_model = assistant_model.to("cuda")

model_kwargs = {
    "pretrained_model_name_or_path": model_id,
    "device_map": "auto",
    "max_memory": {0: "20GiB", "cpu": "50GiB"},
    "torch_dtype": torch.float16,
}
model = AutoModelForCausalLM.from_pretrained(**model_kwargs)

inputs = tokenizer("Here's how to cook a good ramen:", return_tensors="pt").to("cuda")

streamer = TextStreamer(tokenizer=tokenizer)

print("Without assistance:")
start = time.time()
model.generate(**inputs, streamer=streamer, max_new_tokens=128)
print(f"Elapsed time: {time.time() - start:.2f} seconds")

print("With assistance:")
start = time.time()
model.generate(**inputs, assistant_model=assistant_model, streamer=streamer, max_new_tokens=128)
print(f"Elapsed time: {time.time() - start:.2f} seconds")
Without assistant With assistant

(focus on the speed and the fact that the output is the same, not on the output itself)

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 16, 2023

The documentation is not available anymore as the PR was closed or merged.

@gante gante marked this pull request as ready for review April 17, 2023 18:47
@gante
Copy link
Member Author

gante commented Apr 17, 2023

@amyeroberts @sgugger -- since this PR is a bit more complex than most, I've decided to request a review from you two 🤗

# may fix in the future: the following models fail to pass this test, and need model-specific fixes
if any(
model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text"]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix for these models is non-obvious, so I've decided to prioritize shipping the feature instead of aiming for 100% coverage :)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gante, what issues did they have? (For example, the gptbigcode model)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! LGTM apart from the change of default of synced_gpus.

docs/source/en/generation_strategies.mdx Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! 🚀

Just left some very small nits 🤏 Only thing I'd say is the generation method is pretty large, so might be good to split up - but the numbered comment sections help navigate it a lot 🔍

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
# 1.1. use the assistant model to obtain the next candidate logits
if "assistant_past_key_values" in model_kwargs:
prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[2]
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we want to add an assert here to check this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather not add the assert -- the tests will fail if something is wrong here and it will cause slowdowns (which is very undesirable since this is a performace-oriented generation method)

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
@gante
Copy link
Member Author

gante commented Apr 18, 2023

@amyeroberts regarding splitting up, I totally agree! And not only on this method but on most parts of GenerationMixin. Not only are the functions long, but they reuse a significant part of the logic. I want to address that in the near future, by designing a .generate() that can be somehow composed of a sequence of smaller functional blocks. I haven't figured out the deets, but I'd expect that a good implementation would get us better readability, less code duplication, and higher flexibility for HW/model/decoding-specific implementations! 💅

Before merging, I'm going to double-check that the current code keeps the performance numbers I got a few weeks ago. If everything goes well, it will be merged today 🙏

@gante gante merged commit 78cda46 into huggingface:main Apr 18, 2023
@gante gante deleted the assisted_generate_2 branch April 18, 2023 16:37
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
* working mvp

* remove breakpoint

* fix commit

* standardize outputs

* tmp commit

* tests almost ready

* tmp commit

* skip a few models

* Add streaming; Docs and examples

* document limitations

* PR commits

* Amy PR comments
@zhaoyang-star
Copy link

zhaoyang-star commented Oct 31, 2023

@gante Excellent work! I just dive into the code these days and found that the impl only support batchsize 1. Speculative Decoding have no relative to batchsize. I guess supporting bs >1 will be more hard to impl so you just support bs=1 firstly?

Another question is about the decision of whether the candidate tokens generated by draft model be accepted or not. The process n_matches is not the same as Google or DeepMind's paper. I found an impl of DeepMind's algrothm. Could you please explain it with more detail? I have the Thanks in advance.

@gante
Copy link
Member Author

gante commented Oct 31, 2023

@zhaoyang-star thank you for the kind words :)

Re batch size 1: it was a mix of implementation simplicity and diminishing returns. Since transformers works with batched inputs with fixed length, efficiently applying assisted generation/speculative decoding would necessarily mean applying extra logic to realign the tensors (e.g. row 1 might get 5 speculated tokens, but row 2 only gets 2 -- row 2 would need to be left-padded to continue). Moving to nested tensors will get us rid of this limitation :)

Re implementation differences: the two techniques were developed independently, despite relying on the same principle (saving GPU memory bandwidth with the aid of a smaller model). To put it plainly:

  1. Speculative Decoding is better when sampling is active with temperatures above 0.3-0.4 -- it employs a clever mathematical trick to handle decoding mismatches. However, you must define how many tokens you want to fetch from the smaller model.
  2. Assisted Generation (our implementation) is better in the other scenarios because it has a dynamic heuristic to decide how many tokens to fetch from the assistant model, based on the assistant hit ratio. This means it can adapt according to the difficulty of the prompt, with additional no user input.

For the record, we will be adding the sampling trick to our implementation soon, so it will be the best of both worlds :)

@zhaoyang-star
Copy link

zhaoyang-star commented Nov 15, 2023

@gante Thanks for your reply.

Speculative Decoding is better when sampling is active with temperatures above 0.3-0.4 -- it employs a clever mathematical trick to handle decoding mismatches. However, you must define how many tokens you want to fetch from the smaller model.

How to get the conclusion that Speculative Decoding is better when sampling is active with temperatures above 0.3-0.4, and Assisted Generation is better in other scenarios? If the conclusion is right, is it better that we implement both the two methods and decide to execute it according to the vaule of temperature?

BTW, Assisted Generation is much easier to understand than Speculative Decoding. So I perfer to use Assisted Generation.

@gante
Copy link
Member Author

gante commented Nov 15, 2023

@zhaoyang-star The conclusion is empirical, with the 0.3-0.4 being a personal rule of thumb based on my assisted generation tests and the values reported in the speculative decoding paper 🤗 It certainly depends on the model and on the task itself.

After we merge the mathematical trick from speculative decoding, calling assisted_generation will actually be the best of both worlds -- it will use the mathematical trick from speculative decoding AND apply the heuristic to determine the number of candidate tokens from assisted generation, all without additional parameterization!

@zhaoyang-star
Copy link

@gante Thanks a lot. Can't waiting to try the merged version. I saw #27270 is relative to speculative decoding.

@Dev-hestabit
Copy link

@gante Have you thought of any solution and approach to implement assisted generation on transformer-nueronx?

@alvations
Copy link

alvations commented Mar 13, 2024

Thanks @gante for the feature!

I was trying out the following snippets and couldn't figure out which model-pairs are supported by the feature. And I've a couple of questions on how to use it.

  1. What model-pairings are known to be supported by the model.generate(..., assistant_model='') feature?
  2. Does it work for decoder-only model too? Anyone tried any pairs of decoder-only models available on the huggingface hub?

I suppose the assumption are that

  • the tokenizer must be the same for assistant and main model
  • the model is supported by AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer

checkpoint = 'EleutherAI/pythia-1.4b-deduped'
assistant = 'EleutherAI/pythia-160m-deduped'

tokenizer = AutoTokenizer.from_pretrained(checkpoint) #, bos_token_id=101, eos_token_id=102)
model = AutoModelForCausalLM.from_pretrained(checkpoint) #, bos_token_id=101, eos_token_id=102)

assistant_model = AutoModelForCausalLM.from_pretrained(assistant)

tokenized_inputs = tokenizer("Alice and Bob", return_tensors="pt")

outputs = model.generate(**tokenized_inputs, assistant_model=assistant_model)

tokenizer.batch_decode(outputs, skip_special_tokens=True)

What I've tried

This works:

  • EleutherAI/pythia-1.4b-deduped + EleutherAI/pythia-160m-deduped

These didn't:

  • google-bert/bert-large-uncased + google-bert/bert-base-uncased (also had to add , bos_token_id=101, eos_token_id=102) to the model and/or tokenizer initialization to avoid None type when assistant model is scoping down the vocabulary)
  • FacebookAI/xlm-roberta-large + FacebookAI/xlm-roberta-base (ended up with TypeError: object of type 'NoneType' has no len() error when looking for candidate generation)

@gante
Copy link
Member Author

gante commented Mar 14, 2024

@alvations 👋

It also works with encoder-decoder models, i.e. models supported by AutoModelForSeq2SeqLM. I am definitely unable to list all working cases, but feel free to open a new issue if you think should be working and isn't :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants