Skip to content

Commit

Permalink
Revert "[All Seq2Seq model + CLM models that can be used with Encoder…
Browse files Browse the repository at this point in the history
…Decoder] Add cross-attention weights to outputs (huggingface#8071)"

This reverts commit b390ebd.
  • Loading branch information
fabiocapsouza authored Nov 15, 2020
1 parent 58bf103 commit f310c07
Show file tree
Hide file tree
Showing 16 changed files with 85 additions and 653 deletions.
36 changes: 0 additions & 36 deletions docs/source/main_classes/output.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,34 +65,12 @@ BaseModelOutputWithPooling
:members:


BaseModelOutputWithCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.modeling_outputs.BaseModelOutputWithCrossAttentions
:members:


BaseModelOutputWithPoolingAndCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions
:members:


BaseModelOutputWithPast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.modeling_outputs.BaseModelOutputWithPast
:members:


BaseModelOutputWithPastAndCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions
:members:


Seq2SeqModelOutput
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -107,20 +85,6 @@ CausalLMOutput
:members:


CausalLMOutputWithCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.modeling_outputs.CausalLMOutputWithCrossAttentions
:members:


CausalLMOutputWithPastAndCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.modeling_outputs.CausalLMOutputWithPastAndCrossAttentions
:members:


CausalLMOutputWithPast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
28 changes: 7 additions & 21 deletions src/transformers/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)
from .modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPast,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
Seq2SeqQuestionAnsweringModelOutput,
Expand Down Expand Up @@ -451,12 +451,11 @@ def forward(
assert self.encoder_attn.cache_key != self.self_attn.cache_key
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
x, cross_attn_weights = self.encoder_attn(
x, _ = self.encoder_attn(
query=x,
key=encoder_hidden_states,
key_padding_mask=encoder_attn_mask,
layer_state=layer_state, # mutates layer state
output_attentions=output_attentions,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
Expand All @@ -478,8 +477,7 @@ def forward(
x,
self_attn_weights,
layer_state,
cross_attn_weights,
) # layer_state = cache for decoding
) # just self_attn weights for now, following t5, layer_state = cache for decoding


class BartDecoder(nn.Module):
Expand Down Expand Up @@ -592,7 +590,6 @@ def forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if output_attentions else None
next_decoder_cache: List[Dict] = []
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
Expand All @@ -604,7 +601,7 @@ def forward(

layer_state = past_key_values[idx] if past_key_values is not None else None

x, layer_self_attn, layer_past, layer_cross_attn = decoder_layer(
x, layer_self_attn, layer_past = decoder_layer(
x,
encoder_hidden_states,
encoder_attn_mask=encoder_padding_mask,
Expand All @@ -619,7 +616,6 @@ def forward(

if output_attentions:
all_self_attns += (layer_self_attn,)
all_cross_attentions += (layer_cross_attn,)

if self.layer_norm: # if config.add_final_layer_norm (mBART)
x = self.layer_norm(x)
Expand All @@ -632,15 +628,9 @@ def forward(

next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v for v in [x, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=x,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=x, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns
)


Expand Down Expand Up @@ -944,7 +934,6 @@ def forward(
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
Expand Down Expand Up @@ -1089,7 +1078,6 @@ def forward(
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
Expand Down Expand Up @@ -1219,7 +1207,6 @@ def forward(
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
Expand Down Expand Up @@ -1330,7 +1317,6 @@ def forward(
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
Expand Down
36 changes: 12 additions & 24 deletions src/transformers/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
replace_return_docstrings,
)
from .modeling_outputs import (
BaseModelOutputWithCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
BaseModelOutput,
BaseModelOutputWithPooling,
CausalLMOutput,
MaskedLMOutput,
MultipleChoiceModelOutput,
NextSentencePredictorOutput,
Expand Down Expand Up @@ -449,8 +449,7 @@ def forward(
return_dict=False,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down Expand Up @@ -484,24 +483,15 @@ def custom_forward(*inputs):
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
all_attentions = all_attentions + (layer_outputs[1],)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(
v
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)


Expand Down Expand Up @@ -762,7 +752,7 @@ class PreTrainedModel
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
output_type=BaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
)
def forward(
Expand Down Expand Up @@ -853,12 +843,11 @@ def forward(
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]

return BaseModelOutputWithPoolingAndCrossAttentions(
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)


Expand Down Expand Up @@ -995,7 +984,7 @@ def get_output_embeddings(self):
return self.cls.predictions.decoder

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
Expand Down Expand Up @@ -1074,12 +1063,11 @@ def forward(
output = (prediction_scores,) + outputs[2:]
return ((lm_loss,) + output) if lm_loss is not None else output

return CausalLMOutputWithCrossAttentions(
return CausalLMOutput(
loss=lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)

def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
Expand Down
12 changes: 5 additions & 7 deletions src/transformers/modeling_bert_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
replace_return_docstrings,
)
from .modeling_bert import BertEncoder
from .modeling_outputs import BaseModelOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions
from .modeling_outputs import BaseModelOutput, CausalLMOutput
from .modeling_utils import PreTrainedModel
from .utils import logging

Expand Down Expand Up @@ -297,7 +297,7 @@ class PreTrainedModel
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/bert_for_seq_generation_L-24_bbc_encoder",
output_type=BaseModelOutputWithCrossAttentions,
output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
Expand Down Expand Up @@ -381,11 +381,10 @@ def forward(
if not return_dict:
return (sequence_output,) + encoder_outputs[1:]

return BaseModelOutputWithCrossAttentions(
return BaseModelOutput(
last_hidden_state=sequence_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)


Expand Down Expand Up @@ -423,7 +422,7 @@ def get_output_embeddings(self):
return self.lm_head.decoder

@add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
Expand Down Expand Up @@ -500,12 +499,11 @@ def forward(
output = (prediction_scores,) + outputs[1:]
return ((lm_loss,) + output) if lm_loss is not None else output

return CausalLMOutputWithCrossAttentions(
return CausalLMOutput(
loss=lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)

def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
Expand Down
24 changes: 7 additions & 17 deletions src/transformers/modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
replace_return_docstrings,
)
from .modeling_outputs import (
BaseModelOutputWithCrossAttentions,
BaseModelOutput,
MaskedLMOutput,
MultipleChoiceModelOutput,
QuestionAnsweringModelOutput,
Expand Down Expand Up @@ -445,8 +445,7 @@ def forward(
return_dict=False,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down Expand Up @@ -480,24 +479,15 @@ def custom_forward(*inputs):
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
all_attentions = all_attentions + (layer_outputs[1],)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(
v
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)


Expand Down Expand Up @@ -707,7 +697,7 @@ class PreTrainedModel
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/electra-small-discriminator",
output_type=BaseModelOutputWithCrossAttentions,
output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
Expand Down
1 change: 0 additions & 1 deletion src/transformers/modeling_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,6 @@ def forward(
past_key_values=None, # TODO(PVP) - need to implement cache for BERT, etc... before this works
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
Expand Down
Loading

0 comments on commit f310c07

Please sign in to comment.