Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making TF MPNet model compliant with XLA #10260

Merged
merged 3 commits into from
Feb 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions src/transformers/models/mpnet/modeling_tf_mpnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,15 +348,22 @@ def __init__(self, config, **kwargs):
self.n_heads = config.num_attention_heads
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.initializer_range = config.initializer_range

self.layer = [TFMPNetLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
self.relative_attention_bias = tf.keras.layers.Embedding(
config.relative_attention_num_buckets,
self.n_heads,
name="relative_attention_bias",
)
self.relative_attention_num_buckets = config.relative_attention_num_buckets

def build(self, input_shape):
with tf.name_scope("relative_attention_bias"):
self.relative_attention_bias = self.add_weight(
name="embeddings",
shape=[self.relative_attention_num_buckets, self.n_heads],
initializer=get_initializer(self.initializer_range),
)

return super().build(input_shape)

def call(
self,
hidden_states,
Expand Down Expand Up @@ -405,18 +412,16 @@ def _relative_position_bucket(relative_position, num_buckets=32, max_distance=12
n = -relative_position

num_buckets //= 2
ret += tf.dtypes.cast(tf.math.less(n, 0), tf.int32) * num_buckets
ret += tf.cast(tf.math.less(n, 0), dtype=relative_position.dtype) * num_buckets
n = tf.math.abs(n)

# now n is in the range [0, inf)
max_exact = num_buckets // 2
is_small = tf.math.less(n, max_exact)

val_if_large = max_exact + tf.dtypes.cast(
tf.math.log(tf.dtypes.cast(n, tf.float32) / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact),
tf.int32,
val_if_large = max_exact + tf.cast(
tf.math.log(n / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact),
dtype=relative_position.dtype,
)

val_if_large = tf.math.minimum(val_if_large, num_buckets - 1)
Expand All @@ -441,7 +446,7 @@ def compute_position_bias(self, x, position_ids=None):
relative_position,
num_buckets=self.relative_attention_num_buckets,
)
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
values = tf.gather(self.relative_attention_bias, rp_bucket) # shape (qlen, klen, num_heads)
values = tf.expand_dims(tf.transpose(values, [2, 0, 1]), axis=0) # shape (1, num_heads, qlen, klen)
return values

Expand Down Expand Up @@ -541,7 +546,9 @@ def call(
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = tf.cast(extended_attention_mask, embedding_output.dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
Expand Down
4 changes: 0 additions & 4 deletions tests/test_modeling_tf_mpnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,6 @@ def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mpnet_for_token_classification(*config_and_inputs)

def test_xla_mode(self):
# TODO JP: Make MPNet XLA compliant
pass

@slow
def test_model_from_pretrained(self):
for model_name in ["microsoft/mpnet-base"]:
Expand Down