From 6880957810143e08c2e773639cc5a74841033115 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Mon, 25 Oct 2021 12:18:01 +0200 Subject: [PATCH 01/10] First draft --- .../modeling_encoder_decoder.py | 42 +++++++++++++++++- .../modeling_speech_encoder_decoder.py | 42 +++++++++++++++++- .../modeling_vision_encoder_decoder.py | 44 +++++++++++++++++-- 3 files changed, 120 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index f3473a53a802..5e83dbde5dff 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -17,6 +17,9 @@ from typing import Optional +import torch +from torch.nn import CrossEntropyLoss + from ...configuration_utils import PretrainedConfig from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...modeling_outputs import Seq2SeqLMOutput @@ -136,6 +139,21 @@ """ +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + @add_start_docstrings(ENCODER_DECODER_START_DOCSTRING) class EncoderDecoderModel(PreTrainedModel): r""" @@ -434,6 +452,12 @@ def forward( encoder_hidden_states = encoder_outputs[0] + if labels is not None: + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, @@ -450,11 +474,22 @@ def forward( **kwargs_decoder, ) + # Compute loss independent from decoder (as some shift the logits inside them) + loss = None + if labels is not None: + logits = decoder_outputs.logits if return_dict else decoder_outputs[1] + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.decoder.config.vocab_size), labels.view(-1)) + if not return_dict: - return decoder_outputs + encoder_outputs + return ( + (loss,) + decoder_outputs[1:] + encoder_outputs + if loss is not None + else decoder_outputs + encoder_outputs + ) return Seq2SeqLMOutput( - loss=decoder_outputs.loss, + loss=loss, logits=decoder_outputs.logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, @@ -465,6 +500,9 @@ def forward( encoder_attentions=encoder_outputs.attentions, ) + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + def prepare_inputs_for_generation( self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs ): diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index 4207067a4659..954b3e948625 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -17,7 +17,9 @@ from typing import Optional +import torch from torch import nn +from torch.nn import CrossEntropyLoss from ...configuration_utils import PretrainedConfig from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings @@ -149,6 +151,21 @@ """ +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + @add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING) class SpeechEncoderDecoderModel(PreTrainedModel): r""" @@ -467,6 +484,12 @@ def forward( else: encoder_attention_mask = None + if labels is not None: + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, @@ -482,20 +505,35 @@ def forward( **kwargs_decoder, ) + # Compute loss independent from decoder (as some shift the logits inside them) + loss = None + if labels is not None: + logits = decoder_outputs.logits if return_dict else decoder_outputs[1] + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.decoder.config.vocab_size), labels.view(-1)) + if not return_dict: - return decoder_outputs + encoder_outputs + return ( + (loss,) + decoder_outputs[1:] + encoder_outputs + if loss is not None + else decoder_outputs + encoder_outputs + ) return Seq2SeqLMOutput( + loss=loss, logits=decoder_outputs.logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_hidden_states, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + def prepare_inputs_for_generation( self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs ): diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index 8476e5c74b7b..14cc75ef2af1 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -17,7 +17,9 @@ from typing import Optional +import torch from torch import nn +from torch.nn import CrossEntropyLoss from ...configuration_utils import PretrainedConfig from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings @@ -29,6 +31,21 @@ from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "VisionEncoderDecoderConfig" @@ -448,6 +465,12 @@ def forward( # else: encoder_attention_mask = None + if labels is not None: + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, @@ -455,7 +478,6 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, inputs_embeds=decoder_inputs_embeds, - labels=labels, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, @@ -464,21 +486,35 @@ def forward( **kwargs_decoder, ) + # Compute loss independent from decoder (as some shift the logits inside them) + loss = None + if labels is not None: + logits = decoder_outputs.logits if return_dict else decoder_outputs[1] + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.decoder.config.vocab_size), labels.view(-1)) + if not return_dict: - return decoder_outputs + encoder_outputs + return ( + (loss,) + decoder_outputs[1:] + encoder_outputs + if loss is not None + else decoder_outputs + encoder_outputs + ) return Seq2SeqLMOutput( - loss=decoder_outputs.loss, + loss=loss, logits=decoder_outputs.logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_hidden_states, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + def prepare_inputs_for_generation( self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs ): From 599ea2e3be1bfd102068e576832fc2ac5bde4214 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Mon, 25 Oct 2021 15:40:45 +0200 Subject: [PATCH 02/10] Make tuple output more readable --- .../models/encoder_decoder/modeling_encoder_decoder.py | 10 ++++------ .../modeling_speech_encoder_decoder.py | 9 ++++----- .../modeling_vision_encoder_decoder.py | 9 ++++----- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 5e83dbde5dff..8796b1ce1ee3 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -465,7 +465,6 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=attention_mask, inputs_embeds=decoder_inputs_embeds, - labels=labels, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, @@ -482,11 +481,10 @@ def forward( loss = loss_fct(logits.view(-1, self.decoder.config.vocab_size), labels.view(-1)) if not return_dict: - return ( - (loss,) + decoder_outputs[1:] + encoder_outputs - if loss is not None - else decoder_outputs + encoder_outputs - ) + if loss is not None: + return (loss,) + decoder_outputs + encoder_outputs + else: + return decoder_outputs + encoder_outputs return Seq2SeqLMOutput( loss=loss, diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index 954b3e948625..48d5afcd2778 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -513,11 +513,10 @@ def forward( loss = loss_fct(logits.view(-1, self.decoder.config.vocab_size), labels.view(-1)) if not return_dict: - return ( - (loss,) + decoder_outputs[1:] + encoder_outputs - if loss is not None - else decoder_outputs + encoder_outputs - ) + if loss is not None: + return (loss,) + decoder_outputs + encoder_outputs + else: + return decoder_outputs + encoder_outputs return Seq2SeqLMOutput( loss=loss, diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index 14cc75ef2af1..dc328723ee69 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -494,11 +494,10 @@ def forward( loss = loss_fct(logits.view(-1, self.decoder.config.vocab_size), labels.view(-1)) if not return_dict: - return ( - (loss,) + decoder_outputs[1:] + encoder_outputs - if loss is not None - else decoder_outputs + encoder_outputs - ) + if loss is not None: + return (loss,) + decoder_outputs + encoder_outputs + else: + return decoder_outputs + encoder_outputs return Seq2SeqLMOutput( loss=loss, From 7f7ec3f916c195a2f52d5297f26cdb8e2952980c Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Mon, 25 Oct 2021 16:37:12 +0200 Subject: [PATCH 03/10] Replace assertions by value errors --- .../models/encoder_decoder/modeling_encoder_decoder.py | 3 ++- .../speech_encoder_decoder/modeling_speech_encoder_decoder.py | 3 ++- .../vision_encoder_decoder/modeling_vision_encoder_decoder.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 8796b1ce1ee3..5827db083acb 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -147,7 +147,8 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() shifted_input_ids[:, 0] = decoder_start_token_id - assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") # replace possible -100 values in labels by `pad_token_id` shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index 48d5afcd2778..5253d2cbaa6b 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -159,7 +159,8 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() shifted_input_ids[:, 0] = decoder_start_token_id - assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") # replace possible -100 values in labels by `pad_token_id` shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index dc328723ee69..5857f6143a12 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -39,7 +39,8 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() shifted_input_ids[:, 0] = decoder_start_token_id - assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") # replace possible -100 values in labels by `pad_token_id` shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) From 8ce50c5d4a5f05a193b355103943ca65139ead9a Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 26 Oct 2021 13:17:20 +0200 Subject: [PATCH 04/10] Make it possible to predict_with_generate for vision and speech models --- src/transformers/trainer_seq2seq.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer_seq2seq.py b/src/transformers/trainer_seq2seq.py index 1995677801d1..8cd2c1dd5d73 100644 --- a/src/transformers/trainer_seq2seq.py +++ b/src/transformers/trainer_seq2seq.py @@ -164,9 +164,13 @@ def prediction_step( "synced_gpus": True if is_deepspeed_zero3_enabled() else False, } + if self.tokenizer is not None: + generation_inputs = {k: v for k, v in inputs.items() if k in self.tokenizer.model_input_names} + else: + generation_inputs = inputs["input_ids"] + generated_tokens = self.model.generate( - inputs["input_ids"], - attention_mask=inputs["attention_mask"], + **generation_inputs, **gen_kwargs, ) # in case the batch is shorter than max length, the output should be padded From 9937637c0592a4cdc9ed9205a12ccabefc32ef25 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 26 Oct 2021 15:59:13 +0200 Subject: [PATCH 05/10] Adapt Seq2SeqTrainer to work with VisionEncoderDecoder/SpeechEncoderDecoder --- src/transformers/trainer_seq2seq.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/transformers/trainer_seq2seq.py b/src/transformers/trainer_seq2seq.py index 8cd2c1dd5d73..5e533a62a659 100644 --- a/src/transformers/trainer_seq2seq.py +++ b/src/transformers/trainer_seq2seq.py @@ -166,6 +166,8 @@ def prediction_step( if self.tokenizer is not None: generation_inputs = {k: v for k, v in inputs.items() if k in self.tokenizer.model_input_names} + # very ugly hack to make it work + generation_inputs["input_ids"] = generation_inputs.pop(self.tokenizer.model_input_names[0]) else: generation_inputs = inputs["input_ids"] @@ -201,15 +203,16 @@ def prediction_step( return (loss, generated_tokens, labels) def _pad_tensors_to_max_len(self, tensor, max_length): - if self.tokenizer is None: - raise ValueError( - f"Tensor need to be padded to `max_length={max_length}` but no tokenizer was passed when creating " - "this `Trainer`. Make sure to create your `Trainer` with the appropriate tokenizer." + if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"): + # If PAD token is not defined at least EOS token has to be defined + pad_token_id = ( + self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id ) - # If PAD token is not defined at least EOS token has to be defined - pad_token_id = ( - self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id - ) + else: + if self.model.config.pad_token_id is not None: + pad_token_id = self.model.config.pad_token_id + else: + raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors") padded_tensor = pad_token_id * torch.ones( (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device From 2c21151f42735f34296be3a63b6d3b743d0b9699 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Wed, 27 Oct 2021 10:20:14 +0200 Subject: [PATCH 06/10] Add deprecation warning --- .../encoder_decoder/modeling_encoder_decoder.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 5827db083acb..86e73c292700 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -14,7 +14,7 @@ # limitations under the License. """ Classes to support Encoder-Decoder architectures """ - +import warnings from typing import Optional import torch @@ -32,6 +32,13 @@ _CONFIG_FOR_DOC = "EncoderDecoderConfig" +DEPRECATION_WARNING = ( + "Version v4.12.0 introduces a better way to train encoder-decoder models by computing the loss inside the " + "encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if fine-tuning " + "a model trained with versions anterior to 4.12.0. The decoder_input_ids are now created based on the labels, no " + "need to pass them yourself anymore." +) + ENCODER_DECODER_START_DOCSTRING = r""" This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via @@ -145,10 +152,12 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start """ shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + if decoder_start_token_id is None: + raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") shifted_input_ids[:, 0] = decoder_start_token_id if pad_token_id is None: - raise ValueError("self.model.config.pad_token_id has to be defined.") + raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") # replace possible -100 values in labels by `pad_token_id` shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) @@ -477,6 +486,7 @@ def forward( # Compute loss independent from decoder (as some shift the logits inside them) loss = None if labels is not None: + warnings.warn(DEPRECATION_WARNING, FutureWarning) logits = decoder_outputs.logits if return_dict else decoder_outputs[1] loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.decoder.config.vocab_size), labels.view(-1)) From 554b47fbc7f426af5cf04ce7c2d27019bc07636f Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Wed, 27 Oct 2021 10:46:09 +0200 Subject: [PATCH 07/10] Add copied from statements to vision and speech encoder decoders --- .../modeling_speech_encoder_decoder.py | 5 ++++- .../modeling_vision_encoder_decoder.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index 5253d2cbaa6b..7026786c8a5b 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -151,16 +151,19 @@ """ +# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ Shift input ids one token to the right. """ shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + if decoder_start_token_id is None: + raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") shifted_input_ids[:, 0] = decoder_start_token_id if pad_token_id is None: - raise ValueError("self.model.config.pad_token_id has to be defined.") + raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") # replace possible -100 values in labels by `pad_token_id` shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index 5857f6143a12..976d805acffa 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -31,16 +31,19 @@ from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig +# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ Shift input ids one token to the right. """ shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + if decoder_start_token_id is None: + raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") shifted_input_ids[:, 0] = decoder_start_token_id if pad_token_id is None: - raise ValueError("self.model.config.pad_token_id has to be defined.") + raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") # replace possible -100 values in labels by `pad_token_id` shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) From dfbbceb82d3db3da1dfb5c5993e4e1cc50d368ec Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Wed, 27 Oct 2021 18:48:43 +0200 Subject: [PATCH 08/10] Fix failing test --- .../models/encoder_decoder/modeling_encoder_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 86e73c292700..04033472c0e7 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -489,7 +489,7 @@ def forward( warnings.warn(DEPRECATION_WARNING, FutureWarning) logits = decoder_outputs.logits if return_dict else decoder_outputs[1] loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.decoder.config.vocab_size), labels.view(-1)) + loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1)) if not return_dict: if loss is not None: From 263f54de9d813cdf560877f74b65d5e900947e2a Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Thu, 28 Oct 2021 14:37:46 +0200 Subject: [PATCH 09/10] Apply @patrickvonplaten's suggestion --- .../models/encoder_decoder/modeling_encoder_decoder.py | 9 ++++----- .../modeling_speech_encoder_decoder.py | 9 ++++----- .../modeling_vision_encoder_decoder.py | 9 ++++----- 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 04033472c0e7..461501737a52 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -462,11 +462,10 @@ def forward( encoder_hidden_states = encoder_outputs[0] - if labels is not None: - if decoder_input_ids is None and decoder_inputs_embeds is None: - decoder_input_ids = shift_tokens_right( - labels, self.config.pad_token_id, self.config.decoder_start_token_id - ) + if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) # Decode decoder_outputs = self.decoder( diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index 7026786c8a5b..fa3f195d92b2 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -488,11 +488,10 @@ def forward( else: encoder_attention_mask = None - if labels is not None: - if decoder_input_ids is None and decoder_inputs_embeds is None: - decoder_input_ids = shift_tokens_right( - labels, self.config.pad_token_id, self.config.decoder_start_token_id - ) + if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) # Decode decoder_outputs = self.decoder( diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index 976d805acffa..20a2d0ad2644 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -469,11 +469,10 @@ def forward( # else: encoder_attention_mask = None - if labels is not None: - if decoder_input_ids is None and decoder_inputs_embeds is None: - decoder_input_ids = shift_tokens_right( - labels, self.config.pad_token_id, self.config.decoder_start_token_id - ) + if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) # Decode decoder_outputs = self.decoder( From 8aa8ce66e26ce2b619b299634f1a267e68605a59 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Thu, 28 Oct 2021 14:47:50 +0200 Subject: [PATCH 10/10] Use reshape instead of view for consistency --- .../speech_encoder_decoder/modeling_speech_encoder_decoder.py | 2 +- .../vision_encoder_decoder/modeling_vision_encoder_decoder.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index fa3f195d92b2..7863e3681045 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -513,7 +513,7 @@ def forward( if labels is not None: logits = decoder_outputs.logits if return_dict else decoder_outputs[1] loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.decoder.config.vocab_size), labels.view(-1)) + loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1)) if not return_dict: if loss is not None: diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index 20a2d0ad2644..4355b232113c 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -494,7 +494,7 @@ def forward( if labels is not None: logits = decoder_outputs.logits if return_dict else decoder_outputs[1] loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.decoder.config.vocab_size), labels.view(-1)) + loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1)) if not return_dict: if loss is not None: