Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Self-speculation (Layer-Skip Llama) #34240

Merged
merged 16 commits into from
Nov 19, 2024
70 changes: 47 additions & 23 deletions docs/source/en/generation_strategies.md
Original file line number Diff line number Diff line change
Expand Up @@ -416,16 +416,6 @@ Assisted decoding assumes the main and assistant models have the same tokenizer,
Currently, only greedy search and sampling are supported with assisted decoding, and assisted decoding doesn't support batched inputs.
To learn more about assisted decoding, check [this blog post](https://huggingface.co/blog/assisted-generation).

#### Universal Assisted Decoding
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nesting was not fully right -- normal "speculative decoding" examples were under "Universal Assisted Decoding". Moved a few things around)


Universal Assisted Decoding (UAD) adds support for main and assistant models with different tokenizers.
To use it, simply pass the tokenizers using the `tokenizer` and `assistant_tokenizer` arguments (see below).
Internally, the main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are
in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above.
The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer.
Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings,
to ensure the new tokens include the correct prompt suffix.

To enable assisted decoding, set the `assistant_model` argument with a model.

```python
Expand All @@ -445,7 +435,36 @@ To enable assisted decoding, set the `assistant_model` argument with a model.
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
```

If the main and assistant models have different tokenizers, use Universal Assisted Decoding.
When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness,
just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency.

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
>>> set_seed(42) # For reproducibility

>>> prompt = "Alice and Bob"
>>> checkpoint = "EleutherAI/pythia-1.4b-deduped"
>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped"

>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
>>> outputs = model.generate(**inputs, assistant_model=assistant_model, do_sample=True, temperature=0.5)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Alice and Bob, a couple of friends of mine, who are both in the same office as']
```

#### Universal Assisted Decoding

Universal Assisted Decoding (UAD) adds support for main and assistant models with different tokenizers.
To use it, simply pass the tokenizers using the `tokenizer` and `assistant_tokenizer` arguments (see below).
Internally, the main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are
in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above.
The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer.
Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings,
to ensure the new tokens include the correct prompt suffix.

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
Expand All @@ -465,30 +484,35 @@ If the main and assistant models have different tokenizers, use Universal Assist
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
```

When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness,
just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency.
#### Prompt Lookup

Alternatively, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed
to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259).

#### Self-Speculative Decoding

An LLM can be trained to also use its language modeling head with earlier hidden states as input, effectively
skipping layers to yield a lower-quality output -- a technique called early exiting.
We use the lower-quality early exit output as an assistant output, and apply self-speculation to fix the output using the remaining layers. The final generation of that self-speculative solution is the same (or has the same distribution) as the original model's generation.
If the model you're using was trained to do early exit, you can pass
`assistant_early_exit` (integer). In this case, the assistant model will be the same model but exiting early, hence the
"self-speculative" name. Because the assistant model is a portion of the target model, caches and weights can be shared, which results in lower memory requirements. As in other assisted generation methods, the final generated result has the same quality as if no assistant had been used.

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
>>> set_seed(42) # For reproducibility
>>> from transformers import AutoModelForCausalLM, AutoTokenizer

>>> prompt = "Alice and Bob"
>>> checkpoint = "EleutherAI/pythia-1.4b-deduped"
>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped"
>>> checkpoint = "facebook/layerskip-llama3.2-1B"

>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
>>> outputs = model.generate(**inputs, assistant_model=assistant_model, do_sample=True, temperature=0.5)
>>> outputs = model.generate(**inputs, assistant_early_exit=4, do_sample=False, max_new_tokens=20)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Alice and Bob, a couple of friends of mine, who are both in the same office as']
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
```

Alternatively, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed
to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259).

### DoLa Decoding

**D**ecoding by C**o**ntrasting **La**yers (DoLa) is a contrastive decoding strategy to improve the factuality and reduce the
Expand Down
29 changes: 16 additions & 13 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,19 +433,22 @@ def update(
self._seen_tokens += key_states.shape[-2]

# Update the cache
if len(self.key_cache) <= layer_idx:
# There may be skipped layers, fill them with empty lists
for _ in range(len(self.key_cache), layer_idx):
self.key_cache.append([])
self.value_cache.append([])
self.key_cache.append(key_states)
self.value_cache.append(value_states)
elif len(self.key_cache[layer_idx]) == 0: # fills previously skipped layers; checking for tensor causes errors
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
if key_states is not None:
if len(self.key_cache) <= layer_idx:
# There may be skipped layers, fill them with empty lists
for _ in range(len(self.key_cache), layer_idx):
self.key_cache.append([])
self.value_cache.append([])
self.key_cache.append(key_states)
self.value_cache.append(value_states)
elif (
len(self.key_cache[layer_idx]) == 0
): # fills previously skipped layers; checking for tensor causes errors
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
Comment on lines +450 to +451
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here torch.cat will only be correct if min(new_positions) == previous_length + 1? If that's correct, should we also add an assert statement for that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that is correct!

I'm not going to add any check for now, though, and rely on internal tests to detect issues: adding a check here would hurt throughput in the forward pass, and a test can immediately detect issues :)


return self.key_cache[layer_idx], self.value_cache[layer_idx]

Expand Down
8 changes: 7 additions & 1 deletion src/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
_import_structure["candidate_generator"] = [
"AssistedCandidateGenerator",
"CandidateGenerator",
"EarlyExitCandidateGenerator",
"PromptLookupCandidateGenerator",
]
_import_structure["logits_process"] = [
Expand Down Expand Up @@ -206,7 +207,12 @@
else:
from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .candidate_generator import AssistedCandidateGenerator, CandidateGenerator, PromptLookupCandidateGenerator
from .candidate_generator import (
AssistedCandidateGenerator,
CandidateGenerator,
EarlyExitCandidateGenerator,
PromptLookupCandidateGenerator,
)
from .logits_process import (
AlternatingCodebooksLogitsProcessor,
ClassifierFreeGuidanceLogitsProcessor,
Expand Down
55 changes: 55 additions & 0 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,61 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F
return


class EarlyExitCandidateGenerator(AssistedCandidateGenerator):
"""
`CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates
candidates through the use of **the model itself**, exiting early. Can only be used with models that support early
exit.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe mention a specific model as an example here? (there aren't many models that currently support it)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding an example of a model that supports early exit as per the suggestion of @pcuenca . Not sure if it is a good idea to add link to a model collection in the docstring but feel free to remove it.

Suggested change
exit.
exit, e.g., `facebook/layerskip-llama3.2-1B` or any of the models listed in this [collection](https://huggingface.co/collections/facebook/layerskip-666b25c50c8ae90e1965727a).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a single model would be enough for me, a collection could give the impression that we are maintaining a list of compatible models there, which is not the case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(added a single model :) )

gante marked this conversation as resolved.
Show resolved Hide resolved

Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
assistant_model (`PreTrainedModel`):
The original model. This model must support early exit (i.e. is trained to compute logits in earlier
layers).
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
model_kwargs (`Dict`):
The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant
model as well.
inputs_tensor (`torch.Tensor`, *optional*):
The model input tensor. In encoder-decoder models, this is the encoder input.
"""

def __init__(
self,
input_ids: torch.LongTensor,
assistant_model: "PreTrainedModel",
generation_config: "GenerationConfig",
model_kwargs: Dict,
inputs_tensor: Optional[torch.Tensor] = None,
logits_processor: "LogitsProcessorList" = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, I recently also added stopping_criteria as well to support integration with Eleuther LM Eval Harness:
facebookresearch/LayerSkip@e38784d

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can ignore my comment about supporting StoppingCriteria. I checked out the PR and integrated with LM Eval Harness and found out that we don't need it.
I think I needed it in my custom implementation, but the native HF implementation doesn't.

):
super().__init__(
input_ids=input_ids,
assistant_model=assistant_model,
generation_config=generation_config,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
logits_processor=logits_processor,
)
# We have to move early exit out of the generation config, otherwise the assistant will also call `generate`
# with early exit
self.assistant_early_exit = self.generation_config.assistant_early_exit
self.generation_config.assistant_early_exit = None

def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
# Temporarily sets the number of hidden layers to the early exit value
base_model = getattr(self.assistant_model, self.assistant_model.base_model_prefix)
base_model.num_hidden_layers = self.assistant_early_exit
candidate_ids, candidate_logits = super().get_candidates(input_ids)
gante marked this conversation as resolved.
Show resolved Hide resolved
base_model.num_hidden_layers = base_model.config.num_hidden_layers
return candidate_ids, candidate_logits


def _crop_past_key_values(model, past_key_values, max_length):
"""Crops the past key values up to a certain maximum length."""
new_past = []
Expand Down
16 changes: 11 additions & 5 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,10 +353,13 @@ class GenerationConfig(PushToHubMixin):
than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_
(defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead
from Dynamic Speculation Lookahead Accelerates Speculative Decoding of Large Language Models <https://arxiv.org/abs/2405.04304>.
prompt_lookup_num_tokens (`int`, *optional*, default to `None`):
prompt_lookup_num_tokens (`int`, *optional*):
The number of tokens to be output as candidate tokens.
max_matching_ngram_size (`int`, *optional*, default to `None`):
max_matching_ngram_size (`int`, *optional*):
The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided.
assistant_early_exit(`int`, *optional*):
If set to a positive integer, early exit of the model will be used as an assistant. Can only be used with
models that support early exit (i.e. is trained to compute logits in earlier layers).
gante marked this conversation as resolved.
Show resolved Hide resolved

> Wild card

Expand Down Expand Up @@ -454,10 +457,9 @@ def __init__(self, **kwargs):
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 20)
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "constant")
self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", 0.4)

# Prompt lookup decoding
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None)
self.assistant_early_exit = kwargs.pop("assistant_early_exit", None)

# Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
Expand Down Expand Up @@ -534,7 +536,11 @@ def get_generation_mode(self, assistant_model: Optional["PreTrainedModel"] = Non
generation_mode = GenerationMode.BEAM_SEARCH

# Assisted generation may extend some generation modes
if assistant_model is not None or self.prompt_lookup_num_tokens is not None:
if (
assistant_model is not None
or self.prompt_lookup_num_tokens is not None
or self.assistant_early_exit is not None
):
if generation_mode in ("greedy_search", "sample"):
generation_mode = GenerationMode.ASSISTED_GENERATION
else:
Expand Down
12 changes: 11 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
AssistedCandidateGenerator,
AssistedCandidateGeneratorDifferentTokenizers,
CandidateGenerator,
EarlyExitCandidateGenerator,
PromptLookupCandidateGenerator,
_crop_past_key_values,
_prepare_attention_mask,
Expand Down Expand Up @@ -822,7 +823,16 @@ def _get_candidate_generator(
"""
different_tokenizers = all(v is not None for v in (assistant_model, target_tokenizer, assistant_tokenizer))

if generation_config.prompt_lookup_num_tokens is not None:
if generation_config.assistant_early_exit is not None:
candidate_generator = EarlyExitCandidateGenerator(
input_ids=input_ids,
assistant_model=self,
generation_config=generation_config,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
logits_processor=logits_processor,
)
elif generation_config.prompt_lookup_num_tokens is not None:
candidate_generator = PromptLookupCandidateGenerator(
eos_token_id=generation_config._eos_token_tensor,
num_output_tokens=generation_config.prompt_lookup_num_tokens,
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,7 @@ def __init__(self, config: CohereConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
Expand Down Expand Up @@ -890,7 +891,7 @@ def forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for decoder_layer in self.layers:
for decoder_layer in self.layers[: self.num_hidden_layers]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smart! I like that simple change that enables flexibility.

if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,7 @@ def __init__(self, config: GemmaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
Expand Down Expand Up @@ -805,7 +806,7 @@ def forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for decoder_layer in self.layers:
for decoder_layer in self.layers[: self.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma/modular_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,7 @@ def forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for decoder_layer in self.layers:
for decoder_layer in self.layers[: self.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,7 @@ def __init__(self, config: Gemma2Config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
Expand Down Expand Up @@ -820,7 +821,7 @@ def forward(
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

for decoder_layer in self.layers:
for decoder_layer in self.layers[: self.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def forward(
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

for decoder_layer in self.layers:
for decoder_layer in self.layers[: self.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,7 @@ def __init__(self, config: GlmConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
Expand Down Expand Up @@ -787,7 +788,7 @@ def forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for decoder_layer in self.layers:
for decoder_layer in self.layers[: self.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand Down
Loading
Loading