Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TF Funnel #9300

Merged
merged 5 commits into from
Jan 5, 2021
Merged

Fix TF Funnel #9300

merged 5 commits into from
Jan 5, 2021

Conversation

jplu
Copy link
Contributor

@jplu jplu commented Dec 24, 2020

What does this PR do?

This PR fixes Funnel to make it full graph compliant. Even though all the slow/quick tests are passing and got similar results with few experiements, @sgugger I would appreciate that you thoroughly look at the changes in order to be sure no bugs have been introduced.

if block_index == 0:
position_embeds_pooling = None
else:
position_embeds_pooling = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

position_embeds_pooling seems to be only used at line 288. IMO it makes more sense to have a single if-else statement further below (at line 287):

if block_index != 0:
    position_embeds_pooling = tf.gather(pos_embed, rel_pos, axis=0) 
else:
    position_embeds_pooling = tf.fill(shape_list(position_embeds_no_pooling), value=-1.0)

@@ -652,10 +661,11 @@ def call(
for block_index, block in enumerate(self.blocks):
pooling_flag = shape_list(hidden)[1] > (2 if self.separate_cls else 1)
pooling_flag = pooling_flag and block_index > 0
pooled_hidden = self.attention_structure.pool_tensor(hidden, mode=self.attention_structure.pooling_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to slightly change the logic here - is this OK?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a comment here would be great

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem was coming from shift = 2 if q_head.shape[1] != context_len else 1. Here, shift in graph mode becomes an undefined tensor because it can be either 1 or 2. Which makes the line r = position_embeds[self.block_index][shift - 1] impossible to be compiled because in a graph, TF cannot mix tensors and numbers, here (shift -1)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the method self.attention_structure.pre_attention_pooling is not changed as suggested above, this line should be removed (not sure how it's linked to the code you mention which is in self.attention_structure.post_attention_pooling).

Otherwise, the line should be inside the test if pooling_flag otherwise it breaks the current behavior.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think this looks ok for me! We should wait for @sgugger feedback here though....

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing! There is one change of behavior we should revert (last of my comments), left a few other comments.

attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask)
return output, attention_inputs
return attention_inputs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change this function input and return? The tuple always has the same length so there is no reason for it to be incompatible with graph mode, no? The output can be pooled outside of the test to avoid the repetition of the same line of code, but otherwise, I'd prefer to keep the same logic as the PT implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand but this is not graph compliant. The reason is because pooled_hidden is not defined outside the if pooling_flag line 655. So to fix this I had to export the pool_tensor call outside the if.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If pooled_hidden cannot be set outside if has to be deleted. Basically, either it has a value in both branches (same shape + same dtype) either it should not be set anywhere. Any idea what could be a "dummy" value?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not used if pooling_flag is not used, so a tensor with a 0 can be a good dummy value (if it needs the same shape, a tensor of the right shape).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect! Doing the update.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't succeed to guess the shape of pooled_hidden without calling pool_tensor. Any idea how to do this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pooled_hidden = tf.zeros(shape_list(hidden)) seems to work perfectly fine 👍

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should have the same shape as the hidden state if no pooling is done (so on the else side).

# Notations from the paper, appending A.2.1, final formula (https://arxiv.org/abs/2006.03236)
# Grab the proper positional encoding, shape max_rel_len x d_model
r = position_embeds[self.block_index][shift - 1]
# shift = 2 if shape_list(q_head)[1] != context_len else 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's clean the comment if it's not needed.

Comment on lines +492 to +497
if shape_list(q_head)[1] != context_len:
shift = 2
r = position_embeds[self.block_index][1]
else:
shift = 1
r = position_embeds[self.block_index][0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the shift variable isn't used after? In this case, we can just remove it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used after see line 508.

@@ -652,10 +661,11 @@ def call(
for block_index, block in enumerate(self.blocks):
pooling_flag = shape_list(hidden)[1] > (2 if self.separate_cls else 1)
pooling_flag = pooling_flag and block_index > 0
pooled_hidden = self.attention_structure.pool_tensor(hidden, mode=self.attention_structure.pooling_type)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the method self.attention_structure.pre_attention_pooling is not changed as suggested above, this line should be removed (not sure how it's linked to the code you mention which is in self.attention_structure.post_attention_pooling).

Otherwise, the line should be inside the test if pooling_flag otherwise it breaks the current behavior.

@jplu
Copy link
Contributor Author

jplu commented Jan 4, 2021

@LysandreJik feel free to merge if it looks ok for you and if @sgugger approves the last fix on pooled_hidden.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks perfect now, thanks!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@LysandreJik LysandreJik merged commit 52d62e6 into huggingface:master Jan 5, 2021
@jplu jplu deleted the fix-tf-funnel branch January 5, 2021 11:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants