-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix TF Funnel #9300
Fix TF Funnel #9300
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -185,7 +185,7 @@ def init_attention_inputs(self, inputs_embeds, attention_mask=None, token_type_i | |
# inputs_embeds has shape batch_size x seq_len x d_model | ||
# attention_mask and token_type_ids have shape batch_size x seq_len | ||
self.pooling_mult = 1 | ||
self.seq_len = seq_len = inputs_embeds.shape[1] | ||
self.seq_len = seq_len = shape_list(inputs_embeds)[1] | ||
position_embeds = self.get_position_embeds(seq_len, dtype=inputs_embeds.dtype, training=training) | ||
token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None | ||
cls_mask = ( | ||
|
@@ -241,7 +241,7 @@ def get_position_embeds(self, seq_len, dtype=tf.float32, training=False): | |
inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2))) | ||
# Maximum relative positions for the first input | ||
rel_pos_id = tf.range(-seq_len * 2, seq_len * 2, 1.0, dtype=dtype) | ||
zero_offset = seq_len * 2 | ||
zero_offset = seq_len * tf.constant(2) | ||
sinusoid = tf.einsum("i,d->id", rel_pos_id, inv_freq) | ||
sin_embed = self.sin_dropout(tf.sin(sinusoid), training=training) | ||
cos_embed = self.cos_dropout(tf.cos(sinusoid), training=training) | ||
|
@@ -257,16 +257,17 @@ def get_position_embeds(self, seq_len, dtype=tf.float32, training=False): | |
# For block_index = 0 we only need the second one and leave the first one as None. | ||
|
||
# First type | ||
if block_index == 0: | ||
position_embeds_pooling = None | ||
else: | ||
position_embeds_pooling = tf.fill([1], value=-1.0) | ||
|
||
if block_index != 0: | ||
pooled_pos = self.stride_pool_pos(pos, block_index) | ||
|
||
# construct rel_pos_id | ||
stride = 2 ** (block_index - 1) | ||
rel_pos = self.relative_pos(pos, stride, pooled_pos, shift=2) | ||
# rel_pos = tf.expand_dims(rel_pos,1) + zero_offset | ||
# rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model)) | ||
rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype) | ||
rel_pos = rel_pos + zero_offset | ||
position_embeds_pooling = tf.gather(pos_embed, rel_pos, axis=0) | ||
|
||
|
@@ -277,6 +278,7 @@ def get_position_embeds(self, seq_len, dtype=tf.float32, training=False): | |
|
||
# rel_pos = tf.expand_dims(rel_pos,1) + zero_offset | ||
# rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model)) | ||
rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype) | ||
rel_pos = rel_pos + zero_offset | ||
position_embeds_no_pooling = tf.gather(pos_embed, rel_pos, axis=0) | ||
|
||
|
@@ -298,19 +300,19 @@ def stride_pool_pos(self, pos_id, block_index): | |
else: | ||
return pos_id[::2] | ||
|
||
def relative_pos(self, pos, stride, pooled_pos=None, shift=1): | ||
def relative_pos(self, pos, stride, pooled_pos=None, shift=1.0): | ||
""" | ||
Build the relative positional vector between `pos` and `pooled_pos`. | ||
""" | ||
if pooled_pos is None: | ||
pooled_pos = pos | ||
|
||
ref_point = pooled_pos[0] - pos[0] | ||
num_remove = shift * pooled_pos.shape[0] | ||
num_remove = shift * tf.cast(shape_list(pooled_pos)[0], dtype=ref_point.dtype) | ||
max_dist = ref_point + num_remove * stride | ||
min_dist = pooled_pos[0] - pos[-1] | ||
|
||
return tf.range(max_dist, min_dist - 1, -stride, dtype=tf.int64) | ||
return tf.range(max_dist, min_dist - 1, -stride) | ||
|
||
def stride_pool(self, tensor, axis): | ||
""" | ||
|
@@ -330,7 +332,7 @@ def stride_pool(self, tensor, axis): | |
return type(tensor)(self.stride_pool(x, axis) for x in tensor) | ||
|
||
# Deal with negative axis | ||
axis %= tensor.shape.ndims | ||
axis %= len(shape_list(tensor)) | ||
|
||
axis_slice = slice(None, -1, 2) if self.separate_cls and self.truncate_seq else slice(None, None, 2) | ||
enc_slice = [slice(None)] * axis + [axis_slice] | ||
|
@@ -352,7 +354,7 @@ def pool_tensor(self, tensor, mode="mean", stride=2): | |
suffix = tensor[:, :-1] if self.truncate_seq else tensor | ||
tensor = tf.concat([tensor[:, :1], suffix], axis=1) | ||
|
||
ndim = tensor.shape.ndims | ||
ndim = len(shape_list(tensor)) | ||
if ndim == 2: | ||
tensor = tensor[:, :, None] | ||
|
||
|
@@ -367,25 +369,24 @@ def pool_tensor(self, tensor, mode="mean", stride=2): | |
|
||
return tf.squeeze(tensor, 2) if ndim == 2 else tensor | ||
|
||
def pre_attention_pooling(self, output, attention_inputs): | ||
def pre_attention_pooling(self, attention_inputs): | ||
""" Pool `output` and the proper parts of `attention_inputs` before the attention layer. """ | ||
position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs | ||
if self.pool_q_only: | ||
if self.attention_type == "factorized": | ||
position_embeds = self.stride_pool(position_embeds[:2], 0) + position_embeds[2:] | ||
token_type_mat = self.stride_pool(token_type_mat, 1) | ||
cls_mask = self.stride_pool(cls_mask, 0) | ||
output = self.pool_tensor(output, mode=self.pooling_type) | ||
else: | ||
self.pooling_mult *= 2 | ||
if self.attention_type == "factorized": | ||
position_embeds = self.stride_pool(position_embeds, 0) | ||
token_type_mat = self.stride_pool(token_type_mat, [1, 2]) | ||
cls_mask = self.stride_pool(cls_mask, [1, 2]) | ||
attention_mask = self.pool_tensor(attention_mask, mode="min") | ||
output = self.pool_tensor(output, mode=self.pooling_type) | ||
|
||
attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask) | ||
return output, attention_inputs | ||
return attention_inputs | ||
|
||
def post_attention_pooling(self, attention_inputs): | ||
""" Pool the proper parts of `attention_inputs` after the attention layer. """ | ||
|
@@ -485,10 +486,15 @@ def relative_positional_attention(self, position_embeds, q_head, context_len, cl | |
"bind,jd->bnij", q_r_attention_2, omega | ||
) | ||
else: | ||
shift = 2 if q_head.shape[1] != context_len else 1 | ||
# Notations from the paper, appending A.2.1, final formula (https://arxiv.org/abs/2006.03236) | ||
# Grab the proper positional encoding, shape max_rel_len x d_model | ||
r = position_embeds[self.block_index][shift - 1] | ||
# shift = 2 if shape_list(q_head)[1] != context_len else 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's clean the comment if it's not needed. |
||
if shape_list(q_head)[1] != context_len: | ||
shift = 2 | ||
r = position_embeds[self.block_index][1] | ||
else: | ||
shift = 1 | ||
r = position_embeds[self.block_index][0] | ||
Comment on lines
+492
to
+497
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like the shift variable isn't used after? In this case, we can just remove it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is used after see line 508. |
||
# Shape n_head x d_head | ||
v = self.r_r_bias * self.scale | ||
# Shape d_model x n_head x d_head | ||
|
@@ -517,7 +523,7 @@ def relative_token_type_attention(self, token_type_mat, q_head, cls_mask=None): | |
# Shape batch_size x n_head x seq_len x 2 | ||
token_type_bias = tf.einsum("bind,snd->bnis", q_head + r_s_bias, self.seg_embed) | ||
# Shape batch_size x n_head x seq_len x context_len | ||
new_shape = [batch_size, q_head.shape[2], seq_len, context_len] | ||
new_shape = [batch_size, shape_list(q_head)[2], seq_len, context_len] | ||
token_type_mat = tf.broadcast_to(token_type_mat[:, None], new_shape) | ||
# Shapes batch_size x n_head x seq_len | ||
diff_token_type, same_token_type = tf.split(token_type_bias, 2, axis=-1) | ||
|
@@ -536,7 +542,7 @@ def call(self, query, key, value, attention_inputs, output_attentions=False, tra | |
position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs | ||
|
||
batch_size, seq_len, _ = shape_list(query) | ||
context_len = key.shape[1] | ||
context_len = shape_list(key)[1] | ||
n_head, d_head = self.n_head, self.d_head | ||
|
||
# Shape batch_size x seq_len x n_head x d_head | ||
|
@@ -652,10 +658,11 @@ def call( | |
for block_index, block in enumerate(self.blocks): | ||
pooling_flag = shape_list(hidden)[1] > (2 if self.separate_cls else 1) | ||
pooling_flag = pooling_flag and block_index > 0 | ||
pooled_hidden = self.attention_structure.pool_tensor(hidden, mode=self.attention_structure.pooling_type) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this seems to slightly change the logic here - is this OK? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe a comment here would be great There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem was coming from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the method Otherwise, the line should be inside the test |
||
|
||
if pooling_flag: | ||
pooled_hidden, attention_inputs = self.attention_structure.pre_attention_pooling( | ||
hidden, attention_inputs | ||
) | ||
attention_inputs = self.attention_structure.pre_attention_pooling(attention_inputs) | ||
|
||
for (layer_index, layer) in enumerate(block): | ||
for repeat_index in range(self.block_repeats[block_index]): | ||
do_pooling = (repeat_index == 0) and (layer_index == 0) and pooling_flag | ||
|
@@ -724,7 +731,7 @@ def call( | |
upsampled_hidden = upsample( | ||
final_hidden, | ||
stride=self.stride, | ||
target_len=first_block_hidden.shape[1], | ||
target_len=shape_list(first_block_hidden)[1], | ||
separate_cls=self.separate_cls, | ||
truncate_seq=self.truncate_seq, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change this function input and return? The tuple always has the same length so there is no reason for it to be incompatible with graph mode, no? The
output
can be pooled outside of the test to avoid the repetition of the same line of code, but otherwise, I'd prefer to keep the same logic as the PT implementation.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand but this is not graph compliant. The reason is because
pooled_hidden
is not defined outside theif pooling_flag
line 655. So to fix this I had to export thepool_tensor
call outside theif
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If
pooled_hidden
cannot be set outside if has to be deleted. Basically, either it has a value in both branches (same shape + same dtype) either it should not be set anywhere. Any idea what could be a "dummy" value?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not used if
pooling_flag
is not used, so a tensor with a 0 can be a good dummy value (if it needs the same shape, a tensor of the right shape).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect! Doing the update.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't succeed to guess the shape of
pooled_hidden
without callingpool_tensor
. Any idea how to do this?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pooled_hidden = tf.zeros(shape_list(hidden))
seems to work perfectly fine 👍There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it should have the same shape as the hidden state if no pooling is done (so on the else side).