Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Design proposal] Fix EncoderDecoderModel classes to be more like BART and T5 #14139

Merged
merged 10 commits into from
Oct 28, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
# limitations under the License.
""" Classes to support Encoder-Decoder architectures """


import warnings
from typing import Optional

import torch
from torch.nn import CrossEntropyLoss

from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_outputs import Seq2SeqLMOutput
Expand All @@ -29,6 +32,13 @@

_CONFIG_FOR_DOC = "EncoderDecoderConfig"

DEPRECATION_WARNING = (
"Version v4.12.0 introduces a better way to train encoder-decoder models by computing the loss inside the "
"encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if fine-tuning "
"a model trained with versions anterior to 4.12.0. The decoder_input_ids are now created based on the labels, no "
"need to pass them yourself anymore."
)

ENCODER_DECODER_START_DOCSTRING = r"""
This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the
encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via
Expand Down Expand Up @@ -136,6 +146,24 @@
"""


def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
shifted_input_ids[:, 0] = decoder_start_token_id

if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

return shifted_input_ids


@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
class EncoderDecoderModel(PreTrainedModel):
r"""
Expand Down Expand Up @@ -434,14 +462,18 @@ def forward(

encoder_hidden_states = encoder_outputs[0]

if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)

# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=attention_mask,
inputs_embeds=decoder_inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
Expand All @@ -450,11 +482,22 @@ def forward(
**kwargs_decoder,
)

# Compute loss independent from decoder (as some shift the logits inside them)
loss = None
if labels is not None:
warnings.warn(DEPRECATION_WARNING, FutureWarning)
logits = decoder_outputs.logits if return_dict else decoder_outputs[1]
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))

if not return_dict:
return decoder_outputs + encoder_outputs
if loss is not None:
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs

return Seq2SeqLMOutput(
loss=decoder_outputs.loss,
loss=loss,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
Expand All @@ -465,6 +508,9 @@ def forward(
encoder_attentions=encoder_outputs.attentions,
)

def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

from typing import Optional

import torch
from torch import nn
from torch.nn import CrossEntropyLoss

from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
Expand Down Expand Up @@ -149,6 +151,25 @@
"""


# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
shifted_input_ids[:, 0] = decoder_start_token_id

if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

return shifted_input_ids


@add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING)
class SpeechEncoderDecoderModel(PreTrainedModel):
r"""
Expand Down Expand Up @@ -467,6 +488,11 @@ def forward(
else:
encoder_attention_mask = None

if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)

# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
Expand All @@ -482,20 +508,34 @@ def forward(
**kwargs_decoder,
)

# Compute loss independent from decoder (as some shift the logits inside them)
loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[1]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))

if not return_dict:
return decoder_outputs + encoder_outputs
if loss is not None:
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs

return Seq2SeqLMOutput(
loss=loss,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_hidden_states,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)

def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

from typing import Optional

import torch
from torch import nn
from torch.nn import CrossEntropyLoss

from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
Expand All @@ -29,6 +31,25 @@
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig


# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
shifted_input_ids[:, 0] = decoder_start_token_id

if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

return shifted_input_ids


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "VisionEncoderDecoderConfig"
Expand Down Expand Up @@ -448,14 +469,18 @@ def forward(
# else:
encoder_attention_mask = None

if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)

# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
Expand All @@ -464,21 +489,34 @@ def forward(
**kwargs_decoder,
)

# Compute loss independent from decoder (as some shift the logits inside them)
loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[1]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))

if not return_dict:
return decoder_outputs + encoder_outputs
if loss is not None:
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs

return Seq2SeqLMOutput(
loss=decoder_outputs.loss,
loss=loss,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_hidden_states,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)

def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
Expand Down
27 changes: 17 additions & 10 deletions src/transformers/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,15 @@ def prediction_step(
"synced_gpus": True if is_deepspeed_zero3_enabled() else False,
}

if self.tokenizer is not None:
generation_inputs = {k: v for k, v in inputs.items() if k in self.tokenizer.model_input_names}
# very ugly hack to make it work
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
generation_inputs["input_ids"] = generation_inputs.pop(self.tokenizer.model_input_names[0])
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
else:
generation_inputs = inputs["input_ids"]

generated_tokens = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
**generation_inputs,
**gen_kwargs,
)
# in case the batch is shorter than max length, the output should be padded
Expand Down Expand Up @@ -197,15 +203,16 @@ def prediction_step(
return (loss, generated_tokens, labels)

def _pad_tensors_to_max_len(self, tensor, max_length):
if self.tokenizer is None:
raise ValueError(
f"Tensor need to be padded to `max_length={max_length}` but no tokenizer was passed when creating "
"this `Trainer`. Make sure to create your `Trainer` with the appropriate tokenizer."
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
# If PAD token is not defined at least EOS token has to be defined
pad_token_id = (
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
)
# If PAD token is not defined at least EOS token has to be defined
pad_token_id = (
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
)
else:
if self.model.config.pad_token_id is not None:
pad_token_id = self.model.config.pad_token_id
else:
raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")

padded_tensor = pad_token_id * torch.ones(
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
Expand Down