-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix TF Longformer #9348
Fix TF Longformer #9348
Changes from 6 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -390,7 +390,7 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se | |
True` else after `sep_token_id`. | ||
""" | ||
|
||
assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions" | ||
assert shape_list(sep_token_indices)[1] == 2, "`input_ids` should have two dimensions" | ||
question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1] | ||
question_end_index = tf.cast(question_end_index[:, None], tf.dtypes.int32) # size: batch_size x 1 | ||
# bool attention mask with True in locations of global attention | ||
|
@@ -1028,7 +1028,7 @@ def _mask_invalid_locations(input_tensor, window_overlap): | |
) | ||
|
||
# pad to full matrix | ||
padding = tf.constant( | ||
padding = tf.convert_to_tensor( | ||
[[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]] | ||
) | ||
|
||
|
@@ -1523,8 +1523,7 @@ def call( | |
training=False, | ||
): | ||
all_hidden_states = () if output_hidden_states else None | ||
all_attentions = () if output_attentions else None | ||
all_global_attentions = () if (output_attentions and is_global_attn) else None | ||
all_attentions = all_global_attentions = () if output_attentions else None | ||
|
||
for i, layer_module in enumerate(self.layer): | ||
if output_hidden_states: | ||
|
@@ -1547,9 +1546,8 @@ def call( | |
# bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1) | ||
all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),) | ||
|
||
if is_global_attn: | ||
# bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn | ||
all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2))) | ||
# bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn | ||
all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2))) | ||
|
||
# Add last layer | ||
if output_hidden_states: | ||
|
@@ -1766,24 +1764,26 @@ def _pad_to_window_size( | |
) | ||
) | ||
|
||
paddings = tf.constant([[0, 0], [0, padding_len]]) | ||
paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]]) | ||
|
||
if input_ids is not None: | ||
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id) | ||
if input_ids is not None: | ||
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct! |
||
|
||
if position_ids is not None: | ||
# pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings | ||
position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id) | ||
|
||
if position_ids is not None: | ||
# pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings | ||
position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id) | ||
if inputs_embeds is not None: | ||
|
||
if inputs_embeds is not None: | ||
def pad_embeddings(): | ||
input_ids_padding = tf.fill((batch_size, padding_len), self.pad_token_id) | ||
inputs_embeds_padding = self.embeddings(input_ids_padding) | ||
inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2) | ||
return tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2) | ||
|
||
attention_mask = tf.pad( | ||
attention_mask, paddings, constant_values=False | ||
) # no attention on the padding tokens | ||
token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0 | ||
inputs_embeds = tf.cond(padding_len > 0, pad_embeddings, lambda: inputs_embeds) | ||
|
||
attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens | ||
token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0 | ||
|
||
return ( | ||
padding_len, | ||
|
@@ -2171,16 +2171,14 @@ def call( | |
|
||
# set global attention on question tokens | ||
if inputs["global_attention_mask"] is None and inputs["input_ids"] is not None: | ||
if inputs["input_ids"] is None: | ||
logger.warning( | ||
"It is not possible to automatically generate the `global_attention_mask`. Please make sure that it is correctly set." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we leave this warning? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can but here the problem is that we test if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed with @jplu here ;-) |
||
) | ||
elif ( | ||
tf.where(inputs["input_ids"] == self.config.sep_token_id).shape[0] != 3 * inputs["input_ids"].shape[0] | ||
if ( | ||
shape_list(tf.where(inputs["input_ids"] == self.config.sep_token_id))[0] | ||
!= 3 * shape_list(inputs["input_ids"])[0] | ||
): | ||
logger.warning( | ||
f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error." | ||
) | ||
inputs["global_attention_mask"] = tf.zeros(shape_list(inputs["input_ids"])) | ||
else: | ||
logger.info("Initializing global attention on question tokens...") | ||
# put global attention on all tokens until `config.sep_token_id` is reached | ||
|
@@ -2317,8 +2315,8 @@ def call( | |
inputs["global_attention_mask"] = tf.zeros_like(inputs["input_ids"]) | ||
inputs["global_attention_mask"] = tf.tensor_scatter_nd_update( | ||
inputs["global_attention_mask"], | ||
[[i, 0] for i in range(inputs["input_ids"].shape[0])], | ||
[1 for _ in range(inputs["input_ids"].shape[0])], | ||
[[i, 0] for i in range(shape_list(inputs["input_ids"])[0])], | ||
[1 for _ in range(shape_list(inputs["input_ids"])[0])], | ||
) | ||
|
||
outputs = self.longformer( | ||
|
@@ -2443,7 +2441,7 @@ def call( | |
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None | ||
) | ||
flat_global_attention_mask = ( | ||
tf.reshape(inputs["global_attention_mask"], (-1, inputs["global_attention_mask"].shape[-1])) | ||
tf.reshape(inputs["global_attention_mask"], (-1, shape_list(inputs["global_attention_mask"])[-1])) | ||
if inputs["global_attention_mask"] is not None | ||
else None | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!