From 4380ae5c74fd35640368289df4e46916d1b5a15a Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Tue, 2 Feb 2021 19:20:38 +0100 Subject: [PATCH 1/6] Add head_mask & decoder_head_mask + some corrections --- .../models/prophetnet/modeling_prophetnet.py | 135 +++++++++++++++++- tests/test_modeling_prophetnet.py | 1 - 2 files changed, 130 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index d473e8758a6258..415f184a8bb606 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -107,6 +107,18 @@ If you want to change padding behavior, you should read :func:`modeling_bart._prepare_decoder_inputs` and modify to your needs. See diagram 1 in `the paper `__ for more information on the default strategy. + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, @@ -149,6 +161,12 @@ - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more detail. @@ -632,6 +650,7 @@ def forward( hidden_states, key_value_states: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, + layer_head_mask: Optional[Tensor] = None, layer_state: Optional[Dict[str, Optional[Tensor]]] = None, ) -> Tuple[Tensor, Optional[Tensor]]: @@ -706,6 +725,17 @@ def forward( ) attn_weights = F.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + assert layer_head_mask.size() == ( + self.num_attn_heads, + ), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}" + attn_weights = ( + layer_head_mask.view(1, -1, 1, 1) + * attn_weights.view(batch_size, self.num_attn_heads, sequence_length, key_sequence_length) + ) + attn_weights = attn_weights.view(batch_size * self.num_attn_heads, sequence_length, key_sequence_length) + attn_probs = F.dropout( attn_weights, p=self.attention_dropout, @@ -790,6 +820,7 @@ def forward( hidden_states, layer_state=None, attention_mask=None, + layer_head_mask=None, extended_predict_attention_mask=None, main_relative_position_buckets=None, predict_relative_position_buckets=None, @@ -870,6 +901,18 @@ def forward( onnx_trace=self.onnx_trace, ).type_as(main_attn_weights) + if layer_head_mask is not None: + assert layer_head_mask.size() == ( + self.num_attn_heads, + ), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}" + main_attn_probs = ( + layer_head_mask.view(1, -1, 1, 1) + * main_attn_probs.view(batch_size, self.num_attn_heads, -1, main_sequence_length) + ) + main_attn_probs = main_attn_probs.view( + batch_size * self.num_attn_heads, -1, main_sequence_length + ) + main_attn_probs = F.dropout(main_attn_probs, p=self.attention_dropout, training=self.training) # project to attn_output @@ -918,6 +961,21 @@ def forward( dim=-1, onnx_trace=self.onnx_trace, ).type_as(predict_attn_weights) + + if layer_head_mask is not None: + assert layer_head_mask.size() == ( + self.num_attn_heads, + ), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}" + predict_attn_probs = ( + layer_head_mask.view(1, 1, -1, 1, 1) + * predict_attn_probs.view( + self.ngram, batch_size, self.num_attn_heads, main_sequence_length, 2 * main_sequence_length + ) + ) + predict_attn_probs = predict_attn_probs.view( + self.ngram, batch_size * self.num_attn_heads, main_sequence_length, 2 * main_sequence_length + ) + predict_attn_probs = F.dropout(predict_attn_probs, p=self.attention_dropout, training=self.training) # project to attention output @@ -1053,11 +1111,12 @@ def __init__(self, config: ProphetNetConfig): self.feed_forward = ProphetNetFeedForward(config, config.encoder_ffn_dim) self.feed_forward_layer_norm = LayerNorm(config.hidden_size) - def forward(self, hidden_states, attention_mask): + def forward(self, hidden_states, attention_mask, layer_head_mask): # 1st residual block attention_output, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, ) hidden_states = self.self_attn_layer_norm(attention_output + hidden_states) @@ -1094,6 +1153,8 @@ def forward( encoder_attn_mask=None, layer_state=None, attention_mask=None, + layer_head_mask=None, + encoder_layer_head_mask=None, extended_predict_attention_mask=None, main_relative_position_buckets=None, predict_relative_position_buckets=None, @@ -1106,6 +1167,7 @@ def forward( hidden_states=hidden_states, layer_state=layer_state, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, extended_predict_attention_mask=extended_predict_attention_mask, main_relative_position_buckets=main_relative_position_buckets, predict_relative_position_buckets=predict_relative_position_buckets, @@ -1120,6 +1182,7 @@ def forward( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attn_mask, + layer_head_mask=encoder_layer_head_mask, layer_state=layer_state, # mutates layer state ) hidden_states = self.cross_attn_layer_norm(attention_output + hidden_states) @@ -1175,6 +1238,7 @@ def forward( self, input_ids=None, attention_mask=None, + head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, @@ -1228,12 +1292,21 @@ def forward( encoder_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None - for encoder_layer in self.layers: + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: hidden_states = hidden_states.transpose(0, 1) encoder_hidden_states = encoder_hidden_states + (hidden_states,) hidden_states = hidden_states.transpose(0, 1) - hidden_states, attn_probs = encoder_layer(hidden_states, attention_mask=extended_attention_mask) + hidden_states, attn_probs = encoder_layer( + hidden_states, + attention_mask=extended_attention_mask, + layer_head_mask=head_mask[idx] if head_mask is not None else None, + ) if output_attentions: all_attentions = all_attentions + (attn_probs,) @@ -1295,6 +1368,8 @@ def forward( attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, @@ -1309,6 +1384,13 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. @@ -1422,6 +1504,10 @@ def forward( all_cross_attns = () if output_attentions and self.config.add_cross_attention else None present_key_values = () if use_cache else None + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: # grad cannot be kept because tensor is sliced @@ -1442,6 +1528,8 @@ def forward( encoder_attn_mask=extended_encoder_attention_mask, layer_state=layer_state, attention_mask=extended_attention_mask, + layer_head_mask=head_mask[idx] if head_mask is not None else None, + encoder_layer_head_mask=encoder_head_mask[idx] if head_mask is not None else None, extended_predict_attention_mask=extended_predict_attention_mask, main_relative_position_buckets=main_relative_position_buckets, predict_relative_position_buckets=predict_relative_position_buckets, @@ -1612,6 +1700,8 @@ def forward( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs: Optional[Tuple] = None, past_key_values=None, inputs_embeds=None, @@ -1650,6 +1740,7 @@ def forward( encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1662,6 +1753,8 @@ def forward( attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + encoder_head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, output_attentions=output_attentions, @@ -1719,6 +1812,8 @@ def forward( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1762,6 +1857,8 @@ def forward( attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1836,7 +1933,14 @@ def _compute_loss(self, logits, labels, ignore_index=-100): return loss def prepare_inputs_for_generation( - self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + self, + decoder_input_ids, + past=None, + attention_mask=None, + head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, ): assert encoder_outputs is not None, "`encoder_outputs` have to be passed for generation." @@ -1849,6 +1953,7 @@ def prepare_inputs_for_generation( "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "head_mask": head_mask, "use_cache": use_cache, } @@ -1923,6 +2028,8 @@ def forward( attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, past_key_values=None, inputs_embeds=None, labels=None, @@ -1938,6 +2045,13 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. @@ -1998,6 +2112,8 @@ def forward( attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_head_mask=encoder_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -2061,7 +2177,15 @@ def _compute_loss(self, logits, labels, ignore_index=-100): return loss - def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs): + def prepare_inputs_for_generation( + self, + input_ids, + past=None, + attention_mask=None, + head_mask=None, + use_cache=None, + **kwargs, + ): # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: attention_mask = input_ids.new_ones(input_ids.shape) @@ -2072,6 +2196,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed "attention_mask": attention_mask, + "head_mask": head_mask, "past_key_values": past, "use_cache": use_cache, } diff --git a/tests/test_modeling_prophetnet.py b/tests/test_modeling_prophetnet.py index c9ba56396e1070..00614a3447ff21 100644 --- a/tests/test_modeling_prophetnet.py +++ b/tests/test_modeling_prophetnet.py @@ -891,7 +891,6 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test test_pruning = False test_torchscript = False test_resize_embeddings = False - test_headmasking = False is_encoder_decoder = True def setUp(self): From 7bb01a039106181ae604bd65178daecdb1af32d9 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Tue, 2 Feb 2021 19:40:24 +0100 Subject: [PATCH 2/6] Fix head masking for N-grams --- src/transformers/models/prophetnet/modeling_prophetnet.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 415f184a8bb606..c72ebd2865eb42 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -736,6 +736,9 @@ def forward( ) attn_weights = attn_weights.view(batch_size * self.num_attn_heads, sequence_length, key_sequence_length) + # apply head_mask also on attn_weights_reshaped which is used for n-gram attention inside the model + attn_weights_reshaped = layer_head_mask.view(1, -1, 1, 1) * attn_weights_reshaped + attn_probs = F.dropout( attn_weights, p=self.attention_dropout, From 53b3d182e783f62687484146e74cb89d9dd566a4 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Tue, 2 Feb 2021 19:48:12 +0100 Subject: [PATCH 3/6] Enable test_headmasking for encoder and decod * Fix one typo regarding in modeling_propgetnet.py * Enable test_headmasking for ProphetNetStandaloneDecoderModelTest and ProphetNetStandaloneEncoderModelTest in test_modeling_prophetnet.py --- src/transformers/models/prophetnet/modeling_prophetnet.py | 2 +- tests/test_modeling_prophetnet.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index c72ebd2865eb42..62166f495a54d9 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -1532,7 +1532,7 @@ def forward( layer_state=layer_state, attention_mask=extended_attention_mask, layer_head_mask=head_mask[idx] if head_mask is not None else None, - encoder_layer_head_mask=encoder_head_mask[idx] if head_mask is not None else None, + encoder_layer_head_mask=encoder_head_mask[idx] if encoder_head_mask is not None else None, extended_predict_attention_mask=extended_predict_attention_mask, main_relative_position_buckets=main_relative_position_buckets, predict_relative_position_buckets=predict_relative_position_buckets, diff --git a/tests/test_modeling_prophetnet.py b/tests/test_modeling_prophetnet.py index 00614a3447ff21..3a8a70665fda56 100644 --- a/tests/test_modeling_prophetnet.py +++ b/tests/test_modeling_prophetnet.py @@ -1096,7 +1096,6 @@ class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMix test_pruning = False test_torchscript = False test_resize_embeddings = False - test_headmasking = False is_encoder_decoder = False def setUp(self): @@ -1125,7 +1124,6 @@ class ProphetNetStandaloneEncoderModelTest(ModelTesterMixin, unittest.TestCase): test_pruning = False test_torchscript = False test_resize_embeddings = False - test_headmasking = False is_encoder_decoder = False def setUp(self): From df8407d3c22187c1c160098de216275727845740 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Tue, 2 Feb 2021 20:07:53 +0100 Subject: [PATCH 4/6] make style --- .../models/prophetnet/modeling_prophetnet.py | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 62166f495a54d9..125120ca2ecd84 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -730,9 +730,8 @@ def forward( assert layer_head_mask.size() == ( self.num_attn_heads, ), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}" - attn_weights = ( - layer_head_mask.view(1, -1, 1, 1) - * attn_weights.view(batch_size, self.num_attn_heads, sequence_length, key_sequence_length) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( + batch_size, self.num_attn_heads, sequence_length, key_sequence_length ) attn_weights = attn_weights.view(batch_size * self.num_attn_heads, sequence_length, key_sequence_length) @@ -908,13 +907,10 @@ def forward( assert layer_head_mask.size() == ( self.num_attn_heads, ), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}" - main_attn_probs = ( - layer_head_mask.view(1, -1, 1, 1) - * main_attn_probs.view(batch_size, self.num_attn_heads, -1, main_sequence_length) - ) - main_attn_probs = main_attn_probs.view( - batch_size * self.num_attn_heads, -1, main_sequence_length + main_attn_probs = layer_head_mask.view(1, -1, 1, 1) * main_attn_probs.view( + batch_size, self.num_attn_heads, -1, main_sequence_length ) + main_attn_probs = main_attn_probs.view(batch_size * self.num_attn_heads, -1, main_sequence_length) main_attn_probs = F.dropout(main_attn_probs, p=self.attention_dropout, training=self.training) @@ -969,11 +965,8 @@ def forward( assert layer_head_mask.size() == ( self.num_attn_heads, ), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}" - predict_attn_probs = ( - layer_head_mask.view(1, 1, -1, 1, 1) - * predict_attn_probs.view( - self.ngram, batch_size, self.num_attn_heads, main_sequence_length, 2 * main_sequence_length - ) + predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs.view( + self.ngram, batch_size, self.num_attn_heads, main_sequence_length, 2 * main_sequence_length ) predict_attn_probs = predict_attn_probs.view( self.ngram, batch_size * self.num_attn_heads, main_sequence_length, 2 * main_sequence_length @@ -2054,7 +2047,7 @@ def forward( - 1 indicates the head is **not masked**, - 0 indicates the heas is **masked**. - + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. From 3d6153a225775df32170bc0f5c201f938c02fc7d Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Thu, 11 Mar 2021 00:38:02 +0100 Subject: [PATCH 5/6] Fix cross_head_mask --- .../models/prophetnet/modeling_prophetnet.py | 61 +++++++++++-------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 125120ca2ecd84..cc21608fc11a62 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -107,18 +107,24 @@ If you want to change padding behavior, you should read :func:`modeling_bart._prepare_decoder_inputs` and modify to your needs. See diagram 1 in `the paper `__ for more information on the default strategy. - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. + cross_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, @@ -161,11 +167,11 @@ - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned @@ -1150,7 +1156,7 @@ def forward( layer_state=None, attention_mask=None, layer_head_mask=None, - encoder_layer_head_mask=None, + cross_layer_head_mask=None, extended_predict_attention_mask=None, main_relative_position_buckets=None, predict_relative_position_buckets=None, @@ -1178,7 +1184,7 @@ def forward( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attn_mask, - layer_head_mask=encoder_layer_head_mask, + layer_head_mask=cross_layer_head_mask, layer_state=layer_state, # mutates layer state ) hidden_states = self.cross_attn_layer_norm(attention_output + hidden_states) @@ -1365,7 +1371,7 @@ def forward( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, @@ -1380,12 +1386,11 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: - - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. @@ -1500,10 +1505,12 @@ def forward( all_cross_attns = () if output_attentions and self.config.add_cross_attention else None present_key_values = () if use_cache else None - if head_mask is not None: - assert head_mask.size()[0] == ( - len(self.layers) - ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + # check if head_mask/cross_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_head_mask], ["head_mask", "cross_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == ( + len(self.layers) + ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: # grad cannot be kept because tensor is sliced @@ -1525,7 +1532,7 @@ def forward( layer_state=layer_state, attention_mask=extended_attention_mask, layer_head_mask=head_mask[idx] if head_mask is not None else None, - encoder_layer_head_mask=encoder_head_mask[idx] if encoder_head_mask is not None else None, + cross_layer_head_mask=cross_head_mask[idx] if cross_head_mask is not None else None, extended_predict_attention_mask=extended_predict_attention_mask, main_relative_position_buckets=main_relative_position_buckets, predict_relative_position_buckets=predict_relative_position_buckets, @@ -1698,6 +1705,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_head_mask=None, encoder_outputs: Optional[Tuple] = None, past_key_values=None, inputs_embeds=None, @@ -1750,7 +1758,7 @@ def forward( encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, - encoder_head_mask=head_mask, + cross_head_mask=cross_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, output_attentions=output_attentions, @@ -1810,6 +1818,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1855,6 +1864,7 @@ def forward( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_head_mask=cross_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -2025,7 +2035,7 @@ def forward( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_head_mask=None, past_key_values=None, inputs_embeds=None, labels=None, @@ -2041,12 +2051,11 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: - - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. @@ -2109,7 +2118,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, - encoder_head_mask=encoder_head_mask, + cross_head_mask=cross_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, From 84986014c974a49944523ad8aa241bc433b1dfb6 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Fri, 12 Mar 2021 22:04:51 +0100 Subject: [PATCH 6/6] Fix attention head mask naming * `cross_head_mask` -> `cross_attn_head_mask` * `cross_layer_head_mask` -> `cross_attn_layer_head_mask` * Still need to merge #10605 to master to pass the tests --- .../models/prophetnet/modeling_prophetnet.py | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index df43588e3bb29a..59ee1bb2434b14 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -119,7 +119,7 @@ - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. - cross_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, @@ -1181,7 +1181,7 @@ def forward( encoder_hidden_states=None, encoder_attn_mask=None, layer_head_mask=None, - cross_layer_head_mask=None, + cross_attn_layer_head_mask=None, extended_predict_attention_mask=None, main_relative_position_buckets=None, predict_relative_position_buckets=None, @@ -1214,7 +1214,7 @@ def forward( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attn_mask, - layer_head_mask=cross_layer_head_mask, + layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) @@ -1423,7 +1423,7 @@ def forward( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - cross_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, @@ -1438,7 +1438,7 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - cross_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, @@ -1552,8 +1552,8 @@ def forward( all_cross_attns = () if output_attentions and self.config.add_cross_attention else None present_key_values = () if use_cache else None - # check if head_mask/cross_head_mask has a correct number of layers specified if desired - for attn_mask, mask_name in zip([head_mask, cross_head_mask], ["head_mask", "cross_head_mask"]): + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): if attn_mask is not None: assert attn_mask.size()[0] == ( len(self.layers) @@ -1590,7 +1590,7 @@ def custom_forward(*inputs): encoder_hidden_states, extended_encoder_attention_mask, (head_mask[idx] if head_mask is not None else None), - (cross_head_mask[idx] if cross_head_mask is not None else None), + (cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), extended_predict_attention_mask, main_relative_position_buckets, predict_relative_position_buckets, @@ -1604,7 +1604,9 @@ def custom_forward(*inputs): encoder_hidden_states=encoder_hidden_states, encoder_attn_mask=extended_encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_layer_head_mask=(cross_head_mask[idx] if cross_head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), extended_predict_attention_mask=extended_predict_attention_mask, main_relative_position_buckets=main_relative_position_buckets, predict_relative_position_buckets=predict_relative_position_buckets, @@ -1782,7 +1784,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, - cross_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[Tuple] = None, past_key_values=None, inputs_embeds=None, @@ -1835,7 +1837,7 @@ def forward( encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, - cross_head_mask=cross_head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, output_attentions=output_attentions, @@ -1895,7 +1897,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, - cross_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1941,7 +1943,7 @@ def forward( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, - cross_head_mask=cross_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -2108,7 +2110,7 @@ def forward( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - cross_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, labels=None, @@ -2124,7 +2126,7 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - cross_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, @@ -2191,7 +2193,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, - cross_head_mask=cross_head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache,