diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 086ad26992fefd..1d8cb78e6f1d5a 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1687,7 +1687,7 @@ def beam_search( # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - max_length = max_length if max_length is not None else self.config.max_length + max_length = max_length if max_length is not None else beam_scorer.max_length validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id @@ -2222,7 +2222,7 @@ def group_beam_search( # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - max_length = max_length if max_length is not None else self.config.max_length + max_length = max_length if max_length is not None else beam_scorer.max_length validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id