From 39f49a3e2336dfcc069241814deffe513ccfaaf2 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Thu, 24 Dec 2020 16:51:01 +0100 Subject: [PATCH 1/5] Fix Funnel --- .../models/funnel/modeling_tf_funnel.py | 54 +++++++++++-------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/funnel/modeling_tf_funnel.py b/src/transformers/models/funnel/modeling_tf_funnel.py index 38208112bfff..03828266ae33 100644 --- a/src/transformers/models/funnel/modeling_tf_funnel.py +++ b/src/transformers/models/funnel/modeling_tf_funnel.py @@ -185,7 +185,7 @@ def init_attention_inputs(self, inputs_embeds, attention_mask=None, token_type_i # inputs_embeds has shape batch_size x seq_len x d_model # attention_mask and token_type_ids have shape batch_size x seq_len self.pooling_mult = 1 - self.seq_len = seq_len = inputs_embeds.shape[1] + self.seq_len = seq_len = shape_list(inputs_embeds)[1] position_embeds = self.get_position_embeds(seq_len, dtype=inputs_embeds.dtype, training=training) token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None cls_mask = ( @@ -241,7 +241,7 @@ def get_position_embeds(self, seq_len, dtype=tf.float32, training=False): inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2))) # Maximum relative positions for the first input rel_pos_id = tf.range(-seq_len * 2, seq_len * 2, 1.0, dtype=dtype) - zero_offset = seq_len * 2 + zero_offset = seq_len * tf.constant(2) sinusoid = tf.einsum("i,d->id", rel_pos_id, inv_freq) sin_embed = self.sin_dropout(tf.sin(sinusoid), training=training) cos_embed = self.cos_dropout(tf.cos(sinusoid), training=training) @@ -257,9 +257,9 @@ def get_position_embeds(self, seq_len, dtype=tf.float32, training=False): # For block_index = 0 we only need the second one and leave the first one as None. # First type - if block_index == 0: - position_embeds_pooling = None - else: + position_embeds_pooling = None + + if block_index != 0: pooled_pos = self.stride_pool_pos(pos, block_index) # construct rel_pos_id @@ -267,6 +267,7 @@ def get_position_embeds(self, seq_len, dtype=tf.float32, training=False): rel_pos = self.relative_pos(pos, stride, pooled_pos, shift=2) # rel_pos = tf.expand_dims(rel_pos,1) + zero_offset # rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model)) + rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype) rel_pos = rel_pos + zero_offset position_embeds_pooling = tf.gather(pos_embed, rel_pos, axis=0) @@ -277,9 +278,13 @@ def get_position_embeds(self, seq_len, dtype=tf.float32, training=False): # rel_pos = tf.expand_dims(rel_pos,1) + zero_offset # rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model)) + rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype) rel_pos = rel_pos + zero_offset position_embeds_no_pooling = tf.gather(pos_embed, rel_pos, axis=0) + if position_embeds_pooling is None: + position_embeds_pooling = tf.fill(shape_list(position_embeds_no_pooling), value=-1.0) + position_embeds_list.append([position_embeds_no_pooling, position_embeds_pooling]) return position_embeds_list @@ -298,7 +303,7 @@ def stride_pool_pos(self, pos_id, block_index): else: return pos_id[::2] - def relative_pos(self, pos, stride, pooled_pos=None, shift=1): + def relative_pos(self, pos, stride, pooled_pos=None, shift=1.0): """ Build the relative positional vector between `pos` and `pooled_pos`. """ @@ -306,11 +311,11 @@ def relative_pos(self, pos, stride, pooled_pos=None, shift=1): pooled_pos = pos ref_point = pooled_pos[0] - pos[0] - num_remove = shift * pooled_pos.shape[0] + num_remove = shift * tf.cast(shape_list(pooled_pos)[0], dtype=ref_point.dtype) max_dist = ref_point + num_remove * stride min_dist = pooled_pos[0] - pos[-1] - return tf.range(max_dist, min_dist - 1, -stride, dtype=tf.int64) + return tf.range(max_dist, min_dist - 1, -stride) def stride_pool(self, tensor, axis): """ @@ -330,7 +335,7 @@ def stride_pool(self, tensor, axis): return type(tensor)(self.stride_pool(x, axis) for x in tensor) # Deal with negative axis - axis %= tensor.shape.ndims + axis %= len(shape_list(tensor)) axis_slice = slice(None, -1, 2) if self.separate_cls and self.truncate_seq else slice(None, None, 2) enc_slice = [slice(None)] * axis + [axis_slice] @@ -352,7 +357,7 @@ def pool_tensor(self, tensor, mode="mean", stride=2): suffix = tensor[:, :-1] if self.truncate_seq else tensor tensor = tf.concat([tensor[:, :1], suffix], axis=1) - ndim = tensor.shape.ndims + ndim = len(shape_list(tensor)) if ndim == 2: tensor = tensor[:, :, None] @@ -367,7 +372,7 @@ def pool_tensor(self, tensor, mode="mean", stride=2): return tf.squeeze(tensor, 2) if ndim == 2 else tensor - def pre_attention_pooling(self, output, attention_inputs): + def pre_attention_pooling(self, attention_inputs): """ Pool `output` and the proper parts of `attention_inputs` before the attention layer. """ position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs if self.pool_q_only: @@ -375,7 +380,6 @@ def pre_attention_pooling(self, output, attention_inputs): position_embeds = self.stride_pool(position_embeds[:2], 0) + position_embeds[2:] token_type_mat = self.stride_pool(token_type_mat, 1) cls_mask = self.stride_pool(cls_mask, 0) - output = self.pool_tensor(output, mode=self.pooling_type) else: self.pooling_mult *= 2 if self.attention_type == "factorized": @@ -383,9 +387,9 @@ def pre_attention_pooling(self, output, attention_inputs): token_type_mat = self.stride_pool(token_type_mat, [1, 2]) cls_mask = self.stride_pool(cls_mask, [1, 2]) attention_mask = self.pool_tensor(attention_mask, mode="min") - output = self.pool_tensor(output, mode=self.pooling_type) + attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask) - return output, attention_inputs + return attention_inputs def post_attention_pooling(self, attention_inputs): """ Pool the proper parts of `attention_inputs` after the attention layer. """ @@ -485,10 +489,15 @@ def relative_positional_attention(self, position_embeds, q_head, context_len, cl "bind,jd->bnij", q_r_attention_2, omega ) else: - shift = 2 if q_head.shape[1] != context_len else 1 # Notations from the paper, appending A.2.1, final formula (https://arxiv.org/abs/2006.03236) # Grab the proper positional encoding, shape max_rel_len x d_model - r = position_embeds[self.block_index][shift - 1] + # shift = 2 if shape_list(q_head)[1] != context_len else 1 + if shape_list(q_head)[1] != context_len: + shift = 2 + r = position_embeds[self.block_index][1] + else: + shift = 1 + r = position_embeds[self.block_index][0] # Shape n_head x d_head v = self.r_r_bias * self.scale # Shape d_model x n_head x d_head @@ -517,7 +526,7 @@ def relative_token_type_attention(self, token_type_mat, q_head, cls_mask=None): # Shape batch_size x n_head x seq_len x 2 token_type_bias = tf.einsum("bind,snd->bnis", q_head + r_s_bias, self.seg_embed) # Shape batch_size x n_head x seq_len x context_len - new_shape = [batch_size, q_head.shape[2], seq_len, context_len] + new_shape = [batch_size, shape_list(q_head)[2], seq_len, context_len] token_type_mat = tf.broadcast_to(token_type_mat[:, None], new_shape) # Shapes batch_size x n_head x seq_len diff_token_type, same_token_type = tf.split(token_type_bias, 2, axis=-1) @@ -536,7 +545,7 @@ def call(self, query, key, value, attention_inputs, output_attentions=False, tra position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs batch_size, seq_len, _ = shape_list(query) - context_len = key.shape[1] + context_len = shape_list(key)[1] n_head, d_head = self.n_head, self.d_head # Shape batch_size x seq_len x n_head x d_head @@ -652,10 +661,11 @@ def call( for block_index, block in enumerate(self.blocks): pooling_flag = shape_list(hidden)[1] > (2 if self.separate_cls else 1) pooling_flag = pooling_flag and block_index > 0 + pooled_hidden = self.attention_structure.pool_tensor(hidden, mode=self.attention_structure.pooling_type) + if pooling_flag: - pooled_hidden, attention_inputs = self.attention_structure.pre_attention_pooling( - hidden, attention_inputs - ) + attention_inputs = self.attention_structure.pre_attention_pooling(attention_inputs) + for (layer_index, layer) in enumerate(block): for repeat_index in range(self.block_repeats[block_index]): do_pooling = (repeat_index == 0) and (layer_index == 0) and pooling_flag @@ -724,7 +734,7 @@ def call( upsampled_hidden = upsample( final_hidden, stride=self.stride, - target_len=first_block_hidden.shape[1], + target_len=shape_list(first_block_hidden)[1], separate_cls=self.separate_cls, truncate_seq=self.truncate_seq, ) From 8a1a5c021be55cf9c448c557c766def699b06055 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Mon, 28 Dec 2020 10:46:32 +0100 Subject: [PATCH 2/5] Apply Patrick's comment --- src/transformers/models/funnel/modeling_tf_funnel.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/transformers/models/funnel/modeling_tf_funnel.py b/src/transformers/models/funnel/modeling_tf_funnel.py index 03828266ae33..16bf5ce56726 100644 --- a/src/transformers/models/funnel/modeling_tf_funnel.py +++ b/src/transformers/models/funnel/modeling_tf_funnel.py @@ -257,7 +257,7 @@ def get_position_embeds(self, seq_len, dtype=tf.float32, training=False): # For block_index = 0 we only need the second one and leave the first one as None. # First type - position_embeds_pooling = None + position_embeds_pooling = tf.fill([1], value=-1.0) if block_index != 0: pooled_pos = self.stride_pool_pos(pos, block_index) @@ -282,9 +282,6 @@ def get_position_embeds(self, seq_len, dtype=tf.float32, training=False): rel_pos = rel_pos + zero_offset position_embeds_no_pooling = tf.gather(pos_embed, rel_pos, axis=0) - if position_embeds_pooling is None: - position_embeds_pooling = tf.fill(shape_list(position_embeds_no_pooling), value=-1.0) - position_embeds_list.append([position_embeds_no_pooling, position_embeds_pooling]) return position_embeds_list From e682b34e5306d6000982ca54e13adaf696cad38b Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Mon, 4 Jan 2021 16:25:47 +0100 Subject: [PATCH 3/5] Remove comment --- src/transformers/models/funnel/modeling_tf_funnel.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/funnel/modeling_tf_funnel.py b/src/transformers/models/funnel/modeling_tf_funnel.py index 16bf5ce56726..db4d72a74653 100644 --- a/src/transformers/models/funnel/modeling_tf_funnel.py +++ b/src/transformers/models/funnel/modeling_tf_funnel.py @@ -488,7 +488,6 @@ def relative_positional_attention(self, position_embeds, q_head, context_len, cl else: # Notations from the paper, appending A.2.1, final formula (https://arxiv.org/abs/2006.03236) # Grab the proper positional encoding, shape max_rel_len x d_model - # shift = 2 if shape_list(q_head)[1] != context_len else 1 if shape_list(q_head)[1] != context_len: shift = 2 r = position_embeds[self.block_index][1] From 78d7e1e215034c82f346be26596e67fd15e0868f Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Mon, 4 Jan 2021 17:07:28 +0100 Subject: [PATCH 4/5] Fix dummy value --- .../models/funnel/modeling_tf_funnel.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/funnel/modeling_tf_funnel.py b/src/transformers/models/funnel/modeling_tf_funnel.py index db4d72a74653..1b5e2fc62cad 100644 --- a/src/transformers/models/funnel/modeling_tf_funnel.py +++ b/src/transformers/models/funnel/modeling_tf_funnel.py @@ -369,7 +369,7 @@ def pool_tensor(self, tensor, mode="mean", stride=2): return tf.squeeze(tensor, 2) if ndim == 2 else tensor - def pre_attention_pooling(self, attention_inputs): + def pre_attention_pooling(self, output, attention_inputs): """ Pool `output` and the proper parts of `attention_inputs` before the attention layer. """ position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs if self.pool_q_only: @@ -377,6 +377,7 @@ def pre_attention_pooling(self, attention_inputs): position_embeds = self.stride_pool(position_embeds[:2], 0) + position_embeds[2:] token_type_mat = self.stride_pool(token_type_mat, 1) cls_mask = self.stride_pool(cls_mask, 0) + output = self.pool_tensor(output, mode=self.pooling_type) else: self.pooling_mult *= 2 if self.attention_type == "factorized": @@ -384,9 +385,10 @@ def pre_attention_pooling(self, attention_inputs): token_type_mat = self.stride_pool(token_type_mat, [1, 2]) cls_mask = self.stride_pool(cls_mask, [1, 2]) attention_mask = self.pool_tensor(attention_mask, mode="min") - + output = self.pool_tensor(output, mode=self.pooling_type) attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask) - return attention_inputs + return output, attention_inputs + def post_attention_pooling(self, attention_inputs): """ Pool the proper parts of `attention_inputs` after the attention layer. """ @@ -657,10 +659,12 @@ def call( for block_index, block in enumerate(self.blocks): pooling_flag = shape_list(hidden)[1] > (2 if self.separate_cls else 1) pooling_flag = pooling_flag and block_index > 0 - pooled_hidden = self.attention_structure.pool_tensor(hidden, mode=self.attention_structure.pooling_type) + pooled_hidden = tf.zeros((1)) if pooling_flag: - attention_inputs = self.attention_structure.pre_attention_pooling(attention_inputs) + pooled_hidden, attention_inputs = self.attention_structure.pre_attention_pooling( + hidden, attention_inputs + ) for (layer_index, layer) in enumerate(block): for repeat_index in range(self.block_repeats[block_index]): From de9f43c8055cb7cce7078a1e78c2af3c9b981768 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Mon, 4 Jan 2021 17:53:00 +0100 Subject: [PATCH 5/5] Apply style --- src/transformers/models/funnel/modeling_tf_funnel.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/models/funnel/modeling_tf_funnel.py b/src/transformers/models/funnel/modeling_tf_funnel.py index 1b5e2fc62cad..c9c087578133 100644 --- a/src/transformers/models/funnel/modeling_tf_funnel.py +++ b/src/transformers/models/funnel/modeling_tf_funnel.py @@ -389,7 +389,6 @@ def pre_attention_pooling(self, output, attention_inputs): attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask) return output, attention_inputs - def post_attention_pooling(self, attention_inputs): """ Pool the proper parts of `attention_inputs` after the attention layer. """ position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs @@ -659,7 +658,7 @@ def call( for block_index, block in enumerate(self.blocks): pooling_flag = shape_list(hidden)[1] > (2 if self.separate_cls else 1) pooling_flag = pooling_flag and block_index > 0 - pooled_hidden = tf.zeros((1)) + pooled_hidden = tf.zeros(shape_list(hidden)) if pooling_flag: pooled_hidden, attention_inputs = self.attention_structure.pre_attention_pooling(