-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix TF Longformer #9348
Fix TF Longformer #9348
Conversation
if input_ids is not None: | ||
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id) | ||
if input_ids is not None: | ||
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When padding_len==0
, then this won't change the input_ids
correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct!
@@ -2171,16 +2171,14 @@ def call( | |||
|
|||
# set global attention on question tokens | |||
if inputs["global_attention_mask"] is None and inputs["input_ids"] is not None: | |||
if inputs["input_ids"] is None: | |||
logger.warning( | |||
"It is not possible to automatically generate the `global_attention_mask`. Please make sure that it is correctly set." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we leave this warning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can but here the problem is that we test if inputs["input_ids"]
is None inside a if that already test if inputs["input_ids"]
is not None, this seems strange
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed with @jplu here ;-)
): | ||
logger.warning( | ||
f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error." | ||
) | ||
inputs["global_attention_mask"] = tf.fill(shape_list(inputs["input_ids"]), value=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doesn't look correct to me. The "default" global_attention_mask is all 0
s so I think it should be:
inputs["global_attention_mask"] = tf.fill(shape_list(inputs["input_ids"]), value=0)
Also we could improve the warning a bit by appending a sentence like. Disabling global attention for this forward pass...
@@ -1523,8 +1523,7 @@ def call( | |||
training=False, | |||
): | |||
all_hidden_states = () if output_hidden_states else None | |||
all_attentions = () if output_attentions else None | |||
all_global_attentions = () if (output_attentions and is_global_attn) else None | |||
all_attentions = all_global_attentions = () if output_attentions else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the general direction of this PR! Here we should also run the slow tests to be sure nothing is broken.
IMO, the only thing left to do is to correct the "default" global attention to all 0's instead of 1's (global_attention_mask
is different from attention_mask
)
I have already ran the slow tests as well and they all pass! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for fixing!
@@ -2171,16 +2171,14 @@ def call( | |||
|
|||
# set global attention on question tokens | |||
if inputs["global_attention_mask"] is None and inputs["input_ids"] is not None: | |||
if inputs["input_ids"] is None: | |||
logger.warning( | |||
"It is not possible to automatically generate the `global_attention_mask`. Please make sure that it is correctly set." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed with @jplu here ;-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @jplu!
* Fix longformer * Apply style * Remove serving content * Forgot a condition * Apply style * Address Patrick's comments * Fix dtype
What does this PR do?
This PR aims to fix the TF Longformer version in order to make it graph compliant. As seen offline with @patrickvonplaten
all_global_attentions
now is added in the output whenoutput_attentions=True
. The global attentions are filled with zeros in caseis_global_attn
is False (see line 897 inTFLongformerSelfAttention
.Fix issue
#9333