Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Seq2Seq Trainer] Make sure padding is implemented for models without pad_token #8043

Merged
51 changes: 31 additions & 20 deletions examples/seq2seq/seq2seq_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
from typing import Any, Dict, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -60,6 +59,11 @@ def __init__(self, config=None, data_args=None, *args, **kwargs):
self.config.pad_token_id is not None
), "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss calculation or doing label smoothing."

if self.config.pad_token_id is None and self.config.eos_token_id is not None:
logger.warn(
f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for padding.."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if eos_token_id is None? Should we raise?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be a bit too edge-casy but eos_token_id could be None in which case padding would never take place

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we raise early in that case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I meant is that there are models, like openai-gpt or ctrl that do not have a eos_token_id nor do they have a pad_token_id => the way it is implemented now these models could still make use of seq2seqTrainer because they would never require padding (because they never finish early). So I'd just leave it as it is - or if you think that models that don't have an EOS token should not use Seq2SeqTrainer we could raise as well - up to you!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't understand that they always go to max_length your implem makes total sense. Thanks for clarifying.


def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
Setup the optimizer and the learning rate scheduler.
Expand Down Expand Up @@ -126,22 +130,19 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
else DistributedSampler(self.train_dataset)
)

def _compute_loss(self, model, inputs):
inputs = copy.deepcopy(inputs)
def _compute_loss(self, model, inputs, labels):
if self.args.label_smoothing == 0:
if self.data_args is not None and self.data_args.ignore_pad_token_for_loss:
# force training to ignore pad token
labels = inputs.pop("labels")
logits = model(**inputs, use_cache=False)[0]

loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
else:
# compute usual loss via models
loss, logits = model(**inputs, use_cache=False)[:2]
loss, logits = model(**inputs, labels=labels, use_cache=False)[:2]
else:
# compute label smoothed loss
labels = inputs.pop("labels")
logits = model(**inputs, use_cache=False)[0]
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
loss, _ = label_smoothed_nll_loss(
Expand All @@ -150,7 +151,8 @@ def _compute_loss(self, model, inputs):
return loss, logits

def compute_loss(self, model, inputs):
loss, _ = self._compute_loss(model, inputs)
labels = inputs.pop("labels")
loss, _ = self._compute_loss(model, inputs, labels)
return loss

def prediction_step(
Expand Down Expand Up @@ -178,40 +180,49 @@ def prediction_step(
"""
inputs = self._prepare_inputs(inputs)

gen_kwargs = {
"max_length": self.data_args.val_max_target_length
if self.data_args is not None
else self.config.max_length,
"num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams,
}

if self.args.predict_with_generate and not self.args.prediction_loss_only:
gen_kwargs = {
"max_length": self.data_args.val_max_target_length
if self.data_args is not None
else self.config.max_length,
"num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams,
}
generated_tokens = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
**gen_kwargs,
)
# in case the batch is shorter than max length, the output should be padded
if self.config.pad_token_id is not None:
if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])

# compute loss on predict data
labels = inputs.pop("labels")
with torch.no_grad():
loss, logits = self._compute_loss(model, inputs)
# compute loss on predict data
loss, logits = self._compute_loss(model, inputs, labels)

loss = loss.mean().detach()
if self.args.prediction_loss_only:
return (loss, None, None)

logits = generated_tokens if self.args.predict_with_generate else logits

labels = inputs["labels"]
if self.config.pad_token_id is not None:
labels = self._pad_tensors_to_max_len(labels, self.config.max_length)
if labels.shape[-1] < gen_kwargs["max_length"]:
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])

return (loss, logits, labels)

def _pad_tensors_to_max_len(self, tensor, max_length):
padded_tensor = self.config.pad_token_id * torch.ones(
# If PAD token is not defined at least EOS token has to be defined
pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else self.config.eos_token_id

if pad_token_id is None:
raise ValueError(
f"Make sure that either `config.pad_token_id` or `config.eos_token_id` is defined if tensor has to be padded to `max_length`={max_length}"
)
Comment on lines +218 to +223
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check in __init__ for faster failures?


padded_tensor = pad_token_id * torch.ones(
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
)
padded_tensor[:, : tensor.shape[-1]] = tensor
Expand Down
2 changes: 2 additions & 0 deletions examples/seq2seq/test_finetune_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def test_finetune_bert2bert(self):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
bert2bert.config.eos_token_id = tokenizer.sep_token_id
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
bert2bert.config.max_length = 128

train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]")
val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]")
Expand Down