Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TF Funnel #9300

Merged
merged 5 commits into from
Jan 5, 2021
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 29 additions & 22 deletions src/transformers/models/funnel/modeling_tf_funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def init_attention_inputs(self, inputs_embeds, attention_mask=None, token_type_i
# inputs_embeds has shape batch_size x seq_len x d_model
# attention_mask and token_type_ids have shape batch_size x seq_len
self.pooling_mult = 1
self.seq_len = seq_len = inputs_embeds.shape[1]
self.seq_len = seq_len = shape_list(inputs_embeds)[1]
position_embeds = self.get_position_embeds(seq_len, dtype=inputs_embeds.dtype, training=training)
token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None
cls_mask = (
Expand Down Expand Up @@ -241,7 +241,7 @@ def get_position_embeds(self, seq_len, dtype=tf.float32, training=False):
inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2)))
# Maximum relative positions for the first input
rel_pos_id = tf.range(-seq_len * 2, seq_len * 2, 1.0, dtype=dtype)
zero_offset = seq_len * 2
zero_offset = seq_len * tf.constant(2)
sinusoid = tf.einsum("i,d->id", rel_pos_id, inv_freq)
sin_embed = self.sin_dropout(tf.sin(sinusoid), training=training)
cos_embed = self.cos_dropout(tf.cos(sinusoid), training=training)
Expand All @@ -257,16 +257,17 @@ def get_position_embeds(self, seq_len, dtype=tf.float32, training=False):
# For block_index = 0 we only need the second one and leave the first one as None.

# First type
if block_index == 0:
position_embeds_pooling = None
else:
position_embeds_pooling = tf.fill([1], value=-1.0)

if block_index != 0:
pooled_pos = self.stride_pool_pos(pos, block_index)

# construct rel_pos_id
stride = 2 ** (block_index - 1)
rel_pos = self.relative_pos(pos, stride, pooled_pos, shift=2)
# rel_pos = tf.expand_dims(rel_pos,1) + zero_offset
# rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model))
rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype)
rel_pos = rel_pos + zero_offset
position_embeds_pooling = tf.gather(pos_embed, rel_pos, axis=0)

Expand All @@ -277,6 +278,7 @@ def get_position_embeds(self, seq_len, dtype=tf.float32, training=False):

# rel_pos = tf.expand_dims(rel_pos,1) + zero_offset
# rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model))
rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype)
rel_pos = rel_pos + zero_offset
position_embeds_no_pooling = tf.gather(pos_embed, rel_pos, axis=0)

Expand All @@ -298,19 +300,19 @@ def stride_pool_pos(self, pos_id, block_index):
else:
return pos_id[::2]

def relative_pos(self, pos, stride, pooled_pos=None, shift=1):
def relative_pos(self, pos, stride, pooled_pos=None, shift=1.0):
"""
Build the relative positional vector between `pos` and `pooled_pos`.
"""
if pooled_pos is None:
pooled_pos = pos

ref_point = pooled_pos[0] - pos[0]
num_remove = shift * pooled_pos.shape[0]
num_remove = shift * tf.cast(shape_list(pooled_pos)[0], dtype=ref_point.dtype)
max_dist = ref_point + num_remove * stride
min_dist = pooled_pos[0] - pos[-1]

return tf.range(max_dist, min_dist - 1, -stride, dtype=tf.int64)
return tf.range(max_dist, min_dist - 1, -stride)

def stride_pool(self, tensor, axis):
"""
Expand All @@ -330,7 +332,7 @@ def stride_pool(self, tensor, axis):
return type(tensor)(self.stride_pool(x, axis) for x in tensor)

# Deal with negative axis
axis %= tensor.shape.ndims
axis %= len(shape_list(tensor))

axis_slice = slice(None, -1, 2) if self.separate_cls and self.truncate_seq else slice(None, None, 2)
enc_slice = [slice(None)] * axis + [axis_slice]
Expand All @@ -352,7 +354,7 @@ def pool_tensor(self, tensor, mode="mean", stride=2):
suffix = tensor[:, :-1] if self.truncate_seq else tensor
tensor = tf.concat([tensor[:, :1], suffix], axis=1)

ndim = tensor.shape.ndims
ndim = len(shape_list(tensor))
if ndim == 2:
tensor = tensor[:, :, None]

Expand All @@ -367,25 +369,24 @@ def pool_tensor(self, tensor, mode="mean", stride=2):

return tf.squeeze(tensor, 2) if ndim == 2 else tensor

def pre_attention_pooling(self, output, attention_inputs):
def pre_attention_pooling(self, attention_inputs):
""" Pool `output` and the proper parts of `attention_inputs` before the attention layer. """
position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
if self.pool_q_only:
if self.attention_type == "factorized":
position_embeds = self.stride_pool(position_embeds[:2], 0) + position_embeds[2:]
token_type_mat = self.stride_pool(token_type_mat, 1)
cls_mask = self.stride_pool(cls_mask, 0)
output = self.pool_tensor(output, mode=self.pooling_type)
else:
self.pooling_mult *= 2
if self.attention_type == "factorized":
position_embeds = self.stride_pool(position_embeds, 0)
token_type_mat = self.stride_pool(token_type_mat, [1, 2])
cls_mask = self.stride_pool(cls_mask, [1, 2])
attention_mask = self.pool_tensor(attention_mask, mode="min")
output = self.pool_tensor(output, mode=self.pooling_type)

attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask)
return output, attention_inputs
return attention_inputs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change this function input and return? The tuple always has the same length so there is no reason for it to be incompatible with graph mode, no? The output can be pooled outside of the test to avoid the repetition of the same line of code, but otherwise, I'd prefer to keep the same logic as the PT implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand but this is not graph compliant. The reason is because pooled_hidden is not defined outside the if pooling_flag line 655. So to fix this I had to export the pool_tensor call outside the if.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If pooled_hidden cannot be set outside if has to be deleted. Basically, either it has a value in both branches (same shape + same dtype) either it should not be set anywhere. Any idea what could be a "dummy" value?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not used if pooling_flag is not used, so a tensor with a 0 can be a good dummy value (if it needs the same shape, a tensor of the right shape).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect! Doing the update.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't succeed to guess the shape of pooled_hidden without calling pool_tensor. Any idea how to do this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pooled_hidden = tf.zeros(shape_list(hidden)) seems to work perfectly fine 👍

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should have the same shape as the hidden state if no pooling is done (so on the else side).


def post_attention_pooling(self, attention_inputs):
""" Pool the proper parts of `attention_inputs` after the attention layer. """
Expand Down Expand Up @@ -485,10 +486,15 @@ def relative_positional_attention(self, position_embeds, q_head, context_len, cl
"bind,jd->bnij", q_r_attention_2, omega
)
else:
shift = 2 if q_head.shape[1] != context_len else 1
# Notations from the paper, appending A.2.1, final formula (https://arxiv.org/abs/2006.03236)
# Grab the proper positional encoding, shape max_rel_len x d_model
r = position_embeds[self.block_index][shift - 1]
# shift = 2 if shape_list(q_head)[1] != context_len else 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's clean the comment if it's not needed.

if shape_list(q_head)[1] != context_len:
shift = 2
r = position_embeds[self.block_index][1]
else:
shift = 1
r = position_embeds[self.block_index][0]
Comment on lines +492 to +497
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the shift variable isn't used after? In this case, we can just remove it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used after see line 508.

# Shape n_head x d_head
v = self.r_r_bias * self.scale
# Shape d_model x n_head x d_head
Expand Down Expand Up @@ -517,7 +523,7 @@ def relative_token_type_attention(self, token_type_mat, q_head, cls_mask=None):
# Shape batch_size x n_head x seq_len x 2
token_type_bias = tf.einsum("bind,snd->bnis", q_head + r_s_bias, self.seg_embed)
# Shape batch_size x n_head x seq_len x context_len
new_shape = [batch_size, q_head.shape[2], seq_len, context_len]
new_shape = [batch_size, shape_list(q_head)[2], seq_len, context_len]
token_type_mat = tf.broadcast_to(token_type_mat[:, None], new_shape)
# Shapes batch_size x n_head x seq_len
diff_token_type, same_token_type = tf.split(token_type_bias, 2, axis=-1)
Expand All @@ -536,7 +542,7 @@ def call(self, query, key, value, attention_inputs, output_attentions=False, tra
position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs

batch_size, seq_len, _ = shape_list(query)
context_len = key.shape[1]
context_len = shape_list(key)[1]
n_head, d_head = self.n_head, self.d_head

# Shape batch_size x seq_len x n_head x d_head
Expand Down Expand Up @@ -652,10 +658,11 @@ def call(
for block_index, block in enumerate(self.blocks):
pooling_flag = shape_list(hidden)[1] > (2 if self.separate_cls else 1)
pooling_flag = pooling_flag and block_index > 0
pooled_hidden = self.attention_structure.pool_tensor(hidden, mode=self.attention_structure.pooling_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to slightly change the logic here - is this OK?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a comment here would be great

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem was coming from shift = 2 if q_head.shape[1] != context_len else 1. Here, shift in graph mode becomes an undefined tensor because it can be either 1 or 2. Which makes the line r = position_embeds[self.block_index][shift - 1] impossible to be compiled because in a graph, TF cannot mix tensors and numbers, here (shift -1)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the method self.attention_structure.pre_attention_pooling is not changed as suggested above, this line should be removed (not sure how it's linked to the code you mention which is in self.attention_structure.post_attention_pooling).

Otherwise, the line should be inside the test if pooling_flag otherwise it breaks the current behavior.


if pooling_flag:
pooled_hidden, attention_inputs = self.attention_structure.pre_attention_pooling(
hidden, attention_inputs
)
attention_inputs = self.attention_structure.pre_attention_pooling(attention_inputs)

for (layer_index, layer) in enumerate(block):
for repeat_index in range(self.block_repeats[block_index]):
do_pooling = (repeat_index == 0) and (layer_index == 0) and pooling_flag
Expand Down Expand Up @@ -724,7 +731,7 @@ def call(
upsampled_hidden = upsample(
final_hidden,
stride=self.stride,
target_len=first_block_hidden.shape[1],
target_len=shape_list(first_block_hidden)[1],
separate_cls=self.separate_cls,
truncate_seq=self.truncate_seq,
)
Expand Down