Skip to content

Commit

Permalink
reflect max_new_tokens in Seq2SeqTrainer (huggingface#18786)
Browse files Browse the repository at this point in the history
* reflect max_new_tokens in gen_kwargs to `trainer.generate()`

* reflect max_new_tokens in `Seq2SeqTrainer`

* remove unnecessary variable

* Trigger CI

* fix style
  • Loading branch information
kumapo authored and oneraghavan committed Sep 26, 2022
1 parent e8eec21 commit f435cb3
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions src/transformers/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,8 @@ def evaluate(
"""

gen_kwargs = gen_kwargs.copy()
gen_kwargs["max_length"] = (
gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length
)
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
gen_kwargs["max_length"] = self.args.generation_max_length
gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
)
Expand Down Expand Up @@ -126,9 +125,8 @@ def predict(
"""

gen_kwargs = gen_kwargs.copy()
gen_kwargs["max_length"] = (
gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length
)
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
gen_kwargs["max_length"] = self.args.generation_max_length
gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
)
Expand Down Expand Up @@ -174,9 +172,8 @@ def prediction_step(

# XXX: adapt synced_gpus for fairscale as well
gen_kwargs = self._gen_kwargs.copy()
gen_kwargs["max_length"] = (
gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.model.config.max_length
)
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
gen_kwargs["max_length"] = self.model.config.max_length
gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
)
Expand All @@ -203,8 +200,12 @@ def prediction_step(
**gen_kwargs,
)
# in case the batch is shorter than max length, the output should be padded
if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < (
gen_kwargs["max_new_tokens"] + 1
):
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1)

with torch.no_grad():
with self.compute_loss_context_manager():
Expand All @@ -222,8 +223,12 @@ def prediction_step(

if has_labels:
labels = inputs["labels"]
if labels.shape[-1] < gen_kwargs["max_length"]:
if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]:
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < (
gen_kwargs["max_new_tokens"] + 1
):
labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1))
else:
labels = None

Expand Down

0 comments on commit f435cb3

Please sign in to comment.