-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Seq2Seq] Allow EncoderDecoderModels to be trained with Seq2Seq #7809
[Seq2Seq] Allow EncoderDecoderModels to be trained with Seq2Seq #7809
Conversation
examples/seq2seq/seq2seq_trainer.py
Outdated
@@ -41,12 +41,13 @@ | |||
|
|||
|
|||
class Seq2SeqTrainer(Trainer): | |||
def __init__(self, config, data_args, *args, **kwargs): | |||
def __init__(self, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patil-suraj @sshleifer - I think it would be better to align the init of Seq2SeqTrainer
100% with Trainer
.
Is there a reason why we would insert config instead of using the model's config?
Also I don't really think the variable data_args
is necessary. Both max_length
and num_beams
can be defined in the config and don't have to be "force" passed to the generate()
method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model.config
breaks under DistributedDataParallel
, so we decided to pass it explicitly. See #7461 and #7460.
if default num_beams
and max_length
is too high it'll slow down evaluation, so we allow the user to control it during training. And not overriding config
since defaults will be needed for inference after training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okok I see! I'm a bit confused why Trainer
does not break with DistributedDataParallel
when only using model.config....
, but Seq2SeqTrainer
does? Do you guys know why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eval_beams
/eval_max_gen_length
reasoning:
@patil-suraj said exactly this LOL, but in my words:
users are not good at modifying configs locally. We want to have a way to runnum_beams=2
during the generation step, but then end up with a trained model with the default # beams. In general, we try not to manipulate config attributes that would only be desired during training.
I mean modifying the configs locally is as simple as config.num_beams = 4
and I would think one wants to evaluate a model during training with exactly the beam size and max_length that is stored in the config (I mean changing the beam size and max_length does not simple reduce time, but also changes the output...) But I guess I can see the use case where the people want to tweak max_length
and num_beams
without changing the config. Would it be fine to make data_args
optional and call them generation_args
that will just be passed as **generation_args to the generate function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ups that was supposed to land further below not here. @sshleifer for reference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data_args
-> generation_kwargs
seems like a good change (at least in seq2seq_trainer.py), but the CLI naming has a purpose:
It wouldn't have been obvious to me that passing --min_length 32
would affect generation, rather than truncating source docs. That's why the eval_
prefix was added.
if self.args.label_smoothing == 0: | ||
# Same behavior as modeling_bart.py | ||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not seem to work for all models (EncoderDecoderModel
does not work with it) -> Let's instead use the loss function of each model here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
loss functions of model use -100 as ignore_index
, we will also need to replace pad
tokens in labels
with -100
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I usually do this manually before -> should that be the role of the Seq2SeqTrainer
? Trainer
also does not have this feature
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ignoring pad_token_id confused lots of people and helps metrics so we automated it.
Related: #7828
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I usually do this manually before
we could do this in the collator
, but we won't need to do if #7828 is merged
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we will still need to cover FSMT/T5.
I would definitely not do this change right now, it works as is and is much easier than checking that every model ignores padding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyTorch's CE loss function has -100
as a default value and from what I understood it is the default behavior of the library to ignore tokens when there have the index -100 and not when there are equal to the padding token (often we set padding token == -100): https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
It would require models to manually replace tokens with -100, but I think that's how it should be done in general in the library. How would be handle models that don't have a padding_token or want to disregard loss of more than just the padding token? For such cases I think it can be quite handy if the user overwrites all labels he does not want to consider with -100
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we will discuss on zoom!
# in case the batch is shorter than max length, the output should be padded | ||
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, self.model.config.max_length) | ||
|
||
# compute loss on predict data | ||
with torch.no_grad(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generate()
is always in torch.no_grad()
context.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me but I'm not an expert on seq2seq
All for aligning the signatures with the general Trainer
however, thanks for doing that! We can have a Seq2SeqTrainingArguments
that subclasses TrainingArguments
if that helps.
if self.args.label_smoothing == 0: | ||
# Same behavior as modeling_bart.py | ||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
loss functions of model use -100 as ignore_index
, we will also need to replace pad
tokens in labels
with -100
attention_mask=inputs["attention_mask"], | ||
use_cache=True, | ||
num_beams=self.data_args.eval_beams, | ||
max_length=self.max_gen_length, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need this if eval_beams
and and max_length
are different than default
LGTM, thanks for aligning it! We just need some way to pass
@sgugger we do have
|
Ah, had forgotten about that :-) |
|
Also would <3 an encoder decoder test in |
@@ -230,7 +233,7 @@ def main(): | |||
freeze_params(model.get_encoder()) | |||
assert_all_frozen(model.get_encoder()) | |||
|
|||
dataset_class = Seq2SeqDataset if hasattr(tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prepare_seq2seq_batch
is now as a function in PretrainedTokenizer
so this cannot be False
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great catch
@@ -137,6 +136,10 @@ class DataTrainingArguments: | |||
src_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."}) | |||
tgt_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."}) | |||
eval_beams: Optional[int] = field(default=None, metadata={"help": "# num_beams to use for evaluation."}) | |||
ignore_pad_token_for_loss: bool = field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
put at True
for backward compatibility
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
After discussion @sshleifer - changed the @sshleifer , @patil-suraj - could you do another review please? :-) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if we are trying to show that config.pad_token_id
is not mandatory, we should add a test, even if that test does not use the command line interface. Sorry for being difficult.
**gen_kwargs, | ||
) | ||
# in case the batch is shorter than max length, the output should be padded | ||
if self.config.pad_token_id is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would expect this case to break. _pad_tensors_to_max_len
is needed for some sort of Trainer
/consistent shapes reason @patil-suraj .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, Trainer
expects all returned preds to be of same shape, which it concatenates at for every batch eval
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get it -> if config.pad_token_id
is not defined we cannot run _pad_tensors_to_max_len
. How is this breaking anything? I am running all my experiments with no pad_token_id
defined, so this case works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since Trainer
concatenates the preds
I assuming they should be of same length across batches. It was breaking in my last experiment when not using _pad_tensors_to_max_len
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is fine -> see the test I added for bert2bert. Such a model does not have a self.config.pad_token_id
defined and still works.
…ormers into adapt_seq2seq_trainer
@@ -41,12 +41,21 @@ | |||
|
|||
|
|||
class Seq2SeqTrainer(Trainer): | |||
def __init__(self, config, data_args, *args, **kwargs): | |||
def __init__(self, config=None, data_args=None, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make those variables optional to align better with Trainer
and to keep 100% backwards compatibility
examples/seq2seq/seq2seq_trainer.py
Outdated
# set all ids to -100 to be ignored | ||
if self.data_args is not None and self.data_args.ignore_pad_token_for_loss: | ||
assert self.config.pad_token_id >= 0, "Make sure that `config.pad_token_id` is correcly defined" | ||
inputs["labels"][inputs["labels"] == self.config.pad_token_id] = -100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This keeps 100% backwards compatibility
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this cause a TPU issue @LysandreJik ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would this cause a TPU issue? All of our models work with -100 to ignore CE loss
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think some tensor manipulations/assignments on TPU requiring sending the tensor back to CPU to do the op then returning it to CPU. Lys told me it was bad to assert -100 in inputs['labels']
, for that reason. In that case we could do this in the collator I guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the test!
2 comments about moving asserts to __init__
for quicker failure
Most important is whether the line I tagged lysandre on causes TPU slowdown.
@@ -230,7 +233,7 @@ def main(): | |||
freeze_params(model.get_encoder()) | |||
assert_all_frozen(model.get_encoder()) | |||
|
|||
dataset_class = Seq2SeqDataset if hasattr(tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great catch
examples/seq2seq/seq2seq_trainer.py
Outdated
# set all ids to -100 to be ignored | ||
if self.data_args is not None and self.data_args.ignore_pad_token_for_loss: | ||
assert self.config.pad_token_id >= 0, "Make sure that `config.pad_token_id` is correcly defined" | ||
inputs["labels"][inputs["labels"] == self.config.pad_token_id] = -100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this cause a TPU issue @LysandreJik ?
Should be good - I don't really see how -100 would slow down the TPU, but let's wait for @LysandreJik opinion here. |
Can't seem to reply to the comment, but yes, the line @sshleifer is pointing at will slow down on TPU since it's probably using a |
Okey, I see -> let's move back in the old CE loss function then to keep backward compatibility! @sshleifer - one last review please :-) |
else: | ||
# compute label smoothed loss | ||
labels = inputs.pop("labels") | ||
logits = model(**inputs, use_cache=False)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think use_cache=False
everywhere or nowhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed it - think it's better this way to not give the false impression that use_cache=True
will break training. All models have use_cache=True
by default and training works by default. It's all about whether past_key_values
are inserted or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh this actually breaks a test - it shouldn't. This is related to this Bart bug we never solved: #6353 :-/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will add use_cache=False
again for now and remove it when fixing the bug in Bart.
bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size | ||
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id | ||
|
||
train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is cool.
cc @stas00 if you ever want to add more training data to a unit-test.
|
||
return batch | ||
|
||
def _compute_metrics(pred): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI by default you will get rouge1, rouge2, rougeL (if you don't overwrite compute_metrics
What does this PR do?
This PR changes the Seq2Seq Trainer a bit to:
EncoderDecoder
Trainer
@sshleifer @patil-suraj @sgugger - it would be great if you could take a look and give your general opinion on it :-)
If this would be ok for you, I will fix the examples test.