Skip to content

Commit

Permalink
Update the MHA layer to respect the dtypes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 552523481
  • Loading branch information
qlzh727 authored and tensorflower-gardener committed Jul 31, 2023
1 parent 21c25fd commit 397ad57
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
13 changes: 7 additions & 6 deletions keras/layers/attention/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ def _get_common_kwargs_for_sublayer(self):
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
dtype=self._dtype_policy,
)
# Create new clone of kernel/bias initializer, so that we don't reuse
# the initializer instance, which could lead to same init value since
Expand Down Expand Up @@ -474,8 +475,12 @@ def _build_attention(self, rank):
attn_scores_rank - len(self._attention_axes), attn_scores_rank
)
)
self._softmax = activation.Softmax(axis=norm_axes)
self._dropout_layer = regularization.Dropout(rate=self._dropout)
self._softmax = activation.Softmax(
axis=norm_axes, dtype=self._dtype_policy
)
self._dropout_layer = regularization.Dropout(
rate=self._dropout, dtype=self._dtype_policy
)

def _masked_softmax(self, attention_scores, attention_mask=None):
# Normalize the attention scores to probabilities.
Expand Down Expand Up @@ -525,17 +530,14 @@ def _compute_attention(
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum(self._dot_product_equation, key, query)

attention_scores = self._masked_softmax(
attention_scores, attention_mask
)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_scores_dropout = self._dropout_layer(
attention_scores, training=training
)

# `context_layer` = [B, T, N, H]
attention_output = tf.einsum(
self._combine_equation, attention_scores_dropout, value
Expand Down Expand Up @@ -702,7 +704,6 @@ def _compute_causal_mask(self, query, value=None):
)

def compute_output_shape(self, query_shape, value_shape, key_shape=None):

if key_shape is None:
key_shape = value_shape

Expand Down
20 changes: 20 additions & 0 deletions keras/layers/attention/multi_head_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,26 @@ def test_initializer(self):
keras.backend.eval(test_layer._output_dense.kernel),
)

@parameterized.named_parameters(
("bfloat16", tf.bfloat16),
("float16", tf.float16),
("float32", tf.float32),
("float64", tf.float64),
)
def test_sublayer_dtypes(self, dtype):
test_layer = keras.layers.MultiHeadAttention(
num_heads=12, key_dim=64, dtype=dtype
)

query = keras.Input(shape=(40, 80), dtype=dtype)
# Build the layer
test_layer(query=query, value=query)

self.assertEqual(test_layer._query_dense.dtype, dtype)
self.assertEqual(test_layer._key_dense.dtype, dtype)
self.assertEqual(test_layer._value_dense.dtype, dtype)
self.assertEqual(test_layer._output_dense.dtype, dtype)

def test_masked_attention_with_scores(self):
"""Test with a mask tensor."""
test_layer = keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
Expand Down

0 comments on commit 397ad57

Please sign in to comment.