From 741d48f5c7bf0acdf9b40d0deb8560b997761f3a Mon Sep 17 00:00:00 2001 From: Ashwin Geet D'Sa Date: Tue, 27 Apr 2021 00:28:40 +0200 Subject: [PATCH] Remove max length beam scorer (#11378) * removed max_len * removed max_length from BeamSearchScorer * correct max length * finish * del vim * finish & add test Co-authored-by: Patrick von Platen --- src/transformers/generation_beam_search.py | 21 +++++--- src/transformers/generation_utils.py | 44 ++++++++------- .../models/marian/modeling_marian.py | 2 +- src/transformers/models/rag/modeling_rag.py | 1 - tests/test_generation_beam_search.py | 7 ++- tests/test_generation_utils.py | 54 ++++++++++++++++--- 6 files changed, 91 insertions(+), 38 deletions(-) diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index 1fea43e1d7e503..cebe754af23ee9 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from abc import ABC, abstractmethod from collections import UserDict from typing import Optional, Tuple @@ -110,6 +111,7 @@ def finalize( next_scores: torch.FloatTensor, next_tokens: torch.LongTensor, next_indices: torch.LongTensor, + max_length: int, **kwargs ) -> torch.LongTensor: raise NotImplementedError("This is an abstract method.") @@ -152,15 +154,14 @@ class BeamSearchScorer(BeamScorer): def __init__( self, batch_size: int, - max_length: int, num_beams: int, device: torch.device, length_penalty: Optional[float] = 1.0, do_early_stopping: Optional[bool] = False, num_beam_hyps_to_keep: Optional[int] = 1, num_beam_groups: Optional[int] = 1, + **kwargs, ): - self.max_length = max_length self.num_beams = num_beams self.device = device self.length_penalty = length_penalty @@ -173,7 +174,6 @@ def __init__( self._beam_hyps = [ BeamHypotheses( num_beams=self.num_beams, - max_length=self.max_length, length_penalty=self.length_penalty, early_stopping=self.do_early_stopping, ) @@ -192,6 +192,13 @@ def __init__( f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." ) + if "max_length" in kwargs: + warnings.warn( + "Passing `max_length` to BeamSearchScorer is deprecated and has no effect." + "`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`" + ",or `group_beam_search(...)`." + ) + @property def is_done(self) -> bool: return self._done.all() @@ -279,6 +286,7 @@ def finalize( final_beam_scores: torch.FloatTensor, final_beam_tokens: torch.LongTensor, final_beam_indices: torch.LongTensor, + max_length: int, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, ) -> Tuple[torch.LongTensor]: @@ -316,7 +324,7 @@ def finalize( best_scores[i * self.num_beam_hyps_to_keep + j] = best_score # prepare for adding eos - sent_max_len = min(sent_lengths.max().item() + 1, self.max_length) + sent_max_len = min(sent_lengths.max().item() + 1, max_length) decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) # shorter batches are padded if needed if sent_lengths.min().item() != sent_lengths.max().item(): @@ -326,7 +334,7 @@ def finalize( # fill with hypotheses and eos_token_id if the latter fits in for i, hypo in enumerate(best): decoded[i, : sent_lengths[i]] = hypo - if sent_lengths[i] < self.max_length: + if sent_lengths[i] < max_length: decoded[i, sent_lengths[i]] = eos_token_id return UserDict( { @@ -337,11 +345,10 @@ def finalize( class BeamHypotheses: - def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool): + def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool): """ Initialize n-best list of hypotheses. """ - self.max_length = max_length - 1 # ignoring bos_token self.length_penalty = length_penalty self.early_stopping = early_stopping self.num_beams = num_beams diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 17a7ae1dec7be3..9f21ee104a60fa 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1027,7 +1027,6 @@ def generate( beam_scorer = BeamSearchScorer( batch_size=batch_size, - max_length=stopping_criteria.max_length, num_beams=num_beams, device=self.device, length_penalty=length_penalty, @@ -1063,7 +1062,6 @@ def generate( raise ValueError("`max_length` needs to be a stopping_criteria for now.") beam_scorer = BeamSearchScorer( batch_size=batch_size, - max_length=stopping_criteria.max_length, num_beams=num_beams, device=self.device, length_penalty=length_penalty, @@ -1700,7 +1698,6 @@ def beam_search( >>> # instantiate beam scorer >>> beam_scorer = BeamSearchScorer( ... batch_size=1, - ... max_length=model.config.max_length, ... num_beams=num_beams, ... device=model.device, ... ) @@ -1756,7 +1753,7 @@ def beam_search( assert ( num_beams * batch_size == batch_beam_size - ), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores[:, 1:] = -1e9 @@ -1792,10 +1789,7 @@ def beam_search( # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # cannot be generated both before and after the `F.log_softmax` operation. - next_token_logits = self.adjust_logits_during_generation( - next_token_logits, cur_len=cur_len, max_length=None - ) - + next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) next_token_scores = logits_processor(input_ids, next_token_scores) @@ -1861,7 +1855,13 @@ def beam_search( this_peer_finished = True sequence_outputs = beam_scorer.finalize( - input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, ) if return_dict_in_generate: @@ -2086,10 +2086,7 @@ def beam_sample( # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # cannot be generated both before and after the `F.log_softmax` operation. - next_token_logits = self.adjust_logits_during_generation( - next_token_logits, cur_len=cur_len, max_length=None - ) - + next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) next_token_scores = logits_processor(input_ids, next_token_scores) @@ -2160,7 +2157,13 @@ def beam_sample( this_peer_finished = True sequence_outputs = beam_scorer.finalize( - input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, ) if return_dict_in_generate: @@ -2411,10 +2414,7 @@ def group_beam_search( # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # cannot be generated both before and after the `F.log_softmax` operation. - next_token_logits = self.adjust_logits_during_generation( - next_token_logits, cur_len=cur_len, max_length=None - ) - + next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size) vocab_size = next_token_scores.shape[-1] @@ -2497,7 +2497,13 @@ def group_beam_search( this_peer_finished = True sequence_outputs = beam_scorer.finalize( - input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, ) if return_dict_in_generate: diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index baa872843d2575..c56e498e2b0548 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -1335,7 +1335,7 @@ def prepare_inputs_for_generation( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - def adjust_logits_during_generation(self, logits, cur_len, max_length): + def adjust_logits_during_generation(self, logits, cur_len): logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token. return logits diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 7975361749f7d7..42c2e16d6ca795 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1543,7 +1543,6 @@ def extend_enc_output(tensor, num_beams=None): raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") beam_scorer = BeamSearchScorer( batch_size=batch_size, - max_length=max_length, num_beams=num_beams, device=self.device, length_penalty=length_penalty, diff --git a/tests/test_generation_beam_search.py b/tests/test_generation_beam_search.py index aa8270c31f2c0a..fdbe35eafaa449 100644 --- a/tests/test_generation_beam_search.py +++ b/tests/test_generation_beam_search.py @@ -59,7 +59,6 @@ def __init__( def prepare_beam_scorer(self, **kwargs): return BeamSearchScorer( batch_size=kwargs.get("batch_size", self.batch_size), - max_length=kwargs.get("max_length", self.max_length), num_beams=kwargs.get("num_beams", self.num_beams), device=torch_device, length_penalty=kwargs.get("length_penalty", self.length_penalty), @@ -170,9 +169,7 @@ def cut_expected_tensor(tensor): def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores): # max_length should be only one more than current input_ids to check that eos is correctly appended max_length = self.sequence_length + 1 - beam_scorer = self.prepare_beam_scorer( - num_beam_hyps_to_keep=1, max_length=max_length, length_penalty=1.0, do_early_stopping=False - ) + beam_scorer = self.prepare_beam_scorer(num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False) # update beams and append to input_ids tokens = next_tokens.clone() @@ -197,6 +194,7 @@ def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_ output_indices, pad_token_id=self.pad_token_id, eos_token_id=self.eos_token_id, + max_length=max_length, ) sequences = sequence_output["sequences"] @@ -225,6 +223,7 @@ def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_ output_indices, pad_token_id=self.pad_token_id, eos_token_id=self.eos_token_id, + max_length=max_length, ) sequences = sequence_output["sequences"] sequence_scores = sequence_output["sequence_scores"] diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index 42c44b8c54e83d..4a7140d2ca3e50 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -148,7 +148,6 @@ def _get_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1): } beam_scorer = BeamSearchScorer( batch_size=batch_size, - max_length=max_length, num_beams=beam_kwargs["num_beams"], device=torch_device, length_penalty=beam_kwargs["length_penalty"], @@ -169,7 +168,6 @@ def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_seque } beam_scorer = BeamSearchScorer( batch_size=batch_size, - max_length=max_length, num_beams=beam_kwargs["num_beams"], device=torch_device, length_penalty=beam_kwargs["length_penalty"], @@ -1411,7 +1409,6 @@ def test_max_length_backward_compat_beam_search(self): beam_scorer = BeamSearchScorer( batch_size=batch_size, - max_length=max_length, num_beams=num_beams, device=torch_device, ) @@ -1442,7 +1439,6 @@ def test_max_length_backward_compat_group_beam_search(self): diverse_beam_scorer = BeamSearchScorer( batch_size=batch_size, - max_length=max_length, num_beams=num_beams, device=torch_device, num_beam_hyps_to_keep=num_return_sequences, @@ -1502,7 +1498,6 @@ def test_max_length_warning_if_different(self): # Beam beam_scorer = BeamSearchScorer( batch_size=batch_size, - max_length=max_length, num_beams=num_beams, device=torch_device, ) @@ -1520,7 +1515,6 @@ def test_max_length_warning_if_different(self): # Grouped beam search diverse_beam_scorer = BeamSearchScorer( batch_size=batch_size, - max_length=max_length, num_beams=num_beams, device=torch_device, num_beam_hyps_to_keep=num_return_sequences, @@ -1535,3 +1529,51 @@ def test_max_length_warning_if_different(self): max_length=max_length, **model_kwargs, ) + + def test_beam_search_warning_if_max_length_is_passed(self): + article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" + bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") + bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) + + batch_size = 1 + num_beams = 3 + + input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + input_ids = input_ids.expand(num_beams, -1) + model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) + + stopping_criteria_max_length = 18 + stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)]) + + with self.assertWarns(UserWarning): + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=num_beams, + device=torch_device, + max_length=10, + ) + + generated_ids = bart_model.beam_search( + input_ids, + num_beams=num_beams, + stopping_criteria=stopping_criteria, + beam_scorer=beam_scorer, + **model_kwargs, + ) + + beam_scorer_no_max_len = BeamSearchScorer( + batch_size=batch_size, + num_beams=num_beams, + device=torch_device, + ) + + generated_ids_no_max_len = bart_model.beam_search( + input_ids, + num_beams=num_beams, + stopping_criteria=stopping_criteria, + beam_scorer=beam_scorer_no_max_len, + **model_kwargs, + ) + + # BeamSearchScorer max_length should not influence "real" max_length + self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())