Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF: tf.debugging assertions without tf.running_eagerly() protection #19030

Merged
merged 3 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 42 additions & 62 deletions src/transformers/models/bart/modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)

if tf.executing_eagerly():
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))

# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)

return shifted_input_ids

Expand Down Expand Up @@ -229,49 +228,40 @@ def call(
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)

# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)

if attention_mask is not None:
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)

if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)

attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))

attn_weights = stable_softmax(attn_weights, axis=-1)

if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)

attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
Expand All @@ -281,17 +271,14 @@ def call(
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)

# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)

attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
Expand Down Expand Up @@ -339,14 +326,11 @@ def call(
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)

# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)

hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -776,9 +760,7 @@ def call(
all_attentions = () if output_attentions else None

# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if head_mask is not None and tf.executing_eagerly():
if head_mask is not None:
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
Expand Down Expand Up @@ -983,10 +965,8 @@ def call(
present_key_values = () if use_cache else None

# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly():
if attn_mask is not None:
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),
Expand Down
104 changes: 42 additions & 62 deletions src/transformers/models/blenderbot/modeling_tf_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)

if tf.executing_eagerly():
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))

# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)

return shifted_input_ids

Expand Down Expand Up @@ -225,49 +224,40 @@ def call(
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)

# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)

if attention_mask is not None:
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)

if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)

attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))

attn_weights = stable_softmax(attn_weights, axis=-1)

if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)

attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
Expand All @@ -277,17 +267,14 @@ def call(
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)

# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)

attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
Expand Down Expand Up @@ -337,14 +324,11 @@ def call(
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)

# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)

hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -755,9 +739,7 @@ def call(
all_attentions = () if output_attentions else None

# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if head_mask is not None and tf.executing_eagerly():
if head_mask is not None:
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
Expand Down Expand Up @@ -966,10 +948,8 @@ def call(
present_key_values = () if use_cache else None

# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly():
if attn_mask is not None:
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),
Expand Down
Loading