From ef39d524c50eefe429dd6d9f37df723ad4c0e9bd Mon Sep 17 00:00:00 2001 From: Yossi Synett Date: Mon, 26 Oct 2020 22:37:46 +0200 Subject: [PATCH] Output cross-attention with decoder attention output --- src/transformers/modeling_bert.py | 36 ++-- src/transformers/modeling_bert_generation.py | 12 +- src/transformers/modeling_electra.py | 24 ++- src/transformers/modeling_encoder_decoder.py | 8 +- src/transformers/modeling_gpt2.py | 31 ++- src/transformers/modeling_layoutlm.py | 32 ++- src/transformers/modeling_outputs.py | 202 ++++++++++++++++++- src/transformers/modeling_prophetnet.py | 1 + src/transformers/modeling_roberta.py | 36 ++-- tests/test_modeling_encoder_decoder.py | 60 ++++++ 10 files changed, 384 insertions(+), 58 deletions(-) diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index 587e33695b6a13..1da07777e9218a 100755 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -37,9 +37,9 @@ replace_return_docstrings, ) from .modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPooling, - CausalLMOutput, + BaseModelOutputWithCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, MaskedLMOutput, MultipleChoiceModelOutput, NextSentencePredictorOutput, @@ -449,7 +449,8 @@ def forward( return_dict=False, ): all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -483,15 +484,24 @@ def custom_forward(*inputs): ) hidden_states = layer_outputs[0] if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, ) @@ -752,7 +762,7 @@ class PreTrainedModel @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased", - output_type=BaseModelOutputWithPooling, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def forward( @@ -843,11 +853,12 @@ def forward( if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] - return BaseModelOutputWithPooling( + return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, ) @@ -984,7 +995,7 @@ def get_output_embeddings(self): return self.cls.predictions.decoder @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids=None, @@ -1063,11 +1074,12 @@ def forward( output = (prediction_scores,) + outputs[2:] return ((lm_loss,) + output) if lm_loss is not None else output - return CausalLMOutput( + return CausalLMOutputWithCrossAttentions( loss=lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + cross_attentions=outputs.attentions, ) def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): diff --git a/src/transformers/modeling_bert_generation.py b/src/transformers/modeling_bert_generation.py index f201c1bd85563f..8366f182bd748c 100755 --- a/src/transformers/modeling_bert_generation.py +++ b/src/transformers/modeling_bert_generation.py @@ -28,7 +28,7 @@ replace_return_docstrings, ) from .modeling_bert import BertEncoder -from .modeling_outputs import BaseModelOutput, CausalLMOutput +from .modeling_outputs import BaseModelOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions from .modeling_utils import PreTrainedModel from .utils import logging @@ -297,7 +297,7 @@ class PreTrainedModel @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/bert_for_seq_generation_L-24_bbc_encoder", - output_type=BaseModelOutput, + output_type=BaseModelOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def forward( @@ -381,10 +381,11 @@ def forward( if not return_dict: return (sequence_output,) + encoder_outputs[1:] - return BaseModelOutput( + return BaseModelOutputWithCrossAttentions( last_hidden_state=sequence_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, ) @@ -422,7 +423,7 @@ def get_output_embeddings(self): return self.lm_head.decoder @add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids=None, @@ -499,11 +500,12 @@ def forward( output = (prediction_scores,) + outputs[1:] return ((lm_loss,) + output) if lm_loss is not None else output - return CausalLMOutput( + return CausalLMOutputWithCrossAttentions( loss=lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, ) def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): diff --git a/src/transformers/modeling_electra.py b/src/transformers/modeling_electra.py index 3d5a161460b0e0..e244ac1c55e39b 100644 --- a/src/transformers/modeling_electra.py +++ b/src/transformers/modeling_electra.py @@ -34,7 +34,7 @@ replace_return_docstrings, ) from .modeling_outputs import ( - BaseModelOutput, + BaseModelOutputWithCrossAttentions, MaskedLMOutput, MultipleChoiceModelOutput, QuestionAnsweringModelOutput, @@ -445,7 +445,8 @@ def forward( return_dict=False, ): all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -479,15 +480,24 @@ def custom_forward(*inputs): ) hidden_states = layer_outputs[0] if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, ) @@ -697,7 +707,7 @@ class PreTrainedModel @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-discriminator", - output_type=BaseModelOutput, + output_type=BaseModelOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def forward( diff --git a/src/transformers/modeling_encoder_decoder.py b/src/transformers/modeling_encoder_decoder.py index 8efd43f555262b..e4e073aaf22f3c 100644 --- a/src/transformers/modeling_encoder_decoder.py +++ b/src/transformers/modeling_encoder_decoder.py @@ -110,8 +110,11 @@ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up decoding (see :obj:`past_key_values`). output_attentions (:obj:`bool`, `optional`): - Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned - tensors for more detail. + Whether or not to return the attentions tensors of all encoder attention layers. See ``attentions`` under + returned tensors for more detail. + decoder_output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all decoder attention layers including cross-attention. + See ``attentions`` under returned tensors for more detail. output_hidden_states (:obj:`bool`, `optional`): Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for more detail. @@ -426,6 +429,7 @@ def forward( past_key_values=None, # TODO(PVP) - need to implement cache for BERT, etc... before this works decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index 030ac24edbae5b..838ea248f1722e 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -33,7 +33,11 @@ add_start_docstrings_to_model_forward, replace_return_docstrings, ) -from .modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from .modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithPastAndCrossAttentions, + SequenceClassifierOutputWithPast, +) from .modeling_utils import ( Conv1D, PreTrainedModel, @@ -311,14 +315,14 @@ def forward( attn_output = cross_attn_outputs[0] # residual connection hidden_states = hidden_states + attn_output - outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states)) # residual connection hidden_states = hidden_states + feed_forward_hidden_states outputs = [hidden_states] + outputs - return outputs # hidden_states, present, (cross_attentions, attentions) + return outputs # hidden_states, present, (attentions, cross_attentions) class GPT2PreTrainedModel(PreTrainedModel): @@ -506,7 +510,7 @@ def _prune_heads(self, heads_to_prune): @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="gpt2", - output_type=BaseModelOutputWithPast, + output_type=BaseModelOutputWithPastAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def forward( @@ -618,7 +622,8 @@ def forward( output_shape = input_shape + (hidden_states.size(-1),) presents = () if use_cache else None - all_attentions = () if output_attentions else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: @@ -659,7 +664,9 @@ def custom_forward(*inputs): presents = presents + (present,) if output_attentions: - all_attentions = all_attentions + (outputs[2],) + all_self_attentions = all_self_attentions + (outputs[2],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3],) hidden_states = self.ln_f(hidden_states) @@ -669,13 +676,14 @@ def custom_forward(*inputs): all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - return BaseModelOutputWithPast( + return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, - attentions=all_attentions, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, ) @@ -727,7 +735,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="gpt2", - output_type=CausalLMOutputWithPast, + output_type=CausalLMOutputWithPastAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def forward( @@ -795,12 +803,13 @@ def forward( output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output - return CausalLMOutputWithPast( + return CausalLMOutputWithPastAndCrossAttentions( loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, ) diff --git a/src/transformers/modeling_layoutlm.py b/src/transformers/modeling_layoutlm.py index 29ff2ce77c0e02..24126c0c00746d 100644 --- a/src/transformers/modeling_layoutlm.py +++ b/src/transformers/modeling_layoutlm.py @@ -24,7 +24,12 @@ from .activations import ACT2FN from .configuration_layoutlm import LayoutLMConfig from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward -from .modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, TokenClassifierOutput +from .modeling_outputs import ( + BaseModelOutputWithCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + MaskedLMOutput, + TokenClassifierOutput, +) from .modeling_utils import ( PreTrainedModel, apply_chunking_to_forward, @@ -374,7 +379,8 @@ def forward( return_dict=False, ): all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -408,15 +414,24 @@ def custom_forward(*inputs): ) hidden_states = layer_outputs[0] if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, ) @@ -611,7 +626,7 @@ class PreTrainedModel @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="layoutlm-base-uncased", - output_type=BaseModelOutputWithPooling, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def forward( @@ -716,11 +731,12 @@ def forward( if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] - return BaseModelOutputWithPooling( + return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, ) diff --git a/src/transformers/modeling_outputs.py b/src/transformers/modeling_outputs.py index 26a41fc9d1a520..a61190bd9e9cb1 100644 --- a/src/transformers/modeling_outputs.py +++ b/src/transformers/modeling_outputs.py @@ -99,6 +99,120 @@ class BaseModelOutputWithPast(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None +@dataclass +class BaseModelOutputWithCrossAttentions(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Cross attentions weights after the attention softmax, used to compute the weighted average in the + cross-attention heads. + """ + + last_hidden_state: torch.FloatTensor + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) further processed by a + Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence + prediction (classification) objective during pretraining. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Cross attentions weights after the attention softmax, used to compute the weighted average in the + cross-attention heads. + """ + + last_hidden_state: torch.FloatTensor + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, + 1, hidden_size)` is output. + past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, + batch_size, num_heads, sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + :obj:`past_key_values` input) to speed up sequential decoding. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Cross attentions weights after the attention softmax, used to compute the weighted average in the + cross-attention heads. + """ + + last_hidden_state: torch.FloatTensor + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + @dataclass class Seq2SeqModelOutput(ModelOutput): """ @@ -217,6 +331,85 @@ class CausalLMOutputWithPast(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None +@dataclass +class CausalLMOutputWithCrossAttentions(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Language modeling loss (for next-token prediction). + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Cross attentions weights after the attention softmax, used to compute the weighted average in the + cross-attention heads. + """ + + loss: Optional[torch.FloatTensor] + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class CausalLMOutputWithPastAndCrossAttentions(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Language modeling loss (for next-token prediction). + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, + batch_size, num_heads, sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + :obj:`past_key_values` input) to speed up sequential decoding. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Cross attentions weights after the attention softmax, used to compute the weighted average in the + cross-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + @dataclass class SequenceClassifierOutputWithPast(ModelOutput): """ @@ -303,12 +496,18 @@ class Seq2SeqLMOutput(ModelOutput): of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + decoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``decoder_output_attentions=True`` is passed or when ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. + cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``decoder_output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Cross-attentions weights of the decoder, after the attention softmax, used to compute the weighted average + in the cross-attention heads. encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder of the model. encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): @@ -329,6 +528,7 @@ class Seq2SeqLMOutput(ModelOutput): past_key_values: Optional[List[torch.FloatTensor]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None diff --git a/src/transformers/modeling_prophetnet.py b/src/transformers/modeling_prophetnet.py index 417111c409bdf0..9d7b23eb5e237f 100644 --- a/src/transformers/modeling_prophetnet.py +++ b/src/transformers/modeling_prophetnet.py @@ -1986,6 +1986,7 @@ def forward( hidden_states_ngram=outputs.hidden_states_ngram, attentions=outputs.attentions, ngram_attentions=outputs.ngram_attentions, + cross_attentions=outputs.cross_attentions, ) def _compute_loss(self, logits, labels): diff --git a/src/transformers/modeling_roberta.py b/src/transformers/modeling_roberta.py index 1f676b9fefcaaa..3bb3a79a23263c 100644 --- a/src/transformers/modeling_roberta.py +++ b/src/transformers/modeling_roberta.py @@ -31,9 +31,9 @@ replace_return_docstrings, ) from .modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPooling, - CausalLMOutput, + BaseModelOutputWithCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, MaskedLMOutput, MultipleChoiceModelOutput, QuestionAnsweringModelOutput, @@ -393,7 +393,8 @@ def forward( return_dict=False, ): all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -427,15 +428,24 @@ def custom_forward(*inputs): ) hidden_states = layer_outputs[0] if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, ) @@ -599,7 +609,7 @@ class PreTrainedModel @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="roberta-base", - output_type=BaseModelOutputWithPooling, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) # Copied from transformers.modeling_bert.BertModel.forward @@ -689,11 +699,12 @@ def forward( if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] - return BaseModelOutputWithPooling( + return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, ) @@ -719,7 +730,7 @@ def get_output_embeddings(self): return self.lm_head.decoder @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids=None, @@ -799,11 +810,12 @@ def forward( output = (prediction_scores,) + outputs[2:] return ((lm_loss,) + output) if lm_loss is not None else output - return CausalLMOutput( + return CausalLMOutputWithCrossAttentions( loss=lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + cross_attentions=outputs.attentions, ) def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): diff --git a/tests/test_modeling_encoder_decoder.py b/tests/test_modeling_encoder_decoder.py index 118e40761a718f..0dbab4634d040d 100644 --- a/tests/test_modeling_encoder_decoder.py +++ b/tests/test_modeling_encoder_decoder.py @@ -292,6 +292,62 @@ def check_encoder_decoder_model_labels( outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,)) ) + def check_encoder_decoder_model_output_attentions( + self, + config, + input_ids, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + labels, + **kwargs + ): + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + enc_dec_model.to(torch_device) + outputs_encoder_decoder = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + output_attentions=True, + return_dict=True, + ) + + encoder_attentions = outputs_encoder_decoder["encoder_attentions"] + self.assertEqual(len(encoder_attentions), config.num_hidden_layers) + + self.assertListEqual( + list(encoder_attentions[0].shape[-3:]), + [config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]], + ) + + decoder_attentions = outputs_encoder_decoder["decoder_attentions"] + num_decoder_layers = ( + decoder_config.num_decoder_layers + if hasattr(decoder_config, "num_decoder_layers") + else decoder_config.num_hidden_layers + ) + self.assertEqual(len(decoder_attentions), num_decoder_layers) + + self.assertListEqual( + list(decoder_attentions[0].shape[-3:]), + [decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]], + ) + + cross_attentions = outputs_encoder_decoder["cross_attentions"] + self.assertEqual(len(cross_attentions), num_decoder_layers) + + cross_attention_input_seq_len = input_ids.shape[-1] * ( + 1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0) + ) + self.assertListEqual( + list(cross_attentions[0].shape[-3:]), + [decoder_config.num_attention_heads, cross_attention_input_seq_len, decoder_input_ids.shape[-1]], + ) + def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs): encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) @@ -413,6 +469,10 @@ def test_encoder_decoder_model_labels(self): input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_labels(**input_ids_dict) + def test_encoder_decoder_model_output_attentions(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_encoder_decoder_model_output_attentions(**input_ids_dict) + def test_encoder_decoder_model_generate(self): input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_generate(**input_ids_dict)