-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix TF Funnel #9300
Fix TF Funnel #9300
Conversation
if block_index == 0: | ||
position_embeds_pooling = None | ||
else: | ||
position_embeds_pooling = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
position_embeds_pooling seems to be only used at line 288. IMO it makes more sense to have a single if-else statement further below (at line 287):
if block_index != 0:
position_embeds_pooling = tf.gather(pos_embed, rel_pos, axis=0)
else:
position_embeds_pooling = tf.fill(shape_list(position_embeds_no_pooling), value=-1.0)
@@ -652,10 +661,11 @@ def call( | |||
for block_index, block in enumerate(self.blocks): | |||
pooling_flag = shape_list(hidden)[1] > (2 if self.separate_cls else 1) | |||
pooling_flag = pooling_flag and block_index > 0 | |||
pooled_hidden = self.attention_structure.pool_tensor(hidden, mode=self.attention_structure.pooling_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems to slightly change the logic here - is this OK?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe a comment here would be great
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem was coming from shift = 2 if q_head.shape[1] != context_len else 1
. Here, shift
in graph mode becomes an undefined tensor because it can be either 1 or 2. Which makes the line r = position_embeds[self.block_index][shift - 1]
impossible to be compiled because in a graph, TF cannot mix tensors and numbers, here (shift -1
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the method self.attention_structure.pre_attention_pooling
is not changed as suggested above, this line should be removed (not sure how it's linked to the code you mention which is in self.attention_structure.post_attention_pooling
).
Otherwise, the line should be inside the test if pooling_flag
otherwise it breaks the current behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think this looks ok for me! We should wait for @sgugger feedback here though....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing! There is one change of behavior we should revert (last of my comments), left a few other comments.
attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask) | ||
return output, attention_inputs | ||
return attention_inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change this function input and return? The tuple always has the same length so there is no reason for it to be incompatible with graph mode, no? The output
can be pooled outside of the test to avoid the repetition of the same line of code, but otherwise, I'd prefer to keep the same logic as the PT implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand but this is not graph compliant. The reason is because pooled_hidden
is not defined outside the if pooling_flag
line 655. So to fix this I had to export the pool_tensor
call outside the if
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If pooled_hidden
cannot be set outside if has to be deleted. Basically, either it has a value in both branches (same shape + same dtype) either it should not be set anywhere. Any idea what could be a "dummy" value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not used if pooling_flag
is not used, so a tensor with a 0 can be a good dummy value (if it needs the same shape, a tensor of the right shape).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect! Doing the update.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't succeed to guess the shape of pooled_hidden
without calling pool_tensor
. Any idea how to do this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pooled_hidden = tf.zeros(shape_list(hidden))
seems to work perfectly fine 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it should have the same shape as the hidden state if no pooling is done (so on the else side).
# Notations from the paper, appending A.2.1, final formula (https://arxiv.org/abs/2006.03236) | ||
# Grab the proper positional encoding, shape max_rel_len x d_model | ||
r = position_embeds[self.block_index][shift - 1] | ||
# shift = 2 if shape_list(q_head)[1] != context_len else 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's clean the comment if it's not needed.
if shape_list(q_head)[1] != context_len: | ||
shift = 2 | ||
r = position_embeds[self.block_index][1] | ||
else: | ||
shift = 1 | ||
r = position_embeds[self.block_index][0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the shift variable isn't used after? In this case, we can just remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is used after see line 508.
@@ -652,10 +661,11 @@ def call( | |||
for block_index, block in enumerate(self.blocks): | |||
pooling_flag = shape_list(hidden)[1] > (2 if self.separate_cls else 1) | |||
pooling_flag = pooling_flag and block_index > 0 | |||
pooled_hidden = self.attention_structure.pool_tensor(hidden, mode=self.attention_structure.pooling_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the method self.attention_structure.pre_attention_pooling
is not changed as suggested above, this line should be removed (not sure how it's linked to the code you mention which is in self.attention_structure.post_attention_pooling
).
Otherwise, the line should be inside the test if pooling_flag
otherwise it breaks the current behavior.
@LysandreJik feel free to merge if it looks ok for you and if @sgugger approves the last fix on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks perfect now, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
What does this PR do?
This PR fixes Funnel to make it full graph compliant. Even though all the slow/quick tests are passing and got similar results with few experiements, @sgugger I would appreciate that you thoroughly look at the changes in order to be sure no bugs have been introduced.