From 50697e555414878a2087c9d5d173e8d9c9e3570f Mon Sep 17 00:00:00 2001 From: Dave Spencer Date: Fri, 4 Aug 2023 18:22:40 +0000 Subject: [PATCH] Update MaskedLMHead to support dtype=bfloat16/float16/float64. --- keras_nlp/layers/modeling/masked_lm_head.py | 9 ++++ .../layers/modeling/masked_lm_head_test.py | 53 +++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/keras_nlp/layers/modeling/masked_lm_head.py b/keras_nlp/layers/modeling/masked_lm_head.py index 0bb0a421c..81b3924c1 100644 --- a/keras_nlp/layers/modeling/masked_lm_head.py +++ b/keras_nlp/layers/modeling/masked_lm_head.py @@ -153,9 +153,11 @@ def build(self, inputs_shape, masked_positions_shape=None): activation=self.intermediate_activation, kernel_initializer=self.kernel_initializer, bias_initializer=self.bias_initializer, + dtype=self._dtype_policy, ) self._layer_norm = keras.layers.LayerNormalization( epsilon=self.layer_norm_epsilon, + dtype=self._dtype_policy, ) if masked_positions_shape: gather_length = masked_positions_shape[1] @@ -181,18 +183,25 @@ def call(self, inputs, masked_positions): # Gather the encoded tokens at the masked indices. masked_positions = ops.expand_dims(masked_positions, axis=-1) x = ops.take_along_axis(inputs, masked_positions, axis=1) + print("XXX/1", x.dtype) # Apply a trainable linear transformation and a layer norm. x = self._dense(x) + print("XXX/2", x.dtype) x = self._layer_norm(x) + print("XXX/3", x.dtype) # Transform encodings to vocabulary_size predictions. if self.embedding_weights is None: kernel = self._kernel + print("XXX/4", kernel) else: kernel = ops.cast(self.embedding_weights, self.compute_dtype) + print("XXX/5", kernel) kernel = ops.transpose(kernel) + print("XXX/6", kernel) outputs = ops.matmul(x, kernel) + print("XXX", outputs.dtype, self._bias.dtype) outputs = outputs + self._bias # Apply a final activation. diff --git a/keras_nlp/layers/modeling/masked_lm_head_test.py b/keras_nlp/layers/modeling/masked_lm_head_test.py index f5c3b9d07..c6a701d90 100644 --- a/keras_nlp/layers/modeling/masked_lm_head_test.py +++ b/keras_nlp/layers/modeling/masked_lm_head_test.py @@ -15,6 +15,9 @@ import os +import tensorflow as tf +from absl.testing import parameterized + from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.layers.modeling import masked_lm_head @@ -36,6 +39,30 @@ def test_valid_call(self): position_data = ops.random.randint(minval=0, maxval=10, shape=(4, 5)) model((token_data, position_data)) + @parameterized.named_parameters( + ("bfloat16", tf.bfloat16), + ("float16", tf.float16), + ("float32", tf.float32), + ("float64", tf.float64), + ) + def test_valid_call_with_dtype(self, dtype): + head = masked_lm_head.MaskedLMHead( + vocabulary_size=100, + activation="softmax", + dtype=dtype, + ) + encoded_tokens = keras.Input(shape=(10, 16)) + positions = keras.Input(shape=(5,), dtype="int32") + outputs = head(encoded_tokens, masked_positions=positions) + model = keras.Model((encoded_tokens, positions), outputs) + + token_data = ops.random.uniform(shape=(4, 10, 16)) + position_data = ops.random.randint(minval=0, maxval=10, shape=(4, 5)) + model((token_data, position_data)) + + for w in head.weights: + self.assertEqual(w.dtype, dtype, "Wrong type: " + w.name) + def test_valid_call_with_embedding_weights(self): embedding = keras.layers.Embedding(100, 16) embedding.build((4, 10)) @@ -119,6 +146,32 @@ def test_one_train_step(self): loss = model.train_on_batch(x=(token_data, position_data), y=label_data) self.assertGreater(loss, 0) + @parameterized.named_parameters( + ("bfloat16", tf.bfloat16), + ("float16", tf.float16), + ("float32", tf.float32), + ("float64", tf.float64), + ) + def test_one_train_step_with_dtype(self, dtype): + head = masked_lm_head.MaskedLMHead( + vocabulary_size=100, + dtype=dtype, + ) + encoded_tokens = keras.Input(shape=(10, 16)) + positions = keras.Input(shape=(5,), dtype="int32") + outputs = head(encoded_tokens, masked_positions=positions) + model = keras.Model((encoded_tokens, positions), outputs) + + token_data = ops.random.uniform(shape=(4, 10, 16)) + position_data = ops.random.randint(minval=0, maxval=10, shape=(4, 5)) + label_data = ops.random.randint(minval=0, maxval=2, shape=(4, 5, 1)) + + loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False) + optimizer = keras.optimizers.Adam() + model.compile(loss=loss, optimizer=optimizer) + loss = model.train_on_batch(x=(token_data, position_data), y=label_data) + self.assertGreater(loss, 0) + def test_saved_model(self): head = masked_lm_head.MaskedLMHead( vocabulary_size=100,