Skip to content

Commit

Permalink
Fix EncoderDecoderModel classes to be more like BART and T5 (#14139)
Browse files Browse the repository at this point in the history
* First draft

* Make tuple output more readable

* Replace assertions by value errors

* Make it possible to predict_with_generate for vision and speech models

* Adapt Seq2SeqTrainer to work with VisionEncoderDecoder/SpeechEncoderDecoder

* Add deprecation warning

* Add copied from statements to vision and speech encoder decoders

* Fix failing test

* Apply @patrickvonplaten's suggestion

* Use reshape instead of view for consistency
  • Loading branch information
NielsRogge authored Oct 28, 2021
1 parent 1251072 commit ac12a5a
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
# limitations under the License.
""" Classes to support Encoder-Decoder architectures """


import warnings
from typing import Optional

import torch
from torch.nn import CrossEntropyLoss

from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_outputs import Seq2SeqLMOutput
Expand All @@ -29,6 +32,13 @@

_CONFIG_FOR_DOC = "EncoderDecoderConfig"

DEPRECATION_WARNING = (
"Version v4.12.0 introduces a better way to train encoder-decoder models by computing the loss inside the "
"encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if fine-tuning "
"a model trained with versions anterior to 4.12.0. The decoder_input_ids are now created based on the labels, no "
"need to pass them yourself anymore."
)

ENCODER_DECODER_START_DOCSTRING = r"""
This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the
encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via
Expand Down Expand Up @@ -136,6 +146,24 @@
"""


def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
shifted_input_ids[:, 0] = decoder_start_token_id

if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

return shifted_input_ids


@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
class EncoderDecoderModel(PreTrainedModel):
r"""
Expand Down Expand Up @@ -434,14 +462,18 @@ def forward(

encoder_hidden_states = encoder_outputs[0]

if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)

# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=attention_mask,
inputs_embeds=decoder_inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
Expand All @@ -450,11 +482,22 @@ def forward(
**kwargs_decoder,
)

# Compute loss independent from decoder (as some shift the logits inside them)
loss = None
if labels is not None:
warnings.warn(DEPRECATION_WARNING, FutureWarning)
logits = decoder_outputs.logits if return_dict else decoder_outputs[1]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))

if not return_dict:
return decoder_outputs + encoder_outputs
if loss is not None:
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs

return Seq2SeqLMOutput(
loss=decoder_outputs.loss,
loss=loss,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
Expand All @@ -465,6 +508,9 @@ def forward(
encoder_attentions=encoder_outputs.attentions,
)

def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

from typing import Optional

import torch
from torch import nn
from torch.nn import CrossEntropyLoss

from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
Expand Down Expand Up @@ -149,6 +151,25 @@
"""


# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
shifted_input_ids[:, 0] = decoder_start_token_id

if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

return shifted_input_ids


@add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING)
class SpeechEncoderDecoderModel(PreTrainedModel):
r"""
Expand Down Expand Up @@ -467,6 +488,11 @@ def forward(
else:
encoder_attention_mask = None

if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)

# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
Expand All @@ -482,20 +508,34 @@ def forward(
**kwargs_decoder,
)

# Compute loss independent from decoder (as some shift the logits inside them)
loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[1]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))

if not return_dict:
return decoder_outputs + encoder_outputs
if loss is not None:
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs

return Seq2SeqLMOutput(
loss=loss,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_hidden_states,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)

def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

from typing import Optional

import torch
from torch import nn
from torch.nn import CrossEntropyLoss

from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
Expand All @@ -29,6 +31,25 @@
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig


# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
shifted_input_ids[:, 0] = decoder_start_token_id

if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

return shifted_input_ids


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "VisionEncoderDecoderConfig"
Expand Down Expand Up @@ -448,14 +469,18 @@ def forward(
# else:
encoder_attention_mask = None

if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)

# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
Expand All @@ -464,21 +489,34 @@ def forward(
**kwargs_decoder,
)

# Compute loss independent from decoder (as some shift the logits inside them)
loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[1]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))

if not return_dict:
return decoder_outputs + encoder_outputs
if loss is not None:
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs

return Seq2SeqLMOutput(
loss=decoder_outputs.loss,
loss=loss,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_hidden_states,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)

def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
Expand Down
27 changes: 17 additions & 10 deletions src/transformers/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,15 @@ def prediction_step(
"synced_gpus": True if is_deepspeed_zero3_enabled() else False,
}

if self.tokenizer is not None:
generation_inputs = {k: v for k, v in inputs.items() if k in self.tokenizer.model_input_names}
# very ugly hack to make it work
generation_inputs["input_ids"] = generation_inputs.pop(self.tokenizer.model_input_names[0])
else:
generation_inputs = inputs["input_ids"]

generated_tokens = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
**generation_inputs,
**gen_kwargs,
)
# in case the batch is shorter than max length, the output should be padded
Expand Down Expand Up @@ -197,15 +203,16 @@ def prediction_step(
return (loss, generated_tokens, labels)

def _pad_tensors_to_max_len(self, tensor, max_length):
if self.tokenizer is None:
raise ValueError(
f"Tensor need to be padded to `max_length={max_length}` but no tokenizer was passed when creating "
"this `Trainer`. Make sure to create your `Trainer` with the appropriate tokenizer."
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
# If PAD token is not defined at least EOS token has to be defined
pad_token_id = (
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
)
# If PAD token is not defined at least EOS token has to be defined
pad_token_id = (
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
)
else:
if self.model.config.pad_token_id is not None:
pad_token_id = self.model.config.pad_token_id
else:
raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")

padded_tensor = pad_token_id * torch.ones(
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
Expand Down

0 comments on commit ac12a5a

Please sign in to comment.