Skip to content

Commit

Permalink
Fix TF Funnel (#9300)
Browse files Browse the repository at this point in the history
* Fix Funnel

* Apply Patrick's comment

* Remove comment

* Fix dummy value

* Apply style
  • Loading branch information
jplu authored Jan 5, 2021
1 parent 748006c commit 52d62e6
Showing 1 changed file with 24 additions and 15 deletions.
39 changes: 24 additions & 15 deletions src/transformers/models/funnel/modeling_tf_funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def init_attention_inputs(self, inputs_embeds, attention_mask=None, token_type_i
# inputs_embeds has shape batch_size x seq_len x d_model
# attention_mask and token_type_ids have shape batch_size x seq_len
self.pooling_mult = 1
self.seq_len = seq_len = inputs_embeds.shape[1]
self.seq_len = seq_len = shape_list(inputs_embeds)[1]
position_embeds = self.get_position_embeds(seq_len, dtype=inputs_embeds.dtype, training=training)
token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None
cls_mask = (
Expand Down Expand Up @@ -241,7 +241,7 @@ def get_position_embeds(self, seq_len, dtype=tf.float32, training=False):
inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2)))
# Maximum relative positions for the first input
rel_pos_id = tf.range(-seq_len * 2, seq_len * 2, 1.0, dtype=dtype)
zero_offset = seq_len * 2
zero_offset = seq_len * tf.constant(2)
sinusoid = tf.einsum("i,d->id", rel_pos_id, inv_freq)
sin_embed = self.sin_dropout(tf.sin(sinusoid), training=training)
cos_embed = self.cos_dropout(tf.cos(sinusoid), training=training)
Expand All @@ -257,16 +257,17 @@ def get_position_embeds(self, seq_len, dtype=tf.float32, training=False):
# For block_index = 0 we only need the second one and leave the first one as None.

# First type
if block_index == 0:
position_embeds_pooling = None
else:
position_embeds_pooling = tf.fill([1], value=-1.0)

if block_index != 0:
pooled_pos = self.stride_pool_pos(pos, block_index)

# construct rel_pos_id
stride = 2 ** (block_index - 1)
rel_pos = self.relative_pos(pos, stride, pooled_pos, shift=2)
# rel_pos = tf.expand_dims(rel_pos,1) + zero_offset
# rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model))
rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype)
rel_pos = rel_pos + zero_offset
position_embeds_pooling = tf.gather(pos_embed, rel_pos, axis=0)

Expand All @@ -277,6 +278,7 @@ def get_position_embeds(self, seq_len, dtype=tf.float32, training=False):

# rel_pos = tf.expand_dims(rel_pos,1) + zero_offset
# rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model))
rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype)
rel_pos = rel_pos + zero_offset
position_embeds_no_pooling = tf.gather(pos_embed, rel_pos, axis=0)

Expand All @@ -298,19 +300,19 @@ def stride_pool_pos(self, pos_id, block_index):
else:
return pos_id[::2]

def relative_pos(self, pos, stride, pooled_pos=None, shift=1):
def relative_pos(self, pos, stride, pooled_pos=None, shift=1.0):
"""
Build the relative positional vector between `pos` and `pooled_pos`.
"""
if pooled_pos is None:
pooled_pos = pos

ref_point = pooled_pos[0] - pos[0]
num_remove = shift * pooled_pos.shape[0]
num_remove = shift * tf.cast(shape_list(pooled_pos)[0], dtype=ref_point.dtype)
max_dist = ref_point + num_remove * stride
min_dist = pooled_pos[0] - pos[-1]

return tf.range(max_dist, min_dist - 1, -stride, dtype=tf.int64)
return tf.range(max_dist, min_dist - 1, -stride)

def stride_pool(self, tensor, axis):
"""
Expand All @@ -330,7 +332,7 @@ def stride_pool(self, tensor, axis):
return type(tensor)(self.stride_pool(x, axis) for x in tensor)

# Deal with negative axis
axis %= tensor.shape.ndims
axis %= len(shape_list(tensor))

axis_slice = slice(None, -1, 2) if self.separate_cls and self.truncate_seq else slice(None, None, 2)
enc_slice = [slice(None)] * axis + [axis_slice]
Expand All @@ -352,7 +354,7 @@ def pool_tensor(self, tensor, mode="mean", stride=2):
suffix = tensor[:, :-1] if self.truncate_seq else tensor
tensor = tf.concat([tensor[:, :1], suffix], axis=1)

ndim = tensor.shape.ndims
ndim = len(shape_list(tensor))
if ndim == 2:
tensor = tensor[:, :, None]

Expand Down Expand Up @@ -485,10 +487,14 @@ def relative_positional_attention(self, position_embeds, q_head, context_len, cl
"bind,jd->bnij", q_r_attention_2, omega
)
else:
shift = 2 if q_head.shape[1] != context_len else 1
# Notations from the paper, appending A.2.1, final formula (https://arxiv.org/abs/2006.03236)
# Grab the proper positional encoding, shape max_rel_len x d_model
r = position_embeds[self.block_index][shift - 1]
if shape_list(q_head)[1] != context_len:
shift = 2
r = position_embeds[self.block_index][1]
else:
shift = 1
r = position_embeds[self.block_index][0]
# Shape n_head x d_head
v = self.r_r_bias * self.scale
# Shape d_model x n_head x d_head
Expand Down Expand Up @@ -517,7 +523,7 @@ def relative_token_type_attention(self, token_type_mat, q_head, cls_mask=None):
# Shape batch_size x n_head x seq_len x 2
token_type_bias = tf.einsum("bind,snd->bnis", q_head + r_s_bias, self.seg_embed)
# Shape batch_size x n_head x seq_len x context_len
new_shape = [batch_size, q_head.shape[2], seq_len, context_len]
new_shape = [batch_size, shape_list(q_head)[2], seq_len, context_len]
token_type_mat = tf.broadcast_to(token_type_mat[:, None], new_shape)
# Shapes batch_size x n_head x seq_len
diff_token_type, same_token_type = tf.split(token_type_bias, 2, axis=-1)
Expand All @@ -536,7 +542,7 @@ def call(self, query, key, value, attention_inputs, output_attentions=False, tra
position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs

batch_size, seq_len, _ = shape_list(query)
context_len = key.shape[1]
context_len = shape_list(key)[1]
n_head, d_head = self.n_head, self.d_head

# Shape batch_size x seq_len x n_head x d_head
Expand Down Expand Up @@ -652,10 +658,13 @@ def call(
for block_index, block in enumerate(self.blocks):
pooling_flag = shape_list(hidden)[1] > (2 if self.separate_cls else 1)
pooling_flag = pooling_flag and block_index > 0
pooled_hidden = tf.zeros(shape_list(hidden))

if pooling_flag:
pooled_hidden, attention_inputs = self.attention_structure.pre_attention_pooling(
hidden, attention_inputs
)

for (layer_index, layer) in enumerate(block):
for repeat_index in range(self.block_repeats[block_index]):
do_pooling = (repeat_index == 0) and (layer_index == 0) and pooling_flag
Expand Down Expand Up @@ -724,7 +733,7 @@ def call(
upsampled_hidden = upsample(
final_hidden,
stride=self.stride,
target_len=first_block_hidden.shape[1],
target_len=shape_list(first_block_hidden)[1],
separate_cls=self.separate_cls,
truncate_seq=self.truncate_seq,
)
Expand Down

0 comments on commit 52d62e6

Please sign in to comment.