From 367296c0305379a36e311e75efbafe2a94b9c55f Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 16 Mar 2023 18:12:57 +0000 Subject: [PATCH 01/12] working mvp --- src/transformers/generation/utils.py | 451 +++++++++++++++++- .../models/bloom/modeling_bloom.py | 1 + 2 files changed, 449 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6140f4cb400..e81684a04e5 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -18,7 +18,7 @@ import inspect import warnings from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -72,6 +72,9 @@ ) +if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel + logger = logging.get_logger(__name__) @@ -1115,7 +1118,12 @@ def generate( logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, +<<<<<<< HEAD synced_gpus: Optional[bool] = None, +======= + synced_gpus: Optional[bool] = False, + assistant_model: Optional["PreTrainedModel"] = None, +>>>>>>> ca8162da4 (working mvp) **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: r""" @@ -1161,11 +1169,15 @@ def generate( on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful for constrained generation conditioned on the prefix, as described in [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904). - synced_gpus (`bool`, *optional*): + synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length. Unless overridden this flag will be set to `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished generating before other GPUs. Otherwise it'll be set to `False`. - + assistant_model (`PreTrainedModel`, *optional*): + An assistant model that can be used to accelerate generation. The assistant model must have the exact + same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model + is much faster than running generation with the model you're calling generate from. As such, the + assistant model should be much smaller. kwargs: Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder @@ -1372,6 +1384,14 @@ def generate( and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) + is_assisted_greedy_gen_mode = False + if assistant_model is not None: + if not is_greedy_gen_mode: + raise ValueError( + "You've set `assistant_model`, which triggers assisted generation. Currently, assisted generation " + "is only supported with Greedy Search." + ) + is_assisted_greedy_gen_mode = True if generation_config.num_beam_groups > generation_config.num_beams: raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") @@ -1405,6 +1425,39 @@ def generate( generation_config=generation_config, stopping_criteria=stopping_criteria ) # 10. go into different generation modes + if is_assisted_greedy_gen_mode: + if generation_config.num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " assisted greedy search." + ) + if batch_size > 1: + raise ValueError("Assisted generation is only supported for batch_size = 1") + + # 11. If the assistant model is an encoder-decoder, prepare its encoder outputs + if assistant_model.config.is_encoder_decoder: + assistant_model_kwargs = copy.deepcopy(model_kwargs) + inputs_tensor, model_input_name, assistant_model_kwargs = assistant_model._prepare_model_inputs( + inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_model_kwargs + ) + assistant_model_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, assistant_model_kwargs, model_input_name + ) + model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs["encoder_outputs"] + + # 12. run assisted greedy search + return self.assisted_greedy_search( + input_ids, + assistant_model=assistant_model, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) if is_greedy_gen_mode: if generation_config.num_return_sequences > 1: raise ValueError( @@ -3920,6 +3973,398 @@ def constrained_beam_search( else: return sequence_outputs["sequences"] + def assisted_greedy_search( + self, + input_ids: torch.LongTensor, + assistant_model: "PreTrainedModel", + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, + ): + r""" + Generates sequences of token ids for models with a language modeling head using **greedy decoding**, assisted + by a smaller model. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + + + In most cases, you do not need to call [`~generation.GenerationMixin.assisted_greedy_search`] directly. Use + generate() instead. For an overview of generation strategies and code examples, check the [following + guide](../generation_strategies). + + + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + assistant_model (`PreTrainedModel`, *optional*): + An assistant model that can be used to accelerate generation. The assistant model must have the exact + same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model + is much faster than running generation with the model you're calling generate from. As such, the + assistant model should be much smaller. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`Union[int, List[int]]`, *optional*): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + model_kwargs: + Additional model specific keyword arguments will be forwarded to the `forward` function of the model. + If model is an encoder-decoder model the kwargs should include `encoder_outputs`. + Return: + [`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + Examples: + ////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + update the example before committing + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForCausalLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... StoppingCriteriaList, + ... MaxLengthCriteria, + ... ) + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token + >>> model.generation_config.pad_token_id = model.generation_config.eos_token_id + >>> input_prompt = "It might be possible to" + >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id), + ... ] + ... ) + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) + >>> outputs = model.greedy_search( + ... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria + ... ) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ["It might be possible to get a better understanding of the nature of the problem, but it's not"] + ```""" + # NOTE: the code here is copy/paste from greedy search, except when clearly stated in the comments + # Assistant: initialize assistant-related variables + if not hasattr(assistant_model, "max_assistant_tokens"): + assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls + + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + + this_peer_finished = False # used by synced_gpus only + while True: + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + # Assistant: main logic start + cur_len = input_ids.shape[-1] + max_len = stopping_criteria[0].max_length + + # 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a + # `.generate()` call if we decide to add `past_key_values` as a possible output of the method, as we + # need access to the assistant cache to secure strong speedups. + candidate_input_ids = input_ids + for _ in range(int(assistant_model.max_assistant_tokens)): + # 1.1. use the assistant model to obtain the next candidate logits + if "assistant_past_key_values" in model_kwargs: + prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[2] + # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) + new_token_len = candidate_input_ids.shape[1] - prev_seq_len + tmp_inputs = candidate_input_ids[:, -new_token_len:] + tmp_attn = torch.ones_like(candidate_input_ids) + # TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2 + if assistant_model.config.is_encoder_decoder: + assistant_model_outputs = assistant_model( + decoder_input_ids=tmp_inputs, + decoder_attention_mask=tmp_attn, + past_key_values=model_kwargs["assistant_past_key_values"], + encoder_outputs=model_kwargs["assistant_encoder_outputs"], + ) + else: + assistant_model_outputs = assistant_model( + tmp_inputs, + attention_mask=tmp_attn, + past_key_values=model_kwargs["assistant_past_key_values"], + ) + else: + if assistant_model.config.is_encoder_decoder: + assistant_model_outputs = assistant_model( + decoder_input_ids=candidate_input_ids, + encoder_outputs=model_kwargs["assistant_encoder_outputs"], + ) + else: + assistant_model_outputs = assistant_model(candidate_input_ids) + + # 1.2. greedily select the next candidate token + model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values + if len(logits_processor) > 0: + assistant_model_outputs.logits[:, -1, :] = logits_processor( + candidate_input_ids, assistant_model_outputs.logits[:, -1, :] + ) + new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1) + candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1) + + # 1.3. stop assistant generation on EOS + if eos_token_id_tensor is not None: + last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1) + last_assistant_token_is_eos = ( + ~last_assistant_token_is_eos.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool() + ) + if last_assistant_token_is_eos: + break + + candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] + + # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain + # `candidate_length + 1` relevant logits from this process (see step 7 on why the +1) + if "past_key_values" in model_kwargs: + og_model_attn = torch.ones_like(candidate_input_ids) + og_model_input_ids = candidate_input_ids[:, -candidate_length - 1 :] + if self.config.is_encoder_decoder: + outputs = self( + decoder_input_ids=og_model_input_ids, + decoder_attention_mask=og_model_attn, + past_key_values=model_kwargs["past_key_values"], + encoder_outputs=model_kwargs["encoder_outputs"], + ) + else: + outputs = self( + og_model_input_ids, + attention_mask=og_model_attn, + past_key_values=model_kwargs["past_key_values"], + ) + else: + if self.config.is_encoder_decoder: + outputs = self( + decoder_input_ids=candidate_input_ids, encoder_outputs=model_kwargs["encoder_outputs"] + ) + else: + outputs = self(candidate_input_ids) + + # 3. Obtain the argmax from the original model logits. + if len(logits_processor) > 0: + for i in range(candidate_length): + outputs.logits[:, i, :] = logits_processor( + candidate_input_ids[:, : cur_len + i], outputs.logits[:, i, :] + ) + max_logits = outputs.logits.argmax(dim=-1)[:, -candidate_length - 1 : -1] + + # 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep + # the assistant forecasted tokens until the first mismatch, or until the max length is reached. + candidate_new_tokens = candidate_input_ids[:, -candidate_length:] + n_matches = ((~(candidate_new_tokens == max_logits)).cumsum(dim=-1) < 1).sum() + + # 5. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic, + # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the + # cost of forecasting incorrect assistant tokens. + if n_matches == int(assistant_model.max_assistant_tokens): + assistant_model.max_assistant_tokens += 2.0 + else: + assistant_model.max_assistant_tokens -= 1.0 + if assistant_model.max_assistant_tokens < 1.0: + assistant_model.max_assistant_tokens = 1.0 + + # 6. Update variables according to the number of matching assistant tokens. + n_matches = min(n_matches, max_len - cur_len) + input_ids = candidate_input_ids[:, 0 : cur_len + n_matches] + + # check stopping criteria here + if (last_assistant_token_is_eos and n_matches == candidate_length) or stopping_criteria(input_ids, None): + break + + new_cur_len = input_ids.shape[-1] + + # 6.1. Discard past key values relative to unused assistant tokens + outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cur_len) + model_kwargs["assistant_past_key_values"] = _crop_past_key_values( + assistant_model, model_kwargs["assistant_past_key_values"], new_cur_len + ) + + # 6.2. Extract the logits for the next token + if outputs.logits.shape[1] > candidate_length + 1: + logits_idx = new_cur_len - 1 + else: + logits_idx = n_matches + next_token_scores = outputs.logits[:, logits_idx, :] + + # 7. Use the set of logits after the last matching assistant token to obtain the next token. Note that, + # because of this step, assisted greedy search degenerates to a normal greedy search if there is no match. + next_tokens = torch.argmax(next_token_scores, dim=-1) + + # Assistant: main logic end; Compared to greedy search, the following (redundant) blocks were removed + # below: (1) model input preparation; (2) model forward pass; (3) score preparation; (4) model cache + # update. + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id_tensor is not None: + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + ) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, None): + if not synced_gpus: + break + else: + this_peer_finished = True + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return GreedySearchEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return GreedySearchDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return input_ids + + +def _crop_past_key_values(model, past_key_values, maximum_length): + new_past = [] + if model.config.is_encoder_decoder: + for idx in range(len(past_key_values)): + new_past.append( + ( + past_key_values[idx][0][:, :, :maximum_length, :], + past_key_values[idx][1][:, :, :maximum_length, :], + past_key_values[idx][2], + past_key_values[idx][3], + ) + ) + past_key_values = tuple(new_past) + elif "bloom" in model.__class__.__name__.lower(): # bloom is special + for idx in range(len(past_key_values)): + new_past.append( + ( + past_key_values[idx][0][:, :, :maximum_length], + past_key_values[idx][1][:, :maximum_length, :], + ) + ) + past_key_values = tuple(new_past) + else: + for idx in range(len(past_key_values)): + new_past.append( + ( + past_key_values[idx][0][:, :, :maximum_length, :], + past_key_values[idx][1][:, :, :maximum_length, :], + ) + ) + past_key_values = tuple(new_past) + return past_key_values + def top_k_top_p_filtering( logits: torch.FloatTensor, diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index f598c8299d1..b259ab5921f 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -513,6 +513,7 @@ def _convert_to_standard_cache( Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, num_heads, ...])) """ + breakpoint() batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape num_heads = batch_size_times_num_heads // batch_size # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] From 0b5a8eabe2ceebde3a30bd650f809632c4bd3e32 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 16 Mar 2023 18:37:52 +0000 Subject: [PATCH 02/12] remove breakpoint --- src/transformers/models/bloom/modeling_bloom.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index b259ab5921f..f598c8299d1 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -513,7 +513,6 @@ def _convert_to_standard_cache( Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, num_heads, ...])) """ - breakpoint() batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape num_heads = batch_size_times_num_heads // batch_size # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] From 76faabf2757e13a95a1b3d8ac8447020549726bb Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 28 Mar 2023 10:12:18 +0000 Subject: [PATCH 03/12] fix commit --- src/transformers/generation/utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e81684a04e5..78942f3434f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1118,12 +1118,8 @@ def generate( logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, -<<<<<<< HEAD - synced_gpus: Optional[bool] = None, -======= synced_gpus: Optional[bool] = False, assistant_model: Optional["PreTrainedModel"] = None, ->>>>>>> ca8162da4 (working mvp) **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: r""" From af41bcbc357e03700d61b2d7a81b20b1931c4be8 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 28 Mar 2023 10:53:09 +0000 Subject: [PATCH 04/12] standardize outputs --- src/transformers/generation/utils.py | 47 ++++++++++++++++++---------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 78942f3434f..58d26e779e1 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4126,7 +4126,7 @@ def assisted_greedy_search( max_len = stopping_criteria[0].max_length # 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a - # `.generate()` call if we decide to add `past_key_values` as a possible output of the method, as we + # `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we # need access to the assistant cache to secure strong speedups. candidate_input_ids = input_ids for _ in range(int(assistant_model.max_assistant_tokens)): @@ -4247,13 +4247,13 @@ def assisted_greedy_search( # 6.2. Extract the logits for the next token if outputs.logits.shape[1] > candidate_length + 1: - logits_idx = new_cur_len - 1 + last_valid_output_idx = new_cur_len - 1 else: - logits_idx = n_matches - next_token_scores = outputs.logits[:, logits_idx, :] + last_valid_output_idx = n_matches + next_token_scores = outputs.logits[:, last_valid_output_idx, :] # 7. Use the set of logits after the last matching assistant token to obtain the next token. Note that, - # because of this step, assisted greedy search degenerates to a normal greedy search if there is no match. + # because of this step, assisted greedy search reduces to a normal greedy search if there is no match. next_tokens = torch.argmax(next_token_scores, dim=-1) # Assistant: main logic end; Compared to greedy search, the following (redundant) blocks were removed @@ -4264,22 +4264,37 @@ def assisted_greedy_search( continue # don't waste resources running the code we don't need # Store scores, attentions and hidden_states when required + # Assistant: modified to append one tuple element per token, as in the other generation methods. if return_dict_in_generate: if output_scores: - scores += (next_token_scores,) + scores += tuple(outputs.logits[:, i, :] for i in range(last_valid_output_idx)) if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) - ) if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - + cross_attentions += tuple( + layer[..., i, i] + for layer in outputs.cross_attentions + for i in range(last_valid_output_idx) + ) + decoder_attentions += tuple( + layer[..., i, i] + for layer in outputs.decoder_attentions + for i in range(last_valid_output_idx) + ) + else: + decoder_attentions += tuple( + layer[..., i, i] for layer in outputs.attentions for i in range(last_valid_output_idx) + ) if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) + if self.config.is_encoder_decoder: + decoder_hidden_states += tuple( + layer[:, i, :] + for layer in outputs.decoder_hidden_states + for i in range(last_valid_output_idx) + ) + else: + decoder_hidden_states += tuple( + layer[:, i, :] for layer in outputs.hidden_states for i in range(last_valid_output_idx) + ) # finished sentences should have their next token be a padding token if eos_token_id is not None: From b68262bfe25dd24c380ee7d23061e9027b9bb689 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 28 Mar 2023 14:47:35 +0000 Subject: [PATCH 05/12] tmp commit --- src/transformers/generation/utils.py | 46 +++++++++++++-------- tests/generation/test_utils.py | 61 ++++++++++++++++++++++++++-- 2 files changed, 86 insertions(+), 21 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 58d26e779e1..b6b1d90a8e4 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1429,6 +1429,8 @@ def generate( ) if batch_size > 1: raise ValueError("Assisted generation is only supported for batch_size = 1") + if not model_kwargs["use_cache"]: + raise ValueError("Assisted generation requires `use_cache=True`") # 11. If the assistant model is an encoder-decoder, prepare its encoder outputs if assistant_model.config.is_encoder_decoder: @@ -1480,6 +1482,8 @@ def generate( f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" " contrastive search." ) + if not model_kwargs["use_cache"]: + raise ValueError("Contrastive search requires `use_cache=True`") return self.contrastive_search( input_ids, @@ -4191,20 +4195,31 @@ def assisted_greedy_search( decoder_attention_mask=og_model_attn, past_key_values=model_kwargs["past_key_values"], encoder_outputs=model_kwargs["encoder_outputs"], + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, ) else: outputs = self( og_model_input_ids, attention_mask=og_model_attn, past_key_values=model_kwargs["past_key_values"], + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, ) else: if self.config.is_encoder_decoder: outputs = self( - decoder_input_ids=candidate_input_ids, encoder_outputs=model_kwargs["encoder_outputs"] + decoder_input_ids=candidate_input_ids, + encoder_outputs=model_kwargs["encoder_outputs"], + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, ) else: - outputs = self(candidate_input_ids) + outputs = self( + candidate_input_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) # 3. Obtain the argmax from the original model logits. if len(logits_processor) > 0: @@ -4231,12 +4246,9 @@ def assisted_greedy_search( # 6. Update variables according to the number of matching assistant tokens. n_matches = min(n_matches, max_len - cur_len) + if (last_assistant_token_is_eos and n_matches == candidate_length): # don't go beyond an EOS token + n_matches -= 1 input_ids = candidate_input_ids[:, 0 : cur_len + n_matches] - - # check stopping criteria here - if (last_assistant_token_is_eos and n_matches == candidate_length) or stopping_criteria(input_ids, None): - break - new_cur_len = input_ids.shape[-1] # 6.1. Discard past key values relative to unused assistant tokens @@ -4267,33 +4279,33 @@ def assisted_greedy_search( # Assistant: modified to append one tuple element per token, as in the other generation methods. if return_dict_in_generate: if output_scores: - scores += tuple(outputs.logits[:, i, :] for i in range(last_valid_output_idx)) + scores += tuple(outputs.logits[:, i, :] for i in range(last_valid_output_idx + 1)) if output_attentions: if self.config.is_encoder_decoder: cross_attentions += tuple( - layer[..., i, i] + layer[..., i:i + 1, :] for layer in outputs.cross_attentions - for i in range(last_valid_output_idx) + for i in range(last_valid_output_idx + 1) ) decoder_attentions += tuple( - layer[..., i, i] + layer[..., i:i + 1, -(last_valid_output_idx - i):] for layer in outputs.decoder_attentions - for i in range(last_valid_output_idx) + for i in range(last_valid_output_idx + 1) ) else: decoder_attentions += tuple( - layer[..., i, i] for layer in outputs.attentions for i in range(last_valid_output_idx) + layer[..., i:i + 1, -(last_valid_output_idx - i):] for layer in outputs.attentions for i in range(last_valid_output_idx + 1) ) if output_hidden_states: if self.config.is_encoder_decoder: decoder_hidden_states += tuple( - layer[:, i, :] + layer[:, i:i + 1, :] for layer in outputs.decoder_hidden_states - for i in range(last_valid_output_idx) + for i in range(last_valid_output_idx + 1) ) else: decoder_hidden_states += tuple( - layer[:, i, :] for layer in outputs.hidden_states for i in range(last_valid_output_idx) + layer[:, i:i + 1, :] for layer in outputs.hidden_states for i in range(last_valid_output_idx + 1) ) # finished sentences should have their next token be a padding token @@ -4315,7 +4327,7 @@ def assisted_greedy_search( ) # stop when each sentence is finished, or if we exceed the maximum length - if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, None): + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): if not synced_gpus: break else: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index c0278f6ae46..49faccb097a 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -79,14 +79,13 @@ class GenerationTesterMixin: all_generative_model_classes = () input_name = "input_ids" - def _get_input_ids_and_config(self): + def _get_input_ids_and_config(self, batch_size=2): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() input_ids = inputs_dict[self.input_name] # cut to half length & take max batch_size 3 - max_batch_size = 2 sequence_length = input_ids.shape[-1] // 2 - input_ids = input_ids[:max_batch_size, :sequence_length] + input_ids = input_ids[:batch_size, :sequence_length] # generate max 3 tokens max_length = input_ids.shape[-1] + 3 @@ -99,7 +98,7 @@ def _get_input_ids_and_config(self): if "transfoxl" in config.__class__.__name__.lower(): attention_mask = None else: - attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:max_batch_size, :sequence_length] + attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :sequence_length] return config, input_ids, attention_mask, max_length @@ -1458,6 +1457,60 @@ def test_contrastive_generate_dict_outputs_use_cache(self): for output in (output_contrastive, output_generate): self._check_outputs(output, input_ids, model.config, use_cache=True) + def test_assisted_greedy_search_matches_greedy_search(self): + # This test breaks the pattern above, for multiple reasons: + # - assisted_greedy_search, contrarily to the other methods, can't be called on its own (e.g. needs to + # prepare the assistant encoder outputs in the main generate body); + # - assisted_greedy_search does not support `use_cache = False` + # - assisted_greedy_search does not support `batch_size > 1` + # As such, this test ensures that the assisted generation does not introduce output changes over greedy search. + + for model_class in self.all_generative_model_classes: + # won't fix: FSMT and Reformer have a different cache variable type (and format). + if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): + return + + # enable cache + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) + + # NOTE: assisted generation only works with cache on at the moment. + if not hasattr(config, "use_cache"): + return + + config.use_cache = True + model = model_class(config).to(torch_device).eval() + output_greedy = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + num_beams=1, + do_sample=False, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + # Note: with assisted generate, if the same model is used as assistant, then all assistant tokens will + # match + output_assisted = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + num_beams=1, + do_sample=False, + assistant_model=model, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + breakpoint() + + self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist()) + + for output in (output_greedy, output_assisted): + self._check_outputs(output, input_ids, model.config, use_cache=True) + def test_generate_with_head_masking(self): """Test designed for encoder-decoder models to ensure the attention head masking is used.""" attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] From 00637d8f6fc4bbff1d00b3adea59641347441489 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 28 Mar 2023 19:13:26 +0000 Subject: [PATCH 06/12] tests almost ready --- src/transformers/generation/utils.py | 85 +++++++++++-------- tests/generation/test_utils.py | 6 +- .../test_modeling_bigbird_pegasus.py | 7 +- .../whisper/test_modeling_tf_whisper.py | 2 +- tests/models/whisper/test_modeling_whisper.py | 9 +- 5 files changed, 62 insertions(+), 47 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index b6b1d90a8e4..2d6c5381c04 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4181,6 +4181,8 @@ def assisted_greedy_search( ) if last_assistant_token_is_eos: break + else: + last_assistant_token_is_eos = False candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] @@ -4222,12 +4224,13 @@ def assisted_greedy_search( ) # 3. Obtain the argmax from the original model logits. + new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present if len(logits_processor) > 0: for i in range(candidate_length): - outputs.logits[:, i, :] = logits_processor( - candidate_input_ids[:, : cur_len + i], outputs.logits[:, i, :] + new_logits[:, i, :] = logits_processor( + candidate_input_ids[:, : cur_len + i], new_logits[:, i, :] ) - max_logits = outputs.logits.argmax(dim=-1)[:, -candidate_length - 1 : -1] + max_logits = new_logits.argmax(dim=-1)[:, -candidate_length - 1 : -1] # 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep # the assistant forecasted tokens until the first mismatch, or until the max length is reached. @@ -4245,24 +4248,22 @@ def assisted_greedy_search( assistant_model.max_assistant_tokens = 1.0 # 6. Update variables according to the number of matching assistant tokens. - n_matches = min(n_matches, max_len - cur_len) - if (last_assistant_token_is_eos and n_matches == candidate_length): # don't go beyond an EOS token + # 6.1. Ensure we don't generate beyond max_len or an EOS token (remember: one token will be added below) + if n_matches >= max_len - cur_len: + n_matches = max_len - cur_len - 1 + if (last_assistant_token_is_eos and n_matches == candidate_length): n_matches -= 1 input_ids = candidate_input_ids[:, 0 : cur_len + n_matches] new_cur_len = input_ids.shape[-1] - # 6.1. Discard past key values relative to unused assistant tokens + # 6.2. Discard past key values relative to unused assistant tokens outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cur_len) model_kwargs["assistant_past_key_values"] = _crop_past_key_values( assistant_model, model_kwargs["assistant_past_key_values"], new_cur_len ) - # 6.2. Extract the logits for the next token - if outputs.logits.shape[1] > candidate_length + 1: - last_valid_output_idx = new_cur_len - 1 - else: - last_valid_output_idx = n_matches - next_token_scores = outputs.logits[:, last_valid_output_idx, :] + # 6.3. Extract the logits for the next token + next_token_scores = new_logits[:, n_matches, :] # 7. Use the set of logits after the last matching assistant token to obtain the next token. Note that, # because of this step, assisted greedy search reduces to a normal greedy search if there is no match. @@ -4279,34 +4280,26 @@ def assisted_greedy_search( # Assistant: modified to append one tuple element per token, as in the other generation methods. if return_dict_in_generate: if output_scores: - scores += tuple(outputs.logits[:, i, :] for i in range(last_valid_output_idx + 1)) + scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1)) + + if "past_key_values" not in model_kwargs: + last_matching_idx = new_cur_len - 1 + prompt_length = cur_len + else: + last_matching_idx = n_matches + prompt_length = 0 + if output_attentions: if self.config.is_encoder_decoder: - cross_attentions += tuple( - layer[..., i:i + 1, :] - for layer in outputs.cross_attentions - for i in range(last_valid_output_idx + 1) - ) - decoder_attentions += tuple( - layer[..., i:i + 1, -(last_valid_output_idx - i):] - for layer in outputs.decoder_attentions - for i in range(last_valid_output_idx + 1) - ) + cross_attentions = _split_model_outputs(cross_attentions, outputs.cross_attentions, prompt_length, last_matching_idx) + decoder_attentions = _split_model_outputs(decoder_attentions, outputs.decoder_attentions, prompt_length, last_matching_idx, is_decoder_attention=True) else: - decoder_attentions += tuple( - layer[..., i:i + 1, -(last_valid_output_idx - i):] for layer in outputs.attentions for i in range(last_valid_output_idx + 1) - ) + decoder_attentions = _split_model_outputs(decoder_attentions, outputs.attentions, prompt_length, last_matching_idx, is_decoder_attention=True) if output_hidden_states: if self.config.is_encoder_decoder: - decoder_hidden_states += tuple( - layer[:, i:i + 1, :] - for layer in outputs.decoder_hidden_states - for i in range(last_valid_output_idx + 1) - ) + decoder_hidden_states = _split_model_outputs(decoder_hidden_states, outputs.decoder_hidden_states, prompt_length, last_matching_idx) else: - decoder_hidden_states += tuple( - layer[:, i:i + 1, :] for layer in outputs.hidden_states for i in range(last_valid_output_idx + 1) - ) + decoder_hidden_states = _split_model_outputs(decoder_hidden_states, outputs.hidden_states, prompt_length, last_matching_idx) # finished sentences should have their next token be a padding token if eos_token_id is not None: @@ -4356,6 +4349,7 @@ def assisted_greedy_search( def _crop_past_key_values(model, past_key_values, maximum_length): + """Crops the past key values up to a certain maximum length.""" new_past = [] if model.config.is_encoder_decoder: for idx in range(len(past_key_values)): @@ -4389,6 +4383,29 @@ def _crop_past_key_values(model, past_key_values, maximum_length): return past_key_values +def _split_model_outputs(outputs, new_outputs, prompt_length, last_matching_idx, is_decoder_attention=False): + """ + Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple + where each member corresponds to a single generated token. + """ + # Retrocompatibility: in our generation functions, the first iteration includes the attention/hidden states for the + # prompt. + if prompt_length > 0: + new_tuple = () + for layer in new_outputs: + last_dim_size = prompt_length if is_decoder_attention else layer.shape[-1] + new_tuple += (layer[..., :prompt_length, :last_dim_size],) + outputs += (new_tuple,) + + for i in range(prompt_length, last_matching_idx + 1): + new_tuple = () + for layer in new_outputs: + last_dim_size = i + 1 if is_decoder_attention else layer.shape[-1] + new_tuple += (layer[..., i:i + 1, :last_dim_size],) + outputs += (new_tuple,) + return outputs + + def top_k_top_p_filtering( logits: torch.FloatTensor, top_k: int = 0, diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 49faccb097a..9de141fa927 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1478,9 +1478,10 @@ def test_assisted_greedy_search_matches_greedy_search(self): return config.use_cache = True + config.is_decoder = True model = model_class(config).to(torch_device).eval() output_greedy = model.generate( - input_ids=input_ids, + input_ids, attention_mask=attention_mask, max_length=max_length, num_beams=1, @@ -1493,7 +1494,7 @@ def test_assisted_greedy_search_matches_greedy_search(self): # Note: with assisted generate, if the same model is used as assistant, then all assistant tokens will # match output_assisted = model.generate( - input_ids=input_ids, + input_ids, attention_mask=attention_mask, max_length=max_length, num_beams=1, @@ -1504,7 +1505,6 @@ def test_assisted_greedy_search_matches_greedy_search(self): output_attentions=True, return_dict_in_generate=True, ) - breakpoint() self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist()) diff --git a/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py index d7a8e6302d8..7588e0ee05f 100644 --- a/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py +++ b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py @@ -270,7 +270,7 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT # overwrite from GenerationTesterMixin to solve problem # with conflicting random seeds - def _get_input_ids_and_config(self): + def _get_input_ids_and_config(self, batch_size=2): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.attention_type = "original_full" @@ -278,10 +278,9 @@ def _get_input_ids_and_config(self): attention_mask = torch.ones_like(input_ids, dtype=torch.long) # cut to half length & take max batch_size 3 - max_batch_size = 2 sequence_length = input_ids.shape[-1] // 2 - input_ids = input_ids[:max_batch_size, :sequence_length] - attention_mask = attention_mask[:max_batch_size, :sequence_length] + input_ids = input_ids[:batch_size, :sequence_length] + attention_mask = attention_mask[:batch_size, :sequence_length] # generate max 3 tokens max_length = input_ids.shape[-1] + 3 diff --git a/tests/models/whisper/test_modeling_tf_whisper.py b/tests/models/whisper/test_modeling_tf_whisper.py index 2ef3cdcee02..d4abd8f5f03 100644 --- a/tests/models/whisper/test_modeling_tf_whisper.py +++ b/tests/models/whisper/test_modeling_tf_whisper.py @@ -303,7 +303,7 @@ def _get_input_ids_and_config(self): input_ids = input_ids[:max_batch_size, :, :] # generate max 3 tokens - max_length = input_ids.shape[-1] + 3 + max_length = 4 if config.eos_token_id is not None and config.pad_token_id is None: # hack to allow generate for models such as GPT2 as is done in `generate()` config.pad_token_id = config.eos_token_id diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index f0ba1a00f59..0fc9268c402 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -352,16 +352,15 @@ def test_encoder_decoder_model_standalone(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs) - def _get_input_ids_and_config(self): + def _get_input_ids_and_config(self, batch_size=3): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() input_ids = inputs_dict[self.input_name] - # cut to half length & take max batch_size 3 - max_batch_size = 3 - input_ids = input_ids[:max_batch_size, :, :] + # cut to half length & take max batch_size=batch_size + input_ids = input_ids[:batch_size, :, :] # generate max 3 tokens - max_length = input_ids.shape[-1] + 3 + max_length = 4 if config.eos_token_id is not None and config.pad_token_id is None: # hack to allow generate for models such as GPT2 as is done in `generate()` config.pad_token_id = config.eos_token_id From 6872d1d1e53e53299d939478e6e93b65e203156b Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Sat, 15 Apr 2023 16:44:53 +0000 Subject: [PATCH 07/12] tmp commit --- src/transformers/generation/utils.py | 36 ++++++++++++++++++++-------- tests/generation/test_utils.py | 10 ++++---- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index bc1e28be305..8bfb04d327e 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4298,9 +4298,7 @@ def assisted_greedy_search( new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present if len(logits_processor) > 0: for i in range(candidate_length): - new_logits[:, i, :] = logits_processor( - candidate_input_ids[:, : cur_len + i], new_logits[:, i, :] - ) + new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) max_logits = new_logits.argmax(dim=-1)[:, -candidate_length - 1 : -1] # 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep @@ -4322,7 +4320,7 @@ def assisted_greedy_search( # 6.1. Ensure we don't generate beyond max_len or an EOS token (remember: one token will be added below) if n_matches >= max_len - cur_len: n_matches = max_len - cur_len - 1 - if (last_assistant_token_is_eos and n_matches == candidate_length): + if last_assistant_token_is_eos and n_matches == candidate_length: n_matches -= 1 input_ids = candidate_input_ids[:, 0 : cur_len + n_matches] new_cur_len = input_ids.shape[-1] @@ -4362,15 +4360,33 @@ def assisted_greedy_search( if output_attentions: if self.config.is_encoder_decoder: - cross_attentions = _split_model_outputs(cross_attentions, outputs.cross_attentions, prompt_length, last_matching_idx) - decoder_attentions = _split_model_outputs(decoder_attentions, outputs.decoder_attentions, prompt_length, last_matching_idx, is_decoder_attention=True) + cross_attentions = _split_model_outputs( + cross_attentions, outputs.cross_attentions, prompt_length, last_matching_idx + ) + decoder_attentions = _split_model_outputs( + decoder_attentions, + outputs.decoder_attentions, + prompt_length, + last_matching_idx, + is_decoder_attention=True, + ) else: - decoder_attentions = _split_model_outputs(decoder_attentions, outputs.attentions, prompt_length, last_matching_idx, is_decoder_attention=True) + decoder_attentions = _split_model_outputs( + decoder_attentions, + outputs.attentions, + prompt_length, + last_matching_idx, + is_decoder_attention=True, + ) if output_hidden_states: if self.config.is_encoder_decoder: - decoder_hidden_states = _split_model_outputs(decoder_hidden_states, outputs.decoder_hidden_states, prompt_length, last_matching_idx) + decoder_hidden_states = _split_model_outputs( + decoder_hidden_states, outputs.decoder_hidden_states, prompt_length, last_matching_idx + ) else: - decoder_hidden_states = _split_model_outputs(decoder_hidden_states, outputs.hidden_states, prompt_length, last_matching_idx) + decoder_hidden_states = _split_model_outputs( + decoder_hidden_states, outputs.hidden_states, prompt_length, last_matching_idx + ) # finished sentences should have their next token be a padding token if eos_token_id is not None: @@ -4472,7 +4488,7 @@ def _split_model_outputs(outputs, new_outputs, prompt_length, last_matching_idx, new_tuple = () for layer in new_outputs: last_dim_size = i + 1 if is_decoder_attention else layer.shape[-1] - new_tuple += (layer[..., i:i + 1, :last_dim_size],) + new_tuple += (layer[..., i : i + 1, :last_dim_size],) outputs += (new_tuple,) return outputs diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 43b8a1f9e3f..45f7674770a 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1458,17 +1458,19 @@ def test_contrastive_generate_dict_outputs_use_cache(self): self._check_outputs(output, input_ids, model.config, use_cache=True) def test_assisted_greedy_search_matches_greedy_search(self): - # This test breaks the pattern above, for multiple reasons: + # This test ensures that the assisted generation does not introduce output changes over greedy search. + # It breaks the pattern in the tests above, for multiple reasons: # - assisted_greedy_search, contrarily to the other methods, can't be called on its own (e.g. needs to - # prepare the assistant encoder outputs in the main generate body); + # prepare the assistant encoder outputs in the main generate body); # - assisted_greedy_search does not support `use_cache = False` # - assisted_greedy_search does not support `batch_size > 1` - # As such, this test ensures that the assisted generation does not introduce output changes over greedy search. for model_class in self.all_generative_model_classes: # won't fix: FSMT and Reformer have a different cache variable type (and format). if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): return + # may fix in the future: the following models fail to pass this test + # enable cache config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) @@ -1492,7 +1494,7 @@ def test_assisted_greedy_search_matches_greedy_search(self): return_dict_in_generate=True, ) # Note: with assisted generate, if the same model is used as assistant, then all assistant tokens will - # match + # be correct output_assisted = model.generate( input_ids, attention_mask=attention_mask, From 8937b89d1056325f0c8ff5f59492dda0d73e353f Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 17 Apr 2023 15:09:31 +0000 Subject: [PATCH 08/12] skip a few models --- tests/generation/test_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 45f7674770a..4246154adba 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1469,8 +1469,12 @@ def test_assisted_greedy_search_matches_greedy_search(self): # won't fix: FSMT and Reformer have a different cache variable type (and format). if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): return - # may fix in the future: the following models fail to pass this test - + # may fix in the future: the following models fail to pass this test, and need model-specific fixes + if any( + model_name in model_class.__name__.lower() + for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text"] + ): + return # enable cache config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) From 3d91f91cd6b272ac7df82ba7c2b7b21e4c807709 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 17 Apr 2023 17:58:24 +0000 Subject: [PATCH 09/12] Add streaming; Docs and examples --- docs/source/en/generation_strategies.mdx | 27 ++++++++++++++++++++++++ src/transformers/generation/utils.py | 26 +++++++++++++++++++---- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/docs/source/en/generation_strategies.mdx b/docs/source/en/generation_strategies.mdx index ced19762f07..9c5d11af10c 100644 --- a/docs/source/en/generation_strategies.mdx +++ b/docs/source/en/generation_strategies.mdx @@ -332,3 +332,30 @@ The groups are selected to ensure they are distinct enough compared to the other This guide illustrates the main parameters that enable various decoding strategies. More advanced parameters exist for the [`generate`] method, which gives you even further control over the [`generate`] method's behavior. For the complete list of the available parameters, refer to the [API documentation](./main_classes/text_generation.mdx). + +### Assisted Generation + +Assisted generation is a modification of the decoding strategies above that uses an assistant model with the same +tokenizer (ideally a much smaller model) to speed up the decoding process. Currently, only assisted greedy search is +supported. + + + +To enable assisted generation, set the `assistant_model` argument with a model. + +```python +>>> from transformers import AutoModelForCausalLM, AutoTokenizer + +>>> prompt = "I look forward to" +>>> checkpoint = "gpt2" +>>> assistant_checkpoint = "distilgpt2" + +>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) +>>> inputs = tokenizer(prompt, return_tensors="pt") + +>>> model = AutoModelForCausalLM.from_pretrained(checkpoint) +>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint) +>>> outputs = model.generate(**inputs, assistant_model=assistant_model) +>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) +['I look forward to seeing you in the future."\n\n"I\'m sure you\'ll be fine'] +``` diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8bfb04d327e..65aab2a06e7 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1495,6 +1495,7 @@ def generate( output_scores=generation_config.output_scores, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, + streamer=streamer, **model_kwargs, ) if is_greedy_gen_mode: @@ -4057,6 +4058,7 @@ def assisted_greedy_search( output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: Optional[bool] = False, + streamer: Optional["BaseStreamer"] = None, **model_kwargs, ): r""" @@ -4101,18 +4103,22 @@ def assisted_greedy_search( Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. model_kwargs: Additional model specific keyword arguments will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. + Return: [`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a [`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. + Examples: - ////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - update the example before committing + ```python >>> from transformers import ( ... AutoTokenizer, @@ -4125,6 +4131,7 @@ def assisted_greedy_search( >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + >>> assistant_model = AutoModelForCausalLM.from_pretrained("distilgpt2") >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token >>> model.generation_config.pad_token_id = model.generation_config.eos_token_id >>> input_prompt = "It might be possible to" @@ -4136,8 +4143,11 @@ def assisted_greedy_search( ... ] ... ) >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) - >>> outputs = model.greedy_search( - ... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria + >>> outputs = model.assisted_greedy_search( + ... input_ids, + ... assistant_model=assistant_model, + ... logits_processor=logits_processor, + ... stopping_criteria=stopping_criteria, ... ) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ["It might be possible to get a better understanding of the nature of the problem, but it's not"] @@ -4324,6 +4334,8 @@ def assisted_greedy_search( n_matches -= 1 input_ids = candidate_input_ids[:, 0 : cur_len + n_matches] new_cur_len = input_ids.shape[-1] + if streamer is not None: + streamer.put(candidate_input_ids[:, cur_len : cur_len + n_matches]) # 6.2. Discard past key values relative to unused assistant tokens outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cur_len) @@ -4396,6 +4408,9 @@ def assisted_greedy_search( # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if streamer is not None: + streamer.put(next_tokens.cpu()) + model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) @@ -4413,6 +4428,9 @@ def assisted_greedy_search( else: this_peer_finished = True + if streamer is not None: + streamer.end() + if return_dict_in_generate: if self.config.is_encoder_decoder: return GreedySearchEncoderDecoderOutput( From d8d3a1d384605dadc31b7ae5ba206167d888f71a Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 17 Apr 2023 18:46:18 +0000 Subject: [PATCH 10/12] document limitations --- docs/source/en/generation_strategies.mdx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/generation_strategies.mdx b/docs/source/en/generation_strategies.mdx index 9c5d11af10c..d091287433c 100644 --- a/docs/source/en/generation_strategies.mdx +++ b/docs/source/en/generation_strategies.mdx @@ -336,8 +336,8 @@ For the complete list of the available parameters, refer to the [API documentati ### Assisted Generation Assisted generation is a modification of the decoding strategies above that uses an assistant model with the same -tokenizer (ideally a much smaller model) to speed up the decoding process. Currently, only assisted greedy search is -supported. +tokenizer (ideally a much smaller model) to speed up the decoding process. Currently only assisted greedy search is +supported, and doesn't support batched inputs. From 3f16cdf3f1ae3ed0861b547db741e3e6bee7bd8c Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 18 Apr 2023 14:44:24 +0000 Subject: [PATCH 11/12] PR commits --- docs/source/en/generation_strategies.mdx | 8 ++++---- src/transformers/generation/utils.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/en/generation_strategies.mdx b/docs/source/en/generation_strategies.mdx index d091287433c..c3d1f953bfa 100644 --- a/docs/source/en/generation_strategies.mdx +++ b/docs/source/en/generation_strategies.mdx @@ -346,9 +346,9 @@ To enable assisted generation, set the `assistant_model` argument with a model. ```python >>> from transformers import AutoModelForCausalLM, AutoTokenizer ->>> prompt = "I look forward to" ->>> checkpoint = "gpt2" ->>> assistant_checkpoint = "distilgpt2" +>>> prompt = "Alice and Bob" +>>> checkpoint = "EleutherAI/pythia-1.4b-deduped" +>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped" >>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) >>> inputs = tokenizer(prompt, return_tensors="pt") @@ -357,5 +357,5 @@ To enable assisted generation, set the `assistant_model` argument with a model. >>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint) >>> outputs = model.generate(**inputs, assistant_model=assistant_model) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) -['I look forward to seeing you in the future."\n\n"I\'m sure you\'ll be fine'] +['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a'] ``` diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 65aab2a06e7..8411449a579 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1145,7 +1145,7 @@ def generate( logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - synced_gpus: Optional[bool] = False, + synced_gpus: Optional[bool] = None, assistant_model: Optional["PreTrainedModel"] = None, streamer: Optional["BaseStreamer"] = None, **kwargs, @@ -1193,7 +1193,7 @@ def generate( on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful for constrained generation conditioned on the prefix, as described in [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904). - synced_gpus (`bool`, *optional*, defaults to `False`): + synced_gpus (`bool`, *optional*): Whether to continue running the while loop until max_length. Unless overridden this flag will be set to `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished generating before other GPUs. Otherwise it'll be set to `False`. From a4c164a5f80c6f79c72bb5143c38143b6f269536 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 18 Apr 2023 15:05:32 +0000 Subject: [PATCH 12/12] Amy PR comments --- src/transformers/generation/utils.py | 37 +++++++++++++--------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8411449a579..ef4068439a2 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1465,8 +1465,8 @@ def generate( if is_assisted_greedy_gen_mode: if generation_config.num_return_sequences > 1: raise ValueError( - f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" - " assisted greedy search." + "num_return_sequences has to be 1 when doing assisted greedy search, " + f"but is {generation_config.num_return_sequences}." ) if batch_size > 1: raise ValueError("Assisted generation is only supported for batch_size = 1") @@ -1501,8 +1501,8 @@ def generate( if is_greedy_gen_mode: if generation_config.num_return_sequences > 1: raise ValueError( - f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" - " greedy search." + "num_return_sequences has to be 1 when doing greedy search, " + f"but is {generation_config.num_return_sequences}." ) # 11. run greedy search @@ -1522,8 +1522,8 @@ def generate( elif is_contrastive_search_gen_mode: if generation_config.num_return_sequences > 1: raise ValueError( - f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" - " contrastive search." + "num_return_sequences has to be 1 when doing contrastive search, " + f"but is {generation_config.num_return_sequences}." ) if not model_kwargs["use_cache"]: raise ValueError("Contrastive search requires `use_cache=True`") @@ -1796,7 +1796,7 @@ def contrastive_search( output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, + synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[ContrastiveSearchOutput, torch.LongTensor]: @@ -2163,7 +2163,7 @@ def greedy_search( output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, + synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[GreedySearchOutput, torch.LongTensor]: @@ -2419,7 +2419,7 @@ def sample( output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, + synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[SampleOutput, torch.LongTensor]: @@ -2697,7 +2697,7 @@ def beam_search( output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, + synced_gpus: bool = False, **model_kwargs, ) -> Union[BeamSearchOutput, torch.LongTensor]: r""" @@ -3021,7 +3021,7 @@ def beam_sample( output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, + synced_gpus: bool = False, **model_kwargs, ) -> Union[BeamSampleOutput, torch.LongTensor]: r""" @@ -3353,7 +3353,7 @@ def group_beam_search( output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, + synced_gpus: bool = False, **model_kwargs, ): r""" @@ -4057,7 +4057,7 @@ def assisted_greedy_search( output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, + synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, **model_kwargs, ): @@ -4162,6 +4162,8 @@ def assisted_greedy_search( stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if eos_token_id is not None and pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None @@ -4322,14 +4324,11 @@ def assisted_greedy_search( if n_matches == int(assistant_model.max_assistant_tokens): assistant_model.max_assistant_tokens += 2.0 else: - assistant_model.max_assistant_tokens -= 1.0 - if assistant_model.max_assistant_tokens < 1.0: - assistant_model.max_assistant_tokens = 1.0 + assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0) # 6. Update variables according to the number of matching assistant tokens. # 6.1. Ensure we don't generate beyond max_len or an EOS token (remember: one token will be added below) - if n_matches >= max_len - cur_len: - n_matches = max_len - cur_len - 1 + n_matches = min(n_matches, max_len - cur_len - 1) if last_assistant_token_is_eos and n_matches == candidate_length: n_matches -= 1 input_ids = candidate_input_ids[:, 0 : cur_len + n_matches] @@ -4402,8 +4401,6 @@ def assisted_greedy_search( # finished sentences should have their next token be a padding token if eos_token_id is not None: - if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step