Skip to content

Commit

Permalink
Output cross-attention with decoder attention output
Browse files Browse the repository at this point in the history
  • Loading branch information
Yossi Synett committed Nov 2, 2020
1 parent 93354bc commit ef39d52
Show file tree
Hide file tree
Showing 10 changed files with 384 additions and 58 deletions.
36 changes: 24 additions & 12 deletions src/transformers/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
replace_return_docstrings,
)
from .modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
CausalLMOutput,
BaseModelOutputWithCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
MaskedLMOutput,
MultipleChoiceModelOutput,
NextSentencePredictorOutput,
Expand Down Expand Up @@ -449,7 +449,8 @@ def forward(
return_dict=False,
):
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down Expand Up @@ -483,15 +484,24 @@ def custom_forward(*inputs):
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
return tuple(
v
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)


Expand Down Expand Up @@ -752,7 +762,7 @@ class PreTrainedModel
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
output_type=BaseModelOutputWithPooling,
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
Expand Down Expand Up @@ -843,11 +853,12 @@ def forward(
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]

return BaseModelOutputWithPooling(
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)


Expand Down Expand Up @@ -984,7 +995,7 @@ def get_output_embeddings(self):
return self.cls.predictions.decoder

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
Expand Down Expand Up @@ -1063,11 +1074,12 @@ def forward(
output = (prediction_scores,) + outputs[2:]
return ((lm_loss,) + output) if lm_loss is not None else output

return CausalLMOutput(
return CausalLMOutputWithCrossAttentions(
loss=lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.attentions,
)

def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
Expand Down
12 changes: 7 additions & 5 deletions src/transformers/modeling_bert_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
replace_return_docstrings,
)
from .modeling_bert import BertEncoder
from .modeling_outputs import BaseModelOutput, CausalLMOutput
from .modeling_outputs import BaseModelOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions
from .modeling_utils import PreTrainedModel
from .utils import logging

Expand Down Expand Up @@ -297,7 +297,7 @@ class PreTrainedModel
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/bert_for_seq_generation_L-24_bbc_encoder",
output_type=BaseModelOutput,
output_type=BaseModelOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
Expand Down Expand Up @@ -381,10 +381,11 @@ def forward(
if not return_dict:
return (sequence_output,) + encoder_outputs[1:]

return BaseModelOutput(
return BaseModelOutputWithCrossAttentions(
last_hidden_state=sequence_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)


Expand Down Expand Up @@ -422,7 +423,7 @@ def get_output_embeddings(self):
return self.lm_head.decoder

@add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
Expand Down Expand Up @@ -499,11 +500,12 @@ def forward(
output = (prediction_scores,) + outputs[1:]
return ((lm_loss,) + output) if lm_loss is not None else output

return CausalLMOutput(
return CausalLMOutputWithCrossAttentions(
loss=lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)

def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
Expand Down
24 changes: 17 additions & 7 deletions src/transformers/modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
replace_return_docstrings,
)
from .modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithCrossAttentions,
MaskedLMOutput,
MultipleChoiceModelOutput,
QuestionAnsweringModelOutput,
Expand Down Expand Up @@ -445,7 +445,8 @@ def forward(
return_dict=False,
):
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down Expand Up @@ -479,15 +480,24 @@ def custom_forward(*inputs):
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
return tuple(
v
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)


Expand Down Expand Up @@ -697,7 +707,7 @@ class PreTrainedModel
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/electra-small-discriminator",
output_type=BaseModelOutput,
output_type=BaseModelOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/modeling_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,11 @@
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
Whether or not to return the attentions tensors of all encoder attention layers. See ``attentions`` under
returned tensors for more detail.
decoder_output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all decoder attention layers including cross-attention.
See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
Expand Down Expand Up @@ -426,6 +429,7 @@ def forward(
past_key_values=None, # TODO(PVP) - need to implement cache for BERT, etc... before this works
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
Expand Down
31 changes: 20 additions & 11 deletions src/transformers/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from .modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from .modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithPastAndCrossAttentions,
SequenceClassifierOutputWithPast,
)
from .modeling_utils import (
Conv1D,
PreTrainedModel,
Expand Down Expand Up @@ -311,14 +315,14 @@ def forward(
attn_output = cross_attn_outputs[0]
# residual connection
hidden_states = hidden_states + attn_output
outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights

feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))
# residual connection
hidden_states = hidden_states + feed_forward_hidden_states

outputs = [hidden_states] + outputs
return outputs # hidden_states, present, (cross_attentions, attentions)
return outputs # hidden_states, present, (attentions, cross_attentions)


class GPT2PreTrainedModel(PreTrainedModel):
Expand Down Expand Up @@ -506,7 +510,7 @@ def _prune_heads(self, heads_to_prune):
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="gpt2",
output_type=BaseModelOutputWithPast,
output_type=BaseModelOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
Expand Down Expand Up @@ -618,7 +622,8 @@ def forward(
output_shape = input_shape + (hidden_states.size(-1),)

presents = () if use_cache else None
all_attentions = () if output_attentions else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
Expand Down Expand Up @@ -659,7 +664,9 @@ def custom_forward(*inputs):
presents = presents + (present,)

if output_attentions:
all_attentions = all_attentions + (outputs[2],)
all_self_attentions = all_self_attentions + (outputs[2],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3],)

hidden_states = self.ln_f(hidden_states)

Expand All @@ -669,13 +676,14 @@ def custom_forward(*inputs):
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)

return BaseModelOutputWithPast(
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_attentions,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)


Expand Down Expand Up @@ -727,7 +735,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="gpt2",
output_type=CausalLMOutputWithPast,
output_type=CausalLMOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
Expand Down Expand Up @@ -795,12 +803,13 @@ def forward(
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return CausalLMOutputWithPast(
return CausalLMOutputWithPastAndCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)


Expand Down
Loading

0 comments on commit ef39d52

Please sign in to comment.