From 124948290fceefc400edf2556409d60c7cc08c98 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 4 Jun 2020 13:47:03 +0200 Subject: [PATCH] fix bert and gpt2 tf --- src/transformers/modeling_tf_bert.py | 52 ++++++++++++++++++-------- src/transformers/modeling_tf_ctrl.py | 16 ++++---- src/transformers/modeling_tf_gpt2.py | 54 ++++++++++++++------------- src/transformers/modeling_tf_t5.py | 10 +---- src/transformers/modeling_tf_utils.py | 21 +++++++++++ 5 files changed, 95 insertions(+), 58 deletions(-) diff --git a/src/transformers/modeling_tf_bert.py b/src/transformers/modeling_tf_bert.py index 25d17374fda5b6..d492bd8ebe0f16 100644 --- a/src/transformers/modeling_tf_bert.py +++ b/src/transformers/modeling_tf_bert.py @@ -23,7 +23,13 @@ from .configuration_bert import BertConfig from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list +from .modeling_tf_utils import ( + TFPreTrainedModel, + get_initializer, + keras_serializable, + shape_list, + cast_bool_to_primitive, +) from .tokenization_utils import BatchEncoding @@ -224,8 +230,8 @@ def transpose_for_scores(self, x, batch_size): x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size)) return tf.transpose(x, perm=[0, 2, 1, 3]) - def call(self, inputs, training=False, output_attentions=False): - hidden_states, attention_mask, head_mask = inputs + def call(self, inputs, training=False): + hidden_states, attention_mask, head_mask, output_attentions = inputs batch_size = shape_list(hidden_states)[0] mixed_query_layer = self.query(hidden_states) @@ -265,7 +271,10 @@ def call(self, inputs, training=False, output_attentions=False): context_layer, (batch_size, -1, self.all_head_size) ) # (batch_size, seq_len_q, all_head_size) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + outputs = ( + (context_layer, attention_probs) if cast_bool_to_primitive(output_attentions) is True else (context_layer,) + ) + return outputs @@ -297,9 +306,11 @@ def prune_heads(self, heads): raise NotImplementedError def call(self, inputs, training=False): - input_tensor, attention_mask, head_mask = inputs + input_tensor, attention_mask, head_mask, output_attentions = inputs - self_outputs = self.self_attention([input_tensor, attention_mask, head_mask], training=training) + self_outputs = self.self_attention( + [input_tensor, attention_mask, head_mask, output_attentions], training=training + ) attention_output = self.dense_output([self_outputs[0], input_tensor], training=training) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs @@ -348,9 +359,11 @@ def __init__(self, config, **kwargs): self.bert_output = TFBertOutput(config, name="output") def call(self, inputs, training=False): - hidden_states, attention_mask, head_mask = inputs + hidden_states, attention_mask, head_mask, output_attentions = inputs - attention_outputs = self.attention([hidden_states, attention_mask, head_mask], training=training) + attention_outputs = self.attention( + [hidden_states, attention_mask, head_mask, output_attentions], training=training + ) attention_output = attention_outputs[0] intermediate_output = self.intermediate(attention_output) layer_output = self.bert_output([intermediate_output, attention_output], training=training) @@ -364,8 +377,8 @@ def __init__(self, config, **kwargs): self.output_hidden_states = config.output_hidden_states self.layer = [TFBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)] - def call(self, inputs, training=False, output_attentions=False): - hidden_states, attention_mask, head_mask = inputs + def call(self, inputs, training=False): + hidden_states, attention_mask, head_mask, output_attentions = inputs all_hidden_states = () all_attentions = () @@ -373,10 +386,12 @@ def call(self, inputs, training=False, output_attentions=False): if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - layer_outputs = layer_module([hidden_states, attention_mask, head_mask[i]], training=training) + layer_outputs = layer_module( + [hidden_states, attention_mask, head_mask[i], output_attentions], training=training + ) hidden_states = layer_outputs[0] - if output_attentions: + if cast_bool_to_primitive(output_attentions) is True: all_attentions = all_attentions + (layer_outputs[1],) # Add last layer @@ -386,7 +401,7 @@ def call(self, inputs, training=False, output_attentions=False): outputs = (hidden_states,) if self.output_hidden_states: outputs = outputs + (all_hidden_states,) - if output_attentions: + if cast_bool_to_primitive(output_attentions) is True: outputs = outputs + (all_attentions,) return outputs # outputs, (hidden states), (attentions) @@ -504,6 +519,7 @@ def call( position_ids=None, head_mask=None, inputs_embeds=None, + output_attentions=False, training=False, ): if isinstance(inputs, (tuple, list)): @@ -513,7 +529,8 @@ def call( position_ids = inputs[3] if len(inputs) > 3 else position_ids head_mask = inputs[4] if len(inputs) > 4 else head_mask inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds - assert len(inputs) <= 6, "Too many inputs." + output_attentions = inputs[6] if len(inputs) > 6 else output_attentions + assert len(inputs) <= 7, "Too many inputs." elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask", attention_mask) @@ -521,7 +538,8 @@ def call( position_ids = inputs.get("position_ids", position_ids) head_mask = inputs.get("head_mask", head_mask) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) - assert len(inputs) <= 6, "Too many inputs." + output_attentions = inputs.get("output_attentions", output_attentions) + assert len(inputs) <= 7, "Too many inputs." else: input_ids = inputs @@ -567,7 +585,9 @@ def call( # head_mask = tf.constant([0] * self.num_hidden_layers) embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training) - encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training) + encoder_outputs = self.encoder( + [embedding_output, extended_attention_mask, head_mask, output_attentions], training=training + ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) diff --git a/src/transformers/modeling_tf_ctrl.py b/src/transformers/modeling_tf_ctrl.py index b61155d612f960..631df0693a049f 100644 --- a/src/transformers/modeling_tf_ctrl.py +++ b/src/transformers/modeling_tf_ctrl.py @@ -23,7 +23,13 @@ from .configuration_ctrl import CTRLConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list +from .modeling_tf_utils import ( + TFPreTrainedModel, + TFSharedEmbeddings, + keras_serializable, + shape_list, + cast_bool_to_primitive, +) from .tokenization_utils import BatchEncoding @@ -113,13 +119,7 @@ def call(self, inputs, training=False, output_attentions=False): v = tf.concat((past_value, v), axis=-2) # to cope with keras serialization - # we need to cast `use_cache` to correct bool - # if it is a tensor - if tf.is_tensor(use_cache): - if hasattr(use_cache, "numpy"): - use_cache = bool(use_cache.numpy()) - else: - use_cache = True + use_cache = cast_bool_to_primitive(use_cache) if use_cache is True: present = tf.stack((k, v), axis=0) diff --git a/src/transformers/modeling_tf_gpt2.py b/src/transformers/modeling_tf_gpt2.py index 781754d4b67c26..fe776f9f91d28d 100644 --- a/src/transformers/modeling_tf_gpt2.py +++ b/src/transformers/modeling_tf_gpt2.py @@ -31,6 +31,7 @@ get_initializer, keras_serializable, shape_list, + cast_bool_to_primitive, ) from .tokenization_utils import BatchEncoding @@ -91,8 +92,8 @@ def causal_attention_mask(nd, ns, dtype): m = i >= j - ns + nd return tf.cast(m, dtype) - def _attn(self, inputs, training=False, output_attentions=False): - q, k, v, attention_mask, head_mask = inputs + def _attn(self, inputs, training=False): + q, k, v, attention_mask, head_mask, output_attentions = inputs # q, k, v have shape [batch, heads, sequence, features] w = tf.matmul(q, k, transpose_b=True) if self.scale: @@ -117,7 +118,7 @@ def _attn(self, inputs, training=False, output_attentions=False): w = w * head_mask outputs = [tf.matmul(w, v)] - if output_attentions: + if cast_bool_to_primitive(output_attentions) is True: outputs.append(w) return outputs @@ -133,8 +134,8 @@ def split_heads(self, x): x = tf.reshape(x, new_x_shape) return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features) - def call(self, inputs, training=False, output_attentions=False): - x, layer_past, attention_mask, head_mask, use_cache = inputs + def call(self, inputs, training=False): + x, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs x = self.c_attn(x) query, key, value = tf.split(x, 3, axis=2) @@ -147,22 +148,12 @@ def call(self, inputs, training=False, output_attentions=False): value = tf.concat([past_value, value], axis=-2) # to cope with keras serialization - # we need to cast `use_cache` to correct bool - # if it is a tensor - if tf.is_tensor(use_cache): - if hasattr(use_cache, "numpy"): - use_cache = bool(use_cache.numpy()) - else: - use_cache = True - - if use_cache is True: + if cast_bool_to_primitive(use_cache, True) is True: present = tf.stack([key, value], axis=0) else: present = (None,) - attn_outputs = self._attn( - [query, key, value, attention_mask, head_mask], training=training, output_attentions=output_attentions - ) + attn_outputs = self._attn([query, key, value, attention_mask, head_mask, output_attentions], training=training) a = attn_outputs[0] a = self.merge_heads(a) @@ -199,10 +190,12 @@ def __init__(self, n_ctx, config, scale=False, **kwargs): self.mlp = TFMLP(4 * nx, config, name="mlp") def call(self, inputs, training=False): - x, layer_past, attention_mask, head_mask, use_cache = inputs + x, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs a = self.ln_1(x) - output_attn = self.attn([a, layer_past, attention_mask, head_mask, use_cache], training=training) + output_attn = self.attn( + [a, layer_past, attention_mask, head_mask, use_cache, output_attentions], training=training + ) a = output_attn[0] # output_attn: a, present, (attentions) x = x + a @@ -272,7 +265,8 @@ def call( head_mask = inputs[5] if len(inputs) > 5 else head_mask inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds use_cache = inputs[7] if len(inputs) > 7 else use_cache - assert len(inputs) <= 8, "Too many inputs." + output_attentions = inputs[8] if len(inputs) > 7 else output_attentions + assert len(inputs) <= 9, "Too many inputs." elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids") past = inputs.get("past", past) @@ -282,7 +276,8 @@ def call( head_mask = inputs.get("head_mask", head_mask) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) use_cache = inputs.get("use_cache", use_cache) - assert len(inputs) <= 8, "Too many inputs." + output_attentions = inputs.get("output_attentions", output_attentions) + assert len(inputs) <= 9, "Too many inputs." else: input_ids = inputs @@ -356,12 +351,15 @@ def call( if self.output_hidden_states: all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) - outputs = block([hidden_states, layer_past, attention_mask, head_mask[i], use_cache], training=training) + outputs = block( + [hidden_states, layer_past, attention_mask, head_mask[i], use_cache, output_attentions], + training=training, + ) hidden_states, present = outputs[:2] presents = presents + (present,) - if output_attentions: + if cast_bool_to_primitive(output_attentions) is True: all_attentions.append(outputs[2]) hidden_states = self.ln_f(hidden_states) @@ -377,7 +375,7 @@ def call( outputs = outputs + (presents,) if self.output_hidden_states: outputs = outputs + (all_hidden_states,) - if output_attentions: + if cast_bool_to_primitive(output_attentions) is True: # let the number of heads free (-1) so we can extract attention even after head pruning attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) @@ -615,6 +613,7 @@ def call( inputs_embeds=None, mc_token_ids=None, use_cache=True, + output_attentions=False, training=False, ): r""" @@ -682,7 +681,8 @@ def call( inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds mc_token_ids = inputs[7] if len(inputs) > 7 else mc_token_ids use_cache = inputs[8] if len(inputs) > 8 else use_cache - assert len(inputs) <= 9, "Too many inputs." + output_attentions = inputs[9] if len(inputs) > 8 else output_attentions + assert len(inputs) <= 10, "Too many inputs." elif isinstance(inputs, dict): input_ids = inputs.get("input_ids") past = inputs.get("past", past) @@ -693,7 +693,8 @@ def call( inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) mc_token_ids = inputs.get("mc_token_ids", mc_token_ids) use_cache = inputs.get("use_cache", use_cache) - assert len(inputs) <= 9, "Too many inputs." + output_attentions = inputs.get("output_attentions", output_attentions) + assert len(inputs) <= 10, "Too many inputs." else: input_ids = inputs @@ -718,6 +719,7 @@ def call( head_mask, inputs_embeds, use_cache, + output_attentions, ] transformer_outputs = self.transformer(flat_inputs, training=training) diff --git a/src/transformers/modeling_tf_t5.py b/src/transformers/modeling_tf_t5.py index 5d8c2e9f347493..95bc8053ba1129 100644 --- a/src/transformers/modeling_tf_t5.py +++ b/src/transformers/modeling_tf_t5.py @@ -25,7 +25,7 @@ from .configuration_t5 import T5Config from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list +from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list, cast_bool_to_primitive logger = logging.getLogger(__name__) @@ -249,13 +249,7 @@ def unshape(x): k, v = past_key_value_state # to cope with keras serialization - # we need to cast `use_cache` to correct bool - # if it is a tensor - if tf.is_tensor(use_cache): - if hasattr(use_cache, "numpy"): - use_cache = bool(use_cache.numpy()) - else: - use_cache = True + use_cache = cast_bool_to_primitive(use_cache) if self.is_decoder and use_cache is True: present_key_value_state = ((k, v),) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index e0bbeb0d0bd963..405ee7bee9643a 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1694,3 +1694,24 @@ def get_initializer(initializer_range=0.02): TruncatedNormal initializer with stddev = `initializer_range`. """ return tf.keras.initializers.TruncatedNormal(stddev=initializer_range) + + +def cast_bool_to_primitive(bool_variable, default_tensor_to_true=False): + """Function arguments can be inserted as boolean tensor + and bool variables to cope with keras serialization + we need to cast `output_attentions` to correct bool + if it is a tensor + + Args: + default_tensor_to_true: bool, if tensor should default to True + in case tensor has no numpy attribute + """ + # if bool variable is tensor and has numpy value + if tf.is_tensor(bool_variable): + if hasattr(bool_variable, "numpy"): + return bool(bool_variable.numpy()) + elif default_tensor_to_true: + return True + + # else variable is bool + return bool_variable