Skip to content

Commit

Permalink
[All Seq2Seq model + CLM models that can be used with EncoderDecoder]…
Browse files Browse the repository at this point in the history
… Add cross-attention weights to outputs (huggingface#8071)

* Output cross-attention with decoder attention output

* Update src/transformers/modeling_bert.py

* add cross-attention for t5 and bart as well

* fix tests

* correct typo in docs

* add sylvains and sams comments

* correct typo

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
2 people authored and fabiocapsouza committed Nov 15, 2020
1 parent 239bfab commit b390ebd
Show file tree
Hide file tree
Showing 16 changed files with 653 additions and 85 deletions.
36 changes: 36 additions & 0 deletions docs/source/main_classes/output.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,34 @@ BaseModelOutputWithPooling
:members:


BaseModelOutputWithCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.modeling_outputs.BaseModelOutputWithCrossAttentions
:members:


BaseModelOutputWithPoolingAndCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions
:members:


BaseModelOutputWithPast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.modeling_outputs.BaseModelOutputWithPast
:members:


BaseModelOutputWithPastAndCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions
:members:


Seq2SeqModelOutput
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -85,6 +107,20 @@ CausalLMOutput
:members:


CausalLMOutputWithCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.modeling_outputs.CausalLMOutputWithCrossAttentions
:members:


CausalLMOutputWithPastAndCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.modeling_outputs.CausalLMOutputWithPastAndCrossAttentions
:members:


CausalLMOutputWithPast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
28 changes: 21 additions & 7 deletions src/transformers/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)
from .modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPast,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
Seq2SeqQuestionAnsweringModelOutput,
Expand Down Expand Up @@ -451,11 +451,12 @@ def forward(
assert self.encoder_attn.cache_key != self.self_attn.cache_key
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
x, _ = self.encoder_attn(
x, cross_attn_weights = self.encoder_attn(
query=x,
key=encoder_hidden_states,
key_padding_mask=encoder_attn_mask,
layer_state=layer_state, # mutates layer state
output_attentions=output_attentions,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
Expand All @@ -477,7 +478,8 @@ def forward(
x,
self_attn_weights,
layer_state,
) # just self_attn weights for now, following t5, layer_state = cache for decoding
cross_attn_weights,
) # layer_state = cache for decoding


class BartDecoder(nn.Module):
Expand Down Expand Up @@ -590,6 +592,7 @@ def forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if output_attentions else None
next_decoder_cache: List[Dict] = []
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
Expand All @@ -601,7 +604,7 @@ def forward(

layer_state = past_key_values[idx] if past_key_values is not None else None

x, layer_self_attn, layer_past = decoder_layer(
x, layer_self_attn, layer_past, layer_cross_attn = decoder_layer(
x,
encoder_hidden_states,
encoder_attn_mask=encoder_padding_mask,
Expand All @@ -616,6 +619,7 @@ def forward(

if output_attentions:
all_self_attns += (layer_self_attn,)
all_cross_attentions += (layer_cross_attn,)

if self.layer_norm: # if config.add_final_layer_norm (mBART)
x = self.layer_norm(x)
Expand All @@ -628,9 +632,15 @@ def forward(

next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=x, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns
return tuple(
v for v in [x, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=x,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)


Expand Down Expand Up @@ -934,6 +944,7 @@ def forward(
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
Expand Down Expand Up @@ -1078,6 +1089,7 @@ def forward(
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
Expand Down Expand Up @@ -1207,6 +1219,7 @@ def forward(
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
Expand Down Expand Up @@ -1317,6 +1330,7 @@ def forward(
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
Expand Down
36 changes: 24 additions & 12 deletions src/transformers/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
replace_return_docstrings,
)
from .modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
CausalLMOutput,
BaseModelOutputWithCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
MaskedLMOutput,
MultipleChoiceModelOutput,
NextSentencePredictorOutput,
Expand Down Expand Up @@ -449,7 +449,8 @@ def forward(
return_dict=False,
):
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down Expand Up @@ -483,15 +484,24 @@ def custom_forward(*inputs):
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
return tuple(
v
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)


Expand Down Expand Up @@ -752,7 +762,7 @@ class PreTrainedModel
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
output_type=BaseModelOutputWithPooling,
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
Expand Down Expand Up @@ -843,11 +853,12 @@ def forward(
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]

return BaseModelOutputWithPooling(
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)


Expand Down Expand Up @@ -984,7 +995,7 @@ def get_output_embeddings(self):
return self.cls.predictions.decoder

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
Expand Down Expand Up @@ -1063,11 +1074,12 @@ def forward(
output = (prediction_scores,) + outputs[2:]
return ((lm_loss,) + output) if lm_loss is not None else output

return CausalLMOutput(
return CausalLMOutputWithCrossAttentions(
loss=lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)

def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
Expand Down
12 changes: 7 additions & 5 deletions src/transformers/modeling_bert_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
replace_return_docstrings,
)
from .modeling_bert import BertEncoder
from .modeling_outputs import BaseModelOutput, CausalLMOutput
from .modeling_outputs import BaseModelOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions
from .modeling_utils import PreTrainedModel
from .utils import logging

Expand Down Expand Up @@ -297,7 +297,7 @@ class PreTrainedModel
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/bert_for_seq_generation_L-24_bbc_encoder",
output_type=BaseModelOutput,
output_type=BaseModelOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
Expand Down Expand Up @@ -381,10 +381,11 @@ def forward(
if not return_dict:
return (sequence_output,) + encoder_outputs[1:]

return BaseModelOutput(
return BaseModelOutputWithCrossAttentions(
last_hidden_state=sequence_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)


Expand Down Expand Up @@ -422,7 +423,7 @@ def get_output_embeddings(self):
return self.lm_head.decoder

@add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
Expand Down Expand Up @@ -499,11 +500,12 @@ def forward(
output = (prediction_scores,) + outputs[1:]
return ((lm_loss,) + output) if lm_loss is not None else output

return CausalLMOutput(
return CausalLMOutputWithCrossAttentions(
loss=lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)

def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
Expand Down
24 changes: 17 additions & 7 deletions src/transformers/modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
replace_return_docstrings,
)
from .modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithCrossAttentions,
MaskedLMOutput,
MultipleChoiceModelOutput,
QuestionAnsweringModelOutput,
Expand Down Expand Up @@ -445,7 +445,8 @@ def forward(
return_dict=False,
):
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down Expand Up @@ -479,15 +480,24 @@ def custom_forward(*inputs):
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
return tuple(
v
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)


Expand Down Expand Up @@ -697,7 +707,7 @@ class PreTrainedModel
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/electra-small-discriminator",
output_type=BaseModelOutput,
output_type=BaseModelOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
Expand Down
1 change: 1 addition & 0 deletions src/transformers/modeling_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ def forward(
past_key_values=None, # TODO(PVP) - need to implement cache for BERT, etc... before this works
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
Expand Down
Loading

0 comments on commit b390ebd

Please sign in to comment.