Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove max length beam scorer #11378

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions src/transformers/generation_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def finalize(
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
max_length: int,
**kwargs
) -> torch.LongTensor:
raise NotImplementedError("This is an abstract method.")
Expand Down Expand Up @@ -152,15 +153,13 @@ class BeamSearchScorer(BeamScorer):
def __init__(
self,
batch_size: int,
max_length: int,
num_beams: int,
device: torch.device,
length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[bool] = False,
num_beam_hyps_to_keep: Optional[int] = 1,
num_beam_groups: Optional[int] = 1,
):
self.max_length = max_length
self.num_beams = num_beams
self.device = device
self.length_penalty = length_penalty
Expand All @@ -173,7 +172,6 @@ def __init__(
self._beam_hyps = [
BeamHypotheses(
num_beams=self.num_beams,
max_length=self.max_length,
length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping,
)
Expand Down Expand Up @@ -279,6 +277,7 @@ def finalize(
final_beam_scores: torch.FloatTensor,
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
) -> Tuple[torch.LongTensor]:
Expand Down Expand Up @@ -316,7 +315,7 @@ def finalize(
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score

# prepare for adding eos
sent_max_len = min(sent_lengths.max().item() + 1, self.max_length)
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
# shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item():
Expand All @@ -326,7 +325,7 @@ def finalize(
# fill with hypotheses and eos_token_id if the latter fits in
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < self.max_length:
if sent_lengths[i] < max_length:
decoded[i, sent_lengths[i]] = eos_token_id
return UserDict(
{
Expand All @@ -337,11 +336,10 @@ def finalize(


class BeamHypotheses:
def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool):
def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool):
"""
Initialize n-best list of hypotheses.
"""
self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
Expand Down
27 changes: 21 additions & 6 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,6 @@ def generate(

beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=stopping_criteria.max_length,
num_beams=num_beams,
device=self.device,
length_penalty=length_penalty,
Expand Down Expand Up @@ -1063,7 +1062,6 @@ def generate(
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=stopping_criteria.max_length,
num_beams=num_beams,
device=self.device,
length_penalty=length_penalty,
Expand Down Expand Up @@ -1700,7 +1698,6 @@ def beam_search(
>>> # instantiate beam scorer
>>> beam_scorer = BeamSearchScorer(
... batch_size=1,
... max_length=model.config.max_length,
... num_beams=num_beams,
... device=model.device,
... )
Expand Down Expand Up @@ -1861,7 +1858,13 @@ def beam_search(
this_peer_finished = True

sequence_outputs = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=max_length,
)

if return_dict_in_generate:
Expand Down Expand Up @@ -2160,7 +2163,13 @@ def beam_sample(
this_peer_finished = True

sequence_outputs = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=max_length,
)

if return_dict_in_generate:
Expand Down Expand Up @@ -2497,7 +2506,13 @@ def group_beam_search(
this_peer_finished = True

sequence_outputs = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=max_length,
)

if return_dict_in_generate:
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/rag/modeling_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1543,7 +1543,6 @@ def extend_enc_output(tensor, num_beams=None):
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=self.device,
length_penalty=length_penalty,
Expand Down
7 changes: 3 additions & 4 deletions tests/test_generation_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(
def prepare_beam_scorer(self, **kwargs):
return BeamSearchScorer(
batch_size=kwargs.get("batch_size", self.batch_size),
max_length=kwargs.get("max_length", self.max_length),
num_beams=kwargs.get("num_beams", self.num_beams),
device=torch_device,
length_penalty=kwargs.get("length_penalty", self.length_penalty),
Expand Down Expand Up @@ -170,9 +169,7 @@ def cut_expected_tensor(tensor):
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
# max_length should be only one more than current input_ids to check that eos is correctly appended
max_length = self.sequence_length + 1
beam_scorer = self.prepare_beam_scorer(
num_beam_hyps_to_keep=1, max_length=max_length, length_penalty=1.0, do_early_stopping=False
)
beam_scorer = self.prepare_beam_scorer(num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False)

# update beams and append to input_ids
tokens = next_tokens.clone()
Expand All @@ -197,6 +194,7 @@ def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
)

sequences = sequence_output["sequences"]
Expand Down Expand Up @@ -225,6 +223,7 @@ def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
)
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
Expand Down
6 changes: 0 additions & 6 deletions tests/test_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def _get_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1):
}
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=beam_kwargs["num_beams"],
device=torch_device,
length_penalty=beam_kwargs["length_penalty"],
Expand All @@ -169,7 +168,6 @@ def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_seque
}
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=beam_kwargs["num_beams"],
device=torch_device,
length_penalty=beam_kwargs["length_penalty"],
Expand Down Expand Up @@ -1411,7 +1409,6 @@ def test_max_length_backward_compat_beam_search(self):

beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=torch_device,
)
Expand Down Expand Up @@ -1442,7 +1439,6 @@ def test_max_length_backward_compat_group_beam_search(self):

diverse_beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=torch_device,
num_beam_hyps_to_keep=num_return_sequences,
Expand Down Expand Up @@ -1502,7 +1498,6 @@ def test_max_length_warning_if_different(self):
# Beam
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=torch_device,
)
Expand All @@ -1520,7 +1515,6 @@ def test_max_length_warning_if_different(self):
# Grouped beam search
diverse_beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=torch_device,
num_beam_hyps_to_keep=num_return_sequences,
Expand Down