Skip to content

Commit

Permalink
fix bert and gpt2 tf
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Jun 4, 2020
1 parent b825601 commit 1249482
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 58 deletions.
52 changes: 36 additions & 16 deletions src/transformers/modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@

from .configuration_bert import BertConfig
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
from .modeling_tf_utils import (
TFPreTrainedModel,
get_initializer,
keras_serializable,
shape_list,
cast_bool_to_primitive,
)
from .tokenization_utils import BatchEncoding


Expand Down Expand Up @@ -224,8 +230,8 @@ def transpose_for_scores(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(x, perm=[0, 2, 1, 3])

def call(self, inputs, training=False, output_attentions=False):
hidden_states, attention_mask, head_mask = inputs
def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask, output_attentions = inputs

batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(hidden_states)
Expand Down Expand Up @@ -265,7 +271,10 @@ def call(self, inputs, training=False, output_attentions=False):
context_layer, (batch_size, -1, self.all_head_size)
) # (batch_size, seq_len_q, all_head_size)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
outputs = (
(context_layer, attention_probs) if cast_bool_to_primitive(output_attentions) is True else (context_layer,)
)

return outputs


Expand Down Expand Up @@ -297,9 +306,11 @@ def prune_heads(self, heads):
raise NotImplementedError

def call(self, inputs, training=False):
input_tensor, attention_mask, head_mask = inputs
input_tensor, attention_mask, head_mask, output_attentions = inputs

self_outputs = self.self_attention([input_tensor, attention_mask, head_mask], training=training)
self_outputs = self.self_attention(
[input_tensor, attention_mask, head_mask, output_attentions], training=training
)
attention_output = self.dense_output([self_outputs[0], input_tensor], training=training)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
Expand Down Expand Up @@ -348,9 +359,11 @@ def __init__(self, config, **kwargs):
self.bert_output = TFBertOutput(config, name="output")

def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask = inputs
hidden_states, attention_mask, head_mask, output_attentions = inputs

attention_outputs = self.attention([hidden_states, attention_mask, head_mask], training=training)
attention_outputs = self.attention(
[hidden_states, attention_mask, head_mask, output_attentions], training=training
)
attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output)
layer_output = self.bert_output([intermediate_output, attention_output], training=training)
Expand All @@ -364,19 +377,21 @@ def __init__(self, config, **kwargs):
self.output_hidden_states = config.output_hidden_states
self.layer = [TFBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]

def call(self, inputs, training=False, output_attentions=False):
hidden_states, attention_mask, head_mask = inputs
def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask, output_attentions = inputs

all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

layer_outputs = layer_module([hidden_states, attention_mask, head_mask[i]], training=training)
layer_outputs = layer_module(
[hidden_states, attention_mask, head_mask[i], output_attentions], training=training
)
hidden_states = layer_outputs[0]

if output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
all_attentions = all_attentions + (layer_outputs[1],)

# Add last layer
Expand All @@ -386,7 +401,7 @@ def call(self, inputs, training=False, output_attentions=False):
outputs = (hidden_states,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (all_attentions,)
return outputs # outputs, (hidden states), (attentions)

Expand Down Expand Up @@ -504,6 +519,7 @@ def call(
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=False,
training=False,
):
if isinstance(inputs, (tuple, list)):
Expand All @@ -513,15 +529,17 @@ def call(
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs."
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 6, "Too many inputs."
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs

Expand Down Expand Up @@ -567,7 +585,9 @@ def call(
# head_mask = tf.constant([0] * self.num_hidden_layers)

embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training)
encoder_outputs = self.encoder(
[embedding_output, extended_attention_mask, head_mask, output_attentions], training=training
)

sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
Expand Down
16 changes: 8 additions & 8 deletions src/transformers/modeling_tf_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@

from .configuration_ctrl import CTRLConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list
from .modeling_tf_utils import (
TFPreTrainedModel,
TFSharedEmbeddings,
keras_serializable,
shape_list,
cast_bool_to_primitive,
)
from .tokenization_utils import BatchEncoding


Expand Down Expand Up @@ -113,13 +119,7 @@ def call(self, inputs, training=False, output_attentions=False):
v = tf.concat((past_value, v), axis=-2)

# to cope with keras serialization
# we need to cast `use_cache` to correct bool
# if it is a tensor
if tf.is_tensor(use_cache):
if hasattr(use_cache, "numpy"):
use_cache = bool(use_cache.numpy())
else:
use_cache = True
use_cache = cast_bool_to_primitive(use_cache)

if use_cache is True:
present = tf.stack((k, v), axis=0)
Expand Down
54 changes: 28 additions & 26 deletions src/transformers/modeling_tf_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
get_initializer,
keras_serializable,
shape_list,
cast_bool_to_primitive,
)
from .tokenization_utils import BatchEncoding

Expand Down Expand Up @@ -91,8 +92,8 @@ def causal_attention_mask(nd, ns, dtype):
m = i >= j - ns + nd
return tf.cast(m, dtype)

def _attn(self, inputs, training=False, output_attentions=False):
q, k, v, attention_mask, head_mask = inputs
def _attn(self, inputs, training=False):
q, k, v, attention_mask, head_mask, output_attentions = inputs
# q, k, v have shape [batch, heads, sequence, features]
w = tf.matmul(q, k, transpose_b=True)
if self.scale:
Expand All @@ -117,7 +118,7 @@ def _attn(self, inputs, training=False, output_attentions=False):
w = w * head_mask

outputs = [tf.matmul(w, v)]
if output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
outputs.append(w)
return outputs

Expand All @@ -133,8 +134,8 @@ def split_heads(self, x):
x = tf.reshape(x, new_x_shape)
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)

def call(self, inputs, training=False, output_attentions=False):
x, layer_past, attention_mask, head_mask, use_cache = inputs
def call(self, inputs, training=False):
x, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs

x = self.c_attn(x)
query, key, value = tf.split(x, 3, axis=2)
Expand All @@ -147,22 +148,12 @@ def call(self, inputs, training=False, output_attentions=False):
value = tf.concat([past_value, value], axis=-2)

# to cope with keras serialization
# we need to cast `use_cache` to correct bool
# if it is a tensor
if tf.is_tensor(use_cache):
if hasattr(use_cache, "numpy"):
use_cache = bool(use_cache.numpy())
else:
use_cache = True

if use_cache is True:
if cast_bool_to_primitive(use_cache, True) is True:
present = tf.stack([key, value], axis=0)
else:
present = (None,)

attn_outputs = self._attn(
[query, key, value, attention_mask, head_mask], training=training, output_attentions=output_attentions
)
attn_outputs = self._attn([query, key, value, attention_mask, head_mask, output_attentions], training=training)
a = attn_outputs[0]

a = self.merge_heads(a)
Expand Down Expand Up @@ -199,10 +190,12 @@ def __init__(self, n_ctx, config, scale=False, **kwargs):
self.mlp = TFMLP(4 * nx, config, name="mlp")

def call(self, inputs, training=False):
x, layer_past, attention_mask, head_mask, use_cache = inputs
x, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs

a = self.ln_1(x)
output_attn = self.attn([a, layer_past, attention_mask, head_mask, use_cache], training=training)
output_attn = self.attn(
[a, layer_past, attention_mask, head_mask, use_cache, output_attentions], training=training
)
a = output_attn[0] # output_attn: a, present, (attentions)
x = x + a

Expand Down Expand Up @@ -272,7 +265,8 @@ def call(
head_mask = inputs[5] if len(inputs) > 5 else head_mask
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
use_cache = inputs[7] if len(inputs) > 7 else use_cache
assert len(inputs) <= 8, "Too many inputs."
output_attentions = inputs[8] if len(inputs) > 7 else output_attentions
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
past = inputs.get("past", past)
Expand All @@ -282,7 +276,8 @@ def call(
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache)
assert len(inputs) <= 8, "Too many inputs."
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs

Expand Down Expand Up @@ -356,12 +351,15 @@ def call(
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)

outputs = block([hidden_states, layer_past, attention_mask, head_mask[i], use_cache], training=training)
outputs = block(
[hidden_states, layer_past, attention_mask, head_mask[i], use_cache, output_attentions],
training=training,
)

hidden_states, present = outputs[:2]
presents = presents + (present,)

if output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
all_attentions.append(outputs[2])

hidden_states = self.ln_f(hidden_states)
Expand All @@ -377,7 +375,7 @@ def call(
outputs = outputs + (presents,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
Expand Down Expand Up @@ -615,6 +613,7 @@ def call(
inputs_embeds=None,
mc_token_ids=None,
use_cache=True,
output_attentions=False,
training=False,
):
r"""
Expand Down Expand Up @@ -682,7 +681,8 @@ def call(
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
mc_token_ids = inputs[7] if len(inputs) > 7 else mc_token_ids
use_cache = inputs[8] if len(inputs) > 8 else use_cache
assert len(inputs) <= 9, "Too many inputs."
output_attentions = inputs[9] if len(inputs) > 8 else output_attentions
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, dict):
input_ids = inputs.get("input_ids")
past = inputs.get("past", past)
Expand All @@ -693,7 +693,8 @@ def call(
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
mc_token_ids = inputs.get("mc_token_ids", mc_token_ids)
use_cache = inputs.get("use_cache", use_cache)
assert len(inputs) <= 9, "Too many inputs."
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs

Expand All @@ -718,6 +719,7 @@ def call(
head_mask,
inputs_embeds,
use_cache,
output_attentions,
]

transformer_outputs = self.transformer(flat_inputs, training=training)
Expand Down
10 changes: 2 additions & 8 deletions src/transformers/modeling_tf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from .configuration_t5 import T5Config
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list, cast_bool_to_primitive


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -249,13 +249,7 @@ def unshape(x):
k, v = past_key_value_state

# to cope with keras serialization
# we need to cast `use_cache` to correct bool
# if it is a tensor
if tf.is_tensor(use_cache):
if hasattr(use_cache, "numpy"):
use_cache = bool(use_cache.numpy())
else:
use_cache = True
use_cache = cast_bool_to_primitive(use_cache)

if self.is_decoder and use_cache is True:
present_key_value_state = ((k, v),)
Expand Down
21 changes: 21 additions & 0 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1694,3 +1694,24 @@ def get_initializer(initializer_range=0.02):
TruncatedNormal initializer with stddev = `initializer_range`.
"""
return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)


def cast_bool_to_primitive(bool_variable, default_tensor_to_true=False):
"""Function arguments can be inserted as boolean tensor
and bool variables to cope with keras serialization
we need to cast `output_attentions` to correct bool
if it is a tensor
Args:
default_tensor_to_true: bool, if tensor should default to True
in case tensor has no numpy attribute
"""
# if bool variable is tensor and has numpy value
if tf.is_tensor(bool_variable):
if hasattr(bool_variable, "numpy"):
return bool(bool_variable.numpy())
elif default_tensor_to_true:
return True

# else variable is bool
return bool_variable

0 comments on commit 1249482

Please sign in to comment.