Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TF Longformer #9348

Merged
merged 7 commits into from
Jan 5, 2021
Merged

Fix TF Longformer #9348

merged 7 commits into from
Jan 5, 2021

Conversation

jplu
Copy link
Contributor

@jplu jplu commented Dec 29, 2020

What does this PR do?

This PR aims to fix the TF Longformer version in order to make it graph compliant. As seen offline with @patrickvonplaten all_global_attentions now is added in the output when output_attentions=True. The global attentions are filled with zeros in case is_global_attn is False (see line 897 in TFLongformerSelfAttention.

Fix issue

#9333

if input_ids is not None:
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)
if input_ids is not None:
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When padding_len==0, then this won't change the input_ids correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct!

@@ -2171,16 +2171,14 @@ def call(

# set global attention on question tokens
if inputs["global_attention_mask"] is None and inputs["input_ids"] is not None:
if inputs["input_ids"] is None:
logger.warning(
"It is not possible to automatically generate the `global_attention_mask`. Please make sure that it is correctly set."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we leave this warning?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can but here the problem is that we test if inputs["input_ids"] is None inside a if that already test if inputs["input_ids"] is not None, this seems strange

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed with @jplu here ;-)

):
logger.warning(
f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error."
)
inputs["global_attention_mask"] = tf.fill(shape_list(inputs["input_ids"]), value=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't look correct to me. The "default" global_attention_mask is all 0s so I think it should be:

inputs["global_attention_mask"] = tf.fill(shape_list(inputs["input_ids"]), value=0)

Also we could improve the warning a bit by appending a sentence like. Disabling global attention for this forward pass...

@@ -1523,8 +1523,7 @@ def call(
training=False,
):
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_global_attentions = () if (output_attentions and is_global_attn) else None
all_attentions = all_global_attentions = () if output_attentions else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the general direction of this PR! Here we should also run the slow tests to be sure nothing is broken.

IMO, the only thing left to do is to correct the "default" global attention to all 0's instead of 1's (global_attention_mask is different from attention_mask)

@jplu
Copy link
Contributor Author

jplu commented Dec 30, 2020

I have already ran the slow tests as well and they all pass!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for fixing!

@@ -2171,16 +2171,14 @@ def call(

# set global attention on question tokens
if inputs["global_attention_mask"] is None and inputs["input_ids"] is not None:
if inputs["input_ids"] is None:
logger.warning(
"It is not possible to automatically generate the `global_attention_mask`. Please make sure that it is correctly set."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed with @jplu here ;-)

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @jplu!

@LysandreJik LysandreJik merged commit 83eec97 into huggingface:master Jan 5, 2021
@jplu jplu deleted the fix-tf-longformer branch January 5, 2021 09:37
guyrosin pushed a commit to guyrosin/transformers that referenced this pull request Jan 15, 2021
* Fix longformer

* Apply style

* Remove serving content

* Forgot a condition

* Apply style

* Address Patrick's comments

* Fix dtype
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants