From 4501254c67ddf11e77582f8b127203da797f57b2 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 19 Nov 2024 13:20:07 +0100 Subject: [PATCH] Self-speculation (Layer-Skip Llama) (#34240) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 😅 * early exit (#34244) * mvp * docs and tests * a few fixes * no shared cache * Apply suggestions from code review Co-authored-by: Mostafa Elhoushi * docs * make fix-copies * cohere fix * [test all] * [test all] consistent model code copies * [test all] make fix-copies :D * Apply suggestions from code review Co-authored-by: Pedro Cuenca Co-authored-by: Mostafa Elhoushi * Update src/transformers/generation/candidate_generator.py * Update src/transformers/generation/configuration_utils.py Co-authored-by: Pedro Cuenca * [test all] don't use a stand-alone attribute; fix test --------- Co-authored-by: Joao Gante Co-authored-by: Joao Gante Co-authored-by: Mostafa Elhoushi Co-authored-by: Pedro Cuenca --- docs/source/en/generation_strategies.md | 70 +++++++++++++------ src/transformers/cache_utils.py | 29 ++++---- src/transformers/generation/__init__.py | 8 ++- .../generation/candidate_generator.py | 56 +++++++++++++++ .../generation/configuration_utils.py | 16 +++-- src/transformers/generation/utils.py | 12 +++- .../models/cohere/modeling_cohere.py | 2 +- .../models/gemma/modeling_gemma.py | 2 +- .../models/gemma/modular_gemma.py | 2 +- .../models/gemma2/modeling_gemma2.py | 2 +- .../models/gemma2/modular_gemma2.py | 2 +- src/transformers/models/glm/modeling_glm.py | 2 +- .../models/llama/modeling_llama.py | 2 +- .../models/olmoe/modeling_olmoe.py | 2 +- tests/generation/test_utils.py | 22 ++++++ 15 files changed, 178 insertions(+), 51 deletions(-) diff --git a/docs/source/en/generation_strategies.md b/docs/source/en/generation_strategies.md index 621edeb20e8ea3..380b39fe62acdf 100644 --- a/docs/source/en/generation_strategies.md +++ b/docs/source/en/generation_strategies.md @@ -416,16 +416,6 @@ Assisted decoding assumes the main and assistant models have the same tokenizer, Currently, only greedy search and sampling are supported with assisted decoding, and assisted decoding doesn't support batched inputs. To learn more about assisted decoding, check [this blog post](https://huggingface.co/blog/assisted-generation). -#### Universal Assisted Decoding - -Universal Assisted Decoding (UAD) adds support for main and assistant models with different tokenizers. -To use it, simply pass the tokenizers using the `tokenizer` and `assistant_tokenizer` arguments (see below). -Internally, the main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are -in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above. -The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer. -Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings, -to ensure the new tokens include the correct prompt suffix. - To enable assisted decoding, set the `assistant_model` argument with a model. ```python @@ -445,7 +435,36 @@ To enable assisted decoding, set the `assistant_model` argument with a model. ['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a'] ``` -If the main and assistant models have different tokenizers, use Universal Assisted Decoding. +When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness, +just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency. + +```python +>>> from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed +>>> set_seed(42) # For reproducibility + +>>> prompt = "Alice and Bob" +>>> checkpoint = "EleutherAI/pythia-1.4b-deduped" +>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped" + +>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) +>>> inputs = tokenizer(prompt, return_tensors="pt") + +>>> model = AutoModelForCausalLM.from_pretrained(checkpoint) +>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint) +>>> outputs = model.generate(**inputs, assistant_model=assistant_model, do_sample=True, temperature=0.5) +>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) +['Alice and Bob, a couple of friends of mine, who are both in the same office as'] +``` + +#### Universal Assisted Decoding + +Universal Assisted Decoding (UAD) adds support for main and assistant models with different tokenizers. +To use it, simply pass the tokenizers using the `tokenizer` and `assistant_tokenizer` arguments (see below). +Internally, the main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are +in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above. +The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer. +Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings, +to ensure the new tokens include the correct prompt suffix. ```python >>> from transformers import AutoModelForCausalLM, AutoTokenizer @@ -465,30 +484,35 @@ If the main and assistant models have different tokenizers, use Universal Assist ['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a'] ``` -When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness, -just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency. +#### Prompt Lookup + +Alternatively, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed +to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259). + +#### Self-Speculative Decoding + +An LLM can be trained to also use its language modeling head with earlier hidden states as input, effectively +skipping layers to yield a lower-quality output -- a technique called early exiting. +We use the lower-quality early exit output as an assistant output, and apply self-speculation to fix the output using the remaining layers. The final generation of that self-speculative solution is the same (or has the same distribution) as the original model's generation. +If the model you're using was trained to do early exit, you can pass +`assistant_early_exit` (integer). In this case, the assistant model will be the same model but exiting early, hence the +"self-speculative" name. Because the assistant model is a portion of the target model, caches and weights can be shared, which results in lower memory requirements. As in other assisted generation methods, the final generated result has the same quality as if no assistant had been used. ```python ->>> from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed ->>> set_seed(42) # For reproducibility +>>> from transformers import AutoModelForCausalLM, AutoTokenizer >>> prompt = "Alice and Bob" ->>> checkpoint = "EleutherAI/pythia-1.4b-deduped" ->>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped" +>>> checkpoint = "facebook/layerskip-llama3.2-1B" >>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) >>> inputs = tokenizer(prompt, return_tensors="pt") >>> model = AutoModelForCausalLM.from_pretrained(checkpoint) ->>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint) ->>> outputs = model.generate(**inputs, assistant_model=assistant_model, do_sample=True, temperature=0.5) +>>> outputs = model.generate(**inputs, assistant_early_exit=4, do_sample=False, max_new_tokens=20) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) -['Alice and Bob, a couple of friends of mine, who are both in the same office as'] +['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a'] ``` -Alternatively, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed -to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259). - ### DoLa Decoding **D**ecoding by C**o**ntrasting **La**yers (DoLa) is a contrastive decoding strategy to improve the factuality and reduce the diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 0f696cc3ac6a4d..aeb184f7400ce7 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -433,19 +433,22 @@ def update( self._seen_tokens += key_states.shape[-2] # Update the cache - if len(self.key_cache) <= layer_idx: - # There may be skipped layers, fill them with empty lists - for _ in range(len(self.key_cache), layer_idx): - self.key_cache.append([]) - self.value_cache.append([]) - self.key_cache.append(key_states) - self.value_cache.append(value_states) - elif len(self.key_cache[layer_idx]) == 0: # fills previously skipped layers; checking for tensor causes errors - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + if key_states is not None: + if len(self.key_cache) <= layer_idx: + # There may be skipped layers, fill them with empty lists + for _ in range(len(self.key_cache), layer_idx): + self.key_cache.append([]) + self.value_cache.append([]) + self.key_cache.append(key_states) + self.value_cache.append(value_states) + elif ( + len(self.key_cache[layer_idx]) == 0 + ): # fills previously skipped layers; checking for tensor causes errors + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) return self.key_cache[layer_idx], self.value_cache[layer_idx] diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index b487fa3c7fe6ec..e2ed48433b1639 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -49,6 +49,7 @@ _import_structure["candidate_generator"] = [ "AssistedCandidateGenerator", "CandidateGenerator", + "EarlyExitCandidateGenerator", "PromptLookupCandidateGenerator", ] _import_structure["logits_process"] = [ @@ -206,7 +207,12 @@ else: from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer - from .candidate_generator import AssistedCandidateGenerator, CandidateGenerator, PromptLookupCandidateGenerator + from .candidate_generator import ( + AssistedCandidateGenerator, + CandidateGenerator, + EarlyExitCandidateGenerator, + PromptLookupCandidateGenerator, + ) from .logits_process import ( AlternatingCodebooksLogitsProcessor, ClassifierFreeGuidanceLogitsProcessor, diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 1e4d7a4702453a..d8344c25a6526a 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -670,6 +670,62 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F return +class EarlyExitCandidateGenerator(AssistedCandidateGenerator): + """ + `CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates + candidates through the use of **the model itself**, exiting early. Can only be used with models that support early + exit, e.g., `facebook/layerskip-llama3.2-1B`. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + assistant_model (`PreTrainedModel`): + The original model. This model must support early exit (i.e. is trained to compute logits in earlier + layers). + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + model_kwargs (`Dict`): + The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant + model as well. + inputs_tensor (`torch.Tensor`, *optional*): + The model input tensor. In encoder-decoder models, this is the encoder input. + """ + + def __init__( + self, + input_ids: torch.LongTensor, + assistant_model: "PreTrainedModel", + generation_config: "GenerationConfig", + model_kwargs: Dict, + inputs_tensor: Optional[torch.Tensor] = None, + logits_processor: "LogitsProcessorList" = None, + ): + super().__init__( + input_ids=input_ids, + assistant_model=assistant_model, + generation_config=generation_config, + model_kwargs=model_kwargs, + inputs_tensor=inputs_tensor, + logits_processor=logits_processor, + ) + # We have to move early exit out of the generation config, otherwise the assistant will also call `generate` + # with early exit + self.assistant_early_exit = self.generation_config.assistant_early_exit + self.generation_config.assistant_early_exit = None + + def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: + # Temporarily sets the number of hidden layers to the early exit value + base_model = getattr(self.assistant_model, self.assistant_model.base_model_prefix) + original_num_hidden_layers = base_model.config.num_hidden_layers + base_model.config.num_hidden_layers = self.assistant_early_exit + candidate_ids, candidate_logits = super().get_candidates(input_ids) + base_model.config.num_hidden_layers = original_num_hidden_layers + return candidate_ids, candidate_logits + + def _crop_past_key_values(model, past_key_values, max_length): """Crops the past key values up to a certain maximum length.""" new_past = [] diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 9b543f6c35711d..de62ee767aeda0 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -353,10 +353,13 @@ class GenerationConfig(PushToHubMixin): than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_ (defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead from Dynamic Speculation Lookahead Accelerates Speculative Decoding of Large Language Models . - prompt_lookup_num_tokens (`int`, *optional*, default to `None`): + prompt_lookup_num_tokens (`int`, *optional*): The number of tokens to be output as candidate tokens. - max_matching_ngram_size (`int`, *optional*, default to `None`): + max_matching_ngram_size (`int`, *optional*): The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided. + assistant_early_exit(`int`, *optional*): + If set to a positive integer, early exit of the model will be used as an assistant. Can only be used with + models that support early exit (i.e. models where logits from intermediate layers can be interpreted by the LM head). > Wild card @@ -454,10 +457,9 @@ def __init__(self, **kwargs): self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 20) self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "constant") self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", 0.4) - - # Prompt lookup decoding self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None) self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None) + self.assistant_early_exit = kwargs.pop("assistant_early_exit", None) # Wild card self.generation_kwargs = kwargs.pop("generation_kwargs", {}) @@ -534,7 +536,11 @@ def get_generation_mode(self, assistant_model: Optional["PreTrainedModel"] = Non generation_mode = GenerationMode.BEAM_SEARCH # Assisted generation may extend some generation modes - if assistant_model is not None or self.prompt_lookup_num_tokens is not None: + if ( + assistant_model is not None + or self.prompt_lookup_num_tokens is not None + or self.assistant_early_exit is not None + ): if generation_mode in ("greedy_search", "sample"): generation_mode = GenerationMode.ASSISTED_GENERATION else: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 97a294fd427987..432b3142873d39 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -54,6 +54,7 @@ AssistedCandidateGenerator, AssistedCandidateGeneratorDifferentTokenizers, CandidateGenerator, + EarlyExitCandidateGenerator, PromptLookupCandidateGenerator, _crop_past_key_values, _prepare_attention_mask, @@ -822,7 +823,16 @@ def _get_candidate_generator( """ different_tokenizers = all(v is not None for v in (assistant_model, target_tokenizer, assistant_tokenizer)) - if generation_config.prompt_lookup_num_tokens is not None: + if generation_config.assistant_early_exit is not None: + candidate_generator = EarlyExitCandidateGenerator( + input_ids=input_ids, + assistant_model=self, + generation_config=generation_config, + model_kwargs=model_kwargs, + inputs_tensor=inputs_tensor, + logits_processor=logits_processor, + ) + elif generation_config.prompt_lookup_num_tokens is not None: candidate_generator = PromptLookupCandidateGenerator( eos_token_id=generation_config._eos_token_tensor, num_output_tokens=generation_config.prompt_lookup_num_tokens, diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 0261f997da1110..d481d87e7ab8ed 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -890,7 +890,7 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 6fead73eced704..52d02995016167 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -808,7 +808,7 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 807f91ff9e6baa..ad1348ae5e3163 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -886,7 +886,7 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 6a3d8f27fb177d..c439ec069f7a93 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -823,7 +823,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index dacaca1c7ef4a9..ff2d42d671c3c4 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -653,7 +653,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 58a89d90b44ff5..9080b5b9cc7c39 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -789,7 +789,7 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 679296648a9135..0408bb73c7f2da 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -893,7 +893,7 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 47cb0964eca8b6..169827ffd75777 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -995,7 +995,7 @@ def forward( all_router_logits = () if output_router_logits else None next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index cbe851e97e9aed..6630fc2ba9d152 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -4108,6 +4108,28 @@ def test_generate_compile_fullgraph_tiny(self): gen_out = compiled_generate(**model_inputs, generation_config=generation_config) self.assertTrue(gen_out.shape[1] > model_inputs["input_ids"].shape[1]) # some text was generated + def test_assisted_generation_early_exit(self): + """ + Tests that assisted generation with early exit works as expected. Under the hood, this has complex cache + manipulation, which will cause the test to fail if something goes wrong there. + """ + expected_output = "Alice and Bob are playing a game of poker. Alice has a pair of 8s and Bob has a pair" + + prompt = "Alice and Bob" + checkpoint = "facebook/layerskip-llama3.2-1B" + + tokenizer = AutoTokenizer.from_pretrained(checkpoint) + inputs = tokenizer(prompt, return_tensors="pt").to(torch_device) + + model = AutoModelForCausalLM.from_pretrained(checkpoint).to(torch_device) + original_outputs = model.generate(**inputs, do_sample=False, max_new_tokens=20) + original_decoded = tokenizer.batch_decode(original_outputs, skip_special_tokens=True) + self.assertEqual(original_decoded, [expected_output]) + + outputs_assisted = model.generate(**inputs, assistant_early_exit=4, do_sample=False, max_new_tokens=20) + decoded_assisted = tokenizer.batch_decode(outputs_assisted, skip_special_tokens=True) + self.assertEqual(decoded_assisted, [expected_output]) + @require_torch class TokenHealingTestCase(unittest.TestCase):