diff --git a/src/transformers/trainer_seq2seq.py b/src/transformers/trainer_seq2seq.py index 02ce3d393b9e75..7689998c051b8f 100644 --- a/src/transformers/trainer_seq2seq.py +++ b/src/transformers/trainer_seq2seq.py @@ -68,9 +68,8 @@ def evaluate( """ gen_kwargs = gen_kwargs.copy() - gen_kwargs["max_length"] = ( - gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length - ) + if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: + gen_kwargs["max_length"] = self.args.generation_max_length gen_kwargs["num_beams"] = ( gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams ) @@ -126,9 +125,8 @@ def predict( """ gen_kwargs = gen_kwargs.copy() - gen_kwargs["max_length"] = ( - gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length - ) + if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: + gen_kwargs["max_length"] = self.args.generation_max_length gen_kwargs["num_beams"] = ( gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams ) @@ -174,9 +172,8 @@ def prediction_step( # XXX: adapt synced_gpus for fairscale as well gen_kwargs = self._gen_kwargs.copy() - gen_kwargs["max_length"] = ( - gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.model.config.max_length - ) + if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: + gen_kwargs["max_length"] = self.model.config.max_length gen_kwargs["num_beams"] = ( gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams ) @@ -203,8 +200,12 @@ def prediction_step( **gen_kwargs, ) # in case the batch is shorter than max length, the output should be padded - if generated_tokens.shape[-1] < gen_kwargs["max_length"]: + if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]: generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) + elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < ( + gen_kwargs["max_new_tokens"] + 1 + ): + generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1) with torch.no_grad(): with self.compute_loss_context_manager(): @@ -222,8 +223,12 @@ def prediction_step( if has_labels: labels = inputs["labels"] - if labels.shape[-1] < gen_kwargs["max_length"]: + if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]: labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) + elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < ( + gen_kwargs["max_new_tokens"] + 1 + ): + labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1)) else: labels = None