From e325a49ae6d498d63acb5933d2deb00ea914db13 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Fri, 12 Mar 2021 22:53:29 +0100 Subject: [PATCH 1/5] Add cross_attn_head_mask to BART --- .../models/bart/modeling_tf_bart.py | 56 +++++++++++-------- tests/test_modeling_tf_bart.py | 6 +- tests/test_modeling_tf_common.py | 9 ++- 3 files changed, 46 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index ce67fc6541ff..5b3b1ecfb939 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -365,7 +365,7 @@ def call( encoder_hidden_states: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None, - encoder_layer_head_mask: Optional[tf.Tensor] = None, + cross_attn_layer_head_mask: Optional[tf.Tensor] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None, training=False, ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: @@ -379,8 +379,8 @@ def call( `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size `(decoder_attention_heads,)` - encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size - `(encoder_attention_heads,)` + cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module. + `(decoder_attention_heads,)` past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states """ residual = hidden_states @@ -410,7 +410,7 @@ def call( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - layer_head_mask=encoder_layer_head_mask, + layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, ) hidden_states = self.dropout(hidden_states, training=training) @@ -572,7 +572,7 @@ def serving(self, inputs): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: @@ -580,6 +580,12 @@ def serving(self, inputs): - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. + cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tf.FloatTensor`, `optional`): hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of @@ -814,7 +820,7 @@ def call( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, use_cache=None, output_attentions=None, @@ -858,12 +864,11 @@ def call( - 1 indicates the head is **not masked**, - 0 indicates the heas is **masked**. - encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up @@ -894,7 +899,7 @@ def call( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, - encoder_head_mask=encoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, @@ -954,12 +959,13 @@ def call( # check if head_mask has a correct number of layers specified if desired # The tf.debugging asserts are not compliant with XLA then they # have to be disabled in other modes than eager. - if inputs["head_mask"] is not None and tf.executing_eagerly(): - tf.debugging.assert_equal( - shape_list(inputs["head_mask"])[0], - len(self.layers), - message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", - ) + for attn_mask in ["head_mask", "cross_attn_head_mask"]: + if inputs[attn_mask] is not None and tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(inputs[attn_mask])[0], + len(self.layers), + message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.", + ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -979,8 +985,8 @@ def call( encoder_hidden_states=inputs["encoder_hidden_states"], encoder_attention_mask=inputs["encoder_attention_mask"], layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, - encoder_layer_head_mask=inputs["encoder_head_mask"][idx] - if inputs["encoder_head_mask"] is not None + cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx] + if inputs["cross_attn_head_mask"] is not None else None, past_key_value=past_key_value, ) @@ -1054,6 +1060,7 @@ def call( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -1074,6 +1081,7 @@ def call( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1128,7 +1136,7 @@ def call( encoder_hidden_states=inputs["encoder_outputs"][0], encoder_attention_mask=inputs["attention_mask"], head_mask=inputs["decoder_head_mask"], - encoder_head_mask=inputs["head_mask"], + cross_attn_head_mask=inputs["cross_attn_head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["decoder_inputs_embeds"], use_cache=inputs["use_cache"], @@ -1183,6 +1191,7 @@ def call( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -1202,7 +1211,7 @@ def call( decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, - decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1221,7 +1230,7 @@ def call( decoder_input_ids=inputs["decoder_input_ids"], decoder_attention_mask=inputs["decoder_attention_mask"], head_mask=inputs["head_mask"], - decoder_head_mask=inputs["decoder_head_mask"], + cross_attn_head_mask=inputs["cross_attn_head_mask"], encoder_outputs=inputs["encoder_outputs"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], @@ -1301,6 +1310,7 @@ def call( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[TFBaseModelOutput] = None, past_key_values=None, inputs_embeds=None, @@ -1331,6 +1341,7 @@ def call( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1364,6 +1375,7 @@ def call( decoder_attention_mask=inputs["decoder_attention_mask"], head_mask=inputs["head_mask"], decoder_head_mask=inputs["decoder_head_mask"], + cross_attn_head_mask=inputs["cross_attn_head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"], diff --git a/tests/test_modeling_tf_bart.py b/tests/test_modeling_tf_bart.py index 3aef4c03f947..642fb4ab9d26 100644 --- a/tests/test_modeling_tf_bart.py +++ b/tests/test_modeling_tf_bart.py @@ -147,6 +147,7 @@ def prepare_bart_inputs_dict( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, ): if attention_mask is None: attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) @@ -162,13 +163,16 @@ def prepare_bart_inputs_dict( head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) if decoder_head_mask is None: decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) + if cross_attn_head_mask is None: + cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "decoder_attention_mask": decoder_attention_mask, "head_mask": head_mask, - "decoder_head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, } diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 0405192a6aaa..2c2eb1e546ac 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -172,8 +172,8 @@ def test_forward_signature(self): "decoder_attention_mask", ] expected_arg_names.extend( - ["head_mask", "decoder_head_mask", "encoder_outputs"] - if "head_mask" and "decoder_head_mask" in arg_names + ["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"] + if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names else ["encoder_outputs"] ) self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) @@ -491,6 +491,8 @@ def test_train_pipeline_custom_model(self): del inputs_dict["head_mask"] if "decoder_head_mask" in inputs_dict: del inputs_dict["decoder_head_mask"] + if "cross_attn_head_mask" in inputs_dict: + del inputs_dict["cross_attn_head_mask"] tf_main_layer_classes = set( module_member for model_class in self.all_model_classes @@ -712,6 +714,8 @@ def prepare_layer_head_mask(i, attention_heads, num_hidden_layers): arg_names = [*signature.parameters.keys()] if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model inputs["decoder_head_mask"] = head_mask + if "cross_attn_head_mask" in arg_names: + inputs["cross_attn_head_mask"] = head_mask outputs = model(**inputs, return_dict=True) @@ -736,6 +740,7 @@ def check_attentions_validity(attentions): if model.config.is_encoder_decoder: check_attentions_validity(outputs.encoder_attentions) check_attentions_validity(outputs.decoder_attentions) + check_attentions_validity(outputs.cross_attentions) else: check_attentions_validity(outputs.attentions) From 896bc461ef547f3667e1a0f1c533ddc6a1f7f2be Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Sat, 13 Mar 2021 12:02:51 +0100 Subject: [PATCH 2/5] Fix cross_attentions in TFBart-like models * This commit enables returning of `cross_attentions` for TFBart-like models * It also fixes attention head masking in cross-attenion module --- src/transformers/modeling_tf_outputs.py | 90 ++++++++++++++++ .../models/bart/modeling_tf_bart.py | 34 ++++-- .../blenderbot/modeling_tf_blenderbot.py | 100 ++++++++++++------ .../modeling_tf_blenderbot_small.py | 100 ++++++++++++------ .../models/marian/modeling_tf_marian.py | 100 ++++++++++++------ .../models/mbart/modeling_tf_mbart.py | 98 +++++++++++------ .../models/pegasus/modeling_tf_pegasus.py | 98 +++++++++++------ tests/test_modeling_tf_blenderbot.py | 4 + tests/test_modeling_tf_blenderbot_small.py | 4 + tests/test_modeling_tf_common.py | 13 ++- tests/test_modeling_tf_marian.py | 4 + tests/test_modeling_tf_mbart.py | 6 +- tests/test_modeling_tf_pegasus.py | 4 + 13 files changed, 469 insertions(+), 186 deletions(-) diff --git a/src/transformers/modeling_tf_outputs.py b/src/transformers/modeling_tf_outputs.py index 4c98106e3045..fefc65ec9b07 100644 --- a/src/transformers/modeling_tf_outputs.py +++ b/src/transformers/modeling_tf_outputs.py @@ -116,6 +116,82 @@ class TFBaseModelOutputWithPast(ModelOutput): attentions: Optional[Tuple[tf.Tensor]] = None +@dataclass +class TFBaseModelOutputWithCrossAttentions(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (:obj:`tuple(tf.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: tf.Tensor = None + hidden_states: Optional[Tuple[tf.Tensor]] = None + attentions: Optional[Tuple[tf.Tensor]] = None + cross_attentions: Optional[Tuple[tf.Tensor]] = None + + +@dataclass +class TFBaseModelOutputWithPastAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, + 1, hidden_size)` is output. + past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, + num_heads, sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + :obj:`past_key_values` input) to speed up sequential decoding. + hidden_states (:obj:`tuple(tf.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: tf.Tensor = None + past_key_values: Optional[List[tf.Tensor]] = None + hidden_states: Optional[Tuple[tf.Tensor]] = None + attentions: Optional[Tuple[tf.Tensor]] = None + cross_attentions: Optional[Tuple[tf.Tensor]] = None + + @dataclass class TFSeq2SeqModelOutput(ModelOutput): """ @@ -145,6 +221,12 @@ class TFSeq2SeqModelOutput(ModelOutput): Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. + cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. encoder_last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder of the model. encoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): @@ -164,6 +246,7 @@ class TFSeq2SeqModelOutput(ModelOutput): past_key_values: Optional[List[tf.Tensor]] = None decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None decoder_attentions: Optional[Tuple[tf.Tensor]] = None + cross_attentions: Optional[Tuple[tf.Tensor]] = None encoder_last_hidden_state: Optional[tf.Tensor] = None encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None encoder_attentions: Optional[Tuple[tf.Tensor]] = None @@ -290,6 +373,12 @@ class TFSeq2SeqLMOutput(ModelOutput): Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. + cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. encoder_last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder of the model. encoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): @@ -310,6 +399,7 @@ class TFSeq2SeqLMOutput(ModelOutput): past_key_values: Optional[List[tf.Tensor]] = None decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None decoder_attentions: Optional[Tuple[tf.Tensor]] = None + cross_attentions: Optional[Tuple[tf.Tensor]] = None encoder_last_hidden_state: Optional[tf.Tensor] = None encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None encoder_attentions: Optional[Tuple[tf.Tensor]] = None diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index 5b3b1ecfb939..846557aba8dd 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -30,7 +30,7 @@ ) from ...modeling_tf_outputs import ( TFBaseModelOutput, - TFBaseModelOutputWithPast, + TFBaseModelOutputWithPastAndCrossAttentions, TFSeq2SeqLMOutput, TFSeq2SeqModelOutput, ) @@ -401,12 +401,13 @@ def call( # Cross-Attention Block cross_attn_present_key_value = None + cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, _, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -432,6 +433,7 @@ def call( return ( hidden_states, self_attn_weights, + cross_attn_weights, present_key_value, ) @@ -683,7 +685,7 @@ def call( Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded @@ -862,7 +864,7 @@ def call( Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: @@ -954,9 +956,10 @@ def call( # decoder layers all_hidden_states = () if inputs["output_hidden_states"] else None all_self_attns = () if inputs["output_attentions"] else None + all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None present_key_values = () if inputs["use_cache"] else None - # check if head_mask has a correct number of layers specified if desired + # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # The tf.debugging asserts are not compliant with XLA then they # have to be disabled in other modes than eager. for attn_mask in ["head_mask", "cross_attn_head_mask"]: @@ -979,7 +982,7 @@ def call( past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None - hidden_states, layer_self_attn, present_key_value = decoder_layer( + hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( hidden_states, attention_mask=combined_attention_mask, encoder_hidden_states=inputs["encoder_hidden_states"], @@ -997,23 +1000,30 @@ def call( if inputs["output_attentions"]: all_self_attns += (layer_self_attn,) + if inputs["encoder_hidden_states"] is not None: + all_cross_attns += (layer_cross_attn,) + if inputs["output_hidden_states"]: all_hidden_states += (hidden_states,) if inputs["output_attentions"]: all_self_attns = list(all_self_attns) + if inputs["encoder_hidden_states"] is not None: + all_cross_attns = list(all_cross_attns) + if inputs["use_cache"]: present_key_values = (inputs["encoder_hidden_states"], present_key_values) if not inputs["return_dict"]: - return hidden_states, present_key_values, all_hidden_states, all_self_attns + return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns else: - return TFBaseModelOutputWithPast( + return TFBaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=present_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, + cross_attentions=all_cross_attns, ) @@ -1154,6 +1164,7 @@ def call( past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_attentions=inputs["encoder_outputs"].attentions, @@ -1211,6 +1222,7 @@ def call( decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, + decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, @@ -1230,6 +1242,7 @@ def call( decoder_input_ids=inputs["decoder_input_ids"], decoder_attention_mask=inputs["decoder_attention_mask"], head_mask=inputs["head_mask"], + decoder_head_mask=inputs["decoder_head_mask"], cross_attn_head_mask=inputs["cross_attn_head_mask"], encoder_outputs=inputs["encoder_outputs"], past_key_values=inputs["past_key_values"], @@ -1248,6 +1261,7 @@ def serving_output(self, output): pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.congig.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None @@ -1256,6 +1270,7 @@ def serving_output(self, output): past_key_values=pkv, decoder_hidden_states=dec_hs, decoder_attentions=dec_attns, + cross_attentions=cross_attns, encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_hidden_states=enc_hs, encoder_attentions=enc_attns, @@ -1398,6 +1413,7 @@ def call( past_key_values=outputs.past_key_values, # index 1 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs + cross_attentions=outputs.cross_attentions, # index 4 of d outputs encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out @@ -1407,6 +1423,7 @@ def serving_output(self, output): pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None @@ -1415,6 +1432,7 @@ def serving_output(self, output): past_key_values=pkv, decoder_hidden_states=dec_hs, decoder_attentions=dec_attns, + cross_attentions=cross_attns, encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_hidden_states=enc_hs, encoder_attentions=enc_attns, diff --git a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py index 42bcad541121..e9e07df2981d 100644 --- a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py @@ -32,7 +32,7 @@ ) from ...modeling_tf_outputs import ( TFBaseModelOutput, - TFBaseModelOutputWithPast, + TFBaseModelOutputWithPastAndCrossAttentions, TFSeq2SeqLMOutput, TFSeq2SeqModelOutput, ) @@ -370,7 +370,7 @@ def call( encoder_hidden_states: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None, - encoder_layer_head_mask: Optional[tf.Tensor] = None, + cross_attn_layer_head_mask: Optional[tf.Tensor] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None, training=False, ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: @@ -384,8 +384,8 @@ def call( `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size `(decoder_attention_heads,)` - encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size - `(encoder_attention_heads,)` + cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module. + `(decoder_attention_heads,)` past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states """ residual = hidden_states @@ -406,17 +406,18 @@ def call( # Cross-Attention Block cross_attn_present_key_value = None + cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, _, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - layer_head_mask=encoder_layer_head_mask, + layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, ) hidden_states = self.dropout(hidden_states, training=training) @@ -437,6 +438,7 @@ def call( return ( hidden_states, self_attn_weights, + cross_attn_weights, present_key_value, ) @@ -569,7 +571,7 @@ def serving(self, inputs): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: @@ -577,6 +579,12 @@ def serving(self, inputs): - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. + cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tf.FloatTensor`, `optional`): hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of @@ -674,7 +682,7 @@ def call( Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded @@ -818,7 +826,7 @@ def call( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, use_cache=None, output_attentions=None, @@ -860,14 +868,13 @@ def call( Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up @@ -904,7 +911,7 @@ def call( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, - encoder_head_mask=encoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, @@ -957,19 +964,21 @@ def call( hidden_states = self.dropout(hidden_states, training=inputs["training"]) # decoder layers - all_hidden_states = () - all_self_attns = () - present_key_values = () + all_hidden_states = () if inputs["output_hidden_states"] else None + all_self_attns = () if inputs["output_attentions"] else None + all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None + present_key_values = () if inputs["use_cache"] else None - # check if head_mask has a correct number of layers specified if desired + # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # The tf.debugging asserts are not compliant with XLA then they # have to be disabled in other modes than eager. - if inputs["head_mask"] is not None and tf.executing_eagerly(): - tf.debugging.assert_equal( - shape_list(inputs["head_mask"])[0], - len(self.layers), - message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", - ) + for attn_mask in ["head_mask", "cross_attn_head_mask"]: + if inputs[attn_mask] is not None and tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(inputs[attn_mask])[0], + len(self.layers), + message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.", + ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -982,14 +991,14 @@ def call( past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None - hidden_states, layer_self_attn, present_key_value = decoder_layer( + hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( hidden_states, attention_mask=combined_attention_mask, encoder_hidden_states=inputs["encoder_hidden_states"], encoder_attention_mask=inputs["encoder_attention_mask"], layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, - encoder_layer_head_mask=inputs["encoder_head_mask"][idx] - if inputs["encoder_head_mask"] is not None + cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx] + if inputs["cross_attn_head_mask"] is not None else None, past_key_value=past_key_value, ) @@ -1000,25 +1009,32 @@ def call( if inputs["output_attentions"]: all_self_attns += (layer_self_attn,) + if inputs["encoder_hidden_states"] is not None: + all_cross_attns += (layer_cross_attn,) + hidden_states = self.layer_norm(hidden_states) if inputs["output_hidden_states"]: all_hidden_states += (hidden_states,) - else: - all_hidden_states = None - all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None + if inputs["output_attentions"]: + all_self_attns = list(all_self_attns) + + if inputs["encoder_hidden_states"] is not None: + all_cross_attns = list(all_cross_attns) - present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None + if inputs["use_cache"]: + present_key_values = (inputs["encoder_hidden_states"], present_key_values) if not inputs["return_dict"]: - return hidden_states, present_key_values, all_hidden_states, all_self_attns + return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns else: - return TFBaseModelOutputWithPast( + return TFBaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=present_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, + cross_attentions=all_cross_attns, ) @@ -1065,6 +1081,7 @@ def call( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -1085,6 +1102,7 @@ def call( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1131,7 +1149,7 @@ def call( encoder_hidden_states=inputs["encoder_outputs"][0], encoder_attention_mask=inputs["attention_mask"], head_mask=inputs["decoder_head_mask"], - encoder_head_mask=inputs["head_mask"], + cross_attn_head_mask=inputs["cross_attn_head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["decoder_inputs_embeds"], use_cache=inputs["use_cache"], @@ -1149,6 +1167,7 @@ def call( past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_attentions=inputs["encoder_outputs"].attentions, @@ -1199,6 +1218,7 @@ def call( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -1219,6 +1239,7 @@ def call( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1238,6 +1259,7 @@ def call( decoder_attention_mask=inputs["decoder_attention_mask"], head_mask=inputs["head_mask"], decoder_head_mask=inputs["decoder_head_mask"], + cross_attn_head_mask=inputs["cross_attn_head_mask"], encoder_outputs=inputs["encoder_outputs"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], @@ -1256,6 +1278,7 @@ def serving_output(self, output): pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.congig.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None @@ -1264,6 +1287,7 @@ def serving_output(self, output): past_key_values=pkv, decoder_hidden_states=dec_hs, decoder_attentions=dec_attns, + cross_attentions=cross_attns, encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_hidden_states=enc_hs, encoder_attentions=enc_attns, @@ -1331,6 +1355,7 @@ def call( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[TFBaseModelOutput] = None, past_key_values=None, inputs_embeds=None, @@ -1361,6 +1386,7 @@ def call( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1394,6 +1420,7 @@ def call( decoder_attention_mask=inputs["decoder_attention_mask"], head_mask=inputs["head_mask"], decoder_head_mask=inputs["decoder_head_mask"], + cross_attn_head_mask=inputs["cross_attn_head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"], @@ -1416,6 +1443,7 @@ def call( past_key_values=outputs.past_key_values, # index 1 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs + cross_attentions=outputs.cross_attentions, # index 4 of d outputs encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out @@ -1426,6 +1454,7 @@ def serving_output(self, output): pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None @@ -1434,6 +1463,7 @@ def serving_output(self, output): past_key_values=pkv, decoder_hidden_states=dec_hs, decoder_attentions=dec_attns, + cross_attentions=cross_attns, encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_hidden_states=enc_hs, encoder_attentions=enc_attns, diff --git a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py index 85ae9e9a4a12..6ca8ec6243d3 100644 --- a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py @@ -30,7 +30,7 @@ ) from ...modeling_tf_outputs import ( TFBaseModelOutput, - TFBaseModelOutputWithPast, + TFBaseModelOutputWithPastAndCrossAttentions, TFSeq2SeqLMOutput, TFSeq2SeqModelOutput, ) @@ -369,7 +369,7 @@ def call( encoder_hidden_states: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None, - encoder_layer_head_mask: Optional[tf.Tensor] = None, + cross_attn_layer_head_mask: Optional[tf.Tensor] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None, training=False, ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: @@ -383,8 +383,8 @@ def call( `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size `(decoder_attention_heads,)` - encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size - `(encoder_attention_heads,)` + cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module. + `(decoder_attention_heads,)` past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states """ residual = hidden_states @@ -405,16 +405,17 @@ def call( # Cross-Attention Block cross_attn_present_key_value = None + cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, _, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - layer_head_mask=encoder_layer_head_mask, + layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, ) hidden_states = self.dropout(hidden_states, training=training) @@ -436,6 +437,7 @@ def call( return ( hidden_states, self_attn_weights, + cross_attn_weights, present_key_value, ) @@ -574,7 +576,7 @@ def serving(self, inputs): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: @@ -582,6 +584,12 @@ def serving(self, inputs): - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. + cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tf.FloatTensor`, `optional`): hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of @@ -679,7 +687,7 @@ def call( Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded @@ -823,7 +831,7 @@ def call( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, use_cache=None, output_attentions=None, @@ -865,14 +873,13 @@ def call( Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up @@ -909,7 +916,7 @@ def call( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, - encoder_head_mask=encoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, @@ -960,19 +967,21 @@ def call( hidden_states = self.dropout(hidden_states, training=inputs["training"]) # decoder layers - all_hidden_states = () - all_self_attns = () - present_key_values = () + all_hidden_states = () if inputs["output_hidden_states"] else None + all_self_attns = () if inputs["output_attentions"] else None + all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None + present_key_values = () if inputs["use_cache"] else None - # check if head_mask has a correct number of layers specified if desired + # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # The tf.debugging asserts are not compliant with XLA then they # have to be disabled in other modes than eager. - if inputs["head_mask"] is not None and tf.executing_eagerly(): - tf.debugging.assert_equal( - shape_list(inputs["head_mask"])[0], - len(self.layers), - message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", - ) + for attn_mask in ["head_mask", "cross_attn_head_mask"]: + if inputs[attn_mask] is not None and tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(inputs[attn_mask])[0], + len(self.layers), + message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.", + ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -985,14 +994,14 @@ def call( past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None - hidden_states, layer_self_attn, present_key_value = decoder_layer( + hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( hidden_states, attention_mask=combined_attention_mask, encoder_hidden_states=inputs["encoder_hidden_states"], encoder_attention_mask=inputs["encoder_attention_mask"], layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, - encoder_layer_head_mask=inputs["encoder_head_mask"][idx] - if inputs["encoder_head_mask"] is not None + cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx] + if inputs["cross_attn_head_mask"] is not None else None, past_key_value=past_key_value, ) @@ -1003,23 +1012,30 @@ def call( if inputs["output_attentions"]: all_self_attns += (layer_self_attn,) + if inputs["encoder_hidden_states"] is not None: + all_cross_attns += (layer_cross_attn,) + if inputs["output_hidden_states"]: all_hidden_states += (hidden_states,) - else: - all_hidden_states = None - all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None + if inputs["output_attentions"]: + all_self_attns = list(all_self_attns) + + if inputs["encoder_hidden_states"] is not None: + all_cross_attns = list(all_cross_attns) - present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None + if inputs["use_cache"]: + present_key_values = (inputs["encoder_hidden_states"], present_key_values) if not inputs["return_dict"]: - return hidden_states, present_key_values, all_hidden_states, all_self_attns + return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns else: - return TFBaseModelOutputWithPast( + return TFBaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=present_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, + cross_attentions=all_cross_attns, ) @@ -1066,6 +1082,7 @@ def call( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -1086,6 +1103,7 @@ def call( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1132,7 +1150,7 @@ def call( encoder_hidden_states=inputs["encoder_outputs"][0], encoder_attention_mask=inputs["attention_mask"], head_mask=inputs["decoder_head_mask"], - encoder_head_mask=inputs["head_mask"], + cross_attn_head_mask=inputs["cross_attn_head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["decoder_inputs_embeds"], use_cache=inputs["use_cache"], @@ -1150,6 +1168,7 @@ def call( past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_attentions=inputs["encoder_outputs"].attentions, @@ -1187,6 +1206,7 @@ def call( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -1207,6 +1227,7 @@ def call( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1226,6 +1247,7 @@ def call( decoder_attention_mask=inputs["decoder_attention_mask"], head_mask=inputs["head_mask"], decoder_head_mask=inputs["decoder_head_mask"], + cross_attn_head_mask=inputs["cross_attn_head_mask"], encoder_outputs=inputs["encoder_outputs"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], @@ -1244,6 +1266,7 @@ def serving_output(self, output): pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.congig.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None @@ -1252,6 +1275,7 @@ def serving_output(self, output): past_key_values=pkv, decoder_hidden_states=dec_hs, decoder_attentions=dec_attns, + cross_attentions=cross_attns, encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_hidden_states=enc_hs, encoder_attentions=enc_attns, @@ -1306,6 +1330,7 @@ def call( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[TFBaseModelOutput] = None, past_key_values=None, inputs_embeds=None, @@ -1336,6 +1361,7 @@ def call( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1369,6 +1395,7 @@ def call( decoder_attention_mask=inputs["decoder_attention_mask"], head_mask=inputs["head_mask"], decoder_head_mask=inputs["decoder_head_mask"], + cross_attn_head_mask=inputs["cross_attn_head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"], @@ -1391,6 +1418,7 @@ def call( past_key_values=outputs.past_key_values, # index 1 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs + cross_attentions=outputs.cross_attentions, # index 4 of d outputs encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out @@ -1401,6 +1429,7 @@ def serving_output(self, output): pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None @@ -1409,6 +1438,7 @@ def serving_output(self, output): past_key_values=pkv, decoder_hidden_states=dec_hs, decoder_attentions=dec_attns, + cross_attentions=cross_attns, encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_hidden_states=enc_hs, encoder_attentions=enc_attns, diff --git a/src/transformers/models/marian/modeling_tf_marian.py b/src/transformers/models/marian/modeling_tf_marian.py index 15271f8b22bd..64c4f8b31220 100644 --- a/src/transformers/models/marian/modeling_tf_marian.py +++ b/src/transformers/models/marian/modeling_tf_marian.py @@ -31,7 +31,7 @@ ) from ...modeling_tf_outputs import ( TFBaseModelOutput, - TFBaseModelOutputWithPast, + TFBaseModelOutputWithPastAndCrossAttentions, TFSeq2SeqLMOutput, TFSeq2SeqModelOutput, ) @@ -408,7 +408,7 @@ def call( encoder_hidden_states: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None, - encoder_layer_head_mask: Optional[tf.Tensor] = None, + cross_attn_layer_head_mask: Optional[tf.Tensor] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None, training=False, ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: @@ -422,8 +422,8 @@ def call( `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size `(decoder_attention_heads,)` - encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size - `(encoder_attention_heads,)` + cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module. + `(decoder_attention_heads,)` past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states """ residual = hidden_states @@ -444,16 +444,17 @@ def call( # Cross-Attention Block cross_attn_present_key_value = None + cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, _, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - layer_head_mask=encoder_layer_head_mask, + layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, ) hidden_states = self.dropout(hidden_states, training=training) @@ -475,6 +476,7 @@ def call( return ( hidden_states, self_attn_weights, + cross_attn_weights, present_key_value, ) @@ -603,7 +605,7 @@ def serving(self, inputs): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: @@ -611,6 +613,12 @@ def serving(self, inputs): - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. + cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tf.FloatTensor`, `optional`): hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of @@ -707,7 +715,7 @@ def call( Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded @@ -848,7 +856,7 @@ def call( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, use_cache=None, output_attentions=None, @@ -890,14 +898,13 @@ def call( Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up @@ -934,7 +941,7 @@ def call( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, - encoder_head_mask=encoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, @@ -986,19 +993,21 @@ def call( hidden_states = self.dropout(hidden_states + positions, training=inputs["training"]) # decoder layers - all_hidden_states = () - all_self_attns = () - present_key_values = () + all_hidden_states = () if inputs["output_hidden_states"] else None + all_self_attns = () if inputs["output_attentions"] else None + all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None + present_key_values = () if inputs["use_cache"] else None - # check if head_mask has a correct number of layers specified if desired + # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # The tf.debugging asserts are not compliant with XLA then they # have to be disabled in other modes than eager. - if inputs["head_mask"] is not None and tf.executing_eagerly(): - tf.debugging.assert_equal( - shape_list(inputs["head_mask"])[0], - len(self.layers), - message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", - ) + for attn_mask in ["head_mask", "cross_attn_head_mask"]: + if inputs[attn_mask] is not None and tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(inputs[attn_mask])[0], + len(self.layers), + message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.", + ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -1011,14 +1020,14 @@ def call( past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None - hidden_states, layer_self_attn, present_key_value = decoder_layer( + hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( hidden_states, attention_mask=combined_attention_mask, encoder_hidden_states=inputs["encoder_hidden_states"], encoder_attention_mask=inputs["encoder_attention_mask"], layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, - encoder_layer_head_mask=inputs["encoder_head_mask"][idx] - if inputs["encoder_head_mask"] is not None + cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx] + if inputs["cross_attn_head_mask"] is not None else None, past_key_value=past_key_value, ) @@ -1029,23 +1038,30 @@ def call( if inputs["output_attentions"]: all_self_attns += (layer_self_attn,) + if inputs["encoder_hidden_states"] is not None: + all_cross_attns += (layer_cross_attn,) + if inputs["output_hidden_states"]: all_hidden_states += (hidden_states,) - else: - all_hidden_states = None - all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None + if inputs["output_attentions"]: + all_self_attns = list(all_self_attns) + + if inputs["encoder_hidden_states"] is not None: + all_cross_attns = list(all_cross_attns) - present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None + if inputs["use_cache"]: + present_key_values = (inputs["encoder_hidden_states"], present_key_values) if not inputs["return_dict"]: - return hidden_states, present_key_values, all_hidden_states, all_self_attns + return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns else: - return TFBaseModelOutputWithPast( + return TFBaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=present_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, + cross_attentions=all_cross_attns, ) @@ -1092,6 +1108,7 @@ def call( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -1112,6 +1129,7 @@ def call( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1161,7 +1179,7 @@ def call( encoder_hidden_states=inputs["encoder_outputs"][0], encoder_attention_mask=inputs["attention_mask"], head_mask=inputs["decoder_head_mask"], - encoder_head_mask=inputs["head_mask"], + cross_attn_head_mask=inputs["cross_attn_head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["decoder_inputs_embeds"], use_cache=inputs["use_cache"], @@ -1179,6 +1197,7 @@ def call( past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_attentions=inputs["encoder_outputs"].attentions, @@ -1216,6 +1235,7 @@ def call( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -1235,6 +1255,7 @@ def call( decoder_input_ids=decoder_input_ids, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, decoder_attention_mask=decoder_attention_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, @@ -1255,6 +1276,7 @@ def call( decoder_attention_mask=inputs["decoder_attention_mask"], head_mask=inputs["head_mask"], decoder_head_mask=inputs["decoder_head_mask"], + cross_attn_head_mask=inputs["cross_attn_head_mask"], encoder_outputs=inputs["encoder_outputs"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], @@ -1273,6 +1295,7 @@ def serving_output(self, output): pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.congig.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None @@ -1281,6 +1304,7 @@ def serving_output(self, output): past_key_values=pkv, decoder_hidden_states=dec_hs, decoder_attentions=dec_attns, + cross_attentions=cross_attns, encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_hidden_states=enc_hs, encoder_attentions=enc_attns, @@ -1335,6 +1359,7 @@ def call( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[TFBaseModelOutput] = None, past_key_values=None, inputs_embeds=None, @@ -1365,6 +1390,7 @@ def call( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1398,6 +1424,7 @@ def call( decoder_attention_mask=inputs["decoder_attention_mask"], head_mask=inputs["head_mask"], decoder_head_mask=inputs["decoder_head_mask"], + cross_attn_head_mask=inputs["cross_attn_head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"], @@ -1420,6 +1447,7 @@ def call( past_key_values=outputs.past_key_values, # index 1 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs + cross_attentions=outputs.cross_attentions, # index 4 of d outputs encoder_last_hidden_state=outputs.last_hidden_state, # index 0 of encoder outputs encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out @@ -1430,6 +1458,7 @@ def serving_output(self, output): pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None @@ -1438,6 +1467,7 @@ def serving_output(self, output): past_key_values=pkv, decoder_hidden_states=dec_hs, decoder_attentions=dec_attns, + cross_attentions=cross_attns, encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_hidden_states=enc_hs, encoder_attentions=enc_attns, diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py index ea3294aa5a5d..b5e8a8b426c5 100644 --- a/src/transformers/models/mbart/modeling_tf_mbart.py +++ b/src/transformers/models/mbart/modeling_tf_mbart.py @@ -30,7 +30,7 @@ ) from ...modeling_tf_outputs import ( TFBaseModelOutput, - TFBaseModelOutputWithPast, + TFBaseModelOutputWithPastAndCrossAttentions, TFSeq2SeqLMOutput, TFSeq2SeqModelOutput, ) @@ -368,7 +368,7 @@ def call( encoder_hidden_states: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None, - encoder_layer_head_mask: Optional[tf.Tensor] = None, + cross_attn_layer_head_mask: Optional[tf.Tensor] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None, training=False, ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: @@ -382,8 +382,8 @@ def call( `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size `(decoder_attention_heads,)` - encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size - `(encoder_attention_heads,)` + cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module. + `(decoder_attention_heads,)` past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states """ residual = hidden_states @@ -404,17 +404,18 @@ def call( # Cross-Attention Block cross_attn_present_key_value = None + cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, _, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - layer_head_mask=encoder_layer_head_mask, + layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, ) hidden_states = self.dropout(hidden_states, training=training) @@ -435,6 +436,7 @@ def call( return ( hidden_states, self_attn_weights, + cross_attn_weights, present_key_value, ) @@ -547,7 +549,7 @@ def serving(self, inputs): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: @@ -555,6 +557,12 @@ def serving(self, inputs): - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. + cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tf.FloatTensor`, `optional`): hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of @@ -828,7 +836,7 @@ def call( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, use_cache=None, output_attentions=None, @@ -870,14 +878,13 @@ def call( Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up @@ -914,7 +921,7 @@ def call( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, - encoder_head_mask=encoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, @@ -967,19 +974,21 @@ def call( hidden_states = self.dropout(hidden_states, training=inputs["training"]) # decoder layers - all_hidden_states = () - all_self_attns = () - present_key_values = () + all_hidden_states = () if inputs["output_hidden_states"] else None + all_self_attns = () if inputs["output_attentions"] else None + all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None + present_key_values = () if inputs["use_cache"] else None - # check if head_mask has a correct number of layers specified if desired + # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # The tf.debugging asserts are not compliant with XLA then they # have to be disabled in other modes than eager. - if inputs["head_mask"] is not None and tf.executing_eagerly(): - tf.debugging.assert_equal( - shape_list(inputs["head_mask"])[0], - len(self.layers), - message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", - ) + for attn_mask in ["head_mask", "cross_attn_head_mask"]: + if inputs[attn_mask] is not None and tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(inputs[attn_mask])[0], + len(self.layers), + message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.", + ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -992,14 +1001,14 @@ def call( past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None - hidden_states, layer_self_attn, present_key_value = decoder_layer( + hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( hidden_states, attention_mask=combined_attention_mask, encoder_hidden_states=inputs["encoder_hidden_states"], encoder_attention_mask=inputs["encoder_attention_mask"], layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, - encoder_layer_head_mask=inputs["encoder_head_mask"][idx] - if inputs["encoder_head_mask"] is not None + cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx] + if inputs["cross_attn_head_mask"] is not None else None, past_key_value=past_key_value, ) @@ -1010,25 +1019,32 @@ def call( if inputs["output_attentions"]: all_self_attns += (layer_self_attn,) + if inputs["encoder_hidden_states"] is not None: + all_cross_attns += (layer_cross_attn,) + hidden_states = self.layer_norm(hidden_states) if inputs["output_hidden_states"]: all_hidden_states += (hidden_states,) - else: - all_hidden_states = None - all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None + if inputs["output_attentions"]: + all_self_attns = list(all_self_attns) + + if inputs["encoder_hidden_states"] is not None: + all_cross_attns = list(all_cross_attns) - present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None + if inputs["use_cache"]: + present_key_values = (inputs["encoder_hidden_states"], present_key_values) if not inputs["return_dict"]: - return hidden_states, present_key_values, all_hidden_states, all_self_attns + return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns else: - return TFBaseModelOutputWithPast( + return TFBaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=present_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, + cross_attentions=all_cross_attns, ) @@ -1075,6 +1091,7 @@ def call( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -1095,6 +1112,7 @@ def call( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1147,7 +1165,7 @@ def call( encoder_hidden_states=inputs["encoder_outputs"][0], encoder_attention_mask=inputs["attention_mask"], head_mask=inputs["decoder_head_mask"], - encoder_head_mask=inputs["head_mask"], + cross_attn_head_mask=inputs["cross_attn_head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["decoder_inputs_embeds"], use_cache=inputs["use_cache"], @@ -1165,6 +1183,7 @@ def call( past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_attentions=inputs["encoder_outputs"].attentions, @@ -1202,6 +1221,7 @@ def call( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -1222,6 +1242,7 @@ def call( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1241,6 +1262,7 @@ def call( decoder_attention_mask=inputs["decoder_attention_mask"], head_mask=inputs["head_mask"], decoder_head_mask=inputs["decoder_head_mask"], + cross_attn_head_mask=inputs["cross_attn_head_mask"], encoder_outputs=inputs["encoder_outputs"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], @@ -1259,6 +1281,7 @@ def serving_output(self, output): pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.congig.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None @@ -1267,6 +1290,7 @@ def serving_output(self, output): past_key_values=pkv, decoder_hidden_states=dec_hs, decoder_attentions=dec_attns, + cross_attentions=cross_attns, encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_hidden_states=enc_hs, encoder_attentions=enc_attns, @@ -1321,6 +1345,7 @@ def call( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[TFBaseModelOutput] = None, past_key_values=None, inputs_embeds=None, @@ -1351,6 +1376,7 @@ def call( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1382,6 +1408,7 @@ def call( decoder_attention_mask=inputs["decoder_attention_mask"], head_mask=inputs["head_mask"], decoder_head_mask=inputs["decoder_head_mask"], + cross_attn_head_mask=inputs["cross_attn_head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"], @@ -1404,6 +1431,7 @@ def call( past_key_values=outputs.past_key_values, # index 1 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs + cross_attentions=outputs.cross_attentions, # index 4 of d outputs encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out @@ -1414,6 +1442,7 @@ def serving_output(self, output): pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None @@ -1422,6 +1451,7 @@ def serving_output(self, output): past_key_values=pkv, decoder_hidden_states=dec_hs, decoder_attentions=dec_attns, + cross_attentions=cross_attns, encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_hidden_states=enc_hs, encoder_attentions=enc_attns, diff --git a/src/transformers/models/pegasus/modeling_tf_pegasus.py b/src/transformers/models/pegasus/modeling_tf_pegasus.py index 504c7d23affb..dc65da296f12 100644 --- a/src/transformers/models/pegasus/modeling_tf_pegasus.py +++ b/src/transformers/models/pegasus/modeling_tf_pegasus.py @@ -31,7 +31,7 @@ ) from ...modeling_tf_outputs import ( TFBaseModelOutput, - TFBaseModelOutputWithPast, + TFBaseModelOutputWithPastAndCrossAttentions, TFSeq2SeqLMOutput, TFSeq2SeqModelOutput, ) @@ -409,7 +409,7 @@ def call( encoder_hidden_states: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None, - encoder_layer_head_mask: Optional[tf.Tensor] = None, + cross_attn_layer_head_mask: Optional[tf.Tensor] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None, training=False, ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: @@ -423,8 +423,8 @@ def call( `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size `(decoder_attention_heads,)` - encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size - `(encoder_attention_heads,)` + cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module. + `(decoder_attention_heads,)` past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states """ residual = hidden_states @@ -445,17 +445,18 @@ def call( # Cross-Attention Block cross_attn_present_key_value = None + cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, _, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - layer_head_mask=encoder_layer_head_mask, + layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, ) hidden_states = self.dropout(hidden_states, training=training) @@ -476,6 +477,7 @@ def call( return ( hidden_states, self_attn_weights, + cross_attn_weights, present_key_value, ) @@ -607,7 +609,7 @@ def serving(self, inputs): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: @@ -615,6 +617,12 @@ def serving(self, inputs): - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. + cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tf.FloatTensor`, `optional`): hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of @@ -859,7 +867,7 @@ def call( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, use_cache=None, output_attentions=None, @@ -901,14 +909,13 @@ def call( Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up @@ -945,7 +952,7 @@ def call( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, - encoder_head_mask=encoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, @@ -997,19 +1004,21 @@ def call( hidden_states = self.dropout(hidden_states + positions, training=inputs["training"]) # decoder layers - all_hidden_states = () - all_self_attns = () - present_key_values = () + all_hidden_states = () if inputs["output_hidden_states"] else None + all_self_attns = () if inputs["output_attentions"] else None + all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None + present_key_values = () if inputs["use_cache"] else None - # check if head_mask has a correct number of layers specified if desired + # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # The tf.debugging asserts are not compliant with XLA then they # have to be disabled in other modes than eager. - if inputs["head_mask"] is not None and tf.executing_eagerly(): - tf.debugging.assert_equal( - shape_list(inputs["head_mask"])[0], - len(self.layers), - message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", - ) + for attn_mask in ["head_mask", "cross_attn_head_mask"]: + if inputs[attn_mask] is not None and tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(inputs[attn_mask])[0], + len(self.layers), + message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.", + ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -1022,14 +1031,14 @@ def call( past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None - hidden_states, layer_self_attn, present_key_value = decoder_layer( + hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( hidden_states, attention_mask=combined_attention_mask, encoder_hidden_states=inputs["encoder_hidden_states"], encoder_attention_mask=inputs["encoder_attention_mask"], layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, - encoder_layer_head_mask=inputs["encoder_head_mask"][idx] - if inputs["encoder_head_mask"] is not None + cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx] + if inputs["cross_attn_head_mask"] is not None else None, past_key_value=past_key_value, ) @@ -1040,25 +1049,32 @@ def call( if inputs["output_attentions"]: all_self_attns += (layer_self_attn,) + if inputs["encoder_hidden_states"] is not None: + all_cross_attns += (layer_cross_attn,) + hidden_states = self.layer_norm(hidden_states) if inputs["output_hidden_states"]: all_hidden_states += (hidden_states,) - else: - all_hidden_states = None - all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None + if inputs["output_attentions"]: + all_self_attns = list(all_self_attns) + + if inputs["encoder_hidden_states"] is not None: + all_cross_attns = list(all_cross_attns) - present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None + if inputs["use_cache"]: + present_key_values = (inputs["encoder_hidden_states"], present_key_values) if not inputs["return_dict"]: - return hidden_states, present_key_values, all_hidden_states, all_self_attns + return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns else: - return TFBaseModelOutputWithPast( + return TFBaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=present_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, + cross_attentions=all_cross_attns, ) @@ -1105,6 +1121,7 @@ def call( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -1125,6 +1142,7 @@ def call( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1174,7 +1192,7 @@ def call( encoder_hidden_states=inputs["encoder_outputs"][0], encoder_attention_mask=inputs["attention_mask"], head_mask=inputs["decoder_head_mask"], - encoder_head_mask=inputs["head_mask"], + cross_attn_head_mask=inputs["cross_attn_head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["decoder_inputs_embeds"], use_cache=inputs["use_cache"], @@ -1192,6 +1210,7 @@ def call( past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_attentions=inputs["encoder_outputs"].attentions, @@ -1229,6 +1248,7 @@ def call( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -1249,6 +1269,7 @@ def call( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1268,6 +1289,7 @@ def call( decoder_attention_mask=inputs["decoder_attention_mask"], head_mask=inputs["head_mask"], decoder_head_mask=inputs["decoder_head_mask"], + cross_attn_head_mask=inputs["cross_attn_head_mask"], encoder_outputs=inputs["encoder_outputs"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], @@ -1286,6 +1308,7 @@ def serving_output(self, output): pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.congig.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None @@ -1294,6 +1317,7 @@ def serving_output(self, output): past_key_values=pkv, decoder_hidden_states=dec_hs, decoder_attentions=dec_attns, + cross_attentions=cross_attns, encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_hidden_states=enc_hs, encoder_attentions=enc_attns, @@ -1348,6 +1372,7 @@ def call( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[TFBaseModelOutput] = None, past_key_values=None, inputs_embeds=None, @@ -1378,6 +1403,7 @@ def call( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1411,6 +1437,7 @@ def call( decoder_attention_mask=inputs["decoder_attention_mask"], head_mask=inputs["head_mask"], decoder_head_mask=inputs["decoder_head_mask"], + cross_attn_head_mask=inputs["cross_attn_head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"], @@ -1433,6 +1460,7 @@ def call( past_key_values=outputs.past_key_values, # index 1 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs + cross_attentions=outputs.cross_attentions, # index 4 of d outputs encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out @@ -1443,6 +1471,7 @@ def serving_output(self, output): pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None @@ -1451,6 +1480,7 @@ def serving_output(self, output): past_key_values=pkv, decoder_hidden_states=dec_hs, decoder_attentions=dec_attns, + cross_attentions=cross_attns, encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_hidden_states=enc_hs, encoder_attentions=enc_attns, diff --git a/tests/test_modeling_tf_blenderbot.py b/tests/test_modeling_tf_blenderbot.py index 050a223f0e05..7e00144a02fc 100644 --- a/tests/test_modeling_tf_blenderbot.py +++ b/tests/test_modeling_tf_blenderbot.py @@ -146,6 +146,7 @@ def prepare_blenderbot_inputs_dict( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, ): if attention_mask is None: attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) @@ -161,6 +162,8 @@ def prepare_blenderbot_inputs_dict( head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) if decoder_head_mask is None: decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) + if cross_attn_head_mask is None: + cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, @@ -168,6 +171,7 @@ def prepare_blenderbot_inputs_dict( "decoder_attention_mask": decoder_attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, } diff --git a/tests/test_modeling_tf_blenderbot_small.py b/tests/test_modeling_tf_blenderbot_small.py index 850fb3357ba8..ed5ed42eef2f 100644 --- a/tests/test_modeling_tf_blenderbot_small.py +++ b/tests/test_modeling_tf_blenderbot_small.py @@ -146,6 +146,7 @@ def prepare_blenderbot_small_inputs_dict( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, ): if attention_mask is None: attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) @@ -161,6 +162,8 @@ def prepare_blenderbot_small_inputs_dict( head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) if decoder_head_mask is None: decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) + if cross_attn_head_mask is None: + cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, @@ -168,6 +171,7 @@ def prepare_blenderbot_small_inputs_dict( "decoder_attention_mask": decoder_attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, } diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 2c2eb1e546ac..fe340dc23675 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -172,8 +172,12 @@ def test_forward_signature(self): "decoder_attention_mask", ] expected_arg_names.extend( - ["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"] - if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names + ["head_mask", "decoder_head_mask"] if "head_mask" and "decoder_head_mask" in arg_names else [] + ) + # Necessary to handle BART with newly added cross_attn_head_mask + expected_arg_names.extend( + ["cross_attn_head_mask", "encoder_outputs"] + if "cross_attn_head_mask" in arg_names else ["encoder_outputs"] ) self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) @@ -620,7 +624,7 @@ def test_attention_outputs(self): def check_decoder_attentions_output(outputs): out_len = len(outputs) - self.assertEqual(out_len % 2, 0) + self.assertEqual(min(out_len % 2, out_len % 5), 0) # differentiation due to newly added cross_attentions decoder_attentions = outputs.decoder_attentions self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) self.assertListEqual( @@ -740,7 +744,8 @@ def check_attentions_validity(attentions): if model.config.is_encoder_decoder: check_attentions_validity(outputs.encoder_attentions) check_attentions_validity(outputs.decoder_attentions) - check_attentions_validity(outputs.cross_attentions) + if "cross_attn_head_mask" in arg_names: + check_attentions_validity(outputs.cross_attentions) else: check_attentions_validity(outputs.attentions) diff --git a/tests/test_modeling_tf_marian.py b/tests/test_modeling_tf_marian.py index 8000e41b5fe2..bd91fec06926 100644 --- a/tests/test_modeling_tf_marian.py +++ b/tests/test_modeling_tf_marian.py @@ -148,6 +148,7 @@ def prepare_marian_inputs_dict( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, ): if attention_mask is None: attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) @@ -163,6 +164,8 @@ def prepare_marian_inputs_dict( head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) if decoder_head_mask is None: decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) + if cross_attn_head_mask is None: + cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, @@ -170,6 +173,7 @@ def prepare_marian_inputs_dict( "decoder_attention_mask": decoder_attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, } diff --git a/tests/test_modeling_tf_mbart.py b/tests/test_modeling_tf_mbart.py index 228fe6a57b4b..d21b08228ba0 100644 --- a/tests/test_modeling_tf_mbart.py +++ b/tests/test_modeling_tf_mbart.py @@ -150,6 +150,7 @@ def prepare_mbart_inputs_dict( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, ): if attention_mask is None: attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) @@ -165,13 +166,16 @@ def prepare_mbart_inputs_dict( head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) if decoder_head_mask is None: decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) + if cross_attn_head_mask is None: + cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "decoder_attention_mask": decoder_attention_mask, "head_mask": head_mask, - "decoder_head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, } diff --git a/tests/test_modeling_tf_pegasus.py b/tests/test_modeling_tf_pegasus.py index adbd618859b3..136907ff0e38 100644 --- a/tests/test_modeling_tf_pegasus.py +++ b/tests/test_modeling_tf_pegasus.py @@ -146,6 +146,7 @@ def prepare_pegasus_inputs_dict( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, ): if attention_mask is None: attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) @@ -161,6 +162,8 @@ def prepare_pegasus_inputs_dict( head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) if decoder_head_mask is None: decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) + if cross_attn_head_mask is None: + cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, @@ -168,6 +171,7 @@ def prepare_pegasus_inputs_dict( "decoder_attention_mask": decoder_attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, } From 995f4033193248fcbb9d6081251f3747310711a5 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Sat, 13 Mar 2021 13:42:35 +0100 Subject: [PATCH 3/5] Update TF model templates --- ...tf_{{cookiecutter.lowercase_modelname}}.py | 138 +++++++++++++++++- 1 file changed, 131 insertions(+), 7 deletions(-) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py index 7d977ae84731..fb814b2db8ce 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -1726,16 +1726,18 @@ def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs) self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False): + def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False): """ Args: hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (:obj:`tf.Tensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)` """ residual = hidden_states hidden_states, self_attn_weights, _ = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask + hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask ) # The tf.debugging asserts are not compliant with XLA then they @@ -1796,6 +1798,8 @@ def call( attention_mask: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None, + layer_head_mask: Optional[tf.Tensor] = None, + cross_attn_layer_head_mask: Optional[tf.Tensor] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None, training=False, ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: @@ -1807,6 +1811,10 @@ def call( encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size + `(decoder_attention_heads,)` + cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module. + `(decoder_attention_heads,)` past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states """ residual = hidden_states @@ -1819,6 +1827,7 @@ def call( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, ) hidden_states = self.dropout(hidden_states, training=training) hidden_states = residual + hidden_states @@ -1826,15 +1835,17 @@ def call( # Cross-Attention Block cross_attn_present_key_value = None + cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, _, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, ) hidden_states = self.dropout(hidden_states, training=training) @@ -1856,6 +1867,7 @@ def call( return ( hidden_states, self_attn_weights, + cross_attn_layer_head_mask, present_key_value, ) @@ -1963,6 +1975,24 @@ def serving(self, inputs): the right for denoising pre-training following the paper. decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. + head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tf.FloatTensor`, `optional`): hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of @@ -2032,6 +2062,7 @@ def call( input_ids=None, inputs_embeds=None, attention_mask=None, + head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -2056,6 +2087,12 @@ def call( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert :obj:`input_ids` indices @@ -2080,6 +2117,7 @@ def call( config=self.config, input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -2113,6 +2151,16 @@ def call( encoder_states = () if inputs["output_hidden_states"] else None all_attentions = () if inputs["output_attentions"] else None + # check if head_mask has a correct number of layers specified if desired + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if inputs["head_mask"] is not None and tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(inputs["head_mask"])[0], + len(self.layers), + message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", + ) + # encoder layers for encoder_layer in self.layers: @@ -2123,7 +2171,11 @@ def call( if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer continue - hidden_states, attn = encoder_layer(hidden_states, inputs["attention_mask"]) + hidden_states, attn = encoder_layer( + hidden_states, + inputs["attention_mask"], + inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, + ) if inputs["output_attentions"]: all_attentions += (attn,) @@ -2179,6 +2231,8 @@ def call( attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + head_mask=None, + cross_attn_head_mask=None, past_key_values=None, use_cache=None, output_attentions=None, @@ -2216,6 +2270,18 @@ def call( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. @@ -2250,6 +2316,8 @@ def call( attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, @@ -2295,8 +2363,20 @@ def call( # decoder layers all_hidden_states = () if inputs["output_hidden_states"] else None all_self_attns = () if inputs["output_attentions"] else None + all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None present_key_values = () if inputs["use_cache"] else None + # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + for attn_mask in ["head_mask", "cross_attn_head_mask"]: + if inputs[attn_mask] is not None and tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(inputs[attn_mask])[0], + len(self.layers), + message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.", + ) + for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if inputs["output_hidden_states"]: @@ -2309,11 +2389,15 @@ def call( past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None - hidden_states, layer_self_attn, present_key_value = decoder_layer( + hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( hidden_states, attention_mask=combined_attention_mask, encoder_hidden_states=inputs["encoder_hidden_states"], encoder_attention_mask=inputs["encoder_attention_mask"], + layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, + cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx] + if inputs["cross_attn_head_mask"] is not None + else None, past_key_value=past_key_value, ) @@ -2323,23 +2407,30 @@ def call( if inputs["output_attentions"]: all_self_attns += (layer_self_attn,) + if inputs["encoder_hidden_states"] is not None: + all_cross_attns += (layer_cross_attn,) + if inputs["output_hidden_states"]: all_hidden_states += (hidden_states,) if inputs["output_attentions"]: all_self_attns = list(all_self_attns) + if inputs["encoder_hidden_states"] is not None: + all_cross_attns = list(all_cross_attns) + if inputs["use_cache"]: present_key_values = (inputs["encoder_hidden_states"], present_key_values) if not inputs["return_dict"]: - return hidden_states, present_key_values, all_hidden_states, all_self_attns + return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns else: - return TFBaseModelOutputWithPast( + return TFBaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=present_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, + cross_attentions=all_cross_attns, ) @tf.function @@ -2411,6 +2502,9 @@ def call( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -2429,6 +2523,9 @@ def call( attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -2448,6 +2545,7 @@ def call( inputs["encoder_outputs"] = self.encoder( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], + head_mask=inputs["head_mask"], inputs_embeds=inputs["inputs_embeds"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], @@ -2470,6 +2568,8 @@ def call( attention_mask=inputs["decoder_attention_mask"], encoder_hidden_states=inputs["encoder_outputs"][0], encoder_attention_mask=inputs["attention_mask"], + head_mask=inputs["decoder_head_mask"], + cross_attn_head_mask=inputs["cross_attn_head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["decoder_inputs_embeds"], use_cache=inputs["use_cache"], @@ -2487,6 +2587,7 @@ def call( past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_attentions=inputs["encoder_outputs"].attentions, @@ -2522,6 +2623,9 @@ def call( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -2540,6 +2644,9 @@ def call( attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -2557,6 +2664,9 @@ def call( attention_mask=inputs["attention_mask"], decoder_input_ids=inputs["decoder_input_ids"], decoder_attention_mask=inputs["decoder_attention_mask"], + head_mask=inputs["head_mask"], + decoder_head_mask=inputs["decoder_head_mask"], + cross_attn_head_mask=inputs["cross_attn_head_mask"], encoder_outputs=inputs["encoder_outputs"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], @@ -2575,6 +2685,7 @@ def serving_output(self, output): pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None @@ -2583,6 +2694,7 @@ def serving_output(self, output): past_key_values=pkv, decoder_hidden_states=dec_hs, decoder_attentions=dec_attns, + cross_attentions=cross_attns, encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_hidden_states=enc_hs, encoder_attentions=enc_attns, @@ -2635,6 +2747,9 @@ def call( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[TFBaseModelOutput] = None, past_key_values=None, inputs_embeds=None, @@ -2670,6 +2785,9 @@ def call( attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -2696,6 +2814,9 @@ def call( decoder_input_ids=inputs["decoder_input_ids"], encoder_outputs=inputs["encoder_outputs"], decoder_attention_mask=inputs["decoder_attention_mask"], + head_mask=inputs["head_mask"], + decoder_head_mask=inputs["decoder_head_mask"], + cross_attn_head_mask=inputs["cross_attn_head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"], @@ -2718,6 +2839,7 @@ def call( past_key_values=outputs.past_key_values, # index 1 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs + cross_attentions=outputs.cross_attentions, # index 4 of d outputs encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out @@ -2728,6 +2850,7 @@ def serving_output(self, output): pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None @@ -2736,6 +2859,7 @@ def serving_output(self, output): past_key_values=pkv, decoder_hidden_states=dec_hs, decoder_attentions=dec_attns, + cross_attentions=cross_attns, encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_hidden_states=enc_hs, encoder_attentions=enc_attns, From a91eb7c8253d6ddd24c25084802c225f1efbb0bc Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Sat, 13 Mar 2021 14:30:14 +0100 Subject: [PATCH 4/5] Fix missing , in TF model templates --- .../modeling_tf_{{cookiecutter.lowercase_modelname}}.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py index fb814b2db8ce..dd4b8bfc65fb 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -146,7 +146,6 @@ def call( return final_embeddings - # Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->{{cookiecutter.camelcase_modelname}} class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer): def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs): @@ -351,6 +350,7 @@ def call( return outputs + # Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->{{cookiecutter.camelcase_modelname}} class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer): def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs): @@ -624,7 +624,6 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel): base_model_prefix = "{{cookiecutter.lowercase_modelname}}" - {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the @@ -884,6 +883,7 @@ def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput: return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns) + @add_start_docstrings( """{{cookiecutter.modelname}} Model with a `language modeling` head on top for CLM fine-tuning. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING ) @@ -2041,7 +2041,6 @@ def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, embed_tok self.max_source_positions = config.max_position_embeddings self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 - self.embed_tokens = embed_tokens self.embed_positions = TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding( config.max_position_embeddings, @@ -2117,7 +2116,7 @@ def call( config=self.config, input_ids=input_ids, attention_mask=attention_mask, - head_mask=head_mask + head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, From ac570503c78395bc042494c3243c26e8fe4247c1 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 15 Mar 2021 14:55:16 +0100 Subject: [PATCH 5/5] Fix typo: congig -> config --- src/transformers/models/bart/modeling_tf_bart.py | 2 +- src/transformers/models/blenderbot/modeling_tf_blenderbot.py | 2 +- .../models/blenderbot_small/modeling_tf_blenderbot_small.py | 2 +- src/transformers/models/marian/modeling_tf_marian.py | 2 +- src/transformers/models/mbart/modeling_tf_mbart.py | 2 +- src/transformers/models/pegasus/modeling_tf_pegasus.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index 1f245f4e5964..961e71a4bbb8 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -1267,7 +1267,7 @@ def serving_output(self, output): pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.congig.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None diff --git a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py index e9e07df2981d..40d8b556a4bc 100644 --- a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py @@ -1278,7 +1278,7 @@ def serving_output(self, output): pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.congig.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None diff --git a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py index 6ca8ec6243d3..8b2ae82df138 100644 --- a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py @@ -1266,7 +1266,7 @@ def serving_output(self, output): pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.congig.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None diff --git a/src/transformers/models/marian/modeling_tf_marian.py b/src/transformers/models/marian/modeling_tf_marian.py index 64c4f8b31220..73dd87d913d0 100644 --- a/src/transformers/models/marian/modeling_tf_marian.py +++ b/src/transformers/models/marian/modeling_tf_marian.py @@ -1295,7 +1295,7 @@ def serving_output(self, output): pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.congig.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py index b5e8a8b426c5..2207c00fed55 100644 --- a/src/transformers/models/mbart/modeling_tf_mbart.py +++ b/src/transformers/models/mbart/modeling_tf_mbart.py @@ -1281,7 +1281,7 @@ def serving_output(self, output): pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.congig.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None diff --git a/src/transformers/models/pegasus/modeling_tf_pegasus.py b/src/transformers/models/pegasus/modeling_tf_pegasus.py index dc65da296f12..7fbd5bffddfa 100644 --- a/src/transformers/models/pegasus/modeling_tf_pegasus.py +++ b/src/transformers/models/pegasus/modeling_tf_pegasus.py @@ -1308,7 +1308,7 @@ def serving_output(self, output): pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.congig.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None