From 584de1803a836516ea170b5493a526511fb518b8 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 5 Mar 2024 13:10:43 +0100 Subject: [PATCH 01/24] add eos stopping criteria --- src/transformers/generation/__init__.py | 2 + .../generation/stopping_criteria.py | 24 ++- src/transformers/generation/utils.py | 203 ++++++++++++------ tests/generation/test_stopping_criteria.py | 16 ++ tests/generation/test_utils.py | 2 + 5 files changed, 183 insertions(+), 64 deletions(-) diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index e45f546cdc2780..9d10d6a86a957b 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -82,6 +82,7 @@ "MaxNewTokensCriteria", "MaxLengthCriteria", "MaxTimeCriteria", + "EOSTokenCriteria", "StoppingCriteria", "StoppingCriteriaList", "validate_stopping_criteria", @@ -218,6 +219,7 @@ WhisperTimeStampLogitsProcessor, ) from .stopping_criteria import ( + EOSTokenCriteria, MaxLengthCriteria, MaxNewTokensCriteria, MaxTimeCriteria, diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index f4624296d237f7..d9dd597a176cc7 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -2,7 +2,7 @@ import warnings from abc import ABC from copy import deepcopy -from typing import Optional +from typing import List, Optional, Union import torch @@ -129,6 +129,28 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool) +class EOSTokenCriteria(StoppingCriteria): + """ + This class can be used to stop generation whenever the "end-of-sequence" token in generated. + By default, it uses the `EOS` token from model's generation config. + + Args: + eos_token_id (`Union[int, List[int]]`): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + """ + + def __init__(self, eos_token_id: Union[int, List[int]]): + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + self.eos_token_id = eos_token_id + + @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: + eos_token_ids = torch.tensor(self.eos_token_id, dtype=torch.int64, device=input_ids.device) + is_done = (input_ids[:, -1].unsqueeze(1) == eos_token_ids).any(dim=1) + return is_done + + class StoppingCriteriaList(list): @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 09e25958ac97b8..554df3aee78a45 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -75,6 +75,7 @@ UnbatchedClassifierFreeGuidanceLogitsProcessor, ) from .stopping_criteria import ( + EOSTokenCriteria, MaxLengthCriteria, MaxTimeCriteria, StoppingCriteria, @@ -942,6 +943,8 @@ def _get_stopping_criteria( ) if generation_config.max_time is not None: criteria.append(MaxTimeCriteria(max_time=generation_config.max_time)) + if generation_config.eos_token_id is not None: + criteria.append(EOSTokenCriteria(eos_token_id=generation_config.eos_token_id)) criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) return criteria @@ -1922,11 +1925,24 @@ def _contrastive_search( logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - sequential = sequential if sequential is not None else self.generation_config.low_memory + if eos_token_id is not None: + warnings.warn( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.", + FutureWarning, + ) + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + else: + eos_token_id = [ + criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] + if not eos_token_id and self.generation_config.eos_token_id: + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + eos_token_id = self.generation_config.eos_token_id + if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + sequential = sequential if sequential is not None else self.generation_config.low_memory output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_attentions = ( @@ -2185,10 +2201,11 @@ def _contrastive_search( continue # don't waste resources running the code we don't need # finished sentences should have their next token be a padding token - if eos_token_id is not None: - if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + for criteria in stopping_criteria: + if hasattr(criteria, "eos_token_id") and criteria.eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) @@ -2198,15 +2215,8 @@ def _contrastive_search( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id_tensor is not None: - unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) - ) - # stop when each sentence is finished unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) - if unfinished_sequences.max() == 0: this_peer_finished = True @@ -2383,9 +2393,23 @@ def _greedy_search( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if eos_token_id is not None: + warnings.warn( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.", + FutureWarning, + ) + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + else: + eos_token_id = [ + criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] + if not eos_token_id and self.generation_config.eos_token_id: + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + eos_token_id = self.generation_config.eos_token_id + if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions @@ -2471,10 +2495,11 @@ def _greedy_search( next_tokens = torch.argmax(next_tokens_scores, dim=-1) # finished sentences should have their next token be a padding token - if eos_token_id is not None: - if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + for criteria in stopping_criteria: + if hasattr(criteria, "eos_token_id") and criteria.eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) @@ -2487,14 +2512,7 @@ def _greedy_search( model_inputs=model_inputs, ) - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id_tensor is not None: - unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) - ) - unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) - # stop when each sentence is finished if unfinished_sequences.max() == 0: this_peer_finished = True @@ -2680,10 +2698,23 @@ def _sample( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if eos_token_id is not None: + warnings.warn( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.", + FutureWarning, + ) + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + else: + eos_token_id = [ + criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] + if not eos_token_id and self.generation_config.eos_token_id: + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + eos_token_id = self.generation_config.eos_token_id + if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_attentions = ( @@ -2773,10 +2804,11 @@ def _sample( next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # finished sentences should have their next token be a padding token - if eos_token_id is not None: - if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + for criteria in stopping_criteria: + if hasattr(criteria, "eos_token_id") and criteria.eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) @@ -2786,14 +2818,7 @@ def _sample( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id_tensor is not None: - unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) - ) - unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) - # stop when each sentence is finished if unfinished_sequences.max() == 0: this_peer_finished = True @@ -3007,7 +3032,21 @@ def _beam_search( if len(stopping_criteria) == 0: warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if eos_token_id is not None: + warnings.warn( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.", + FutureWarning, + ) + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + else: + eos_token_id = [ + criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] + if not eos_token_id and self.generation_config.eos_token_id: + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + eos_token_id = self.generation_config.eos_token_id + if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores @@ -3401,7 +3440,21 @@ def _beam_sample( ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if eos_token_id is not None: + warnings.warn( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.", + FutureWarning, + ) + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + else: + eos_token_id = [ + criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] + if not eos_token_id and self.generation_config.eos_token_id: + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + eos_token_id = self.generation_config.eos_token_id + if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores @@ -3748,7 +3801,21 @@ def _group_beam_search( ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if eos_token_id is not None: + warnings.warn( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.", + FutureWarning, + ) + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + else: + eos_token_id = [ + criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] + if not eos_token_id and self.generation_config.eos_token_id: + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + eos_token_id = self.generation_config.eos_token_id + if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores @@ -4159,7 +4226,21 @@ def _constrained_beam_search( if len(stopping_criteria) == 0: warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if eos_token_id is not None: + warnings.warn( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.", + FutureWarning, + ) + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + else: + eos_token_id = [ + criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] + if not eos_token_id and self.generation_config.eos_token_id: + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + eos_token_id = self.generation_config.eos_token_id + if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores @@ -4502,11 +4583,23 @@ def _assisted_decoding( stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - if eos_token_id is not None and pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + if eos_token_id is not None: + warnings.warn( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.", + FutureWarning, + ) + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + else: + eos_token_id = [ + criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] + if not eos_token_id and self.generation_config.eos_token_id: + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + eos_token_id = self.generation_config.eos_token_id + if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_attentions = ( @@ -4562,13 +4655,7 @@ def _assisted_decoding( candidate_logits = candidate_logits.to(self.device) candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] - last_assistant_token_is_eos = ( - ~candidate_input_ids[:, -1] - .tile(eos_token_id_tensor.shape[0], 1) - .ne(eos_token_id_tensor.unsqueeze(1)) - .prod(dim=0) - .bool() - ) + last_assistant_token_is_eos = stopping_criteria[-1](candidate_input_ids, None) # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct, @@ -4701,17 +4788,7 @@ def _assisted_decoding( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id_tensor is not None: - unfinished_sequences = unfinished_sequences.mul( - input_ids[:, -1] - .tile(eos_token_id_tensor.shape[0], 1) - .ne(eos_token_id_tensor.unsqueeze(1)) - .prod(dim=0) - ) - unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) - # stop when each sentence is finished if unfinished_sequences.max() == 0: this_peer_finished = True diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index 7fa118c9e3550d..f051c4d065080a 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -26,6 +26,7 @@ import torch from transformers.generation import ( + EOSTokenCriteria, MaxLengthCriteria, MaxNewTokensCriteria, MaxTimeCriteria, @@ -98,6 +99,21 @@ def test_max_time_criteria(self): criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2) self.assertTrue(all(criteria(input_ids, scores))) + def test_eos_token_criteria(self): + criteria = EOSTokenCriteria(eos_token_id=0) + + input_ids, scores = self._get_tensors(5) + input_ids[:, -1] = 0 + self.assertTrue(all(criteria(input_ids, scores))) + + input_ids, scores = self._get_tensors(5) + input_ids[:2, -1] = 0 + self.assertListEqual(criteria(input_ids, scores).tolist(), [True, True, False]) + + input_ids, scores = self._get_tensors(5) + input_ids[:, -1] = 1 + self.assertListEqual(criteria(input_ids, scores).tolist(), [False, False, False]) + def test_validate_stopping_criteria(self): validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index cb224c3c6a9d74..4d109d1176549e 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3348,11 +3348,13 @@ def test_default_max_length_warning(self): # Explicitly setting max_length to 20 -> no warning with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("ignore", category=FutureWarning) model.generate(input_ids, max_length=20) self.assertEqual(len(warning_list), 0) # Generation config max_length != 20 -> no warning with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("ignore", category=FutureWarning) # generation_config is modified -> legacy mode is disabled = generation_config takes precedence model.generation_config.max_length = 10 model.generate(input_ids) From 79a47c4545ac5127dc42d5a2e7d38b3d21fbcd3a Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 5 Mar 2024 13:19:49 +0100 Subject: [PATCH 02/24] minor fix --- src/transformers/generation/utils.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 554df3aee78a45..941a50a4c677c7 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2201,11 +2201,10 @@ def _contrastive_search( continue # don't waste resources running the code we don't need # finished sentences should have their next token be a padding token - for criteria in stopping_criteria: - if hasattr(criteria, "eos_token_id") and criteria.eos_token_id is not None: - if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) @@ -2495,11 +2494,10 @@ def _greedy_search( next_tokens = torch.argmax(next_tokens_scores, dim=-1) # finished sentences should have their next token be a padding token - for criteria in stopping_criteria: - if hasattr(criteria, "eos_token_id") and criteria.eos_token_id is not None: - if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) @@ -2804,11 +2802,10 @@ def _sample( next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # finished sentences should have their next token be a padding token - for criteria in stopping_criteria: - if hasattr(criteria, "eos_token_id") and criteria.eos_token_id is not None: - if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) From 6c93f8b6cb4c3037732e5cf9cf1a25acb0aa32e2 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 5 Mar 2024 22:37:45 +0500 Subject: [PATCH 03/24] Update tests/generation/test_stopping_criteria.py Co-authored-by: Joao Gante --- tests/generation/test_stopping_criteria.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index f051c4d065080a..8304c2d2211ea6 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -108,6 +108,7 @@ def test_eos_token_criteria(self): input_ids, scores = self._get_tensors(5) input_ids[:2, -1] = 0 + input_ids[2, -1] = 1 self.assertListEqual(criteria(input_ids, scores).tolist(), [True, True, False]) input_ids, scores = self._get_tensors(5) From f59b83f5b8863fefd8ebc9530ad88eba55b7585c Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 5 Mar 2024 18:56:52 +0100 Subject: [PATCH 04/24] check eos is not None and fix tests --- .../generation/stopping_criteria.py | 5 +- src/transformers/generation/utils.py | 112 ++++++++++-------- 2 files changed, 66 insertions(+), 51 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index d9dd597a176cc7..8170bb44fbc63a 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -142,12 +142,11 @@ class EOSTokenCriteria(StoppingCriteria): def __init__(self, eos_token_id: Union[int, List[int]]): if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - self.eos_token_id = eos_token_id + self.eos_token_id = torch.tensor(eos_token_id) @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: - eos_token_ids = torch.tensor(self.eos_token_id, dtype=torch.int64, device=input_ids.device) - is_done = (input_ids[:, -1].unsqueeze(1) == eos_token_ids).any(dim=1) + is_done = torch.isin(input_ids, self.eos_token_id.to(input_ids.device))[:, -1] return is_done diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 941a50a4c677c7..8aabbbd342eeb0 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1542,7 +1542,6 @@ def generate( logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, @@ -1557,7 +1556,6 @@ def generate( logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, @@ -1577,7 +1575,6 @@ def generate( logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, @@ -1606,7 +1603,6 @@ def generate( logits_warper=logits_warper, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, @@ -1640,7 +1636,6 @@ def generate( logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, @@ -1680,7 +1675,6 @@ def generate( logits_warper=logits_warper, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, @@ -1714,7 +1708,6 @@ def generate( logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, @@ -1788,7 +1781,6 @@ def typeerror(): logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, @@ -1933,12 +1925,15 @@ def _contrastive_search( ) stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) else: - eos_token_id = [ - criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") - ] - if not eos_token_id and self.generation_config.eos_token_id: - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + # TODO remove when the method is totally private + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = ( + [criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")] + ) + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] @@ -2400,12 +2395,15 @@ def _greedy_search( ) stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) else: - eos_token_id = [ - criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") - ] - if not eos_token_id and self.generation_config.eos_token_id: - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + # TODO remove when the method is totally private + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = ( + [criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")] + ) + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] @@ -2704,12 +2702,15 @@ def _sample( ) stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) else: - eos_token_id = [ - criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") - ] - if not eos_token_id and self.generation_config.eos_token_id: - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + # TODO remove when the method is totally private + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = ( + [criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")] + ) + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] @@ -3037,12 +3038,15 @@ def _beam_search( ) stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) else: - eos_token_id = [ - criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") - ] - if not eos_token_id and self.generation_config.eos_token_id: - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + # TODO remove when the method is totally private and beam scorer refactored + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = ( + [criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")] + ) + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] @@ -3445,12 +3449,15 @@ def _beam_sample( ) stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) else: - eos_token_id = [ - criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") - ] - if not eos_token_id and self.generation_config.eos_token_id: - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + # TODO remove when the method is totally private and beam scorer refactored + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = ( + [criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")] + ) + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] @@ -3806,12 +3813,15 @@ def _group_beam_search( ) stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) else: - eos_token_id = [ - criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") - ] - if not eos_token_id and self.generation_config.eos_token_id: - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + # TODO remove when the method is totally private and beam scorer refactored + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = ( + [criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")] + ) + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] @@ -4231,12 +4241,15 @@ def _constrained_beam_search( ) stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) else: - eos_token_id = [ - criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") - ] - if not eos_token_id and self.generation_config.eos_token_id: - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + # TODO remove when the method is totally private and beam scorer refactored + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = ( + [criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")] + ) + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] @@ -4588,12 +4601,15 @@ def _assisted_decoding( ) stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) else: - eos_token_id = [ - criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") - ] - if not eos_token_id and self.generation_config.eos_token_id: - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + # TODO remove when the method is totally private and beam scorer refactored + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = ( + [criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")] + ) + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] From 8ebad2d895b7a156250f988bb5446c774bf8fab3 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 5 Mar 2024 19:46:03 +0100 Subject: [PATCH 05/24] make style and fixup --- src/transformers/generation/utils.py | 64 ++++++++++++++-------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8aabbbd342eeb0..677f7f8974c04f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1926,10 +1926,10 @@ def _contrastive_search( stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) else: # TODO remove when the method is totally private - # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever - eos_token_id = ( - [criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")] - ) + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] eos_token_id = eos_token_id[0] if eos_token_id else None if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id @@ -2396,10 +2396,10 @@ def _greedy_search( stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) else: # TODO remove when the method is totally private - # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever - eos_token_id = ( - [criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")] - ) + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] eos_token_id = eos_token_id[0] if eos_token_id else None if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id @@ -2703,10 +2703,10 @@ def _sample( stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) else: # TODO remove when the method is totally private - # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever - eos_token_id = ( - [criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")] - ) + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] eos_token_id = eos_token_id[0] if eos_token_id else None if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id @@ -3039,10 +3039,10 @@ def _beam_search( stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) else: # TODO remove when the method is totally private and beam scorer refactored - # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever - eos_token_id = ( - [criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")] - ) + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] eos_token_id = eos_token_id[0] if eos_token_id else None if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id @@ -3450,10 +3450,10 @@ def _beam_sample( stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) else: # TODO remove when the method is totally private and beam scorer refactored - # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever - eos_token_id = ( - [criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")] - ) + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] eos_token_id = eos_token_id[0] if eos_token_id else None if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id @@ -3814,10 +3814,10 @@ def _group_beam_search( stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) else: # TODO remove when the method is totally private and beam scorer refactored - # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever - eos_token_id = ( - [criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")] - ) + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] eos_token_id = eos_token_id[0] if eos_token_id else None if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id @@ -4242,10 +4242,10 @@ def _constrained_beam_search( stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) else: # TODO remove when the method is totally private and beam scorer refactored - # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever - eos_token_id = ( - [criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")] - ) + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] eos_token_id = eos_token_id[0] if eos_token_id else None if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id @@ -4602,10 +4602,10 @@ def _assisted_decoding( stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) else: # TODO remove when the method is totally private and beam scorer refactored - # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever - eos_token_id = ( - [criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")] - ) + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] eos_token_id = eos_token_id[0] if eos_token_id else None if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id From 3e2507b66a04b3b59747de138a4c602993164d7d Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 6 Mar 2024 13:52:23 +0500 Subject: [PATCH 06/24] Update src/transformers/generation/stopping_criteria.py Co-authored-by: Joao Gante --- src/transformers/generation/stopping_criteria.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 8170bb44fbc63a..2be4c11654fba3 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -146,7 +146,7 @@ def __init__(self, eos_token_id: Union[int, List[int]]): @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: - is_done = torch.isin(input_ids, self.eos_token_id.to(input_ids.device))[:, -1] + is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device)) return is_done From b77b6ab45c9d3c4dc0b9adad0b4412153a8767b7 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 6 Mar 2024 13:52:49 +0500 Subject: [PATCH 07/24] Update tests/generation/test_utils.py Co-authored-by: Joao Gante --- tests/generation/test_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 4d109d1176549e..adf3c77a9b15fa 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3348,7 +3348,6 @@ def test_default_max_length_warning(self): # Explicitly setting max_length to 20 -> no warning with warnings.catch_warnings(record=True) as warning_list: - warnings.simplefilter("ignore", category=FutureWarning) model.generate(input_ids, max_length=20) self.assertEqual(len(warning_list), 0) From edc76df302c062a50564d68e3449b4e6a9a2db37 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 6 Mar 2024 13:52:56 +0500 Subject: [PATCH 08/24] Update tests/generation/test_utils.py Co-authored-by: Joao Gante --- tests/generation/test_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index adf3c77a9b15fa..cb224c3c6a9d74 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3353,7 +3353,6 @@ def test_default_max_length_warning(self): # Generation config max_length != 20 -> no warning with warnings.catch_warnings(record=True) as warning_list: - warnings.simplefilter("ignore", category=FutureWarning) # generation_config is modified -> legacy mode is disabled = generation_config takes precedence model.generation_config.max_length = 10 model.generate(input_ids) From f71a68786aa745e0396869082f8f41b6e40f3736 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 6 Mar 2024 13:53:49 +0500 Subject: [PATCH 09/24] Update src/transformers/generation/__init__.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/generation/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index 9d10d6a86a957b..7a14af7d0cacca 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -82,7 +82,7 @@ "MaxNewTokensCriteria", "MaxLengthCriteria", "MaxTimeCriteria", - "EOSTokenCriteria", + "EosTokenCriteria", "StoppingCriteria", "StoppingCriteriaList", "validate_stopping_criteria", From 387be0ed894b312ce3baf83d914264075325f6a0 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 6 Mar 2024 13:53:55 +0500 Subject: [PATCH 10/24] Update src/transformers/generation/stopping_criteria.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/generation/stopping_criteria.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 2be4c11654fba3..9b265d2063c1da 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -129,7 +129,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool) -class EOSTokenCriteria(StoppingCriteria): +class EosTokenCriteria(StoppingCriteria): """ This class can be used to stop generation whenever the "end-of-sequence" token in generated. By default, it uses the `EOS` token from model's generation config. From 14ece046a5758020ff35e199980ca060b4103bdf Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 6 Mar 2024 13:54:00 +0500 Subject: [PATCH 11/24] Update src/transformers/generation/stopping_criteria.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/generation/stopping_criteria.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 9b265d2063c1da..0ed27c7573d391 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -131,7 +131,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa class EosTokenCriteria(StoppingCriteria): """ - This class can be used to stop generation whenever the "end-of-sequence" token in generated. + This class can be used to stop generation whenever the "end-of-sequence" token is generated. By default, it uses the `EOS` token from model's generation config. Args: From bc3eea99648d9a7f10f4d3ca6da4c031f0acfd46 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 6 Mar 2024 13:54:05 +0500 Subject: [PATCH 12/24] Update src/transformers/generation/stopping_criteria.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/generation/stopping_criteria.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 0ed27c7573d391..bac537b71b96ec 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -132,7 +132,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa class EosTokenCriteria(StoppingCriteria): """ This class can be used to stop generation whenever the "end-of-sequence" token is generated. - By default, it uses the `EOS` token from model's generation config. + By default, it uses the `model.generation_config.eos_token_id`. Args: eos_token_id (`Union[int, List[int]]`): From 7518aeac5b17f650c4c6f9c6cd1a79c7affa9b78 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 6 Mar 2024 09:55:48 +0100 Subject: [PATCH 13/24] camel case everywhere --- src/transformers/generation/__init__.py | 2 +- src/transformers/generation/utils.py | 52 +++++++++++----------- tests/generation/test_stopping_criteria.py | 4 +- 3 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index 7a14af7d0cacca..975d3fd79f4d60 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -219,7 +219,7 @@ WhisperTimeStampLogitsProcessor, ) from .stopping_criteria import ( - EOSTokenCriteria, + EosTokenCriteria, MaxLengthCriteria, MaxNewTokensCriteria, MaxTimeCriteria, diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 677f7f8974c04f..254fbc125a2400 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -75,7 +75,7 @@ UnbatchedClassifierFreeGuidanceLogitsProcessor, ) from .stopping_criteria import ( - EOSTokenCriteria, + EosTokenCriteria, MaxLengthCriteria, MaxTimeCriteria, StoppingCriteria, @@ -944,7 +944,7 @@ def _get_stopping_criteria( if generation_config.max_time is not None: criteria.append(MaxTimeCriteria(max_time=generation_config.max_time)) if generation_config.eos_token_id is not None: - criteria.append(EOSTokenCriteria(eos_token_id=generation_config.eos_token_id)) + criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id)) criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) return criteria @@ -1920,10 +1920,10 @@ def _contrastive_search( if eos_token_id is not None: warnings.warn( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.", + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", FutureWarning, ) - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) else: # TODO remove when the method is totally private # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever @@ -1933,7 +1933,7 @@ def _contrastive_search( eos_token_id = eos_token_id[0] if eos_token_id else None if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] @@ -2390,10 +2390,10 @@ def _greedy_search( if eos_token_id is not None: warnings.warn( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.", + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", FutureWarning, ) - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) else: # TODO remove when the method is totally private # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever @@ -2403,7 +2403,7 @@ def _greedy_search( eos_token_id = eos_token_id[0] if eos_token_id else None if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] @@ -2697,10 +2697,10 @@ def _sample( if eos_token_id is not None: warnings.warn( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.", + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", FutureWarning, ) - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) else: # TODO remove when the method is totally private # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever @@ -2710,7 +2710,7 @@ def _sample( eos_token_id = eos_token_id[0] if eos_token_id else None if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] @@ -3033,10 +3033,10 @@ def _beam_search( if eos_token_id is not None: warnings.warn( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.", + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", FutureWarning, ) - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) else: # TODO remove when the method is totally private and beam scorer refactored # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever @@ -3046,7 +3046,7 @@ def _beam_search( eos_token_id = eos_token_id[0] if eos_token_id else None if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] @@ -3444,10 +3444,10 @@ def _beam_sample( if eos_token_id is not None: warnings.warn( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.", + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", FutureWarning, ) - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) else: # TODO remove when the method is totally private and beam scorer refactored # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever @@ -3457,7 +3457,7 @@ def _beam_sample( eos_token_id = eos_token_id[0] if eos_token_id else None if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] @@ -3808,10 +3808,10 @@ def _group_beam_search( if eos_token_id is not None: warnings.warn( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.", + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", FutureWarning, ) - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) else: # TODO remove when the method is totally private and beam scorer refactored # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever @@ -3821,7 +3821,7 @@ def _group_beam_search( eos_token_id = eos_token_id[0] if eos_token_id else None if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] @@ -4236,10 +4236,10 @@ def _constrained_beam_search( if eos_token_id is not None: warnings.warn( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.", + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", FutureWarning, ) - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) else: # TODO remove when the method is totally private and beam scorer refactored # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever @@ -4249,7 +4249,7 @@ def _constrained_beam_search( eos_token_id = eos_token_id[0] if eos_token_id else None if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] @@ -4596,10 +4596,10 @@ def _assisted_decoding( if eos_token_id is not None: warnings.warn( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.", + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", FutureWarning, ) - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) else: # TODO remove when the method is totally private and beam scorer refactored # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever @@ -4609,7 +4609,7 @@ def _assisted_decoding( eos_token_id = eos_token_id[0] if eos_token_id else None if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id - stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id)) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index 8304c2d2211ea6..0c770972a7fdff 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -26,7 +26,7 @@ import torch from transformers.generation import ( - EOSTokenCriteria, + EosTokenCriteria, MaxLengthCriteria, MaxNewTokensCriteria, MaxTimeCriteria, @@ -100,7 +100,7 @@ def test_max_time_criteria(self): self.assertTrue(all(criteria(input_ids, scores))) def test_eos_token_criteria(self): - criteria = EOSTokenCriteria(eos_token_id=0) + criteria = EosTokenCriteria(eos_token_id=0) input_ids, scores = self._get_tensors(5) input_ids[:, -1] = 0 From 8e5ec5731ba68d894fcbb479b800871a80cdb64d Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 7 Mar 2024 01:25:38 +0100 Subject: [PATCH 14/24] call stopping criteria list for candidate ids --- src/transformers/generation/utils.py | 22 ++++++---------------- tests/generation/test_utils.py | 2 -- 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 254fbc125a2400..1847354d5f1a19 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2386,7 +2386,6 @@ def _greedy_search( ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if eos_token_id is not None: warnings.warn( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" @@ -4592,7 +4591,6 @@ def _assisted_decoding( logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if eos_token_id is not None: warnings.warn( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" @@ -4644,9 +4642,6 @@ def _assisted_decoding( # keep track of which sequences are already finished unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - # other auxiliary variables - max_len = stopping_criteria[0].max_length - this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: @@ -4668,7 +4663,7 @@ def _assisted_decoding( candidate_logits = candidate_logits.to(self.device) candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] - last_assistant_token_is_eos = stopping_criteria[-1](candidate_input_ids, None) + is_done_candidate = stopping_criteria(candidate_input_ids, None) # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct, @@ -4703,15 +4698,13 @@ def _assisted_decoding( # 3. Select the accepted tokens. There are two possible cases: # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf). - max_matches = max_len - cur_len - 1 if do_sample and candidate_logits is not None: valid_tokens, n_matches = _speculative_sampling( candidate_input_ids, candidate_logits, candidate_length, new_logits, - last_assistant_token_is_eos, - max_matches, + is_done_candidate, ) # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the @@ -4728,9 +4721,8 @@ def _assisted_decoding( n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() # Ensure we don't generate beyond max_len or an EOS token - if last_assistant_token_is_eos and n_matches == candidate_length: + if is_done_candidate and n_matches == candidate_length: n_matches -= 1 - n_matches = min(n_matches, max_matches) valid_tokens = selected_tokens[:, : n_matches + 1] # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated @@ -4850,8 +4842,7 @@ def _speculative_sampling( candidate_logits, candidate_length, new_logits, - last_assistant_token_is_eos, - max_matches, + is_done_candidate, ): """ Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns @@ -4876,16 +4867,15 @@ def _speculative_sampling( n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1 # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior) - if last_assistant_token_is_eos and n_matches == candidate_length: + if is_done_candidate and n_matches == candidate_length: # Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model # due to acceptance on EOS we fix `n_matches` n_matches -= 1 valid_tokens = new_candidate_input_ids[:, : n_matches + 1] else: - n_matches = min(n_matches, max_matches) # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling. - gamma = min(candidate_logits.shape[1], max_matches) + gamma = min(candidate_logits.shape[1], n_matches) p_n_plus_1 = p[:, n_matches, :] if n_matches < gamma: q_n_plus_1 = q[:, n_matches, :] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index cb224c3c6a9d74..58e8c73f03c84c 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2497,14 +2497,12 @@ def test_speculative_sampling(self): ] ) last_assistant_token_is_eos = False - max_matches = 5 validated_tokens, n_matches = _speculative_sampling( candidate_input_ids, candidate_logits, candidate_length, new_logits, last_assistant_token_is_eos, - max_matches, ) self.assertTrue(n_matches.item() == 2) self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8]) From 2544d12662274ceec980cd5c7787a5eed8daa853 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 7 Mar 2024 01:28:17 +0100 Subject: [PATCH 15/24] make style and fixup --- src/transformers/generation/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1847354d5f1a19..8b565b95e50779 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4873,7 +4873,6 @@ def _speculative_sampling( n_matches -= 1 valid_tokens = new_candidate_input_ids[:, : n_matches + 1] else: - # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling. gamma = min(candidate_logits.shape[1], n_matches) p_n_plus_1 = p[:, n_matches, :] From 12acbc43c9c4c9caff47a71d2239c1067ed64a78 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 7 Mar 2024 10:29:25 +0100 Subject: [PATCH 16/24] Empty commit From 110767367b243434a0392c78f78c54affb4cc824 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 7 Mar 2024 12:02:51 +0100 Subject: [PATCH 17/24] Empty commit to pass flaky test From 1ffc554e3a08929d04f98f3c23d728b35e9f241a Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 7 Mar 2024 12:50:20 +0100 Subject: [PATCH 18/24] set max length in PromptLookupCandidateGenerator --- src/transformers/generation/candidate_generator.py | 11 ++++++++++- src/transformers/generation/utils.py | 1 + 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 4b8fa144f04b6b..11e7a0446b6ece 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -247,15 +247,20 @@ class PromptLookupCandidateGenerator(CandidateGenerator): The maximum ngram size to be considered for matching in the prompt num_output_tokens (`int`): The number of tokens to be output as candidate tokens. + max_length (`int`): + The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length. + Defaults to 20, which is the max length used as default in generation config. """ def __init__( self, num_output_tokens: int = 10, max_matching_ngram_size: int = 2, + max_length: int = 20, ): self.num_output_tokens = num_output_tokens self.max_matching_ngram_size = max_matching_ngram_size + self.max_length = max_length if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0: raise ValueError("Invalid max_matching_ngram_size or num_output_tokens") @@ -273,6 +278,10 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, """ input_length = input_ids.size(1) + # Don't generate more than `max_length - 1` candidates since the target model generates one extra token. + if self.max_length == input_length + 1: + return input_ids, None + chosen_ids = None match_found = False for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1): @@ -292,7 +301,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, for idx in match_indices: start_idx = idx + ngram_size end_idx = start_idx + self.num_output_tokens - end_idx = min(end_idx, input_length) + end_idx = min(end_idx, input_length, self.max_length) if start_idx < end_idx: chosen_ids = input_ids[0, start_idx:end_idx] diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8b565b95e50779..022d297c06dc6a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -706,6 +706,7 @@ def _get_candidate_generator( if generation_config.prompt_lookup_num_tokens is not None: candidate_generator = PromptLookupCandidateGenerator( num_output_tokens=generation_config.prompt_lookup_num_tokens, + max_length=generation_config.max_length, ) else: candidate_generator = AssistedCandidateGenerator( From ce093c109eb92334ff375350c0c7f24b6e39253d Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Fri, 8 Mar 2024 18:34:30 +0500 Subject: [PATCH 19/24] Update src/transformers/generation/utils.py Co-authored-by: Joao Gante --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index b7b7c6347cb65f..f4ede7846f5387 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4845,7 +4845,7 @@ def _speculative_sampling( valid_tokens = new_candidate_input_ids[:, : n_matches + 1] else: # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling. - gamma = min(candidate_logits.shape[1], n_matches) + gamma = candidate_logits.shape[1] p_n_plus_1 = p[:, n_matches, :] if n_matches < gamma: q_n_plus_1 = q[:, n_matches, :] From 5375d97c5a63b5da03b30d1ba1fc619060b6bfec Mon Sep 17 00:00:00 2001 From: raushan Date: Sat, 9 Mar 2024 11:12:51 +0100 Subject: [PATCH 20/24] lets fix this typo in docs --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f4ede7846f5387..57ccb287cfaa6b 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1304,7 +1304,7 @@ def generate( Return: [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` - or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. + or when `config.return_dict_in_generate=True`) or a `torch.LongTensor`. If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible [`~utils.ModelOutput`] types are: From 48f33fc6ccc71ae02811152ee954e7897e3a0c2e Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 26 Mar 2024 19:31:40 +0500 Subject: [PATCH 21/24] Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f794609782b67c..5eca73779183ef 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1912,7 +1912,7 @@ def _contrastive_search( stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id if eos_token_id is not None: - warnings.warn( + logger.warning_once( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", FutureWarning, From 801af0718454185cf0ca0566cf03ebe8552c5ce7 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 26 Mar 2024 19:33:22 +0500 Subject: [PATCH 22/24] Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/generation/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 5eca73779183ef..5291acb20b992f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1915,6 +1915,7 @@ def _contrastive_search( logger.warning_once( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", + " Otherwise make sure to set `model.generation_config.eos_token_id`" FutureWarning, ) stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) From 7c00bb18126212d7ff835cfc8979917204d93fd8 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 26 Mar 2024 15:48:50 +0100 Subject: [PATCH 23/24] update PR --- src/transformers/generation/utils.py | 39 ++++++++++++++++------------ 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 924844ddea76ef..a958c8c86a92b1 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1915,8 +1915,8 @@ def _contrastive_search( if eos_token_id is not None: logger.warning_once( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", - " Otherwise make sure to set `model.generation_config.eos_token_id`" + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", FutureWarning, ) stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) @@ -2373,9 +2373,10 @@ def _greedy_search( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id if eos_token_id is not None: - warnings.warn( + logger.warning_once( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", FutureWarning, ) stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) @@ -2668,9 +2669,10 @@ def _sample( logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id if eos_token_id is not None: - warnings.warn( + logger.warning_once( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", FutureWarning, ) stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) @@ -2994,9 +2996,10 @@ def _beam_search( warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id if eos_token_id is not None: - warnings.warn( + logger.warning_once( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", FutureWarning, ) stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) @@ -3396,9 +3399,10 @@ def _beam_sample( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id if eos_token_id is not None: - warnings.warn( + logger.warning_once( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", FutureWarning, ) stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) @@ -3750,9 +3754,10 @@ def _group_beam_search( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id if eos_token_id is not None: - warnings.warn( + logger.warning_once( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", FutureWarning, ) stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) @@ -4168,9 +4173,10 @@ def _constrained_beam_search( warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id if eos_token_id is not None: - warnings.warn( + logger.warning_once( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", FutureWarning, ) stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) @@ -4517,9 +4523,10 @@ def _assisted_decoding( stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id if eos_token_id is not None: - warnings.warn( + logger.warning_once( "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead.", + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", FutureWarning, ) stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) From 530064d253b2e02770e1f8ffa8cbae1579c4d801 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 26 Mar 2024 16:04:04 +0100 Subject: [PATCH 24/24] empty commit