Skip to content

Commit

Permalink
make style
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Aug 26, 2020
1 parent efddf0c commit ae3bbe2
Showing 1 changed file with 128 additions and 40 deletions.
168 changes: 128 additions & 40 deletions src/transformers/modeling_tf_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,24 +97,36 @@ def __init__(self, config, layer_id, **kwargs):
self.embed_dim = config.hidden_size

self.query = tf.keras.layers.Dense(
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="query",
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="query",
)
self.key = tf.keras.layers.Dense(
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="key",
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="key",
)
self.value = tf.keras.layers.Dense(
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="value",
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="value",
)

# separate projection layers for tokens with global attention
self.query_global = tf.keras.layers.Dense(
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="query_global",
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="query_global",
)
self.key_global = tf.keras.layers.Dense(
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="key_global",
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="key_global",
)
self.value_global = tf.keras.layers.Dense(
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="value_global",
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="value_global",
)

self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
Expand Down Expand Up @@ -182,7 +194,9 @@ def call(

# diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask = self._sliding_chunks_query_key_matmul(
tf.ones(shape_list(attention_mask), dtype=tf.float32), attention_mask, self.one_sided_attn_window_size,
tf.ones(shape_list(attention_mask), dtype=tf.float32),
attention_mask,
self.one_sided_attn_window_size,
)

# pad local attention probs
Expand Down Expand Up @@ -222,7 +236,9 @@ def call(

# softmax sometimes inserts NaN if all positions are masked, replace them with 0
attn_probs = tf.where(
tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)), 0.0, attn_probs,
tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)),
0.0,
attn_probs,
)

# apply dropout
Expand All @@ -246,7 +262,9 @@ def call(
)

tf.debugging.assert_equal(
shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size",
shape_list(attn_output),
[batch_size, seq_len, self.num_heads, self.head_dim],
message="Unexpected size",
)
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))

Expand Down Expand Up @@ -318,7 +336,10 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
chunks_count = seq_len // window_overlap - 1

# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
query = tf.reshape(tf.transpose(query, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim),)
query = tf.reshape(
tf.transpose(query, (0, 2, 1, 3)),
(batch_size * num_heads, seq_len, head_dim),
)
key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim))

chunked_query = self._chunk(query, window_overlap)
Expand Down Expand Up @@ -360,9 +381,11 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
)
diagonal_attn_scores_first_chunk = tf.concat(
[
tf.roll(diagonal_chunked_attention_scores, shift=[1, window_overlap], axis=[2, 3],)[
:, :, :window_overlap, :window_overlap
],
tf.roll(
diagonal_chunked_attention_scores,
shift=[1, window_overlap],
axis=[2, 3],
)[:, :, :window_overlap, :window_overlap],
tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)),
],
axis=1,
Expand All @@ -371,13 +394,20 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
first_chunk_mask = (
tf.broadcast_to(
tf.range(chunks_count + 1)[None, :, None, None],
shape=(batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap,),
shape=(
batch_size * num_heads,
chunks_count + 1,
window_overlap,
window_overlap,
),
)
< 1
)

diagonal_attn_scores_low_triang = tf.where(
first_chunk_mask, diagonal_attn_scores_first_chunk, diagonal_attn_scores_low_triang,
first_chunk_mask,
diagonal_attn_scores_first_chunk,
diagonal_attn_scores_low_triang,
)

# merging upper and lower triangle
Expand All @@ -387,7 +417,10 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):

# separate batch_size and num_heads dimensions again
diagonal_attention_scores = tf.transpose(
tf.reshape(diagonal_attention_scores, (batch_size, num_heads, seq_len, 2 * window_overlap + 1),),
tf.reshape(
diagonal_attention_scores,
(batch_size, num_heads, seq_len, 2 * window_overlap + 1),
),
(0, 2, 1, 3),
)

Expand All @@ -398,7 +431,8 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
def _mask_invalid_locations(input_tensor, window_overlap):
# create correct upper triangle bool mask
mask_2d_upper = tf.reverse(
tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0), axis=[0],
tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0),
axis=[0],
)
# pad to full matrix
padding = tf.constant(
Expand Down Expand Up @@ -429,7 +463,9 @@ def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_over
batch_size, seq_len, num_heads, head_dim = shape_list(value)

tf.debugging.assert_equal(
seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap",
seq_len % (window_overlap * 2),
0,
message="Seq_len has to be multiple of 2 * window_overlap",
)
tf.debugging.assert_equal(
shape_list(attn_probs)[:3],
Expand All @@ -447,11 +483,19 @@ def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_over

chunked_attn_probs = tf.reshape(
tf.transpose(attn_probs, (0, 2, 1, 3)),
(batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1,),
(
batch_size * num_heads,
seq_len // window_overlap,
window_overlap,
2 * window_overlap + 1,
),
)

# group batch_size and num_heads dimensions into one
value = tf.reshape(tf.transpose(value, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim),)
value = tf.reshape(
tf.transpose(value, (0, 2, 1, 3)),
(batch_size * num_heads, seq_len, head_dim),
)

# pad seq_len with w at the beginning of the sequence and another window overlap at the end

Expand All @@ -464,10 +508,13 @@ def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_over
frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count

chunked_value = tf.signal.frame(
tf.reshape(padded_value, (batch_size * num_heads, -1)), frame_size, frame_hop_size,
tf.reshape(padded_value, (batch_size * num_heads, -1)),
frame_size,
frame_hop_size,
)
chunked_value = tf.reshape(
chunked_value, (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
chunked_value,
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
)

tf.debugging.assert_equal(
Expand All @@ -479,7 +526,10 @@ def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_over
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)

context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
context = tf.transpose(tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)), (0, 2, 1, 3),)
context = tf.transpose(
tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)),
(0, 2, 1, 3),
)
return context

@staticmethod
Expand Down Expand Up @@ -523,7 +573,8 @@ def _pad_and_diagonalize(chunked_hidden_states):
:, :, :-window_overlap
] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap
chunked_hidden_states = tf.reshape(
chunked_hidden_states, (total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim),
chunked_hidden_states,
(total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim),
) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap
chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
return chunked_hidden_states
Expand All @@ -550,7 +601,8 @@ def _chunk(hidden_states, window_overlap):
)

chunked_hidden_states = tf.reshape(
chunked_hidden_states, (batch_size, num_output_chunks, 2 * window_overlap, hidden_dim),
chunked_hidden_states,
(batch_size, num_output_chunks, 2 * window_overlap, hidden_dim),
)

return chunked_hidden_states
Expand Down Expand Up @@ -603,7 +655,12 @@ def _concat_with_global_key_attn_probs(
key_vectors_only_global = tf.scatter_nd(
is_local_index_global_attn_nonzero,
global_key_vectors,
shape=(batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim,),
shape=(
batch_size,
max_num_global_attn_indices,
self.num_heads,
self.head_dim,
),
)

# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
Expand All @@ -617,7 +674,9 @@ def _concat_with_global_key_attn_probs(

# scatter mask
attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update(
attn_probs_from_global_key_trans, is_local_index_no_global_attn_nonzero, mask,
attn_probs_from_global_key_trans,
is_local_index_no_global_attn_nonzero,
mask,
)

# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
Expand Down Expand Up @@ -648,7 +707,12 @@ def _compute_attn_output_with_global_indices(
value_vectors_only_global = tf.scatter_nd(
is_local_index_global_attn_nonzero,
global_value_vectors,
shape=(batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim,),
shape=(
batch_size,
max_num_global_attn_indices,
self.num_heads,
self.head_dim,
),
)

# compute attn output only global
Expand Down Expand Up @@ -706,7 +770,8 @@ def _compute_global_attn_output_from_hidden(
)

global_attn_scores = tf.reshape(
global_attn_scores, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len),
global_attn_scores,
(batch_size, self.num_heads, max_num_global_attn_indices, seq_len),
)

global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3))
Expand All @@ -717,7 +782,9 @@ def _compute_global_attn_output_from_hidden(

# scatter mask
global_attn_scores_trans = tf.tensor_scatter_nd_update(
global_attn_scores_trans, is_local_index_no_global_attn_nonzero, global_attn_mask,
global_attn_scores_trans,
is_local_index_no_global_attn_nonzero,
global_attn_mask,
)
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))

Expand All @@ -726,7 +793,8 @@ def _compute_global_attn_output_from_hidden(
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)

global_attn_scores = tf.reshape(
global_attn_scores, (batch_size * self.num_heads, max_num_global_attn_indices, seq_len),
global_attn_scores,
(batch_size * self.num_heads, max_num_global_attn_indices, seq_len),
)

# compute global attn probs
Expand All @@ -745,15 +813,18 @@ def _compute_global_attn_output_from_hidden(
)

global_attn_output = tf.reshape(
global_attn_output, (batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim),
global_attn_output,
(batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim),
)

# get only non zero global attn output
nonzero_global_attn_output = tf.gather_nd(
tf.transpose(global_attn_output, (0, 2, 1, 3)), is_local_index_global_attn_nonzero,
tf.transpose(global_attn_output, (0, 2, 1, 3)),
is_local_index_global_attn_nonzero,
)
nonzero_global_attn_output = tf.reshape(
nonzero_global_attn_output, (shape_list(is_local_index_global_attn_nonzero)[0], -1),
nonzero_global_attn_output,
(shape_list(is_local_index_global_attn_nonzero)[0], -1),
)

# overwrite values with global attention
Expand All @@ -765,7 +836,10 @@ def _compute_global_attn_output_from_hidden(

def reshape_and_transpose(self, vector, batch_size):
return tf.reshape(
tf.transpose(tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)), (0, 2, 1, 3),),
tf.transpose(
tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)),
(0, 2, 1, 3),
),
(batch_size * self.num_heads, -1, self.head_dim),
)

Expand Down Expand Up @@ -881,7 +955,9 @@ def call(
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions,
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
)


Expand Down Expand Up @@ -1058,7 +1134,13 @@ def call(
)

def _pad_to_window_size(
self, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds, pad_token_id,
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
inputs_embeds,
pad_token_id,
):
"""A helper function to pad tokens and mask to work with implementation of Longformer selfattention."""
# padding
Expand Down Expand Up @@ -1251,7 +1333,8 @@ def call(self, inputs, **kwargs):


@add_start_docstrings(
"""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING,
"""Longformer Model with a `language modeling` head on top. """,
LONGFORMER_START_DOCSTRING,
)
class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss):
def __init__(self, config, *inputs, **kwargs):
Expand Down Expand Up @@ -1322,7 +1405,10 @@ def call(
return ((loss,) + output) if loss is not None else output

return TFMaskedLMOutput(
loss=loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


Expand All @@ -1338,7 +1424,9 @@ def __init__(self, config, *inputs, **kwargs):

self.longformer = TFLongformerMainLayer(config, name="longformer")
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs",
config.num_labels,
kernel_initializer=get_initializer(config.initializer_range),
name="qa_outputs",
)

@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
Expand Down

0 comments on commit ae3bbe2

Please sign in to comment.