-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Seq2Seq Trainer] Make sure padding is implemented for models without pad_token #8043
Merged
patrickvonplaten
merged 9 commits into
huggingface:master
from
patrickvonplaten:fix_seq2seq_trainer
Oct 26, 2020
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
f8f49d7
make sure padding is implemented for non-padding tokens models as well
patrickvonplaten c5b6ab0
add better error message
patrickvonplaten bd4a2fd
add better warning
patrickvonplaten 8667b9c
remove results files
patrickvonplaten d305d66
Update examples/seq2seq/seq2seq_trainer.py
patrickvonplaten 04c533a
remove unnecessary copy line
patrickvonplaten 1f8c26f
Merge branch 'fix_seq2seq_trainer' of https://github.com/patrickvonpl…
patrickvonplaten c1bde00
correct usage of labels
patrickvonplaten ade3d54
delete test files
patrickvonplaten File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
import copy | ||
from typing import Any, Dict, Optional, Tuple, Union | ||
|
||
import torch | ||
|
@@ -60,6 +59,11 @@ def __init__(self, config=None, data_args=None, *args, **kwargs): | |
self.config.pad_token_id is not None | ||
), "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss calculation or doing label smoothing." | ||
|
||
if self.config.pad_token_id is None and self.config.eos_token_id is not None: | ||
logger.warn( | ||
f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for padding.." | ||
) | ||
|
||
def create_optimizer_and_scheduler(self, num_training_steps: int): | ||
""" | ||
Setup the optimizer and the learning rate scheduler. | ||
|
@@ -126,22 +130,19 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: | |
else DistributedSampler(self.train_dataset) | ||
) | ||
|
||
def _compute_loss(self, model, inputs): | ||
inputs = copy.deepcopy(inputs) | ||
def _compute_loss(self, model, inputs, labels): | ||
if self.args.label_smoothing == 0: | ||
if self.data_args is not None and self.data_args.ignore_pad_token_for_loss: | ||
# force training to ignore pad token | ||
labels = inputs.pop("labels") | ||
logits = model(**inputs, use_cache=False)[0] | ||
|
||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) | ||
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1)) | ||
else: | ||
# compute usual loss via models | ||
loss, logits = model(**inputs, use_cache=False)[:2] | ||
loss, logits = model(**inputs, labels=labels, use_cache=False)[:2] | ||
else: | ||
# compute label smoothed loss | ||
labels = inputs.pop("labels") | ||
logits = model(**inputs, use_cache=False)[0] | ||
lprobs = torch.nn.functional.log_softmax(logits, dim=-1) | ||
loss, _ = label_smoothed_nll_loss( | ||
|
@@ -150,7 +151,8 @@ def _compute_loss(self, model, inputs): | |
return loss, logits | ||
|
||
def compute_loss(self, model, inputs): | ||
loss, _ = self._compute_loss(model, inputs) | ||
labels = inputs.pop("labels") | ||
loss, _ = self._compute_loss(model, inputs, labels) | ||
return loss | ||
|
||
def prediction_step( | ||
|
@@ -178,40 +180,49 @@ def prediction_step( | |
""" | ||
inputs = self._prepare_inputs(inputs) | ||
|
||
gen_kwargs = { | ||
"max_length": self.data_args.val_max_target_length | ||
if self.data_args is not None | ||
else self.config.max_length, | ||
"num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams, | ||
} | ||
|
||
if self.args.predict_with_generate and not self.args.prediction_loss_only: | ||
gen_kwargs = { | ||
"max_length": self.data_args.val_max_target_length | ||
if self.data_args is not None | ||
else self.config.max_length, | ||
"num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams, | ||
} | ||
generated_tokens = model.generate( | ||
inputs["input_ids"], | ||
attention_mask=inputs["attention_mask"], | ||
**gen_kwargs, | ||
) | ||
# in case the batch is shorter than max length, the output should be padded | ||
if self.config.pad_token_id is not None: | ||
if generated_tokens.shape[-1] < gen_kwargs["max_length"]: | ||
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) | ||
|
||
# compute loss on predict data | ||
labels = inputs.pop("labels") | ||
with torch.no_grad(): | ||
loss, logits = self._compute_loss(model, inputs) | ||
# compute loss on predict data | ||
loss, logits = self._compute_loss(model, inputs, labels) | ||
|
||
loss = loss.mean().detach() | ||
if self.args.prediction_loss_only: | ||
return (loss, None, None) | ||
|
||
logits = generated_tokens if self.args.predict_with_generate else logits | ||
|
||
labels = inputs["labels"] | ||
if self.config.pad_token_id is not None: | ||
labels = self._pad_tensors_to_max_len(labels, self.config.max_length) | ||
if labels.shape[-1] < gen_kwargs["max_length"]: | ||
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) | ||
|
||
return (loss, logits, labels) | ||
|
||
def _pad_tensors_to_max_len(self, tensor, max_length): | ||
padded_tensor = self.config.pad_token_id * torch.ones( | ||
# If PAD token is not defined at least EOS token has to be defined | ||
pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else self.config.eos_token_id | ||
|
||
if pad_token_id is None: | ||
raise ValueError( | ||
f"Make sure that either `config.pad_token_id` or `config.eos_token_id` is defined if tensor has to be padded to `max_length`={max_length}" | ||
) | ||
Comment on lines
+218
to
+223
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we check in |
||
|
||
padded_tensor = pad_token_id * torch.ones( | ||
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device | ||
) | ||
padded_tensor[:, : tensor.shape[-1]] = tensor | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if
eos_token_id is None
? Should we raise?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be a bit too edge-casy but
eos_token_id
could beNone
in which case padding would never take placeThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we raise early in that case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I meant is that there are models, like
openai-gpt
orctrl
that do not have aeos_token_id
nor do they have apad_token_id
=> the way it is implemented now these models could still make use of seq2seqTrainer because they would never require padding (because they never finish early). So I'd just leave it as it is - or if you think that models that don't have an EOS token should not use Seq2SeqTrainer we could raise as well - up to you!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't understand that they always go to
max_length
your implem makes total sense. Thanks for clarifying.