diff --git a/src/transformers/modeling_tf_bert.py b/src/transformers/modeling_tf_bert.py index f40414826ad4..1417c3b879f3 100644 --- a/src/transformers/modeling_tf_bert.py +++ b/src/transformers/modeling_tf_bert.py @@ -1088,7 +1088,7 @@ def __init__(self, config, *inputs, **kwargs): config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" ) - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-cased", diff --git a/src/transformers/modeling_tf_electra.py b/src/transformers/modeling_tf_electra.py index 05d7469996be..d13535fb5696 100644 --- a/src/transformers/modeling_tf_electra.py +++ b/src/transformers/modeling_tf_electra.py @@ -677,7 +677,7 @@ def __init__(self, config, *inputs, **kwargs): self.electra = TFElectraMainLayer(config, name="electra") self.classifier = TFElectraClassificationHead(config, name="classifier") - @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)")) + @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-discriminator", diff --git a/src/transformers/modeling_tf_longformer.py b/src/transformers/modeling_tf_longformer.py index 52768b5dd975..698ff02340b6 100644 --- a/src/transformers/modeling_tf_longformer.py +++ b/src/transformers/modeling_tf_longformer.py @@ -97,24 +97,36 @@ def __init__(self, config, layer_id, **kwargs): self.embed_dim = config.hidden_size self.query = tf.keras.layers.Dense( - self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="query" + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="query", ) self.key = tf.keras.layers.Dense( - self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="key" + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="key", ) self.value = tf.keras.layers.Dense( - self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="value" + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="value", ) # separate projection layers for tokens with global attention self.query_global = tf.keras.layers.Dense( - self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="query_global" + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="query_global", ) self.key_global = tf.keras.layers.Dense( - self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="key_global" + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="key_global", ) self.value_global = tf.keras.layers.Dense( - self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="value_global" + self.embed_dim, + kernel_initializer=get_initializer(config.initializer_range), + name="value_global", ) self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) @@ -148,23 +160,21 @@ def call( """ # retrieve input args - hidden_states, attention_mask, output_attentions = inputs - - attention_mask = tf.squeeze(tf.squeeze(attention_mask, axis=2), axis=1) - # is index masked or global attention - - is_index_masked = tf.math.less(attention_mask, 0) - is_index_global_attn = tf.math.greater(attention_mask, 0) - is_global_attn = tf.math.reduce_any(is_index_global_attn) - - hidden_states = tf.transpose(hidden_states, (1, 0, 2)) + ( + hidden_states, + attention_mask, + is_index_masked, + is_index_global_attn, + is_global_attn, + output_attentions, + ) = inputs # project hidden states query_vectors = self.query(hidden_states) key_vectors = self.key(hidden_states) value_vectors = self.value(hidden_states) - seq_len, batch_size, embed_dim = shape_list(hidden_states) + batch_size, seq_len, embed_dim = shape_list(hidden_states) tf.debugging.assert_equal( embed_dim, self.embed_dim, @@ -174,24 +184,19 @@ def call( # normalize query query_vectors /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32)) - query_vectors = tf.transpose( - tf.reshape(query_vectors, (seq_len, batch_size, self.num_heads, self.head_dim)), (1, 0, 2, 3) - ) - key_vectors = tf.transpose( - tf.reshape(key_vectors, (seq_len, batch_size, self.num_heads, self.head_dim)), (1, 0, 2, 3) - ) + query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) + key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) # attn_probs = (batch_size, seq_len, num_heads, window*2+1) attn_scores = self._sliding_chunks_query_key_matmul( query_vectors, key_vectors, self.one_sided_attn_window_size ) - # values to pad for attention probs - float_mask = tf.cast((attention_mask != 0)[:, :, None, None], dtype=tf.float32) * -10000.0 - # diagonal mask with zeros everywhere and -inf inplace of padding diagonal_mask = self._sliding_chunks_query_key_matmul( - tf.ones(shape_list(float_mask), dtype=tf.float32), float_mask, self.one_sided_attn_window_size + tf.ones(shape_list(attention_mask), dtype=tf.float32), + attention_mask, + self.one_sided_attn_window_size, ) # pad local attention probs @@ -231,15 +236,15 @@ def call( # softmax sometimes inserts NaN if all positions are masked, replace them with 0 attn_probs = tf.where( - tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)), 0.0, attn_probs + tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)), + 0.0, + attn_probs, ) # apply dropout attn_probs = self.dropout(attn_probs, training=training) - value_vectors = tf.transpose( - tf.reshape(value_vectors, (seq_len, batch_size, self.num_heads, self.head_dim)), (1, 0, 2, 3) - ) + value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) # if global attention, compute sum of global and local attn attn_output = tf.cond( @@ -257,9 +262,11 @@ def call( ) tf.debugging.assert_equal( - shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size" + shape_list(attn_output), + [batch_size, seq_len, self.num_heads, self.head_dim], + message="Unexpected size", ) - attn_output = tf.reshape(tf.transpose(attn_output, (1, 0, 2, 3)), (seq_len, batch_size, embed_dim)) + attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim)) # compute value for global attention and overwrite to attention output # TODO: remove the redundant computation @@ -278,8 +285,6 @@ def call( lambda: attn_output, ) - attn_output = tf.transpose(attn_output, (1, 0, 2)) - # GLOBAL ATTN: # With global attention, return global attention probabilities only # batch_size x num_heads x max_num_global_attention_tokens x sequence_length @@ -294,7 +299,7 @@ def call( attn_probs = tf.cond( is_global_attn, lambda: self._get_global_attn_probs(attn_probs, max_num_global_attn_indices), - lambda: tf.transpose(attn_probs, (0, 2, 1, 3)), + lambda: attn_probs, ) outputs = (attn_output, attn_probs) @@ -310,7 +315,6 @@ def _get_global_attn_probs(attn_probs, max_num_global_attn_indices): ], axis=-1, ) - attn_probs = tf.transpose(attn_probs, (0, 2, 1, 3)) return attn_probs def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): @@ -332,7 +336,10 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): chunks_count = seq_len // window_overlap - 1 # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2 - query = tf.reshape(tf.transpose(query, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim)) + query = tf.reshape( + tf.transpose(query, (0, 2, 1, 3)), + (batch_size * num_heads, seq_len, head_dim), + ) key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim)) chunked_query = self._chunk(query, window_overlap) @@ -374,9 +381,11 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): ) diagonal_attn_scores_first_chunk = tf.concat( [ - tf.roll(diagonal_chunked_attention_scores, shift=[1, window_overlap], axis=[2, 3])[ - :, :, :window_overlap, :window_overlap - ], + tf.roll( + diagonal_chunked_attention_scores, + shift=[1, window_overlap], + axis=[2, 3], + )[:, :, :window_overlap, :window_overlap], tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)), ], axis=1, @@ -385,13 +394,20 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): first_chunk_mask = ( tf.broadcast_to( tf.range(chunks_count + 1)[None, :, None, None], - shape=(batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap), + shape=( + batch_size * num_heads, + chunks_count + 1, + window_overlap, + window_overlap, + ), ) < 1 ) diagonal_attn_scores_low_triang = tf.where( - first_chunk_mask, diagonal_attn_scores_first_chunk, diagonal_attn_scores_low_triang + first_chunk_mask, + diagonal_attn_scores_first_chunk, + diagonal_attn_scores_low_triang, ) # merging upper and lower triangle @@ -401,7 +417,10 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): # separate batch_size and num_heads dimensions again diagonal_attention_scores = tf.transpose( - tf.reshape(diagonal_attention_scores, (batch_size, num_heads, seq_len, 2 * window_overlap + 1)), + tf.reshape( + diagonal_attention_scores, + (batch_size, num_heads, seq_len, 2 * window_overlap + 1), + ), (0, 2, 1, 3), ) @@ -412,7 +431,8 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): def _mask_invalid_locations(input_tensor, window_overlap): # create correct upper triangle bool mask mask_2d_upper = tf.reverse( - tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0), axis=[0] + tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0), + axis=[0], ) # pad to full matrix padding = tf.constant( @@ -443,7 +463,9 @@ def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_over batch_size, seq_len, num_heads, head_dim = shape_list(value) tf.debugging.assert_equal( - seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap" + seq_len % (window_overlap * 2), + 0, + message="Seq_len has to be multiple of 2 * window_overlap", ) tf.debugging.assert_equal( shape_list(attn_probs)[:3], @@ -461,11 +483,19 @@ def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_over chunked_attn_probs = tf.reshape( tf.transpose(attn_probs, (0, 2, 1, 3)), - (batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1), + ( + batch_size * num_heads, + seq_len // window_overlap, + window_overlap, + 2 * window_overlap + 1, + ), ) # group batch_size and num_heads dimensions into one - value = tf.reshape(tf.transpose(value, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim)) + value = tf.reshape( + tf.transpose(value, (0, 2, 1, 3)), + (batch_size * num_heads, seq_len, head_dim), + ) # pad seq_len with w at the beginning of the sequence and another window overlap at the end @@ -478,10 +508,13 @@ def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_over frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count chunked_value = tf.signal.frame( - tf.reshape(padded_value, (batch_size * num_heads, -1)), frame_size, frame_hop_size + tf.reshape(padded_value, (batch_size * num_heads, -1)), + frame_size, + frame_hop_size, ) chunked_value = tf.reshape( - chunked_value, (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim) + chunked_value, + (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim), ) tf.debugging.assert_equal( @@ -493,7 +526,10 @@ def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_over chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs) context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value) - context = tf.transpose(tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)), (0, 2, 1, 3)) + context = tf.transpose( + tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)), + (0, 2, 1, 3), + ) return context @staticmethod @@ -502,12 +538,10 @@ def _pad_and_transpose_last_two_dims(hidden_states_padded, paddings): hidden_states_padded = tf.pad( hidden_states_padded, paddings ) # padding value is not important because it will be overwritten - if len(shape_list(hidden_states_padded)) > 3: - batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded) - hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length)) - else: - batch_size, seq_length, hidden_dim = shape_list(hidden_states_padded) - hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, hidden_dim, seq_length)) + + batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded) + hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length)) + return hidden_states_padded @staticmethod @@ -539,7 +573,8 @@ def _pad_and_diagonalize(chunked_hidden_states): :, :, :-window_overlap ] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap chunked_hidden_states = tf.reshape( - chunked_hidden_states, (total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim) + chunked_hidden_states, + (total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim), ) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap chunked_hidden_states = chunked_hidden_states[:, :, :, :-1] return chunked_hidden_states @@ -566,7 +601,8 @@ def _chunk(hidden_states, window_overlap): ) chunked_hidden_states = tf.reshape( - chunked_hidden_states, (batch_size, num_output_chunks, 2 * window_overlap, hidden_dim) + chunked_hidden_states, + (batch_size, num_output_chunks, 2 * window_overlap, hidden_dim), ) return chunked_hidden_states @@ -619,7 +655,12 @@ def _concat_with_global_key_attn_probs( key_vectors_only_global = tf.scatter_nd( is_local_index_global_attn_nonzero, global_key_vectors, - shape=(batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim), + shape=( + batch_size, + max_num_global_attn_indices, + self.num_heads, + self.head_dim, + ), ) # (batch_size, seq_len, num_heads, max_num_global_attn_indices) @@ -633,7 +674,9 @@ def _concat_with_global_key_attn_probs( # scatter mask attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update( - attn_probs_from_global_key_trans, is_local_index_no_global_attn_nonzero, mask + attn_probs_from_global_key_trans, + is_local_index_no_global_attn_nonzero, + mask, ) # (batch_size, seq_len, num_heads, max_num_global_attn_indices) @@ -664,7 +707,12 @@ def _compute_attn_output_with_global_indices( value_vectors_only_global = tf.scatter_nd( is_local_index_global_attn_nonzero, global_value_vectors, - shape=(batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim), + shape=( + batch_size, + max_num_global_attn_indices, + self.num_heads, + self.head_dim, + ), ) # compute attn output only global @@ -690,14 +738,14 @@ def _compute_global_attn_output_from_hidden( is_index_masked, training, ): - seq_len, batch_size = shape_list(hidden_states)[:2] + batch_size, seq_len = shape_list(hidden_states)[:2] # prepare global hidden states - global_attn_hidden_states = tf.gather_nd(hidden_states, tf.reverse(is_index_global_attn_nonzero, axis=[1])) + global_attn_hidden_states = tf.gather_nd(hidden_states, is_index_global_attn_nonzero) global_attn_hidden_states = tf.scatter_nd( - tf.reverse(is_local_index_global_attn_nonzero, axis=[1]), + is_local_index_global_attn_nonzero, global_attn_hidden_states, - shape=(max_num_global_attn_indices, batch_size, self.embed_dim), + shape=(batch_size, max_num_global_attn_indices, self.embed_dim), ) # global key, query, value @@ -708,27 +756,12 @@ def _compute_global_attn_output_from_hidden( # normalize global_query_vectors_only_global /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32)) - # (batch_size * self.num_heads, max_num_global_attn_indices, head_dim) - global_query_vectors_only_global = tf.transpose( - tf.reshape( - global_query_vectors_only_global, - (max_num_global_attn_indices, batch_size * self.num_heads, self.head_dim), - ), - (1, 0, 2), - ) - - # (..., batch_size * self.num_heads, seq_len, head_dim) - global_key_vectors = tf.transpose( - tf.reshape(global_key_vectors, (-1, batch_size * self.num_heads, self.head_dim)), (1, 0, 2) - ) - - # (..., batch_size * self.num_heads, seq_len, head_dim) - global_value_vectors = tf.transpose( - tf.reshape(global_value_vectors, (-1, batch_size * self.num_heads, self.head_dim)), (1, 0, 2) - ) + global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size) + global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size) + global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size) # compute attn scores - global_attn_scores = tf.matmul(global_query_vectors_only_global, tf.transpose(global_key_vectors, (0, 2, 1))) + global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True) tf.debugging.assert_equal( shape_list(global_attn_scores), @@ -737,7 +770,8 @@ def _compute_global_attn_output_from_hidden( ) global_attn_scores = tf.reshape( - global_attn_scores, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len) + global_attn_scores, + (batch_size, self.num_heads, max_num_global_attn_indices, seq_len), ) global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3)) @@ -748,7 +782,9 @@ def _compute_global_attn_output_from_hidden( # scatter mask global_attn_scores_trans = tf.tensor_scatter_nd_update( - global_attn_scores_trans, is_local_index_no_global_attn_nonzero, global_attn_mask + global_attn_scores_trans, + is_local_index_no_global_attn_nonzero, + global_attn_mask, ) global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3)) @@ -757,7 +793,8 @@ def _compute_global_attn_output_from_hidden( global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores) global_attn_scores = tf.reshape( - global_attn_scores, (batch_size * self.num_heads, max_num_global_attn_indices, seq_len) + global_attn_scores, + (batch_size * self.num_heads, max_num_global_attn_indices, seq_len), ) # compute global attn probs @@ -776,12 +813,14 @@ def _compute_global_attn_output_from_hidden( ) global_attn_output = tf.reshape( - global_attn_output, (batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim) + global_attn_output, + (batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim), ) # get only non zero global attn output nonzero_global_attn_output = tf.gather_nd( - tf.transpose(global_attn_output, (0, 2, 1, 3)), is_local_index_global_attn_nonzero + tf.transpose(global_attn_output, (0, 2, 1, 3)), + is_local_index_global_attn_nonzero, ) nonzero_global_attn_output = tf.reshape( nonzero_global_attn_output, @@ -789,12 +828,21 @@ def _compute_global_attn_output_from_hidden( ) # overwrite values with global attention + attn_output = tf.tensor_scatter_nd_update( - attn_output, tf.reverse(is_index_global_attn_nonzero, axis=[1]), nonzero_global_attn_output + attn_output, is_index_global_attn_nonzero, nonzero_global_attn_output ) - return attn_output + def reshape_and_transpose(self, vector, batch_size): + return tf.reshape( + tf.transpose( + tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)), + (0, 2, 1, 3), + ), + (batch_size * self.num_heads, -1, self.head_dim), + ) + class TFLongformerAttention(tf.keras.layers.Layer): def __init__(self, config, layer_id=0, **kwargs): @@ -806,10 +854,20 @@ def prune_heads(self, heads): raise NotImplementedError def call(self, inputs, training=False): - input_tensor, attention_mask, output_attentions = inputs + ( + hidden_states, + attention_mask, + is_index_masked, + is_index_global_attn, + is_global_attn, + output_attentions, + ) = inputs - self_outputs = self.self_attention([input_tensor, attention_mask, output_attentions], training=training) - attention_output = self.dense_output(self_outputs[0], input_tensor, training=training) + self_outputs = self.self_attention( + [hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, output_attentions], + training=training, + ) + attention_output = self.dense_output(self_outputs[0], hidden_states, training=training) outputs = (attention_output,) + self_outputs[1:] return outputs @@ -823,9 +881,19 @@ def __init__(self, config, layer_id=0, **kwargs): self.longformer_output = TFBertOutput(config, name="output") def call(self, inputs, training=False): - hidden_states, attention_mask, output_attentions = inputs + ( + hidden_states, + attention_mask, + is_index_masked, + is_index_global_attn, + is_global_attn, + output_attentions, + ) = inputs - attention_outputs = self.attention([hidden_states, attention_mask, output_attentions], training=training) + attention_outputs = self.attention( + [hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, output_attentions], + training=training, + ) attention_output = attention_outputs[0] intermediate_output = self.intermediate(attention_output) layer_output = self.longformer_output(intermediate_output, attention_output, training=training) @@ -848,12 +916,14 @@ def call( attention_mask=None, head_mask=None, padding_len=0, + is_index_masked=None, + is_index_global_attn=None, + is_global_attn=None, output_attentions=None, output_hidden_states=None, return_dict=None, training=False, ): - all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None for i, layer_module in enumerate(self.layer): @@ -861,11 +931,21 @@ def call( hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states all_hidden_states = all_hidden_states + (hidden_states_to_add,) - layer_outputs = layer_module([hidden_states, attention_mask, output_attentions], training=training) + layer_outputs = layer_module( + [ + hidden_states, + attention_mask, + is_index_masked, + is_index_global_attn, + is_global_attn, + output_attentions, + ], + training=training, + ) hidden_states = layer_outputs[0] if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) + all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),) # Add last layer if output_hidden_states: @@ -875,7 +955,9 @@ def call( if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) return TFBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, ) @@ -982,7 +1064,14 @@ def call( if global_attention_mask is not None: attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask) - padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size( + ( + padding_len, + input_ids, + attention_mask, + token_type_ids, + position_ids, + inputs_embeds, + ) = self._pad_to_window_size( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -991,27 +1080,32 @@ def call( pad_token_id=self.pad_token_id, ) + # is index masked or global attention + is_index_masked = tf.math.less(attention_mask, 1) + is_index_global_attn = tf.math.greater(attention_mask, 1) + is_global_attn = tf.math.reduce_any(is_index_global_attn) # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] + # Sizes are [batch_size, to_seq_length, 1, 1] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] + extended_attention_mask = attention_mask[:, :, tf.newaxis, tf.newaxis] - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for + # Since attention_mask is 1.0 for positions we want to locall attend locally and 0.0 for + # masked and global attn positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - - extended_attention_mask = tf.cast(extended_attention_mask, tf.float32) - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + extended_attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0 embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, padding_len=padding_len, + is_index_masked=is_index_masked, + is_index_global_attn=is_index_global_attn, + is_global_attn=is_global_attn, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -1081,7 +1175,14 @@ def _pad_to_window_size( ) # no attention on the padding tokens token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0 - return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds + return ( + padding_len, + input_ids, + attention_mask, + token_type_ids, + position_ids, + inputs_embeds, + ) @staticmethod def _merge_to_attention_mask(attention_mask: tf.Tensor, global_attention_mask: tf.Tensor): @@ -1231,7 +1332,10 @@ def call(self, inputs, **kwargs): return outputs -@add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING) +@add_start_docstrings( + """Longformer Model with a `language modeling` head on top. """, + LONGFORMER_START_DOCSTRING, +) class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss): def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) @@ -1320,7 +1424,9 @@ def __init__(self, config, *inputs, **kwargs): self.longformer = TFLongformerMainLayer(config, name="longformer") self.qa_outputs = tf.keras.layers.Dense( - config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" + config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + name="qa_outputs", ) @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 478f8353eba1..298105ffe99b 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -110,15 +110,6 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def test_initialization(self): pass - # config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - # configs_no_init = _config_zero_init(config) - # for model_class in self.all_model_classes: - # model = model_class(config=configs_no_init) - # for name, param in model.named_parameters(): - # if param.requires_grad: - # self.assertIn(param.data.mean().item(), [0.0, 1.0], - # msg="Parameter {} of model {} seems not properly initialized".format(name, model_class)) def test_save_load(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -134,6 +125,19 @@ def test_save_load(self): self.assert_outputs_same(after_outputs, outputs) + def test_graph_mode(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + inputs = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + + @tf.function + def run_in_graph_mode(): + return model(inputs) + + outputs = run_in_graph_mode() + self.assertIsNotNone(outputs) + @slow def test_saved_model_with_hidden_states_output(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_tf_longformer.py b/tests/test_modeling_tf_longformer.py index 090a61c84b36..1282069b0313 100644 --- a/tests/test_modeling_tf_longformer.py +++ b/tests/test_modeling_tf_longformer.py @@ -385,15 +385,16 @@ def test_pad_and_transpose_last_two_dims(self): self.assertTrue(shape_list(hidden_states), [1, 8, 4]) # pad along seq length dim - paddings = tf.constant([[0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32) + paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32) + hidden_states = TFLongformerSelfAttention._chunk(hidden_states, window_overlap=2) padded_hidden_states = TFLongformerSelfAttention._pad_and_transpose_last_two_dims(hidden_states, paddings) - self.assertTrue(shape_list(padded_hidden_states) == [1, 8, 5]) + self.assertTrue(shape_list(padded_hidden_states) == [1, 1, 8, 5]) expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32) - tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, -1, :], rtol=1e-6) + tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, 0, -1, :], rtol=1e-6) tf.debugging.assert_near( - hidden_states[0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6 + hidden_states[0, 0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6 ) def test_mask_invalid_locations(self): @@ -437,10 +438,16 @@ def test_layer_local_attn(self): hidden_states = self._get_hidden_states() batch_size, seq_length, hidden_size = hidden_states.shape - attention_mask = tf.zeros((batch_size, 1, 1, seq_length), dtype=tf.dtypes.float32) - attention_mask = tf.where(tf.range(4)[None, None, None, :] > 1, -10000.0, attention_mask) + attention_mask = tf.zeros((batch_size, seq_length), dtype=tf.dtypes.float32) + is_index_global_attn = tf.math.greater(attention_mask, 1) + is_global_attn = tf.math.reduce_any(is_index_global_attn) + + attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0, attention_mask[:, :, None, None]) + is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0) - output_hidden_states = layer([hidden_states, attention_mask, None])[0] + output_hidden_states = layer( + [hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, None] + )[0] expected_slice = tf.convert_to_tensor( [0.00188, 0.012196, -0.017051, -0.025571, -0.02996, 0.017297, -0.011521, 0.004848], dtype=tf.dtypes.float32 @@ -461,12 +468,18 @@ def test_layer_global_attn(self): attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32) attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32) - attention_mask_1 = tf.where(tf.range(4)[None, None, None, :] > 1, 10000.0, attention_mask_1) - attention_mask_1 = tf.where(tf.range(4)[None, None, None, :] > 2, -10000.0, attention_mask_1) - attention_mask_2 = tf.where(tf.range(4)[None, None, None, :] > 0, 10000.0, attention_mask_2) + attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 1, 10000.0, attention_mask_1) + attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 2, -10000.0, attention_mask_1) + attention_mask_2 = tf.where(tf.range(4)[None, :, None, None] > 0, 10000.0, attention_mask_2) attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0) - output_hidden_states = layer([hidden_states, attention_mask, None])[0] + is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0) + is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0) + is_global_attn = tf.math.reduce_any(is_index_global_attn) + + output_hidden_states = layer( + [hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn, None] + )[0] self.assertTrue(output_hidden_states.shape, (2, 4, 8)) expected_slice_0 = tf.convert_to_tensor(