Skip to content

Commit

Permalink
Remove max length beam scorer (#11378)
Browse files Browse the repository at this point in the history
* removed max_len

* removed max_length from BeamSearchScorer

* correct max length

* finish

* del vim

* finish & add test

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
GeetDsa and patrickvonplaten authored Apr 26, 2021
1 parent bc2571e commit 741d48f
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 38 deletions.
21 changes: 14 additions & 7 deletions src/transformers/generation_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from abc import ABC, abstractmethod
from collections import UserDict
from typing import Optional, Tuple
Expand Down Expand Up @@ -110,6 +111,7 @@ def finalize(
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
max_length: int,
**kwargs
) -> torch.LongTensor:
raise NotImplementedError("This is an abstract method.")
Expand Down Expand Up @@ -152,15 +154,14 @@ class BeamSearchScorer(BeamScorer):
def __init__(
self,
batch_size: int,
max_length: int,
num_beams: int,
device: torch.device,
length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[bool] = False,
num_beam_hyps_to_keep: Optional[int] = 1,
num_beam_groups: Optional[int] = 1,
**kwargs,
):
self.max_length = max_length
self.num_beams = num_beams
self.device = device
self.length_penalty = length_penalty
Expand All @@ -173,7 +174,6 @@ def __init__(
self._beam_hyps = [
BeamHypotheses(
num_beams=self.num_beams,
max_length=self.max_length,
length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping,
)
Expand All @@ -192,6 +192,13 @@ def __init__(
f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
)

if "max_length" in kwargs:
warnings.warn(
"Passing `max_length` to BeamSearchScorer is deprecated and has no effect."
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
",or `group_beam_search(...)`."
)

@property
def is_done(self) -> bool:
return self._done.all()
Expand Down Expand Up @@ -279,6 +286,7 @@ def finalize(
final_beam_scores: torch.FloatTensor,
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
) -> Tuple[torch.LongTensor]:
Expand Down Expand Up @@ -316,7 +324,7 @@ def finalize(
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score

# prepare for adding eos
sent_max_len = min(sent_lengths.max().item() + 1, self.max_length)
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
# shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item():
Expand All @@ -326,7 +334,7 @@ def finalize(
# fill with hypotheses and eos_token_id if the latter fits in
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < self.max_length:
if sent_lengths[i] < max_length:
decoded[i, sent_lengths[i]] = eos_token_id
return UserDict(
{
Expand All @@ -337,11 +345,10 @@ def finalize(


class BeamHypotheses:
def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool):
def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool):
"""
Initialize n-best list of hypotheses.
"""
self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
Expand Down
44 changes: 25 additions & 19 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,6 @@ def generate(

beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=stopping_criteria.max_length,
num_beams=num_beams,
device=self.device,
length_penalty=length_penalty,
Expand Down Expand Up @@ -1063,7 +1062,6 @@ def generate(
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=stopping_criteria.max_length,
num_beams=num_beams,
device=self.device,
length_penalty=length_penalty,
Expand Down Expand Up @@ -1700,7 +1698,6 @@ def beam_search(
>>> # instantiate beam scorer
>>> beam_scorer = BeamSearchScorer(
... batch_size=1,
... max_length=model.config.max_length,
... num_beams=num_beams,
... device=model.device,
... )
Expand Down Expand Up @@ -1756,7 +1753,7 @@ def beam_search(

assert (
num_beams * batch_size == batch_beam_size
), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."

beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
Expand Down Expand Up @@ -1792,10 +1789,7 @@ def beam_search(

# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=None
)

next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)

next_token_scores = logits_processor(input_ids, next_token_scores)
Expand Down Expand Up @@ -1861,7 +1855,13 @@ def beam_search(
this_peer_finished = True

sequence_outputs = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
)

if return_dict_in_generate:
Expand Down Expand Up @@ -2086,10 +2086,7 @@ def beam_sample(

# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=None
)

next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)

next_token_scores = logits_processor(input_ids, next_token_scores)
Expand Down Expand Up @@ -2160,7 +2157,13 @@ def beam_sample(
this_peer_finished = True

sequence_outputs = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
)

if return_dict_in_generate:
Expand Down Expand Up @@ -2411,10 +2414,7 @@ def group_beam_search(

# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=None
)

next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size)
vocab_size = next_token_scores.shape[-1]

Expand Down Expand Up @@ -2497,7 +2497,13 @@ def group_beam_search(
this_peer_finished = True

sequence_outputs = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
)

if return_dict_in_generate:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/marian/modeling_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,7 +1335,7 @@ def prepare_inputs_for_generation(
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

def adjust_logits_during_generation(self, logits, cur_len, max_length):
def adjust_logits_during_generation(self, logits, cur_len):
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
return logits

Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/rag/modeling_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1543,7 +1543,6 @@ def extend_enc_output(tensor, num_beams=None):
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=self.device,
length_penalty=length_penalty,
Expand Down
7 changes: 3 additions & 4 deletions tests/test_generation_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(
def prepare_beam_scorer(self, **kwargs):
return BeamSearchScorer(
batch_size=kwargs.get("batch_size", self.batch_size),
max_length=kwargs.get("max_length", self.max_length),
num_beams=kwargs.get("num_beams", self.num_beams),
device=torch_device,
length_penalty=kwargs.get("length_penalty", self.length_penalty),
Expand Down Expand Up @@ -170,9 +169,7 @@ def cut_expected_tensor(tensor):
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
# max_length should be only one more than current input_ids to check that eos is correctly appended
max_length = self.sequence_length + 1
beam_scorer = self.prepare_beam_scorer(
num_beam_hyps_to_keep=1, max_length=max_length, length_penalty=1.0, do_early_stopping=False
)
beam_scorer = self.prepare_beam_scorer(num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False)

# update beams and append to input_ids
tokens = next_tokens.clone()
Expand All @@ -197,6 +194,7 @@ def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
)

sequences = sequence_output["sequences"]
Expand Down Expand Up @@ -225,6 +223,7 @@ def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
)
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
Expand Down
54 changes: 48 additions & 6 deletions tests/test_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def _get_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1):
}
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=beam_kwargs["num_beams"],
device=torch_device,
length_penalty=beam_kwargs["length_penalty"],
Expand All @@ -169,7 +168,6 @@ def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_seque
}
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=beam_kwargs["num_beams"],
device=torch_device,
length_penalty=beam_kwargs["length_penalty"],
Expand Down Expand Up @@ -1411,7 +1409,6 @@ def test_max_length_backward_compat_beam_search(self):

beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=torch_device,
)
Expand Down Expand Up @@ -1442,7 +1439,6 @@ def test_max_length_backward_compat_group_beam_search(self):

diverse_beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=torch_device,
num_beam_hyps_to_keep=num_return_sequences,
Expand Down Expand Up @@ -1502,7 +1498,6 @@ def test_max_length_warning_if_different(self):
# Beam
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=torch_device,
)
Expand All @@ -1520,7 +1515,6 @@ def test_max_length_warning_if_different(self):
# Grouped beam search
diverse_beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=torch_device,
num_beam_hyps_to_keep=num_return_sequences,
Expand All @@ -1535,3 +1529,51 @@ def test_max_length_warning_if_different(self):
max_length=max_length,
**model_kwargs,
)

def test_beam_search_warning_if_max_length_is_passed(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)

batch_size = 1
num_beams = 3

input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
input_ids = input_ids.expand(num_beams, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})

stopping_criteria_max_length = 18
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)])

with self.assertWarns(UserWarning):
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=torch_device,
max_length=10,
)

generated_ids = bart_model.beam_search(
input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
beam_scorer=beam_scorer,
**model_kwargs,
)

beam_scorer_no_max_len = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=torch_device,
)

generated_ids_no_max_len = bart_model.beam_search(
input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
beam_scorer=beam_scorer_no_max_len,
**model_kwargs,
)

# BeamSearchScorer max_length should not influence "real" max_length
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())

0 comments on commit 741d48f

Please sign in to comment.