Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Seq2Seq] Allow EncoderDecoderModels to be trained with Seq2Seq #7809

Merged

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Oct 15, 2020

What does this PR do?

This PR changes the Seq2Seq Trainer a bit to:

  1. Make it work with EncoderDecoder
  2. Align its API more with the general Trainer

@sshleifer @patil-suraj @sgugger - it would be great if you could take a look and give your general opinion on it :-)
If this would be ok for you, I will fix the examples test.

@@ -41,12 +41,13 @@


class Seq2SeqTrainer(Trainer):
def __init__(self, config, data_args, *args, **kwargs):
def __init__(self, *args, **kwargs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj @sshleifer - I think it would be better to align the init of Seq2SeqTrainer 100% with Trainer.
Is there a reason why we would insert config instead of using the model's config?

Also I don't really think the variable data_args is necessary. Both max_length and num_beams can be defined in the config and don't have to be "force" passed to the generate() method.

Copy link
Contributor

@patil-suraj patil-suraj Oct 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model.config breaks under DistributedDataParallel, so we decided to pass it explicitly. See #7461 and #7460.

if default num_beams and max_length is too high it'll slow down evaluation, so we allow the user to control it during training. And not overriding config since defaults will be needed for inference after training.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okok I see! I'm a bit confused why Trainer does not break with DistributedDataParallel when only using model.config.... , but Seq2SeqTrainer does? Do you guys know why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eval_beams/eval_max_gen_length reasoning:
@patil-suraj said exactly this LOL, but in my words:
users are not good at modifying configs locally. We want to have a way to run num_beams=2 during the generation step, but then end up with a trained model with the default # beams. In general, we try not to manipulate config attributes that would only be desired during training.

I mean modifying the configs locally is as simple as config.num_beams = 4 and I would think one wants to evaluate a model during training with exactly the beam size and max_length that is stored in the config (I mean changing the beam size and max_length does not simple reduce time, but also changes the output...) But I guess I can see the use case where the people want to tweak max_length and num_beams without changing the config. Would it be fine to make data_args optional and call them generation_args that will just be passed as **generation_args to the generate function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ups that was supposed to land further below not here. @sshleifer for reference.

Copy link
Contributor

@sshleifer sshleifer Oct 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data_args -> generation_kwargs seems like a good change (at least in seq2seq_trainer.py), but the CLI naming has a purpose:
It wouldn't have been obvious to me that passing --min_length 32 would affect generation, rather than truncating source docs. That's why the eval_ prefix was added.

if self.args.label_smoothing == 0:
# Same behavior as modeling_bart.py
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not seem to work for all models (EncoderDecoderModel does not work with it) -> Let's instead use the loss function of each model here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loss functions of model use -100 as ignore_index , we will also need to replace pad tokens in labels with -100

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually do this manually before -> should that be the role of the Seq2SeqTrainer? Trainer also does not have this feature

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignoring pad_token_id confused lots of people and helps metrics so we automated it.
Related: #7828

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually do this manually before

we could do this in the collator, but we won't need to do if #7828 is merged

Copy link
Contributor

@sshleifer sshleifer Oct 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will still need to cover FSMT/T5.
I would definitely not do this change right now, it works as is and is much easier than checking that every model ignores padding.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch's CE loss function has -100 as a default value and from what I understood it is the default behavior of the library to ignore tokens when there have the index -100 and not when there are equal to the padding token (often we set padding token == -100): https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html

It would require models to manually replace tokens with -100, but I think that's how it should be done in general in the library. How would be handle models that don't have a padding_token or want to disregard loss of more than just the padding token? For such cases I think it can be quite handy if the user overwrites all labels he does not want to consider with -100

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will discuss on zoom!

# in case the batch is shorter than max length, the output should be padded
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, self.model.config.max_length)

# compute loss on predict data
with torch.no_grad():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generate() is always in torch.no_grad() context.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me but I'm not an expert on seq2seq
All for aligning the signatures with the general Trainer however, thanks for doing that! We can have a Seq2SeqTrainingArguments that subclasses TrainingArguments if that helps.

if self.args.label_smoothing == 0:
# Same behavior as modeling_bart.py
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loss functions of model use -100 as ignore_index , we will also need to replace pad tokens in labels with -100

attention_mask=inputs["attention_mask"],
use_cache=True,
num_beams=self.data_args.eval_beams,
max_length=self.max_gen_length,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need this if eval_beams and and max_length are different than default

@patil-suraj
Copy link
Contributor

LGTM, thanks for aligning it! We just need some way to pass eval_beams and max_gen_length.

We can have a Seq2SeqTrainingArguments that subclasses TrainingArguments if that helps.

@sgugger we do have Seq2SeqTrainingArguments class

class Seq2SeqTrainingArguments(TrainingArguments):

@sgugger
Copy link
Collaborator

sgugger commented Oct 15, 2020

@sgugger we do have Seq2SeqTrainingArguments class

Ah, had forgotten about that :-)

@sshleifer
Copy link
Contributor

sshleifer commented Oct 15, 2020

eval_beams/eval_max_gen_length reasoning:
@patil-suraj said exactly this LOL, but in my words:
users are not good at modifying configs locally. We want to have a way to run num_beams=2 during the generation step, but then end up with a trained model with the default # beams. In general, we try not to manipulate config attributes that would only be desired during training.

@sshleifer
Copy link
Contributor

Also would <3 an encoder decoder test in examples/seq2seq/test_finetune_trainer.py.

@patrickvonplaten patrickvonplaten changed the title [Examples] Align Seq2Seq Trainer with Trainer [Examples] Allow EncoderDecoderModels to be trained with Seq2Seq Oct 18, 2020
@@ -230,7 +233,7 @@ def main():
freeze_params(model.get_encoder())
assert_all_frozen(model.get_encoder())

dataset_class = Seq2SeqDataset if hasattr(tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prepare_seq2seq_batch is now as a function in PretrainedTokenizer so this cannot be False.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great catch

@@ -137,6 +136,10 @@ class DataTrainingArguments:
src_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."})
tgt_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."})
eval_beams: Optional[int] = field(default=None, metadata={"help": "# num_beams to use for evaluation."})
ignore_pad_token_for_loss: bool = field(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put at True for backward compatibility

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@patrickvonplaten
Copy link
Contributor Author

After discussion @sshleifer - changed the Seq2SeqTrainer to be fully backwards compatible and to work with EncoderDecoder.
@sshleifer - cannot add EncDec test yet because the complete command line setup is too constrained (requires prepare_seq2seq_batch to be defined for all tokenizers, etc...) => will see how to add this in the future.

@sshleifer , @patil-suraj - could you do another review please? :-)

Copy link
Contributor

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if we are trying to show that config.pad_token_id is not mandatory, we should add a test, even if that test does not use the command line interface. Sorry for being difficult.

**gen_kwargs,
)
# in case the batch is shorter than max length, the output should be padded
if self.config.pad_token_id is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect this case to break. _pad_tensors_to_max_len is needed for some sort of Trainer/consistent shapes reason @patil-suraj .

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, Trainer expects all returned preds to be of same shape, which it concatenates at for every batch eval

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get it -> if config.pad_token_id is not defined we cannot run _pad_tensors_to_max_len. How is this breaking anything? I am running all my experiments with no pad_token_id defined, so this case works.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since Trainer concatenates the preds I assuming they should be of same length across batches. It was breaking in my last experiment when not using _pad_tensors_to_max_len

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine -> see the test I added for bert2bert. Such a model does not have a self.config.pad_token_id defined and still works.

@@ -41,12 +41,21 @@


class Seq2SeqTrainer(Trainer):
def __init__(self, config, data_args, *args, **kwargs):
def __init__(self, config=None, data_args=None, *args, **kwargs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make those variables optional to align better with Trainer and to keep 100% backwards compatibility

# set all ids to -100 to be ignored
if self.data_args is not None and self.data_args.ignore_pad_token_for_loss:
assert self.config.pad_token_id >= 0, "Make sure that `config.pad_token_id` is correcly defined"
inputs["labels"][inputs["labels"] == self.config.pad_token_id] = -100
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This keeps 100% backwards compatibility

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this cause a TPU issue @LysandreJik ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would this cause a TPU issue? All of our models work with -100 to ignore CE loss

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think some tensor manipulations/assignments on TPU requiring sending the tensor back to CPU to do the op then returning it to CPU. Lys told me it was bad to assert -100 in inputs['labels'], for that reason. In that case we could do this in the collator I guess.

Copy link
Contributor

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the test!
2 comments about moving asserts to __init__ for quicker failure
Most important is whether the line I tagged lysandre on causes TPU slowdown.

@@ -230,7 +233,7 @@ def main():
freeze_params(model.get_encoder())
assert_all_frozen(model.get_encoder())

dataset_class = Seq2SeqDataset if hasattr(tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great catch

# set all ids to -100 to be ignored
if self.data_args is not None and self.data_args.ignore_pad_token_for_loss:
assert self.config.pad_token_id >= 0, "Make sure that `config.pad_token_id` is correcly defined"
inputs["labels"][inputs["labels"] == self.config.pad_token_id] = -100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this cause a TPU issue @LysandreJik ?

@patrickvonplaten
Copy link
Contributor Author

Should be good - I don't really see how -100 would slow down the TPU, but let's wait for @LysandreJik opinion here.

@sgugger
Copy link
Collaborator

sgugger commented Oct 22, 2020

Can't seem to reply to the comment, but yes, the line @sshleifer is pointing at will slow down on TPU since it's probably using a torch.where behind the scene which does not have an XLA operation AFAIK.

@patrickvonplaten
Copy link
Contributor Author

Can't seem to reply to the comment, but yes, the line @sshleifer is pointing at will slow down on TPU since it's probably using a torch.where behind the scene which does not have an XLA operation AFAIK.

Okey, I see -> let's move back in the old CE loss function then to keep backward compatibility!

@sshleifer - one last review please :-)

else:
# compute label smoothed loss
labels = inputs.pop("labels")
logits = model(**inputs, use_cache=False)[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think use_cache=False everywhere or nowhere

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed it - think it's better this way to not give the false impression that use_cache=True will break training. All models have use_cache=True by default and training works by default. It's all about whether past_key_values are inserted or not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this actually breaks a test - it shouldn't. This is related to this Bart bug we never solved: #6353 :-/

Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Oct 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add use_cache=False again for now and remove it when fixing the bug in Bart.

bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id

train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is cool.
cc @stas00 if you ever want to add more training data to a unit-test.


return batch

def _compute_metrics(pred):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI by default you will get rouge1, rouge2, rougeL (if you don't overwrite compute_metrics

@patrickvonplaten patrickvonplaten merged commit 3c682ea into huggingface:master Oct 23, 2020
@patrickvonplaten patrickvonplaten deleted the adapt_seq2seq_trainer branch October 23, 2020 21:06
@patrickvonplaten patrickvonplaten changed the title [Examples] Allow EncoderDecoderModels to be trained with Seq2Seq [Seq2Seq] Allow EncoderDecoderModels to be trained with Seq2Seq Oct 26, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants