From 11cb9088a20f01179507031d7037a7c356d6a499 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 12 Aug 2020 12:01:20 +0000 Subject: [PATCH 1/9] add tf graph compile tests --- tests/test_modeling_tf_common.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 478f8353eba1f5..3ba4b946602979 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -134,6 +134,19 @@ def test_save_load(self): self.assert_outputs_same(after_outputs, outputs) + def test_graph_mode(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + model = model_class(config) + inputs_dict = self._prepare_for_class(inputs_dict, model_class) + + @tf.function + def run_in_graph_mode(): + return model(inputs_dict) + + outputs = run_in_graph_mode() + self.assertIsNotNone(outputs) + @slow def test_saved_model_with_hidden_states_output(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() From c8713d45cb6fbd30a94ad49c36d74318a95ac7c3 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Fri, 14 Aug 2020 09:52:06 +0200 Subject: [PATCH 2/9] fix conflict --- src/transformers/modeling_tf_longformer.py | 124 +++++++++++++-------- tests/test_modeling_tf_longformer.py | 35 ++++-- 2 files changed, 103 insertions(+), 56 deletions(-) diff --git a/src/transformers/modeling_tf_longformer.py b/src/transformers/modeling_tf_longformer.py index 52768b5dd97514..ff3d4d80d79ffe 100644 --- a/src/transformers/modeling_tf_longformer.py +++ b/src/transformers/modeling_tf_longformer.py @@ -148,23 +148,21 @@ def call( """ # retrieve input args - hidden_states, attention_mask, output_attentions = inputs - - attention_mask = tf.squeeze(tf.squeeze(attention_mask, axis=2), axis=1) - # is index masked or global attention - - is_index_masked = tf.math.less(attention_mask, 0) - is_index_global_attn = tf.math.greater(attention_mask, 0) - is_global_attn = tf.math.reduce_any(is_index_global_attn) - - hidden_states = tf.transpose(hidden_states, (1, 0, 2)) + ( + hidden_states, + attention_mask, + is_index_masked, + is_index_global_attn, + is_global_attn, + output_attentions, + ) = inputs # project hidden states query_vectors = self.query(hidden_states) key_vectors = self.key(hidden_states) value_vectors = self.value(hidden_states) - seq_len, batch_size, embed_dim = shape_list(hidden_states) + batch_size, seq_len, embed_dim = shape_list(hidden_states) tf.debugging.assert_equal( embed_dim, self.embed_dim, @@ -174,12 +172,8 @@ def call( # normalize query query_vectors /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32)) - query_vectors = tf.transpose( - tf.reshape(query_vectors, (seq_len, batch_size, self.num_heads, self.head_dim)), (1, 0, 2, 3) - ) - key_vectors = tf.transpose( - tf.reshape(key_vectors, (seq_len, batch_size, self.num_heads, self.head_dim)), (1, 0, 2, 3) - ) + query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) + key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) # attn_probs = (batch_size, seq_len, num_heads, window*2+1) attn_scores = self._sliding_chunks_query_key_matmul( @@ -187,11 +181,11 @@ def call( ) # values to pad for attention probs - float_mask = tf.cast((attention_mask != 0)[:, :, None, None], dtype=tf.float32) * -10000.0 + # float_mask = tf.cast((attention_mask != 0)[:, :, None, None], dtype=tf.float32) * -10000.0 # diagonal mask with zeros everywhere and -inf inplace of padding diagonal_mask = self._sliding_chunks_query_key_matmul( - tf.ones(shape_list(float_mask), dtype=tf.float32), float_mask, self.one_sided_attn_window_size + tf.ones(shape_list(attention_mask), dtype=tf.float32), attention_mask, self.one_sided_attn_window_size ) # pad local attention probs @@ -237,9 +231,10 @@ def call( # apply dropout attn_probs = self.dropout(attn_probs, training=training) - value_vectors = tf.transpose( - tf.reshape(value_vectors, (seq_len, batch_size, self.num_heads, self.head_dim)), (1, 0, 2, 3) - ) + # value_vectors = tf.transpose( + # tf.reshape(value_vectors, (seq_len, batch_size, self.num_heads, self.head_dim)), (1, 0, 2, 3) + # ) + value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) # if global attention, compute sum of global and local attn attn_output = tf.cond( @@ -259,7 +254,7 @@ def call( tf.debugging.assert_equal( shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size" ) - attn_output = tf.reshape(tf.transpose(attn_output, (1, 0, 2, 3)), (seq_len, batch_size, embed_dim)) + attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim)) # compute value for global attention and overwrite to attention output # TODO: remove the redundant computation @@ -278,8 +273,6 @@ def call( lambda: attn_output, ) - attn_output = tf.transpose(attn_output, (1, 0, 2)) - # GLOBAL ATTN: # With global attention, return global attention probabilities only # batch_size x num_heads x max_num_global_attention_tokens x sequence_length @@ -291,6 +284,7 @@ def call( # without global attention, return local attention probabilities # batch_size x num_heads x sequence_length x window_size # which is the attention weights of every token attending to its neighbours + # TODO(PVP) - clean up the tf.transpose statements attn_probs = tf.cond( is_global_attn, lambda: self._get_global_attn_probs(attn_probs, max_num_global_attn_indices), @@ -502,12 +496,10 @@ def _pad_and_transpose_last_two_dims(hidden_states_padded, paddings): hidden_states_padded = tf.pad( hidden_states_padded, paddings ) # padding value is not important because it will be overwritten - if len(shape_list(hidden_states_padded)) > 3: - batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded) - hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length)) - else: - batch_size, seq_length, hidden_dim = shape_list(hidden_states_padded) - hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, hidden_dim, seq_length)) + + batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded) + hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length)) + return hidden_states_padded @staticmethod @@ -690,6 +682,10 @@ def _compute_global_attn_output_from_hidden( is_index_masked, training, ): + # TODO (PVP): clean up all those tf.transpose statements + hidden_states = tf.transpose(hidden_states, (1, 0, 2)) + attn_output = tf.transpose(attn_output, (1, 0, 2)) + seq_len, batch_size = shape_list(hidden_states)[:2] # prepare global hidden states @@ -792,6 +788,7 @@ def _compute_global_attn_output_from_hidden( attn_output = tf.tensor_scatter_nd_update( attn_output, tf.reverse(is_index_global_attn_nonzero, axis=[1]), nonzero_global_attn_output ) + attn_output = tf.transpose(attn_output, (1, 0, 2)) return attn_output @@ -806,10 +803,20 @@ def prune_heads(self, heads): raise NotImplementedError def call(self, inputs, training=False): - input_tensor, attention_mask, output_attentions = inputs + ( + hidden_states, + attention_mask, + is_index_masked, + is_index_global_attn, + is_global_attn, + output_attentions, + ) = inputs - self_outputs = self.self_attention([input_tensor, attention_mask, output_attentions], training=training) - attention_output = self.dense_output(self_outputs[0], input_tensor, training=training) + self_outputs = self.self_attention( + [hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, output_attentions], + training=training, + ) + attention_output = self.dense_output(self_outputs[0], hidden_states, training=training) outputs = (attention_output,) + self_outputs[1:] return outputs @@ -823,9 +830,19 @@ def __init__(self, config, layer_id=0, **kwargs): self.longformer_output = TFBertOutput(config, name="output") def call(self, inputs, training=False): - hidden_states, attention_mask, output_attentions = inputs + ( + hidden_states, + attention_mask, + is_index_masked, + is_index_global_attn, + is_global_attn, + output_attentions, + ) = inputs - attention_outputs = self.attention([hidden_states, attention_mask, output_attentions], training=training) + attention_outputs = self.attention( + [hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, output_attentions], + training=training, + ) attention_output = attention_outputs[0] intermediate_output = self.intermediate(attention_output) layer_output = self.longformer_output(intermediate_output, attention_output, training=training) @@ -848,12 +865,14 @@ def call( attention_mask=None, head_mask=None, padding_len=0, + is_index_masked=None, + is_index_global_attn=None, + is_global_attn=None, output_attentions=None, output_hidden_states=None, return_dict=None, training=False, ): - all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None for i, layer_module in enumerate(self.layer): @@ -861,7 +880,17 @@ def call( hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states all_hidden_states = all_hidden_states + (hidden_states_to_add,) - layer_outputs = layer_module([hidden_states, attention_mask, output_attentions], training=training) + layer_outputs = layer_module( + [ + hidden_states, + attention_mask, + is_index_masked, + is_index_global_attn, + is_global_attn, + output_attentions, + ], + training=training, + ) hidden_states = layer_outputs[0] if output_attentions: @@ -991,27 +1020,32 @@ def call( pad_token_id=self.pad_token_id, ) + # is index masked or global attention + is_index_masked = tf.math.less(attention_mask, 1) + is_index_global_attn = tf.math.greater(attention_mask, 1) + is_global_attn = tf.math.reduce_any(is_index_global_attn) # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] + # Sizes are [batch_size, to_seq_length, 1, 1] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] + extended_attention_mask = attention_mask[:, :, tf.newaxis, tf.newaxis] - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for + # Since attention_mask is 1.0 for positions we want to locall attend locally and 0.0 for + # masked and global attn positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - - extended_attention_mask = tf.cast(extended_attention_mask, tf.float32) - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + extended_attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0 embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, padding_len=padding_len, + is_index_masked=is_index_masked, + is_index_global_attn=is_index_global_attn, + is_global_attn=is_global_attn, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, diff --git a/tests/test_modeling_tf_longformer.py b/tests/test_modeling_tf_longformer.py index 090a61c84b364f..1282069b031318 100644 --- a/tests/test_modeling_tf_longformer.py +++ b/tests/test_modeling_tf_longformer.py @@ -385,15 +385,16 @@ def test_pad_and_transpose_last_two_dims(self): self.assertTrue(shape_list(hidden_states), [1, 8, 4]) # pad along seq length dim - paddings = tf.constant([[0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32) + paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32) + hidden_states = TFLongformerSelfAttention._chunk(hidden_states, window_overlap=2) padded_hidden_states = TFLongformerSelfAttention._pad_and_transpose_last_two_dims(hidden_states, paddings) - self.assertTrue(shape_list(padded_hidden_states) == [1, 8, 5]) + self.assertTrue(shape_list(padded_hidden_states) == [1, 1, 8, 5]) expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32) - tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, -1, :], rtol=1e-6) + tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, 0, -1, :], rtol=1e-6) tf.debugging.assert_near( - hidden_states[0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6 + hidden_states[0, 0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6 ) def test_mask_invalid_locations(self): @@ -437,10 +438,16 @@ def test_layer_local_attn(self): hidden_states = self._get_hidden_states() batch_size, seq_length, hidden_size = hidden_states.shape - attention_mask = tf.zeros((batch_size, 1, 1, seq_length), dtype=tf.dtypes.float32) - attention_mask = tf.where(tf.range(4)[None, None, None, :] > 1, -10000.0, attention_mask) + attention_mask = tf.zeros((batch_size, seq_length), dtype=tf.dtypes.float32) + is_index_global_attn = tf.math.greater(attention_mask, 1) + is_global_attn = tf.math.reduce_any(is_index_global_attn) + + attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0, attention_mask[:, :, None, None]) + is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0) - output_hidden_states = layer([hidden_states, attention_mask, None])[0] + output_hidden_states = layer( + [hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, None] + )[0] expected_slice = tf.convert_to_tensor( [0.00188, 0.012196, -0.017051, -0.025571, -0.02996, 0.017297, -0.011521, 0.004848], dtype=tf.dtypes.float32 @@ -461,12 +468,18 @@ def test_layer_global_attn(self): attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32) attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32) - attention_mask_1 = tf.where(tf.range(4)[None, None, None, :] > 1, 10000.0, attention_mask_1) - attention_mask_1 = tf.where(tf.range(4)[None, None, None, :] > 2, -10000.0, attention_mask_1) - attention_mask_2 = tf.where(tf.range(4)[None, None, None, :] > 0, 10000.0, attention_mask_2) + attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 1, 10000.0, attention_mask_1) + attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 2, -10000.0, attention_mask_1) + attention_mask_2 = tf.where(tf.range(4)[None, :, None, None] > 0, 10000.0, attention_mask_2) attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0) - output_hidden_states = layer([hidden_states, attention_mask, None])[0] + is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0) + is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0) + is_global_attn = tf.math.reduce_any(is_index_global_attn) + + output_hidden_states = layer( + [hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn, None] + )[0] self.assertTrue(output_hidden_states.shape, (2, 4, 8)) expected_slice_0 = tf.convert_to_tensor( From 659ee08e1f0e2a2cab577472b56324b2d8b93d71 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 13 Aug 2020 19:58:21 +0200 Subject: [PATCH 3/9] remove more tf transpose statements --- src/transformers/modeling_tf_bert.py | 2 +- src/transformers/modeling_tf_electra.py | 2 +- src/transformers/modeling_tf_longformer.py | 47 +++++++--------------- tests/test_modeling_tf_common.py | 13 +----- 4 files changed, 19 insertions(+), 45 deletions(-) diff --git a/src/transformers/modeling_tf_bert.py b/src/transformers/modeling_tf_bert.py index f40414826ad4c5..1417c3b879f322 100644 --- a/src/transformers/modeling_tf_bert.py +++ b/src/transformers/modeling_tf_bert.py @@ -1088,7 +1088,7 @@ def __init__(self, config, *inputs, **kwargs): config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" ) - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-cased", diff --git a/src/transformers/modeling_tf_electra.py b/src/transformers/modeling_tf_electra.py index 05d7469996be01..d13535fb569643 100644 --- a/src/transformers/modeling_tf_electra.py +++ b/src/transformers/modeling_tf_electra.py @@ -677,7 +677,7 @@ def __init__(self, config, *inputs, **kwargs): self.electra = TFElectraMainLayer(config, name="electra") self.classifier = TFElectraClassificationHead(config, name="classifier") - @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)")) + @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-discriminator", diff --git a/src/transformers/modeling_tf_longformer.py b/src/transformers/modeling_tf_longformer.py index ff3d4d80d79ffe..78f3158336dd3e 100644 --- a/src/transformers/modeling_tf_longformer.py +++ b/src/transformers/modeling_tf_longformer.py @@ -231,9 +231,6 @@ def call( # apply dropout attn_probs = self.dropout(attn_probs, training=training) - # value_vectors = tf.transpose( - # tf.reshape(value_vectors, (seq_len, batch_size, self.num_heads, self.head_dim)), (1, 0, 2, 3) - # ) value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) # if global attention, compute sum of global and local attn @@ -682,18 +679,14 @@ def _compute_global_attn_output_from_hidden( is_index_masked, training, ): - # TODO (PVP): clean up all those tf.transpose statements - hidden_states = tf.transpose(hidden_states, (1, 0, 2)) - attn_output = tf.transpose(attn_output, (1, 0, 2)) - - seq_len, batch_size = shape_list(hidden_states)[:2] + batch_size, seq_len = shape_list(hidden_states)[:2] # prepare global hidden states - global_attn_hidden_states = tf.gather_nd(hidden_states, tf.reverse(is_index_global_attn_nonzero, axis=[1])) + global_attn_hidden_states = tf.gather_nd(hidden_states, is_index_global_attn_nonzero) global_attn_hidden_states = tf.scatter_nd( - tf.reverse(is_local_index_global_attn_nonzero, axis=[1]), + is_local_index_global_attn_nonzero, global_attn_hidden_states, - shape=(max_num_global_attn_indices, batch_size, self.embed_dim), + shape=(batch_size, max_num_global_attn_indices, self.embed_dim), ) # global key, query, value @@ -704,27 +697,18 @@ def _compute_global_attn_output_from_hidden( # normalize global_query_vectors_only_global /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32)) - # (batch_size * self.num_heads, max_num_global_attn_indices, head_dim) - global_query_vectors_only_global = tf.transpose( - tf.reshape( - global_query_vectors_only_global, - (max_num_global_attn_indices, batch_size * self.num_heads, self.head_dim), - ), - (1, 0, 2), - ) - - # (..., batch_size * self.num_heads, seq_len, head_dim) - global_key_vectors = tf.transpose( - tf.reshape(global_key_vectors, (-1, batch_size * self.num_heads, self.head_dim)), (1, 0, 2) - ) + def reshape_and_transpose(vector): + return tf.reshape( + tf.transpose(tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)), (0, 2, 1, 3)), + (batch_size * self.num_heads, -1, self.head_dim), + ) - # (..., batch_size * self.num_heads, seq_len, head_dim) - global_value_vectors = tf.transpose( - tf.reshape(global_value_vectors, (-1, batch_size * self.num_heads, self.head_dim)), (1, 0, 2) - ) + global_query_vectors_only_global = reshape_and_transpose(global_query_vectors_only_global) + global_key_vectors = reshape_and_transpose(global_key_vectors) + global_value_vectors = reshape_and_transpose(global_value_vectors) # compute attn scores - global_attn_scores = tf.matmul(global_query_vectors_only_global, tf.transpose(global_key_vectors, (0, 2, 1))) + global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True) tf.debugging.assert_equal( shape_list(global_attn_scores), @@ -785,11 +769,10 @@ def _compute_global_attn_output_from_hidden( ) # overwrite values with global attention + attn_output = tf.tensor_scatter_nd_update( - attn_output, tf.reverse(is_index_global_attn_nonzero, axis=[1]), nonzero_global_attn_output + attn_output, is_index_global_attn_nonzero, nonzero_global_attn_output ) - attn_output = tf.transpose(attn_output, (1, 0, 2)) - return attn_output diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 3ba4b946602979..298105ffe99bd7 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -110,15 +110,6 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def test_initialization(self): pass - # config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - # configs_no_init = _config_zero_init(config) - # for model_class in self.all_model_classes: - # model = model_class(config=configs_no_init) - # for name, param in model.named_parameters(): - # if param.requires_grad: - # self.assertIn(param.data.mean().item(), [0.0, 1.0], - # msg="Parameter {} of model {} seems not properly initialized".format(name, model_class)) def test_save_load(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -137,12 +128,12 @@ def test_save_load(self): def test_graph_mode(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: + inputs = self._prepare_for_class(inputs_dict, model_class) model = model_class(config) - inputs_dict = self._prepare_for_class(inputs_dict, model_class) @tf.function def run_in_graph_mode(): - return model(inputs_dict) + return model(inputs) outputs = run_in_graph_mode() self.assertIsNotNone(outputs) From 3563c37db04ababdeb89e4621999bf30ee665d39 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Fri, 14 Aug 2020 09:52:53 +0200 Subject: [PATCH 4/9] fix conflicts --- src/transformers/modeling_tf_longformer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_tf_longformer.py b/src/transformers/modeling_tf_longformer.py index 78f3158336dd3e..22f89f417e9df0 100644 --- a/src/transformers/modeling_tf_longformer.py +++ b/src/transformers/modeling_tf_longformer.py @@ -285,7 +285,7 @@ def call( attn_probs = tf.cond( is_global_attn, lambda: self._get_global_attn_probs(attn_probs, max_num_global_attn_indices), - lambda: tf.transpose(attn_probs, (0, 2, 1, 3)), + lambda: attn_probs, ) outputs = (attn_output, attn_probs) @@ -301,7 +301,6 @@ def _get_global_attn_probs(attn_probs, max_num_global_attn_indices): ], axis=-1, ) - attn_probs = tf.transpose(attn_probs, (0, 2, 1, 3)) return attn_probs def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): @@ -877,7 +876,7 @@ def call( hidden_states = layer_outputs[0] if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) + all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),) # Add last layer if output_hidden_states: From 6ab38ab890e80e461e15900fb1ada13eb3e26b83 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Fri, 14 Aug 2020 09:58:45 +0200 Subject: [PATCH 5/9] fix comment typos --- src/transformers/modeling_tf_longformer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/transformers/modeling_tf_longformer.py b/src/transformers/modeling_tf_longformer.py index 22f89f417e9df0..1fe98ca7d162e6 100644 --- a/src/transformers/modeling_tf_longformer.py +++ b/src/transformers/modeling_tf_longformer.py @@ -180,9 +180,6 @@ def call( query_vectors, key_vectors, self.one_sided_attn_window_size ) - # values to pad for attention probs - # float_mask = tf.cast((attention_mask != 0)[:, :, None, None], dtype=tf.float32) * -10000.0 - # diagonal mask with zeros everywhere and -inf inplace of padding diagonal_mask = self._sliding_chunks_query_key_matmul( tf.ones(shape_list(attention_mask), dtype=tf.float32), attention_mask, self.one_sided_attn_window_size @@ -281,7 +278,6 @@ def call( # without global attention, return local attention probabilities # batch_size x num_heads x sequence_length x window_size # which is the attention weights of every token attending to its neighbours - # TODO(PVP) - clean up the tf.transpose statements attn_probs = tf.cond( is_global_attn, lambda: self._get_global_attn_probs(attn_probs, max_num_global_attn_indices), From c24d8944f5aef09ab6cc961124f4b1a2ad96e52a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 14 Aug 2020 10:45:02 +0000 Subject: [PATCH 6/9] move function to class function --- src/transformers/modeling_tf_longformer.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/transformers/modeling_tf_longformer.py b/src/transformers/modeling_tf_longformer.py index 1fe98ca7d162e6..630af56fd0cf19 100644 --- a/src/transformers/modeling_tf_longformer.py +++ b/src/transformers/modeling_tf_longformer.py @@ -692,15 +692,9 @@ def _compute_global_attn_output_from_hidden( # normalize global_query_vectors_only_global /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32)) - def reshape_and_transpose(vector): - return tf.reshape( - tf.transpose(tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)), (0, 2, 1, 3)), - (batch_size * self.num_heads, -1, self.head_dim), - ) - - global_query_vectors_only_global = reshape_and_transpose(global_query_vectors_only_global) - global_key_vectors = reshape_and_transpose(global_key_vectors) - global_value_vectors = reshape_and_transpose(global_value_vectors) + global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size) + global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size) + global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size) # compute attn scores global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True) @@ -770,6 +764,13 @@ def reshape_and_transpose(vector): ) return attn_output + def reshape_and_transpose(self, vector, batch_size): + return tf.reshape( + tf.transpose(tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)), (0, 2, 1, 3)), + (batch_size * self.num_heads, -1, self.head_dim), + ) + + class TFLongformerAttention(tf.keras.layers.Layer): def __init__(self, config, layer_id=0, **kwargs): From 15d8841c8e93c6cdc156b41b15b40cd0be976f29 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 26 Aug 2020 18:24:09 +0200 Subject: [PATCH 7/9] fix black --- src/transformers/modeling_tf_longformer.py | 546 ++++++++++++++++----- 1 file changed, 419 insertions(+), 127 deletions(-) diff --git a/src/transformers/modeling_tf_longformer.py b/src/transformers/modeling_tf_longformer.py index 630af56fd0cf19..fc147751f56e34 100644 --- a/src/transformers/modeling_tf_longformer.py +++ b/src/transformers/modeling_tf_longformer.py @@ -17,8 +17,17 @@ import tensorflow as tf from .configuration_longformer import LongformerConfig -from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_bert import TFBertIntermediate, TFBertOutput, TFBertPooler, TFBertSelfOutput +from .file_utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_callable, +) +from .modeling_tf_bert import ( + TFBertIntermediate, + TFBertOutput, + TFBertPooler, + TFBertSelfOutput, +) from .modeling_tf_outputs import ( TFBaseModelOutput, TFBaseModelOutputWithPooling, @@ -53,7 +62,9 @@ ] -def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_sep_token=True): +def _compute_global_attention_mask( + input_ids_shape, sep_token_indices, before_sep_token=True +): """ Computes global attention mask by putting attention on all tokens before `sep_token_id` if `before_sep_token is True` else after @@ -61,13 +72,18 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se """ assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions" - question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1] - question_end_index = tf.cast(question_end_index[:, None], tf.dtypes.int32) # size: batch_size x 1 + question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[ + :, 0, 1 + ] + question_end_index = tf.cast( + question_end_index[:, None], tf.dtypes.int32 + ) # size: batch_size x 1 # bool attention mask with True in locations of global attention attention_mask = tf.range(input_ids_shape[1]) if before_sep_token is True: attention_mask = tf.cast( - tf.broadcast_to(attention_mask, input_ids_shape) < tf.broadcast_to(question_end_index, input_ids_shape), + tf.broadcast_to(attention_mask, input_ids_shape) + < tf.broadcast_to(question_end_index, input_ids_shape), tf.dtypes.int32, ) else: @@ -97,28 +113,42 @@ def __init__(self, config, layer_id, **kwargs): self.embed_dim = config.hidden_size self.query = tf.keras.layers.Dense( - self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="query" + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="query", ) self.key = tf.keras.layers.Dense( - self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="key" + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="key", ) self.value = tf.keras.layers.Dense( - self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="value" + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="value", ) # separate projection layers for tokens with global attention self.query_global = tf.keras.layers.Dense( - self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="query_global" + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="query_global", ) self.key_global = tf.keras.layers.Dense( - self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="key_global" + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="key_global", ) self.value_global = tf.keras.layers.Dense( - self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="value_global" + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="value_global", ) self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) - self.global_dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) + self.global_dropout = tf.keras.layers.Dropout( + config.attention_probs_dropout_prob + ) self.layer_id = layer_id @@ -170,10 +200,16 @@ def call( ) # normalize query - query_vectors /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32)) + query_vectors /= tf.math.sqrt( + tf.constant(self.head_dim, dtype=tf.dtypes.float32) + ) - query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) - key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) + query_vectors = tf.reshape( + query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim) + ) + key_vectors = tf.reshape( + key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim) + ) # attn_probs = (batch_size, seq_len, num_heads, window*2+1) attn_scores = self._sliding_chunks_query_key_matmul( @@ -182,7 +218,9 @@ def call( # diagonal mask with zeros everywhere and -inf inplace of padding diagonal_mask = self._sliding_chunks_query_key_matmul( - tf.ones(shape_list(attention_mask), dtype=tf.float32), attention_mask, self.one_sided_attn_window_size + tf.ones(shape_list(attention_mask), dtype=tf.float32), + attention_mask, + self.one_sided_attn_window_size, ) # pad local attention probs @@ -190,7 +228,12 @@ def call( tf.debugging.assert_equal( shape_list(attn_scores), - [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1], + [ + batch_size, + seq_len, + self.num_heads, + self.one_sided_attn_window_size * 2 + 1, + ], message=f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}", ) @@ -222,13 +265,17 @@ def call( # softmax sometimes inserts NaN if all positions are masked, replace them with 0 attn_probs = tf.where( - tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)), 0.0, attn_probs + tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)), + 0.0, + attn_probs, ) # apply dropout attn_probs = self.dropout(attn_probs, training=training) - value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) + value_vectors = tf.reshape( + value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim) + ) # if global attention, compute sum of global and local attn attn_output = tf.cond( @@ -246,7 +293,9 @@ def call( ) tf.debugging.assert_equal( - shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size" + shape_list(attn_output), + [batch_size, seq_len, self.num_heads, self.head_dim], + message="Unexpected size", ) attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim)) @@ -280,7 +329,9 @@ def call( # which is the attention weights of every token attending to its neighbours attn_probs = tf.cond( is_global_attn, - lambda: self._get_global_attn_probs(attn_probs, max_num_global_attn_indices), + lambda: self._get_global_attn_probs( + attn_probs, max_num_global_attn_indices + ), lambda: attn_probs, ) @@ -318,8 +369,13 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): chunks_count = seq_len // window_overlap - 1 # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2 - query = tf.reshape(tf.transpose(query, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim)) - key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim)) + query = tf.reshape( + tf.transpose(query, (0, 2, 1, 3)), + (batch_size * num_heads, seq_len, head_dim), + ) + key = tf.reshape( + tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim) + ) chunked_query = self._chunk(query, window_overlap) chunked_key = self._chunk(key, window_overlap) @@ -328,11 +384,15 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim # bcxy: batch_size * num_heads x chunks x 2window_overlap x window_overlap - chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query, chunked_key) # multiply + chunked_attention_scores = tf.einsum( + "bcxd,bcyd->bcxy", chunked_query, chunked_key + ) # multiply # convert diagonals into columns paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32) - diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(chunked_attention_scores, paddings) + diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims( + chunked_attention_scores, paddings + ) # allocate space for the overall attention matrix where the chunks are combined. The last dimension # has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to @@ -344,8 +404,12 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): # TODO: This code is most likely not very efficient and should be improved diagonal_attn_scores_up_triang = tf.concat( [ - diagonal_chunked_attention_scores[:, :, :window_overlap, : window_overlap + 1], - diagonal_chunked_attention_scores[:, -1:, window_overlap:, : window_overlap + 1], + diagonal_chunked_attention_scores[ + :, :, :window_overlap, : window_overlap + 1 + ], + diagonal_chunked_attention_scores[ + :, -1:, window_overlap:, : window_overlap + 1 + ], ], axis=1, ) @@ -354,15 +418,19 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): diagonal_attn_scores_low_triang = tf.concat( [ tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)), - diagonal_chunked_attention_scores[:, :, -(window_overlap + 1) : -1, window_overlap + 1 :], + diagonal_chunked_attention_scores[ + :, :, -(window_overlap + 1) : -1, window_overlap + 1 : + ], ], axis=1, ) diagonal_attn_scores_first_chunk = tf.concat( [ - tf.roll(diagonal_chunked_attention_scores, shift=[1, window_overlap], axis=[2, 3])[ - :, :, :window_overlap, :window_overlap - ], + tf.roll( + diagonal_chunked_attention_scores, + shift=[1, window_overlap], + axis=[2, 3], + )[:, :, :window_overlap, :window_overlap], tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)), ], axis=1, @@ -371,13 +439,20 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): first_chunk_mask = ( tf.broadcast_to( tf.range(chunks_count + 1)[None, :, None, None], - shape=(batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap), + shape=( + batch_size * num_heads, + chunks_count + 1, + window_overlap, + window_overlap, + ), ) < 1 ) diagonal_attn_scores_low_triang = tf.where( - first_chunk_mask, diagonal_attn_scores_first_chunk, diagonal_attn_scores_low_triang + first_chunk_mask, + diagonal_attn_scores_first_chunk, + diagonal_attn_scores_low_triang, ) # merging upper and lower triangle @@ -387,22 +462,33 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): # separate batch_size and num_heads dimensions again diagonal_attention_scores = tf.transpose( - tf.reshape(diagonal_attention_scores, (batch_size, num_heads, seq_len, 2 * window_overlap + 1)), + tf.reshape( + diagonal_attention_scores, + (batch_size, num_heads, seq_len, 2 * window_overlap + 1), + ), (0, 2, 1, 3), ) - diagonal_attention_scores = self._mask_invalid_locations(diagonal_attention_scores, window_overlap) + diagonal_attention_scores = self._mask_invalid_locations( + diagonal_attention_scores, window_overlap + ) return diagonal_attention_scores @staticmethod def _mask_invalid_locations(input_tensor, window_overlap): # create correct upper triangle bool mask mask_2d_upper = tf.reverse( - tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0), axis=[0] + tf.linalg.band_part( + tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0 + ), + axis=[0], ) # pad to full matrix padding = tf.constant( - [[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]] + [ + [0, shape_list(input_tensor)[1] - window_overlap], + [0, shape_list(input_tensor)[3] - window_overlap - 1], + ] ) # create lower mask @@ -421,7 +507,9 @@ def _mask_invalid_locations(input_tensor, window_overlap): return input_tensor - def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_overlap): + def _sliding_chunks_matmul_attn_probs_value( + self, attn_probs, value, window_overlap + ): """Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the same shape as `attn_probs`""" @@ -429,7 +517,9 @@ def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_over batch_size, seq_len, num_heads, head_dim = shape_list(value) tf.debugging.assert_equal( - seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap" + seq_len % (window_overlap * 2), + 0, + message="Seq_len has to be multiple of 2 * window_overlap", ) tf.debugging.assert_equal( shape_list(attn_probs)[:3], @@ -447,27 +537,42 @@ def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_over chunked_attn_probs = tf.reshape( tf.transpose(attn_probs, (0, 2, 1, 3)), - (batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1), + ( + batch_size * num_heads, + seq_len // window_overlap, + window_overlap, + 2 * window_overlap + 1, + ), ) # group batch_size and num_heads dimensions into one - value = tf.reshape(tf.transpose(value, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim)) + value = tf.reshape( + tf.transpose(value, (0, 2, 1, 3)), + (batch_size * num_heads, seq_len, head_dim), + ) # pad seq_len with w at the beginning of the sequence and another window overlap at the end - paddings = tf.constant([[0, 0], [window_overlap, window_overlap], [0, 0]], dtype=tf.dtypes.int32) + paddings = tf.constant( + [[0, 0], [window_overlap, window_overlap], [0, 0]], dtype=tf.dtypes.int32 + ) padded_value = tf.pad(value, paddings, constant_values=-1) # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap frame_size = 3 * window_overlap * head_dim - frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count + frame_hop_size = ( + shape_list(padded_value)[1] * head_dim - frame_size + ) // chunks_count chunked_value = tf.signal.frame( - tf.reshape(padded_value, (batch_size * num_heads, -1)), frame_size, frame_hop_size + tf.reshape(padded_value, (batch_size * num_heads, -1)), + frame_size, + frame_hop_size, ) chunked_value = tf.reshape( - chunked_value, (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim) + chunked_value, + (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim), ) tf.debugging.assert_equal( @@ -479,7 +584,10 @@ def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_over chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs) context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value) - context = tf.transpose(tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)), (0, 2, 1, 3)) + context = tf.transpose( + tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)), + (0, 2, 1, 3), + ) return context @staticmethod @@ -489,8 +597,12 @@ def _pad_and_transpose_last_two_dims(hidden_states_padded, paddings): hidden_states_padded, paddings ) # padding value is not important because it will be overwritten - batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded) - hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length)) + batch_size, chunk_size, seq_length, hidden_dim = shape_list( + hidden_states_padded + ) + hidden_states_padded = tf.reshape( + hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length) + ) return hidden_states_padded @@ -509,7 +621,9 @@ def _pad_and_diagonalize(chunked_hidden_states): 0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ] """ - total_num_heads, num_chunks, window_overlap, hidden_dim = shape_list(chunked_hidden_states) + total_num_heads, num_chunks, window_overlap, hidden_dim = shape_list( + chunked_hidden_states + ) paddings = tf.constant([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]]) chunked_hidden_states = tf.pad( @@ -523,7 +637,8 @@ def _pad_and_diagonalize(chunked_hidden_states): :, :, :-window_overlap ] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap chunked_hidden_states = tf.reshape( - chunked_hidden_states, (total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim) + chunked_hidden_states, + (total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim), ) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap chunked_hidden_states = chunked_hidden_states[:, :, :, :-1] return chunked_hidden_states @@ -541,7 +656,9 @@ def _chunk(hidden_states, window_overlap): hidden_states = tf.reshape(hidden_states, (batch_size, seq_length * hidden_dim)) # chunk with overlap - chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size) + chunked_hidden_states = tf.signal.frame( + hidden_states, frame_size, frame_hop_size + ) tf.debugging.assert_equal( shape_list(chunked_hidden_states), @@ -550,7 +667,8 @@ def _chunk(hidden_states, window_overlap): ) chunked_hidden_states = tf.reshape( - chunked_hidden_states, (batch_size, num_output_chunks, 2 * window_overlap, hidden_dim) + chunked_hidden_states, + (batch_size, num_output_chunks, 2 * window_overlap, hidden_dim), ) return chunked_hidden_states @@ -559,7 +677,9 @@ def _chunk(hidden_states, window_overlap): def _get_global_attn_indices(is_index_global_attn): """ compute global attn indices required throughout forward pass """ # helper variable - num_global_attn_indices = tf.reduce_sum(tf.cast(is_index_global_attn, dtype=tf.dtypes.int32), axis=1) + num_global_attn_indices = tf.reduce_sum( + tf.cast(is_index_global_attn, dtype=tf.dtypes.int32), axis=1 + ) # max number of global attn indices in batch max_num_global_attn_indices = tf.reduce_max(num_global_attn_indices) @@ -568,15 +688,17 @@ def _get_global_attn_indices(is_index_global_attn): is_index_global_attn_nonzero = tf.where(is_index_global_attn) # helper variable - is_local_index_global_attn = tf.range(max_num_global_attn_indices) < tf.expand_dims( - num_global_attn_indices, axis=-1 - ) + is_local_index_global_attn = tf.range( + max_num_global_attn_indices + ) < tf.expand_dims(num_global_attn_indices, axis=-1) # location of the non-padding values within global attention indices is_local_index_global_attn_nonzero = tf.where(is_local_index_global_attn) # location of the padding values within global attention indices - is_local_index_no_global_attn_nonzero = tf.where(tf.math.logical_not(is_local_index_global_attn)) + is_local_index_no_global_attn_nonzero = tf.where( + tf.math.logical_not(is_local_index_global_attn) + ) return ( max_num_global_attn_indices, @@ -603,13 +725,22 @@ def _concat_with_global_key_attn_probs( key_vectors_only_global = tf.scatter_nd( is_local_index_global_attn_nonzero, global_key_vectors, - shape=(batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim), + shape=( + batch_size, + max_num_global_attn_indices, + self.num_heads, + self.head_dim, + ), ) # (batch_size, seq_len, num_heads, max_num_global_attn_indices) - attn_probs_from_global_key = tf.einsum("blhd,bshd->blhs", query_vectors, key_vectors_only_global) + attn_probs_from_global_key = tf.einsum( + "blhd,bshd->blhs", query_vectors, key_vectors_only_global + ) # (batch_size, max_num_global_attn_indices, seq_len, num_heads) - attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key, (0, 3, 1, 2)) + attn_probs_from_global_key_trans = tf.transpose( + attn_probs_from_global_key, (0, 3, 1, 2) + ) mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple( shape_list(attn_probs_from_global_key_trans)[-2:] ) @@ -617,11 +748,15 @@ def _concat_with_global_key_attn_probs( # scatter mask attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update( - attn_probs_from_global_key_trans, is_local_index_no_global_attn_nonzero, mask + attn_probs_from_global_key_trans, + is_local_index_no_global_attn_nonzero, + mask, ) # (batch_size, seq_len, num_heads, max_num_global_attn_indices) - attn_probs_from_global_key = tf.transpose(attn_probs_from_global_key_trans, (0, 2, 3, 1)) + attn_probs_from_global_key = tf.transpose( + attn_probs_from_global_key_trans, (0, 2, 3, 1) + ) # concat to attn_probs # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) @@ -648,11 +783,18 @@ def _compute_attn_output_with_global_indices( value_vectors_only_global = tf.scatter_nd( is_local_index_global_attn_nonzero, global_value_vectors, - shape=(batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim), + shape=( + batch_size, + max_num_global_attn_indices, + self.num_heads, + self.head_dim, + ), ) # compute attn output only global - attn_output_only_global = tf.einsum("blhs,bshd->blhd", attn_probs_only_global, value_vectors_only_global) + attn_output_only_global = tf.einsum( + "blhs,bshd->blhd", attn_probs_only_global, value_vectors_only_global + ) # reshape attn probs attn_probs_without_global = attn_probs[:, :, :, max_num_global_attn_indices:] @@ -677,7 +819,9 @@ def _compute_global_attn_output_from_hidden( batch_size, seq_len = shape_list(hidden_states)[:2] # prepare global hidden states - global_attn_hidden_states = tf.gather_nd(hidden_states, is_index_global_attn_nonzero) + global_attn_hidden_states = tf.gather_nd( + hidden_states, is_index_global_attn_nonzero + ) global_attn_hidden_states = tf.scatter_nd( is_local_index_global_attn_nonzero, global_attn_hidden_states, @@ -690,14 +834,22 @@ def _compute_global_attn_output_from_hidden( global_value_vectors = self.value_global(hidden_states) # normalize - global_query_vectors_only_global /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32)) + global_query_vectors_only_global /= tf.math.sqrt( + tf.constant(self.head_dim, dtype=tf.dtypes.float32) + ) - global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size) + global_query_vectors_only_global = self.reshape_and_transpose( + global_query_vectors_only_global, batch_size + ) global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size) - global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size) + global_value_vectors = self.reshape_and_transpose( + global_value_vectors, batch_size + ) # compute attn scores - global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True) + global_attn_scores = tf.matmul( + global_query_vectors_only_global, global_key_vectors, transpose_b=True + ) tf.debugging.assert_equal( shape_list(global_attn_scores), @@ -706,7 +858,8 @@ def _compute_global_attn_output_from_hidden( ) global_attn_scores = tf.reshape( - global_attn_scores, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len) + global_attn_scores, + (batch_size, self.num_heads, max_num_global_attn_indices, seq_len), ) global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3)) @@ -717,23 +870,30 @@ def _compute_global_attn_output_from_hidden( # scatter mask global_attn_scores_trans = tf.tensor_scatter_nd_update( - global_attn_scores_trans, is_local_index_no_global_attn_nonzero, global_attn_mask + global_attn_scores_trans, + is_local_index_no_global_attn_nonzero, + global_attn_mask, ) global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3)) # mask global attn scores - attn_mask = tf.broadcast_to(is_index_masked[:, None, None, :], shape_list(global_attn_scores)) + attn_mask = tf.broadcast_to( + is_index_masked[:, None, None, :], shape_list(global_attn_scores) + ) global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores) global_attn_scores = tf.reshape( - global_attn_scores, (batch_size * self.num_heads, max_num_global_attn_indices, seq_len) + global_attn_scores, + (batch_size * self.num_heads, max_num_global_attn_indices, seq_len), ) # compute global attn probs global_attn_probs_float = tf.nn.softmax(global_attn_scores, axis=-1) # dropout - global_attn_probs = self.global_dropout(global_attn_probs_float, training=training) + global_attn_probs = self.global_dropout( + global_attn_probs_float, training=training + ) # global attn output global_attn_output = tf.matmul(global_attn_probs, global_value_vectors) @@ -745,12 +905,14 @@ def _compute_global_attn_output_from_hidden( ) global_attn_output = tf.reshape( - global_attn_output, (batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim) + global_attn_output, + (batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim), ) # get only non zero global attn output nonzero_global_attn_output = tf.gather_nd( - tf.transpose(global_attn_output, (0, 2, 1, 3)), is_local_index_global_attn_nonzero + tf.transpose(global_attn_output, (0, 2, 1, 3)), + is_local_index_global_attn_nonzero, ) nonzero_global_attn_output = tf.reshape( nonzero_global_attn_output, @@ -766,12 +928,14 @@ def _compute_global_attn_output_from_hidden( def reshape_and_transpose(self, vector, batch_size): return tf.reshape( - tf.transpose(tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)), (0, 2, 1, 3)), + tf.transpose( + tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)), + (0, 2, 1, 3), + ), (batch_size * self.num_heads, -1, self.head_dim), ) - class TFLongformerAttention(tf.keras.layers.Layer): def __init__(self, config, layer_id=0, **kwargs): super().__init__(**kwargs) @@ -792,10 +956,19 @@ def call(self, inputs, training=False): ) = inputs self_outputs = self.self_attention( - [hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, output_attentions], + [ + hidden_states, + attention_mask, + is_index_masked, + is_index_global_attn, + is_global_attn, + output_attentions, + ], training=training, ) - attention_output = self.dense_output(self_outputs[0], hidden_states, training=training) + attention_output = self.dense_output( + self_outputs[0], hidden_states, training=training + ) outputs = (attention_output,) + self_outputs[1:] return outputs @@ -819,13 +992,24 @@ def call(self, inputs, training=False): ) = inputs attention_outputs = self.attention( - [hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, output_attentions], + [ + hidden_states, + attention_mask, + is_index_masked, + is_index_global_attn, + is_global_attn, + output_attentions, + ], training=training, ) attention_output = attention_outputs[0] intermediate_output = self.intermediate(attention_output) - layer_output = self.longformer_output(intermediate_output, attention_output, training=training) - outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + layer_output = self.longformer_output( + intermediate_output, attention_output, training=training + ) + outputs = (layer_output,) + attention_outputs[ + 1: + ] # add attentions if we output them return outputs @@ -835,7 +1019,8 @@ def __init__(self, config, **kwargs): self.output_hidden_states = config.output_hidden_states self.output_attentions = config.output_attentions self.layer = [ - TFLongformerLayer(config, i, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers) + TFLongformerLayer(config, i, name="layer_._{}".format(i)) + for i in range(config.num_hidden_layers) ] def call( @@ -856,7 +1041,11 @@ def call( all_attentions = () if output_attentions else None for i, layer_module in enumerate(self.layer): if output_hidden_states: - hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states + hidden_states_to_add = ( + hidden_states[:, :-padding_len] + if padding_len > 0 + else hidden_states + ) all_hidden_states = all_hidden_states + (hidden_states_to_add,) layer_outputs = layer_module( @@ -873,17 +1062,27 @@ def call( hidden_states = layer_outputs[0] if output_attentions: - all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),) + all_attentions = all_attentions + ( + tf.transpose(layer_outputs[1], (0, 2, 1, 3)), + ) # Add last layer if output_hidden_states: - hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states + hidden_states_to_add = ( + hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states + ) all_hidden_states = all_hidden_states + (hidden_states_to_add,) if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + return tuple( + v + for v in [hidden_states, all_hidden_states, all_attentions] + if v is not None + ) return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, ) @@ -895,9 +1094,15 @@ def __init__(self, config, **kwargs): super().__init__(**kwargs) if isinstance(config.attention_window, int): - assert config.attention_window % 2 == 0, "`config.attention_window` has to be an even value" - assert config.attention_window > 0, "`config.attention_window` has to be positive" - config.attention_window = [config.attention_window] * config.num_hidden_layers # one value per layer + assert ( + config.attention_window % 2 == 0 + ), "`config.attention_window` has to be an even value" + assert ( + config.attention_window > 0 + ), "`config.attention_window` has to be positive" + config.attention_window = [ + config.attention_window + ] * config.num_hidden_layers # one value per layer else: assert len(config.attention_window) == config.num_hidden_layers, ( "`len(config.attention_window)` should equal `config.num_hidden_layers`. " @@ -951,29 +1156,45 @@ def call( position_ids = inputs[4] if len(inputs) > 4 else position_ids inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds output_attentions = inputs[6] if len(inputs) > 6 else output_attentions - output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states + output_hidden_states = ( + inputs[7] if len(inputs) > 7 else output_hidden_states + ) return_dict = inputs[8] if len(inputs) > 8 else return_dict assert len(inputs) <= 9, "Too many inputs." elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask", attention_mask) - global_attention_mask = inputs.get("global_attention_mask", global_attention_mask) + global_attention_mask = inputs.get( + "global_attention_mask", global_attention_mask + ) token_type_ids = inputs.get("token_type_ids", token_type_ids) position_ids = inputs.get("position_ids", position_ids) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) output_attentions = inputs.get("output_attentions", output_attentions) - output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) + output_hidden_states = inputs.get( + "output_hidden_states", output_hidden_states + ) return_dict = inputs.get("return_dict", return_dict) assert len(inputs) <= 9, "Too many inputs." else: input_ids = inputs - output_attentions = output_attentions if output_attentions is not None else self.output_attentions - output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states + output_attentions = ( + output_attentions + if output_attentions is not None + else self.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.return_dict if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) elif input_ids is not None: input_shape = shape_list(input_ids) elif inputs_embeds is not None: @@ -988,9 +1209,18 @@ def call( # merge `global_attention_mask` and `attention_mask` if global_attention_mask is not None: - attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask) + attention_mask = self._merge_to_attention_mask( + attention_mask, global_attention_mask + ) - padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size( + ( + padding_len, + input_ids, + attention_mask, + token_type_ids, + position_ids, + inputs_embeds, + ) = self._pad_to_window_size( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1015,9 +1245,14 @@ def call( # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - extended_attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0 + extended_attention_mask = ( + tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) + * -10000.0 + ) - embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training) + embedding_output = self.embeddings( + input_ids, position_ids, token_type_ids, inputs_embeds, training=training + ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, @@ -1064,11 +1299,19 @@ def _pad_to_window_size( """A helper function to pad tokens and mask to work with implementation of Longformer selfattention.""" # padding attention_window = ( - self.attention_window if isinstance(self.attention_window, int) else max(self.attention_window) + self.attention_window + if isinstance(self.attention_window, int) + else max(self.attention_window) ) - assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}" - input_shape = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds) + assert ( + attention_window % 2 == 0 + ), f"`attention_window` should be an even value. Given {attention_window}" + input_shape = ( + shape_list(input_ids) + if input_ids is not None + else shape_list(inputs_embeds) + ) batch_size, seq_len = input_shape[:2] padding_len = (attention_window - seq_len % attention_window) % attention_window @@ -1083,21 +1326,38 @@ def _pad_to_window_size( input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id) if position_ids is not None: # pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings - position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id) + position_ids = tf.pad( + position_ids, paddings, constant_values=pad_token_id + ) if inputs_embeds is not None: - input_ids_padding = tf.fill((batch_size, padding_len), self.pad_token_id) + input_ids_padding = tf.fill( + (batch_size, padding_len), self.pad_token_id + ) inputs_embeds_padding = self.embeddings(input_ids_padding) - inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2) + inputs_embeds = tf.concat( + [inputs_embeds, inputs_embeds_padding], axis=-2 + ) attention_mask = tf.pad( attention_mask, paddings, constant_values=False ) # no attention on the padding tokens - token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0 + token_type_ids = tf.pad( + token_type_ids, paddings, constant_values=0 + ) # pad with token_type_id = 0 - return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds + return ( + padding_len, + input_ids, + attention_mask, + token_type_ids, + position_ids, + inputs_embeds, + ) @staticmethod - def _merge_to_attention_mask(attention_mask: tf.Tensor, global_attention_mask: tf.Tensor): + def _merge_to_attention_mask( + attention_mask: tf.Tensor, global_attention_mask: tf.Tensor + ): # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn) # (global_attention_mask + 1) => 1 for local attention, 2 for global attention # => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention @@ -1122,8 +1382,12 @@ class TFLongformerPreTrainedModel(TFPreTrainedModel): def dummy_inputs(self): input_ids = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]) # make sure global layers are initialized - attention_mask = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]) - global_attention_mask = tf.constant([[0, 0, 0, 0, 1], [0, 0, 1, 0, 0], [0, 0, 0, 0, 1]]) + attention_mask = tf.constant( + [[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]] + ) + global_attention_mask = tf.constant( + [[0, 0, 0, 0, 1], [0, 0, 1, 0, 0], [0, 0, 0, 0, 1]] + ) return { "input_ids": input_ids, "attention_mask": attention_mask, @@ -1238,24 +1502,35 @@ def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.longformer = TFLongformerMainLayer(config, name="longformer") - @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_start_docstrings_to_callable( + LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)") + ) def call(self, inputs, **kwargs): outputs = self.longformer(inputs, **kwargs) return outputs -@add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING) -class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss): +@add_start_docstrings( + """Longformer Model with a `language modeling` head on top. """, + LONGFORMER_START_DOCSTRING, +) +class TFLongformerForMaskedLM( + TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss +): def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.longformer = TFLongformerMainLayer(config, name="longformer") - self.lm_head = TFRobertaLMHead(config, self.longformer.embeddings, name="lm_head") + self.lm_head = TFRobertaLMHead( + config, self.longformer.embeddings, name="lm_head" + ) def get_output_embeddings(self): return self.lm_head.decoder - @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_start_docstrings_to_callable( + LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)") + ) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="allenai/longformer-base-4096", @@ -1283,7 +1558,9 @@ def call( Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` """ - return_dict = return_dict if return_dict is not None else self.longformer.return_dict + return_dict = ( + return_dict if return_dict is not None else self.longformer.return_dict + ) if isinstance(inputs, (tuple, list)): labels = inputs[9] if len(inputs) > 9 else labels if len(inputs) > 9: @@ -1326,17 +1603,23 @@ def call( the hidden-states output to compute `span start logits` and `span end logits`). """, LONGFORMER_START_DOCSTRING, ) -class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAnsweringLoss): +class TFLongformerForQuestionAnswering( + TFLongformerPreTrainedModel, TFQuestionAnsweringLoss +): def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.num_labels = config.num_labels self.longformer = TFLongformerMainLayer(config, name="longformer") self.qa_outputs = tf.keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" + config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + name="qa_outputs", ) - @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_start_docstrings_to_callable( + LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)") + ) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="allenai/longformer-large-4096-finetuned-triviaqa", @@ -1368,7 +1651,9 @@ def call( Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ - return_dict = return_dict if return_dict is not None else self.longformer.return_dict + return_dict = ( + return_dict if return_dict is not None else self.longformer.return_dict + ) if isinstance(inputs, (tuple, list)): input_ids = inputs[0] global_attention_mask = inputs[2] @@ -1378,7 +1663,9 @@ def call( inputs = inputs[:9] elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids", inputs) - global_attention_mask = inputs.get("global_attention_mask", global_attention_mask) + global_attention_mask = inputs.get( + "global_attention_mask", global_attention_mask + ) start_positions = inputs.pop("start_positions", start_positions) end_positions = inputs.pop("end_positions", start_positions) else: @@ -1390,7 +1677,10 @@ def call( logger.warning( "It is not possible to automatically generate the `global_attention_mask`. Please make sure that it is correctly set." ) - elif tf.where(input_ids == self.config.sep_token_id).shape[0] != 3 * input_ids.shape[0]: + elif ( + tf.where(input_ids == self.config.sep_token_id).shape[0] + != 3 * input_ids.shape[0] + ): logger.warning( f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error." ) @@ -1398,7 +1688,9 @@ def call( logger.info("Initializing global attention on question tokens...") # put global attention on all tokens until `config.sep_token_id` is reached sep_token_indices = tf.where(input_ids == self.config.sep_token_id) - global_attention_mask = _compute_global_attention_mask(shape_list(input_ids), sep_token_indices) + global_attention_mask = _compute_global_attention_mask( + shape_list(input_ids), sep_token_indices + ) outputs = self.longformer( inputs, From efddf0cec692f096db537a42d58b1f9a3f92d51f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 26 Aug 2020 18:24:51 +0200 Subject: [PATCH 8/9] fix black --- src/transformers/modeling_tf_longformer.py | 541 +++++---------------- 1 file changed, 127 insertions(+), 414 deletions(-) diff --git a/src/transformers/modeling_tf_longformer.py b/src/transformers/modeling_tf_longformer.py index fc147751f56e34..92be8cf6afbb92 100644 --- a/src/transformers/modeling_tf_longformer.py +++ b/src/transformers/modeling_tf_longformer.py @@ -17,17 +17,8 @@ import tensorflow as tf from .configuration_longformer import LongformerConfig -from .file_utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_callable, -) -from .modeling_tf_bert import ( - TFBertIntermediate, - TFBertOutput, - TFBertPooler, - TFBertSelfOutput, -) +from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable +from .modeling_tf_bert import TFBertIntermediate, TFBertOutput, TFBertPooler, TFBertSelfOutput from .modeling_tf_outputs import ( TFBaseModelOutput, TFBaseModelOutputWithPooling, @@ -62,9 +53,7 @@ ] -def _compute_global_attention_mask( - input_ids_shape, sep_token_indices, before_sep_token=True -): +def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_sep_token=True): """ Computes global attention mask by putting attention on all tokens before `sep_token_id` if `before_sep_token is True` else after @@ -72,18 +61,13 @@ def _compute_global_attention_mask( """ assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions" - question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[ - :, 0, 1 - ] - question_end_index = tf.cast( - question_end_index[:, None], tf.dtypes.int32 - ) # size: batch_size x 1 + question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1] + question_end_index = tf.cast(question_end_index[:, None], tf.dtypes.int32) # size: batch_size x 1 # bool attention mask with True in locations of global attention attention_mask = tf.range(input_ids_shape[1]) if before_sep_token is True: attention_mask = tf.cast( - tf.broadcast_to(attention_mask, input_ids_shape) - < tf.broadcast_to(question_end_index, input_ids_shape), + tf.broadcast_to(attention_mask, input_ids_shape) < tf.broadcast_to(question_end_index, input_ids_shape), tf.dtypes.int32, ) else: @@ -113,42 +97,28 @@ def __init__(self, config, layer_id, **kwargs): self.embed_dim = config.hidden_size self.query = tf.keras.layers.Dense( - self.embed_dim, - kernel_initializer=get_initializer(config.initializer_range), - name="query", + self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="query", ) self.key = tf.keras.layers.Dense( - self.embed_dim, - kernel_initializer=get_initializer(config.initializer_range), - name="key", + self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="key", ) self.value = tf.keras.layers.Dense( - self.embed_dim, - kernel_initializer=get_initializer(config.initializer_range), - name="value", + self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="value", ) # separate projection layers for tokens with global attention self.query_global = tf.keras.layers.Dense( - self.embed_dim, - kernel_initializer=get_initializer(config.initializer_range), - name="query_global", + self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="query_global", ) self.key_global = tf.keras.layers.Dense( - self.embed_dim, - kernel_initializer=get_initializer(config.initializer_range), - name="key_global", + self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="key_global", ) self.value_global = tf.keras.layers.Dense( - self.embed_dim, - kernel_initializer=get_initializer(config.initializer_range), - name="value_global", + self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="value_global", ) self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) - self.global_dropout = tf.keras.layers.Dropout( - config.attention_probs_dropout_prob - ) + self.global_dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) self.layer_id = layer_id @@ -200,16 +170,10 @@ def call( ) # normalize query - query_vectors /= tf.math.sqrt( - tf.constant(self.head_dim, dtype=tf.dtypes.float32) - ) + query_vectors /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32)) - query_vectors = tf.reshape( - query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim) - ) - key_vectors = tf.reshape( - key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim) - ) + query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) + key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) # attn_probs = (batch_size, seq_len, num_heads, window*2+1) attn_scores = self._sliding_chunks_query_key_matmul( @@ -218,9 +182,7 @@ def call( # diagonal mask with zeros everywhere and -inf inplace of padding diagonal_mask = self._sliding_chunks_query_key_matmul( - tf.ones(shape_list(attention_mask), dtype=tf.float32), - attention_mask, - self.one_sided_attn_window_size, + tf.ones(shape_list(attention_mask), dtype=tf.float32), attention_mask, self.one_sided_attn_window_size, ) # pad local attention probs @@ -228,12 +190,7 @@ def call( tf.debugging.assert_equal( shape_list(attn_scores), - [ - batch_size, - seq_len, - self.num_heads, - self.one_sided_attn_window_size * 2 + 1, - ], + [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1], message=f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}", ) @@ -265,17 +222,13 @@ def call( # softmax sometimes inserts NaN if all positions are masked, replace them with 0 attn_probs = tf.where( - tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)), - 0.0, - attn_probs, + tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)), 0.0, attn_probs, ) # apply dropout attn_probs = self.dropout(attn_probs, training=training) - value_vectors = tf.reshape( - value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim) - ) + value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) # if global attention, compute sum of global and local attn attn_output = tf.cond( @@ -293,9 +246,7 @@ def call( ) tf.debugging.assert_equal( - shape_list(attn_output), - [batch_size, seq_len, self.num_heads, self.head_dim], - message="Unexpected size", + shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size", ) attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim)) @@ -329,9 +280,7 @@ def call( # which is the attention weights of every token attending to its neighbours attn_probs = tf.cond( is_global_attn, - lambda: self._get_global_attn_probs( - attn_probs, max_num_global_attn_indices - ), + lambda: self._get_global_attn_probs(attn_probs, max_num_global_attn_indices), lambda: attn_probs, ) @@ -369,13 +318,8 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): chunks_count = seq_len // window_overlap - 1 # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2 - query = tf.reshape( - tf.transpose(query, (0, 2, 1, 3)), - (batch_size * num_heads, seq_len, head_dim), - ) - key = tf.reshape( - tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim) - ) + query = tf.reshape(tf.transpose(query, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim),) + key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim)) chunked_query = self._chunk(query, window_overlap) chunked_key = self._chunk(key, window_overlap) @@ -384,15 +328,11 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim # bcxy: batch_size * num_heads x chunks x 2window_overlap x window_overlap - chunked_attention_scores = tf.einsum( - "bcxd,bcyd->bcxy", chunked_query, chunked_key - ) # multiply + chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query, chunked_key) # multiply # convert diagonals into columns paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32) - diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims( - chunked_attention_scores, paddings - ) + diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(chunked_attention_scores, paddings) # allocate space for the overall attention matrix where the chunks are combined. The last dimension # has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to @@ -404,12 +344,8 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): # TODO: This code is most likely not very efficient and should be improved diagonal_attn_scores_up_triang = tf.concat( [ - diagonal_chunked_attention_scores[ - :, :, :window_overlap, : window_overlap + 1 - ], - diagonal_chunked_attention_scores[ - :, -1:, window_overlap:, : window_overlap + 1 - ], + diagonal_chunked_attention_scores[:, :, :window_overlap, : window_overlap + 1], + diagonal_chunked_attention_scores[:, -1:, window_overlap:, : window_overlap + 1], ], axis=1, ) @@ -418,19 +354,15 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): diagonal_attn_scores_low_triang = tf.concat( [ tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)), - diagonal_chunked_attention_scores[ - :, :, -(window_overlap + 1) : -1, window_overlap + 1 : - ], + diagonal_chunked_attention_scores[:, :, -(window_overlap + 1) : -1, window_overlap + 1 :], ], axis=1, ) diagonal_attn_scores_first_chunk = tf.concat( [ - tf.roll( - diagonal_chunked_attention_scores, - shift=[1, window_overlap], - axis=[2, 3], - )[:, :, :window_overlap, :window_overlap], + tf.roll(diagonal_chunked_attention_scores, shift=[1, window_overlap], axis=[2, 3],)[ + :, :, :window_overlap, :window_overlap + ], tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)), ], axis=1, @@ -439,20 +371,13 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): first_chunk_mask = ( tf.broadcast_to( tf.range(chunks_count + 1)[None, :, None, None], - shape=( - batch_size * num_heads, - chunks_count + 1, - window_overlap, - window_overlap, - ), + shape=(batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap,), ) < 1 ) diagonal_attn_scores_low_triang = tf.where( - first_chunk_mask, - diagonal_attn_scores_first_chunk, - diagonal_attn_scores_low_triang, + first_chunk_mask, diagonal_attn_scores_first_chunk, diagonal_attn_scores_low_triang, ) # merging upper and lower triangle @@ -462,33 +387,22 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): # separate batch_size and num_heads dimensions again diagonal_attention_scores = tf.transpose( - tf.reshape( - diagonal_attention_scores, - (batch_size, num_heads, seq_len, 2 * window_overlap + 1), - ), + tf.reshape(diagonal_attention_scores, (batch_size, num_heads, seq_len, 2 * window_overlap + 1),), (0, 2, 1, 3), ) - diagonal_attention_scores = self._mask_invalid_locations( - diagonal_attention_scores, window_overlap - ) + diagonal_attention_scores = self._mask_invalid_locations(diagonal_attention_scores, window_overlap) return diagonal_attention_scores @staticmethod def _mask_invalid_locations(input_tensor, window_overlap): # create correct upper triangle bool mask mask_2d_upper = tf.reverse( - tf.linalg.band_part( - tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0 - ), - axis=[0], + tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0), axis=[0], ) # pad to full matrix padding = tf.constant( - [ - [0, shape_list(input_tensor)[1] - window_overlap], - [0, shape_list(input_tensor)[3] - window_overlap - 1], - ] + [[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]] ) # create lower mask @@ -507,9 +421,7 @@ def _mask_invalid_locations(input_tensor, window_overlap): return input_tensor - def _sliding_chunks_matmul_attn_probs_value( - self, attn_probs, value, window_overlap - ): + def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_overlap): """Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the same shape as `attn_probs`""" @@ -517,9 +429,7 @@ def _sliding_chunks_matmul_attn_probs_value( batch_size, seq_len, num_heads, head_dim = shape_list(value) tf.debugging.assert_equal( - seq_len % (window_overlap * 2), - 0, - message="Seq_len has to be multiple of 2 * window_overlap", + seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap", ) tf.debugging.assert_equal( shape_list(attn_probs)[:3], @@ -537,42 +447,27 @@ def _sliding_chunks_matmul_attn_probs_value( chunked_attn_probs = tf.reshape( tf.transpose(attn_probs, (0, 2, 1, 3)), - ( - batch_size * num_heads, - seq_len // window_overlap, - window_overlap, - 2 * window_overlap + 1, - ), + (batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1,), ) # group batch_size and num_heads dimensions into one - value = tf.reshape( - tf.transpose(value, (0, 2, 1, 3)), - (batch_size * num_heads, seq_len, head_dim), - ) + value = tf.reshape(tf.transpose(value, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim),) # pad seq_len with w at the beginning of the sequence and another window overlap at the end - paddings = tf.constant( - [[0, 0], [window_overlap, window_overlap], [0, 0]], dtype=tf.dtypes.int32 - ) + paddings = tf.constant([[0, 0], [window_overlap, window_overlap], [0, 0]], dtype=tf.dtypes.int32) padded_value = tf.pad(value, paddings, constant_values=-1) # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap frame_size = 3 * window_overlap * head_dim - frame_hop_size = ( - shape_list(padded_value)[1] * head_dim - frame_size - ) // chunks_count + frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count chunked_value = tf.signal.frame( - tf.reshape(padded_value, (batch_size * num_heads, -1)), - frame_size, - frame_hop_size, + tf.reshape(padded_value, (batch_size * num_heads, -1)), frame_size, frame_hop_size, ) chunked_value = tf.reshape( - chunked_value, - (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim), + chunked_value, (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim), ) tf.debugging.assert_equal( @@ -584,10 +479,7 @@ def _sliding_chunks_matmul_attn_probs_value( chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs) context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value) - context = tf.transpose( - tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)), - (0, 2, 1, 3), - ) + context = tf.transpose(tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)), (0, 2, 1, 3),) return context @staticmethod @@ -597,12 +489,8 @@ def _pad_and_transpose_last_two_dims(hidden_states_padded, paddings): hidden_states_padded, paddings ) # padding value is not important because it will be overwritten - batch_size, chunk_size, seq_length, hidden_dim = shape_list( - hidden_states_padded - ) - hidden_states_padded = tf.reshape( - hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length) - ) + batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded) + hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length)) return hidden_states_padded @@ -621,9 +509,7 @@ def _pad_and_diagonalize(chunked_hidden_states): 0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ] """ - total_num_heads, num_chunks, window_overlap, hidden_dim = shape_list( - chunked_hidden_states - ) + total_num_heads, num_chunks, window_overlap, hidden_dim = shape_list(chunked_hidden_states) paddings = tf.constant([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]]) chunked_hidden_states = tf.pad( @@ -637,8 +523,7 @@ def _pad_and_diagonalize(chunked_hidden_states): :, :, :-window_overlap ] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap chunked_hidden_states = tf.reshape( - chunked_hidden_states, - (total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim), + chunked_hidden_states, (total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim), ) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap chunked_hidden_states = chunked_hidden_states[:, :, :, :-1] return chunked_hidden_states @@ -656,9 +541,7 @@ def _chunk(hidden_states, window_overlap): hidden_states = tf.reshape(hidden_states, (batch_size, seq_length * hidden_dim)) # chunk with overlap - chunked_hidden_states = tf.signal.frame( - hidden_states, frame_size, frame_hop_size - ) + chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size) tf.debugging.assert_equal( shape_list(chunked_hidden_states), @@ -667,8 +550,7 @@ def _chunk(hidden_states, window_overlap): ) chunked_hidden_states = tf.reshape( - chunked_hidden_states, - (batch_size, num_output_chunks, 2 * window_overlap, hidden_dim), + chunked_hidden_states, (batch_size, num_output_chunks, 2 * window_overlap, hidden_dim), ) return chunked_hidden_states @@ -677,9 +559,7 @@ def _chunk(hidden_states, window_overlap): def _get_global_attn_indices(is_index_global_attn): """ compute global attn indices required throughout forward pass """ # helper variable - num_global_attn_indices = tf.reduce_sum( - tf.cast(is_index_global_attn, dtype=tf.dtypes.int32), axis=1 - ) + num_global_attn_indices = tf.reduce_sum(tf.cast(is_index_global_attn, dtype=tf.dtypes.int32), axis=1) # max number of global attn indices in batch max_num_global_attn_indices = tf.reduce_max(num_global_attn_indices) @@ -688,17 +568,15 @@ def _get_global_attn_indices(is_index_global_attn): is_index_global_attn_nonzero = tf.where(is_index_global_attn) # helper variable - is_local_index_global_attn = tf.range( - max_num_global_attn_indices - ) < tf.expand_dims(num_global_attn_indices, axis=-1) + is_local_index_global_attn = tf.range(max_num_global_attn_indices) < tf.expand_dims( + num_global_attn_indices, axis=-1 + ) # location of the non-padding values within global attention indices is_local_index_global_attn_nonzero = tf.where(is_local_index_global_attn) # location of the padding values within global attention indices - is_local_index_no_global_attn_nonzero = tf.where( - tf.math.logical_not(is_local_index_global_attn) - ) + is_local_index_no_global_attn_nonzero = tf.where(tf.math.logical_not(is_local_index_global_attn)) return ( max_num_global_attn_indices, @@ -725,22 +603,13 @@ def _concat_with_global_key_attn_probs( key_vectors_only_global = tf.scatter_nd( is_local_index_global_attn_nonzero, global_key_vectors, - shape=( - batch_size, - max_num_global_attn_indices, - self.num_heads, - self.head_dim, - ), + shape=(batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim,), ) # (batch_size, seq_len, num_heads, max_num_global_attn_indices) - attn_probs_from_global_key = tf.einsum( - "blhd,bshd->blhs", query_vectors, key_vectors_only_global - ) + attn_probs_from_global_key = tf.einsum("blhd,bshd->blhs", query_vectors, key_vectors_only_global) # (batch_size, max_num_global_attn_indices, seq_len, num_heads) - attn_probs_from_global_key_trans = tf.transpose( - attn_probs_from_global_key, (0, 3, 1, 2) - ) + attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key, (0, 3, 1, 2)) mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple( shape_list(attn_probs_from_global_key_trans)[-2:] ) @@ -748,15 +617,11 @@ def _concat_with_global_key_attn_probs( # scatter mask attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update( - attn_probs_from_global_key_trans, - is_local_index_no_global_attn_nonzero, - mask, + attn_probs_from_global_key_trans, is_local_index_no_global_attn_nonzero, mask, ) # (batch_size, seq_len, num_heads, max_num_global_attn_indices) - attn_probs_from_global_key = tf.transpose( - attn_probs_from_global_key_trans, (0, 2, 3, 1) - ) + attn_probs_from_global_key = tf.transpose(attn_probs_from_global_key_trans, (0, 2, 3, 1)) # concat to attn_probs # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) @@ -783,18 +648,11 @@ def _compute_attn_output_with_global_indices( value_vectors_only_global = tf.scatter_nd( is_local_index_global_attn_nonzero, global_value_vectors, - shape=( - batch_size, - max_num_global_attn_indices, - self.num_heads, - self.head_dim, - ), + shape=(batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim,), ) # compute attn output only global - attn_output_only_global = tf.einsum( - "blhs,bshd->blhd", attn_probs_only_global, value_vectors_only_global - ) + attn_output_only_global = tf.einsum("blhs,bshd->blhd", attn_probs_only_global, value_vectors_only_global) # reshape attn probs attn_probs_without_global = attn_probs[:, :, :, max_num_global_attn_indices:] @@ -819,9 +677,7 @@ def _compute_global_attn_output_from_hidden( batch_size, seq_len = shape_list(hidden_states)[:2] # prepare global hidden states - global_attn_hidden_states = tf.gather_nd( - hidden_states, is_index_global_attn_nonzero - ) + global_attn_hidden_states = tf.gather_nd(hidden_states, is_index_global_attn_nonzero) global_attn_hidden_states = tf.scatter_nd( is_local_index_global_attn_nonzero, global_attn_hidden_states, @@ -834,22 +690,14 @@ def _compute_global_attn_output_from_hidden( global_value_vectors = self.value_global(hidden_states) # normalize - global_query_vectors_only_global /= tf.math.sqrt( - tf.constant(self.head_dim, dtype=tf.dtypes.float32) - ) + global_query_vectors_only_global /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32)) - global_query_vectors_only_global = self.reshape_and_transpose( - global_query_vectors_only_global, batch_size - ) + global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size) global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size) - global_value_vectors = self.reshape_and_transpose( - global_value_vectors, batch_size - ) + global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size) # compute attn scores - global_attn_scores = tf.matmul( - global_query_vectors_only_global, global_key_vectors, transpose_b=True - ) + global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True) tf.debugging.assert_equal( shape_list(global_attn_scores), @@ -858,8 +706,7 @@ def _compute_global_attn_output_from_hidden( ) global_attn_scores = tf.reshape( - global_attn_scores, - (batch_size, self.num_heads, max_num_global_attn_indices, seq_len), + global_attn_scores, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len), ) global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3)) @@ -870,30 +717,23 @@ def _compute_global_attn_output_from_hidden( # scatter mask global_attn_scores_trans = tf.tensor_scatter_nd_update( - global_attn_scores_trans, - is_local_index_no_global_attn_nonzero, - global_attn_mask, + global_attn_scores_trans, is_local_index_no_global_attn_nonzero, global_attn_mask, ) global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3)) # mask global attn scores - attn_mask = tf.broadcast_to( - is_index_masked[:, None, None, :], shape_list(global_attn_scores) - ) + attn_mask = tf.broadcast_to(is_index_masked[:, None, None, :], shape_list(global_attn_scores)) global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores) global_attn_scores = tf.reshape( - global_attn_scores, - (batch_size * self.num_heads, max_num_global_attn_indices, seq_len), + global_attn_scores, (batch_size * self.num_heads, max_num_global_attn_indices, seq_len), ) # compute global attn probs global_attn_probs_float = tf.nn.softmax(global_attn_scores, axis=-1) # dropout - global_attn_probs = self.global_dropout( - global_attn_probs_float, training=training - ) + global_attn_probs = self.global_dropout(global_attn_probs_float, training=training) # global attn output global_attn_output = tf.matmul(global_attn_probs, global_value_vectors) @@ -905,18 +745,15 @@ def _compute_global_attn_output_from_hidden( ) global_attn_output = tf.reshape( - global_attn_output, - (batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim), + global_attn_output, (batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim), ) # get only non zero global attn output nonzero_global_attn_output = tf.gather_nd( - tf.transpose(global_attn_output, (0, 2, 1, 3)), - is_local_index_global_attn_nonzero, + tf.transpose(global_attn_output, (0, 2, 1, 3)), is_local_index_global_attn_nonzero, ) nonzero_global_attn_output = tf.reshape( - nonzero_global_attn_output, - (shape_list(is_local_index_global_attn_nonzero)[0], -1), + nonzero_global_attn_output, (shape_list(is_local_index_global_attn_nonzero)[0], -1), ) # overwrite values with global attention @@ -928,10 +765,7 @@ def _compute_global_attn_output_from_hidden( def reshape_and_transpose(self, vector, batch_size): return tf.reshape( - tf.transpose( - tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)), - (0, 2, 1, 3), - ), + tf.transpose(tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)), (0, 2, 1, 3),), (batch_size * self.num_heads, -1, self.head_dim), ) @@ -956,19 +790,10 @@ def call(self, inputs, training=False): ) = inputs self_outputs = self.self_attention( - [ - hidden_states, - attention_mask, - is_index_masked, - is_index_global_attn, - is_global_attn, - output_attentions, - ], + [hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, output_attentions], training=training, ) - attention_output = self.dense_output( - self_outputs[0], hidden_states, training=training - ) + attention_output = self.dense_output(self_outputs[0], hidden_states, training=training) outputs = (attention_output,) + self_outputs[1:] return outputs @@ -992,24 +817,13 @@ def call(self, inputs, training=False): ) = inputs attention_outputs = self.attention( - [ - hidden_states, - attention_mask, - is_index_masked, - is_index_global_attn, - is_global_attn, - output_attentions, - ], + [hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, output_attentions], training=training, ) attention_output = attention_outputs[0] intermediate_output = self.intermediate(attention_output) - layer_output = self.longformer_output( - intermediate_output, attention_output, training=training - ) - outputs = (layer_output,) + attention_outputs[ - 1: - ] # add attentions if we output them + layer_output = self.longformer_output(intermediate_output, attention_output, training=training) + outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them return outputs @@ -1019,8 +833,7 @@ def __init__(self, config, **kwargs): self.output_hidden_states = config.output_hidden_states self.output_attentions = config.output_attentions self.layer = [ - TFLongformerLayer(config, i, name="layer_._{}".format(i)) - for i in range(config.num_hidden_layers) + TFLongformerLayer(config, i, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers) ] def call( @@ -1041,11 +854,7 @@ def call( all_attentions = () if output_attentions else None for i, layer_module in enumerate(self.layer): if output_hidden_states: - hidden_states_to_add = ( - hidden_states[:, :-padding_len] - if padding_len > 0 - else hidden_states - ) + hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states all_hidden_states = all_hidden_states + (hidden_states_to_add,) layer_outputs = layer_module( @@ -1062,27 +871,17 @@ def call( hidden_states = layer_outputs[0] if output_attentions: - all_attentions = all_attentions + ( - tf.transpose(layer_outputs[1], (0, 2, 1, 3)), - ) + all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),) # Add last layer if output_hidden_states: - hidden_states_to_add = ( - hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states - ) + hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states all_hidden_states = all_hidden_states + (hidden_states_to_add,) if not return_dict: - return tuple( - v - for v in [hidden_states, all_hidden_states, all_attentions] - if v is not None - ) + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) return TFBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_attentions, + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions, ) @@ -1094,15 +893,9 @@ def __init__(self, config, **kwargs): super().__init__(**kwargs) if isinstance(config.attention_window, int): - assert ( - config.attention_window % 2 == 0 - ), "`config.attention_window` has to be an even value" - assert ( - config.attention_window > 0 - ), "`config.attention_window` has to be positive" - config.attention_window = [ - config.attention_window - ] * config.num_hidden_layers # one value per layer + assert config.attention_window % 2 == 0, "`config.attention_window` has to be an even value" + assert config.attention_window > 0, "`config.attention_window` has to be positive" + config.attention_window = [config.attention_window] * config.num_hidden_layers # one value per layer else: assert len(config.attention_window) == config.num_hidden_layers, ( "`len(config.attention_window)` should equal `config.num_hidden_layers`. " @@ -1156,45 +949,29 @@ def call( position_ids = inputs[4] if len(inputs) > 4 else position_ids inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds output_attentions = inputs[6] if len(inputs) > 6 else output_attentions - output_hidden_states = ( - inputs[7] if len(inputs) > 7 else output_hidden_states - ) + output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states return_dict = inputs[8] if len(inputs) > 8 else return_dict assert len(inputs) <= 9, "Too many inputs." elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask", attention_mask) - global_attention_mask = inputs.get( - "global_attention_mask", global_attention_mask - ) + global_attention_mask = inputs.get("global_attention_mask", global_attention_mask) token_type_ids = inputs.get("token_type_ids", token_type_ids) position_ids = inputs.get("position_ids", position_ids) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) output_attentions = inputs.get("output_attentions", output_attentions) - output_hidden_states = inputs.get( - "output_hidden_states", output_hidden_states - ) + output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) return_dict = inputs.get("return_dict", return_dict) assert len(inputs) <= 9, "Too many inputs." else: input_ids = inputs - output_attentions = ( - output_attentions - if output_attentions is not None - else self.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.output_hidden_states - ) + output_attentions = output_attentions if output_attentions is not None else self.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states return_dict = return_dict if return_dict is not None else self.return_dict if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time" - ) + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = shape_list(input_ids) elif inputs_embeds is not None: @@ -1209,9 +986,7 @@ def call( # merge `global_attention_mask` and `attention_mask` if global_attention_mask is not None: - attention_mask = self._merge_to_attention_mask( - attention_mask, global_attention_mask - ) + attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask) ( padding_len, @@ -1245,14 +1020,9 @@ def call( # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - extended_attention_mask = ( - tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) - * -10000.0 - ) + extended_attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0 - embedding_output = self.embeddings( - input_ids, position_ids, token_type_ids, inputs_embeds, training=training - ) + embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, @@ -1288,30 +1058,16 @@ def call( ) def _pad_to_window_size( - self, - input_ids, - attention_mask, - token_type_ids, - position_ids, - inputs_embeds, - pad_token_id, + self, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds, pad_token_id, ): """A helper function to pad tokens and mask to work with implementation of Longformer selfattention.""" # padding attention_window = ( - self.attention_window - if isinstance(self.attention_window, int) - else max(self.attention_window) + self.attention_window if isinstance(self.attention_window, int) else max(self.attention_window) ) - assert ( - attention_window % 2 == 0 - ), f"`attention_window` should be an even value. Given {attention_window}" - input_shape = ( - shape_list(input_ids) - if input_ids is not None - else shape_list(inputs_embeds) - ) + assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}" + input_shape = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds) batch_size, seq_len = input_shape[:2] padding_len = (attention_window - seq_len % attention_window) % attention_window @@ -1326,24 +1082,16 @@ def _pad_to_window_size( input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id) if position_ids is not None: # pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings - position_ids = tf.pad( - position_ids, paddings, constant_values=pad_token_id - ) + position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id) if inputs_embeds is not None: - input_ids_padding = tf.fill( - (batch_size, padding_len), self.pad_token_id - ) + input_ids_padding = tf.fill((batch_size, padding_len), self.pad_token_id) inputs_embeds_padding = self.embeddings(input_ids_padding) - inputs_embeds = tf.concat( - [inputs_embeds, inputs_embeds_padding], axis=-2 - ) + inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2) attention_mask = tf.pad( attention_mask, paddings, constant_values=False ) # no attention on the padding tokens - token_type_ids = tf.pad( - token_type_ids, paddings, constant_values=0 - ) # pad with token_type_id = 0 + token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0 return ( padding_len, @@ -1355,9 +1103,7 @@ def _pad_to_window_size( ) @staticmethod - def _merge_to_attention_mask( - attention_mask: tf.Tensor, global_attention_mask: tf.Tensor - ): + def _merge_to_attention_mask(attention_mask: tf.Tensor, global_attention_mask: tf.Tensor): # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn) # (global_attention_mask + 1) => 1 for local attention, 2 for global attention # => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention @@ -1382,12 +1128,8 @@ class TFLongformerPreTrainedModel(TFPreTrainedModel): def dummy_inputs(self): input_ids = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]) # make sure global layers are initialized - attention_mask = tf.constant( - [[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]] - ) - global_attention_mask = tf.constant( - [[0, 0, 0, 0, 1], [0, 0, 1, 0, 0], [0, 0, 0, 0, 1]] - ) + attention_mask = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]) + global_attention_mask = tf.constant([[0, 0, 0, 0, 1], [0, 0, 1, 0, 0], [0, 0, 0, 0, 1]]) return { "input_ids": input_ids, "attention_mask": attention_mask, @@ -1502,35 +1244,26 @@ def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.longformer = TFLongformerMainLayer(config, name="longformer") - @add_start_docstrings_to_callable( - LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)") - ) + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): outputs = self.longformer(inputs, **kwargs) return outputs @add_start_docstrings( - """Longformer Model with a `language modeling` head on top. """, - LONGFORMER_START_DOCSTRING, + """Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING, ) -class TFLongformerForMaskedLM( - TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss -): +class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss): def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.longformer = TFLongformerMainLayer(config, name="longformer") - self.lm_head = TFRobertaLMHead( - config, self.longformer.embeddings, name="lm_head" - ) + self.lm_head = TFRobertaLMHead(config, self.longformer.embeddings, name="lm_head") def get_output_embeddings(self): return self.lm_head.decoder - @add_start_docstrings_to_callable( - LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)") - ) + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="allenai/longformer-base-4096", @@ -1558,9 +1291,7 @@ def call( Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` """ - return_dict = ( - return_dict if return_dict is not None else self.longformer.return_dict - ) + return_dict = return_dict if return_dict is not None else self.longformer.return_dict if isinstance(inputs, (tuple, list)): labels = inputs[9] if len(inputs) > 9 else labels if len(inputs) > 9: @@ -1591,10 +1322,7 @@ def call( return ((loss,) + output) if loss is not None else output return TFMaskedLMOutput( - loss=loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, + loss=loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @@ -1603,23 +1331,17 @@ def call( the hidden-states output to compute `span start logits` and `span end logits`). """, LONGFORMER_START_DOCSTRING, ) -class TFLongformerForQuestionAnswering( - TFLongformerPreTrainedModel, TFQuestionAnsweringLoss -): +class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAnsweringLoss): def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.num_labels = config.num_labels self.longformer = TFLongformerMainLayer(config, name="longformer") self.qa_outputs = tf.keras.layers.Dense( - config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - name="qa_outputs", + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs", ) - @add_start_docstrings_to_callable( - LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)") - ) + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="allenai/longformer-large-4096-finetuned-triviaqa", @@ -1651,9 +1373,7 @@ def call( Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ - return_dict = ( - return_dict if return_dict is not None else self.longformer.return_dict - ) + return_dict = return_dict if return_dict is not None else self.longformer.return_dict if isinstance(inputs, (tuple, list)): input_ids = inputs[0] global_attention_mask = inputs[2] @@ -1663,9 +1383,7 @@ def call( inputs = inputs[:9] elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids", inputs) - global_attention_mask = inputs.get( - "global_attention_mask", global_attention_mask - ) + global_attention_mask = inputs.get("global_attention_mask", global_attention_mask) start_positions = inputs.pop("start_positions", start_positions) end_positions = inputs.pop("end_positions", start_positions) else: @@ -1677,10 +1395,7 @@ def call( logger.warning( "It is not possible to automatically generate the `global_attention_mask`. Please make sure that it is correctly set." ) - elif ( - tf.where(input_ids == self.config.sep_token_id).shape[0] - != 3 * input_ids.shape[0] - ): + elif tf.where(input_ids == self.config.sep_token_id).shape[0] != 3 * input_ids.shape[0]: logger.warning( f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error." ) @@ -1688,9 +1403,7 @@ def call( logger.info("Initializing global attention on question tokens...") # put global attention on all tokens until `config.sep_token_id` is reached sep_token_indices = tf.where(input_ids == self.config.sep_token_id) - global_attention_mask = _compute_global_attention_mask( - shape_list(input_ids), sep_token_indices - ) + global_attention_mask = _compute_global_attention_mask(shape_list(input_ids), sep_token_indices) outputs = self.longformer( inputs, From ae3bbe24a2d3c30c0294ea9dc616e8c55357b963 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 26 Aug 2020 18:25:40 +0200 Subject: [PATCH 9/9] make style --- src/transformers/modeling_tf_longformer.py | 168 ++++++++++++++++----- 1 file changed, 128 insertions(+), 40 deletions(-) diff --git a/src/transformers/modeling_tf_longformer.py b/src/transformers/modeling_tf_longformer.py index 92be8cf6afbb92..698ff02340b673 100644 --- a/src/transformers/modeling_tf_longformer.py +++ b/src/transformers/modeling_tf_longformer.py @@ -97,24 +97,36 @@ def __init__(self, config, layer_id, **kwargs): self.embed_dim = config.hidden_size self.query = tf.keras.layers.Dense( - self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="query", + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="query", ) self.key = tf.keras.layers.Dense( - self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="key", + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="key", ) self.value = tf.keras.layers.Dense( - self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="value", + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="value", ) # separate projection layers for tokens with global attention self.query_global = tf.keras.layers.Dense( - self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="query_global", + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="query_global", ) self.key_global = tf.keras.layers.Dense( - self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="key_global", + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="key_global", ) self.value_global = tf.keras.layers.Dense( - self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="value_global", + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="value_global", ) self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) @@ -182,7 +194,9 @@ def call( # diagonal mask with zeros everywhere and -inf inplace of padding diagonal_mask = self._sliding_chunks_query_key_matmul( - tf.ones(shape_list(attention_mask), dtype=tf.float32), attention_mask, self.one_sided_attn_window_size, + tf.ones(shape_list(attention_mask), dtype=tf.float32), + attention_mask, + self.one_sided_attn_window_size, ) # pad local attention probs @@ -222,7 +236,9 @@ def call( # softmax sometimes inserts NaN if all positions are masked, replace them with 0 attn_probs = tf.where( - tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)), 0.0, attn_probs, + tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)), + 0.0, + attn_probs, ) # apply dropout @@ -246,7 +262,9 @@ def call( ) tf.debugging.assert_equal( - shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size", + shape_list(attn_output), + [batch_size, seq_len, self.num_heads, self.head_dim], + message="Unexpected size", ) attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim)) @@ -318,7 +336,10 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): chunks_count = seq_len // window_overlap - 1 # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2 - query = tf.reshape(tf.transpose(query, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim),) + query = tf.reshape( + tf.transpose(query, (0, 2, 1, 3)), + (batch_size * num_heads, seq_len, head_dim), + ) key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim)) chunked_query = self._chunk(query, window_overlap) @@ -360,9 +381,11 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): ) diagonal_attn_scores_first_chunk = tf.concat( [ - tf.roll(diagonal_chunked_attention_scores, shift=[1, window_overlap], axis=[2, 3],)[ - :, :, :window_overlap, :window_overlap - ], + tf.roll( + diagonal_chunked_attention_scores, + shift=[1, window_overlap], + axis=[2, 3], + )[:, :, :window_overlap, :window_overlap], tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)), ], axis=1, @@ -371,13 +394,20 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): first_chunk_mask = ( tf.broadcast_to( tf.range(chunks_count + 1)[None, :, None, None], - shape=(batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap,), + shape=( + batch_size * num_heads, + chunks_count + 1, + window_overlap, + window_overlap, + ), ) < 1 ) diagonal_attn_scores_low_triang = tf.where( - first_chunk_mask, diagonal_attn_scores_first_chunk, diagonal_attn_scores_low_triang, + first_chunk_mask, + diagonal_attn_scores_first_chunk, + diagonal_attn_scores_low_triang, ) # merging upper and lower triangle @@ -387,7 +417,10 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): # separate batch_size and num_heads dimensions again diagonal_attention_scores = tf.transpose( - tf.reshape(diagonal_attention_scores, (batch_size, num_heads, seq_len, 2 * window_overlap + 1),), + tf.reshape( + diagonal_attention_scores, + (batch_size, num_heads, seq_len, 2 * window_overlap + 1), + ), (0, 2, 1, 3), ) @@ -398,7 +431,8 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): def _mask_invalid_locations(input_tensor, window_overlap): # create correct upper triangle bool mask mask_2d_upper = tf.reverse( - tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0), axis=[0], + tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0), + axis=[0], ) # pad to full matrix padding = tf.constant( @@ -429,7 +463,9 @@ def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_over batch_size, seq_len, num_heads, head_dim = shape_list(value) tf.debugging.assert_equal( - seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap", + seq_len % (window_overlap * 2), + 0, + message="Seq_len has to be multiple of 2 * window_overlap", ) tf.debugging.assert_equal( shape_list(attn_probs)[:3], @@ -447,11 +483,19 @@ def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_over chunked_attn_probs = tf.reshape( tf.transpose(attn_probs, (0, 2, 1, 3)), - (batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1,), + ( + batch_size * num_heads, + seq_len // window_overlap, + window_overlap, + 2 * window_overlap + 1, + ), ) # group batch_size and num_heads dimensions into one - value = tf.reshape(tf.transpose(value, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim),) + value = tf.reshape( + tf.transpose(value, (0, 2, 1, 3)), + (batch_size * num_heads, seq_len, head_dim), + ) # pad seq_len with w at the beginning of the sequence and another window overlap at the end @@ -464,10 +508,13 @@ def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_over frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count chunked_value = tf.signal.frame( - tf.reshape(padded_value, (batch_size * num_heads, -1)), frame_size, frame_hop_size, + tf.reshape(padded_value, (batch_size * num_heads, -1)), + frame_size, + frame_hop_size, ) chunked_value = tf.reshape( - chunked_value, (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim), + chunked_value, + (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim), ) tf.debugging.assert_equal( @@ -479,7 +526,10 @@ def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_over chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs) context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value) - context = tf.transpose(tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)), (0, 2, 1, 3),) + context = tf.transpose( + tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)), + (0, 2, 1, 3), + ) return context @staticmethod @@ -523,7 +573,8 @@ def _pad_and_diagonalize(chunked_hidden_states): :, :, :-window_overlap ] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap chunked_hidden_states = tf.reshape( - chunked_hidden_states, (total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim), + chunked_hidden_states, + (total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim), ) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap chunked_hidden_states = chunked_hidden_states[:, :, :, :-1] return chunked_hidden_states @@ -550,7 +601,8 @@ def _chunk(hidden_states, window_overlap): ) chunked_hidden_states = tf.reshape( - chunked_hidden_states, (batch_size, num_output_chunks, 2 * window_overlap, hidden_dim), + chunked_hidden_states, + (batch_size, num_output_chunks, 2 * window_overlap, hidden_dim), ) return chunked_hidden_states @@ -603,7 +655,12 @@ def _concat_with_global_key_attn_probs( key_vectors_only_global = tf.scatter_nd( is_local_index_global_attn_nonzero, global_key_vectors, - shape=(batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim,), + shape=( + batch_size, + max_num_global_attn_indices, + self.num_heads, + self.head_dim, + ), ) # (batch_size, seq_len, num_heads, max_num_global_attn_indices) @@ -617,7 +674,9 @@ def _concat_with_global_key_attn_probs( # scatter mask attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update( - attn_probs_from_global_key_trans, is_local_index_no_global_attn_nonzero, mask, + attn_probs_from_global_key_trans, + is_local_index_no_global_attn_nonzero, + mask, ) # (batch_size, seq_len, num_heads, max_num_global_attn_indices) @@ -648,7 +707,12 @@ def _compute_attn_output_with_global_indices( value_vectors_only_global = tf.scatter_nd( is_local_index_global_attn_nonzero, global_value_vectors, - shape=(batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim,), + shape=( + batch_size, + max_num_global_attn_indices, + self.num_heads, + self.head_dim, + ), ) # compute attn output only global @@ -706,7 +770,8 @@ def _compute_global_attn_output_from_hidden( ) global_attn_scores = tf.reshape( - global_attn_scores, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len), + global_attn_scores, + (batch_size, self.num_heads, max_num_global_attn_indices, seq_len), ) global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3)) @@ -717,7 +782,9 @@ def _compute_global_attn_output_from_hidden( # scatter mask global_attn_scores_trans = tf.tensor_scatter_nd_update( - global_attn_scores_trans, is_local_index_no_global_attn_nonzero, global_attn_mask, + global_attn_scores_trans, + is_local_index_no_global_attn_nonzero, + global_attn_mask, ) global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3)) @@ -726,7 +793,8 @@ def _compute_global_attn_output_from_hidden( global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores) global_attn_scores = tf.reshape( - global_attn_scores, (batch_size * self.num_heads, max_num_global_attn_indices, seq_len), + global_attn_scores, + (batch_size * self.num_heads, max_num_global_attn_indices, seq_len), ) # compute global attn probs @@ -745,15 +813,18 @@ def _compute_global_attn_output_from_hidden( ) global_attn_output = tf.reshape( - global_attn_output, (batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim), + global_attn_output, + (batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim), ) # get only non zero global attn output nonzero_global_attn_output = tf.gather_nd( - tf.transpose(global_attn_output, (0, 2, 1, 3)), is_local_index_global_attn_nonzero, + tf.transpose(global_attn_output, (0, 2, 1, 3)), + is_local_index_global_attn_nonzero, ) nonzero_global_attn_output = tf.reshape( - nonzero_global_attn_output, (shape_list(is_local_index_global_attn_nonzero)[0], -1), + nonzero_global_attn_output, + (shape_list(is_local_index_global_attn_nonzero)[0], -1), ) # overwrite values with global attention @@ -765,7 +836,10 @@ def _compute_global_attn_output_from_hidden( def reshape_and_transpose(self, vector, batch_size): return tf.reshape( - tf.transpose(tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)), (0, 2, 1, 3),), + tf.transpose( + tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)), + (0, 2, 1, 3), + ), (batch_size * self.num_heads, -1, self.head_dim), ) @@ -881,7 +955,9 @@ def call( if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions, + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, ) @@ -1058,7 +1134,13 @@ def call( ) def _pad_to_window_size( - self, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds, pad_token_id, + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + inputs_embeds, + pad_token_id, ): """A helper function to pad tokens and mask to work with implementation of Longformer selfattention.""" # padding @@ -1251,7 +1333,8 @@ def call(self, inputs, **kwargs): @add_start_docstrings( - """Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING, + """Longformer Model with a `language modeling` head on top. """, + LONGFORMER_START_DOCSTRING, ) class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss): def __init__(self, config, *inputs, **kwargs): @@ -1322,7 +1405,10 @@ def call( return ((loss,) + output) if loss is not None else output return TFMaskedLMOutput( - loss=loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + loss=loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) @@ -1338,7 +1424,9 @@ def __init__(self, config, *inputs, **kwargs): self.longformer = TFLongformerMainLayer(config, name="longformer") self.qa_outputs = tf.keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs", + config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + name="qa_outputs", ) @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))