Skip to content

Commit

Permalink
Add a mixed precision test and fix mixed precision errors for layers
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Sep 14, 2023
1 parent afe0432 commit fc4f0f5
Show file tree
Hide file tree
Showing 12 changed files with 112 additions and 19 deletions.
4 changes: 4 additions & 0 deletions keras_nlp/layers/modeling/cached_multi_head_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.backend import config
from keras_nlp.backend import ops
from keras_nlp.layers.modeling.cached_multi_head_attention import (
CachedMultiHeadAttention,
Expand All @@ -34,6 +35,9 @@ def test_layer_behaviors(self):
expected_output_shape=(2, 4, 6),
expected_num_trainable_weights=8,
expected_num_non_trainable_variables=1,
# tf.keras does not handle mixed precision correctly when not set
# globally.
run_mixed_precision_check=config.multi_backend(),
)

def test_cache_call_is_correct(self):
Expand Down
15 changes: 13 additions & 2 deletions keras_nlp/layers/modeling/f_net_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,13 @@ def build(self, inputs_shape):
# Layer Norm layers.
self._mixing_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="mixing_layer_norm",
)
self._mixing_layer_norm.build(inputs_shape)
self._output_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="output_layer_norm",
)
self._output_layer_norm.build(inputs_shape)
Expand All @@ -112,19 +114,25 @@ def build(self, inputs_shape):
activation=self.activation,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
dtype=self.dtype_policy,
name="intermediate_dense",
)
self._intermediate_dense.build(inputs_shape)
self._output_dense = keras.layers.Dense(
feature_size,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
dtype=self.dtype_policy,
name="output_dense",
)
self._output_dense.build(
self._intermediate_dense.compute_output_shape(inputs_shape)
)
self._output_dropout = keras.layers.Dropout(rate=self.dropout)
self._output_dropout = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
name="output_dropout",
)
self.built = True

def call(self, inputs):
Expand All @@ -140,9 +148,12 @@ def call(self, inputs):

def fourier_transform(input):
# Apply FFT on the input and take the real part.
input_dtype = input.dtype
# FFT transforms do not support float16.
input = ops.cast(input, "float32")
real_in, imaginary_in = (input, ops.zeros_like(input))
real_out, _ = ops.fft2((real_in, imaginary_in))
return real_out
return ops.cast(real_out, input_dtype)

def add_and_norm(input1, input2, norm_layer):
return norm_layer(input1 + input2)
Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/layers/modeling/masked_lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,12 @@ def build(self, inputs_shape, mask_positions_shape=None):
activation=self.intermediate_activation,
kernel_initializer=self.kernel_initializer,
bias_initializer=self.bias_initializer,
dtype=self.dtype_policy,
name="intermediate_dense",
)
self._intermediate_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="intermediate_layer_norm",
)
# The gather length does not affect any of our built variables, so
Expand Down Expand Up @@ -185,6 +187,7 @@ def call(self, inputs, mask_positions):
outputs = self.token_embedding(x, reverse=True)
else:
outputs = ops.matmul(x, self._kernel)
outputs = ops.cast(outputs, self.compute_dtype)
outputs = outputs + self._bias

# Apply a final activation.
Expand Down
16 changes: 11 additions & 5 deletions keras_nlp/layers/modeling/reversible_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class ReversibleEmbedding(keras.layers.Embedding):
the `embeddings` matrix (see `keras.constraints`).
mask_zero: Boolean, whether or not the input value 0 is a special
"padding" value that should be masked out.
reverse_dtype: The dtype for the reverse projection computation.
For stability, it is usually best to use full precision even when
working with half or mixed precision training.
Call args:
inputs: The tensor inputs to the layer.
Expand Down Expand Up @@ -87,6 +90,7 @@ def __init__(
embeddings_regularizer=None,
embeddings_constraint=None,
mask_zero=False,
reverse_dtype="float32",
**kwargs,
):
super().__init__(
Expand All @@ -99,6 +103,7 @@ def __init__(
**kwargs,
)
self.tie_weights = tie_weights
self.reverse_dtype = reverse_dtype

def build(self, inputs_shape=None):
super().build(inputs_shape)
Expand All @@ -114,12 +119,12 @@ def build(self, inputs_shape=None):
def call(self, inputs, reverse=False):
if reverse:
if self.tie_weights:
reverse_embeddings = ops.transpose(
ops.convert_to_tensor(self.embeddings)
)
kernel = ops.transpose(ops.convert_to_tensor(self.embeddings))
else:
reverse_embeddings = self.reverse_embeddings
return ops.matmul(inputs, reverse_embeddings)
kernel = self.reverse_embeddings
inputs = ops.cast(inputs, self.reverse_dtype)
kernel = ops.cast(kernel, self.reverse_dtype)
return ops.matmul(inputs, kernel)

return super().call(inputs)

Expand All @@ -128,6 +133,7 @@ def get_config(self):
config.update(
{
"tie_weights": self.tie_weights,
"reverse_dtype": self.reverse_dtype,
}
)
return config
Expand Down
12 changes: 12 additions & 0 deletions keras_nlp/layers/modeling/reversible_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,15 @@ def test_tied_checkpoint_untied_weights(self):

input_data = ops.ones(shape=(4, 10), dtype="int32")
self.assertAllClose(untied_model(input_data), tied_model(input_data))

def test_reverse_dtype(self):
embedding = ReversibleEmbedding(100, 16, reverse_dtype="float32")
input_data = ops.ones(shape=(4, 10, 16))
output_data = embedding(input_data, reverse=True)
self.assertEqual(output_data.shape, (4, 10, 100))
self.assertDType(output_data, "float32")
embedding = ReversibleEmbedding(100, 16, reverse_dtype="float16")
input_data = ops.ones(shape=(4, 10, 16))
output_data = embedding(input_data, reverse=True)
self.assertEqual(output_data.shape, (4, 10, 100))
self.assertDType(output_data, "float16")
2 changes: 2 additions & 0 deletions keras_nlp/layers/modeling/token_and_position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,13 @@ def __init__(
self.embeddings_initializer
),
mask_zero=mask_zero,
dtype=self.dtype_policy,
name="token_embedding",
)
self.position_embedding = PositionEmbedding(
sequence_length=sequence_length,
initializer=clone_initializer(self.embeddings_initializer),
dtype=self.dtype_policy,
name="position_embedding",
)
self.supports_masking = self.token_embedding.supports_masking
Expand Down
22 changes: 17 additions & 5 deletions keras_nlp/layers/modeling/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,14 @@ def __init__(
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
normalize_first=False,
name=None,
**kwargs,
):
# Work around for model saving, we need to ensure our model is built
# immediately after restoring from config.
decoder_sequence_shape = kwargs.pop("decoder_sequence_shape", None)
encoder_sequence_shape = kwargs.pop("encoder_sequence_shape", None)

super().__init__(name=name, **kwargs)
super().__init__(**kwargs)
self.intermediate_dim = intermediate_dim
self.num_heads = num_heads
self.dropout = dropout
Expand Down Expand Up @@ -160,6 +159,7 @@ def build(
dropout=self.dropout,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
dtype=self.dtype_policy,
name="self_attention",
)
if hasattr(self._self_attention_layer, "_build_from_signature"):
Expand All @@ -174,11 +174,14 @@ def build(
)
self._self_attention_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="self_attention_layer_norm",
)
self._self_attention_layer_norm.build(decoder_sequence_shape)
self._self_attention_dropout = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
name="self_attention_dropout",
)

# Cross attention layers are optional.
Expand All @@ -191,6 +194,7 @@ def build(
dropout=self.dropout,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
dtype=self.dtype_policy,
name="cross_attention",
)
if hasattr(self._cross_attention_layer, "_build_from_signature"):
Expand All @@ -205,11 +209,14 @@ def build(
)
self._cross_attention_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="cross_attention_layer_norm",
)
self._cross_attention_layer_norm.build(encoder_sequence_shape)
self._cross_attention_dropout = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
name="cross_attention_dropout",
)

# Feedforward layers.
Expand All @@ -218,25 +225,30 @@ def build(
activation=self.activation,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
name="intermediate_dense",
dtype=self.dtype_policy,
name="feedforward_intermediate_dense",
)
self._feedforward_intermediate_dense.build(decoder_sequence_shape)
self._feedforward_output_dense = keras.layers.Dense(
hidden_dim,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
name="output_dense",
dtype=self.dtype_policy,
name="feedforward_output_dense",
)
intermediate_shape = list(decoder_sequence_shape)
intermediate_shape[-1] = self.intermediate_dim
self._feedforward_output_dense.build(tuple(intermediate_shape))
self._feedforward_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
name="output_layer_norm",
dtype=self.dtype_policy,
name="feedforward_layer_norm",
)
self._feedforward_layer_norm.build(decoder_sequence_shape)
self._feedforward_dropout = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
name="feedforward_dropout",
)
# Create layers based on input shape.
self.built = True
Expand Down
12 changes: 10 additions & 2 deletions keras_nlp/layers/modeling/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,9 @@ def __init__(
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
normalize_first=False,
name=None,
**kwargs,
):
super().__init__(name=name, **kwargs)
super().__init__(**kwargs)
self.intermediate_dim = intermediate_dim
self.num_heads = num_heads
self.dropout = dropout
Expand Down Expand Up @@ -125,6 +124,7 @@ def build(self, inputs_shape):
dropout=self.dropout,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
dtype=self.dtype_policy,
name="self_attention_layer",
)
if hasattr(self._self_attention_layer, "_build_from_signature"):
Expand All @@ -139,38 +139,46 @@ def build(self, inputs_shape):
)
self._self_attention_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="self_attention_layer_norm",
)
self._self_attention_layer_norm.build(inputs_shape)
self._self_attention_dropout = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
name="self_attention_dropout",
)

# Feedforward layers.
self._feedforward_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="feedforward_layer_norm",
)
self._feedforward_layer_norm.build(inputs_shape)
self._feedforward_intermediate_dense = keras.layers.Dense(
self.intermediate_dim,
activation=self.activation,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
dtype=self.dtype_policy,
name="feedforward_intermediate_dense",
)
self._feedforward_intermediate_dense.build(inputs_shape)
self._feedforward_output_dense = keras.layers.Dense(
hidden_dim,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
dtype=self.dtype_policy,
name="feedforward_output_dense",
)
intermediate_shape = list(inputs_shape)
intermediate_shape[-1] = self.intermediate_dim
self._feedforward_output_dense.build(tuple(intermediate_shape))
self._feedforward_dropout = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
name="feedforward_dropout",
)
self.built = True

Expand Down
3 changes: 1 addition & 2 deletions keras_nlp/samplers/beam_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import tree

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.sampler import call_args_docstring
Expand Down Expand Up @@ -161,7 +160,7 @@ def body(prompt, cache, index, log_probs):
# Compute the softmax distribution for the next token.
logits, _, cache = next(prompt, cache, index)
vocab_size = ops.shape(logits)[-1]
probs = keras.activations.softmax(logits / self.temperature)
probs = self.compute_probabilities(logits)

# Compute the running log-likelihood of each new candidate.
next_log_probs = ops.log(probs) + log_probs[..., None]
Expand Down
3 changes: 1 addition & 2 deletions keras_nlp/samplers/contrastive_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import tree

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.sampler import call_args_docstring
Expand Down Expand Up @@ -133,7 +132,7 @@ def cond(prompt, cache, index, logits, hidden_states):

def body(prompt, cache, index, logits, hidden_states):
# Compute the softmax distribution for the next token.
probabilities = keras.activations.softmax(logits / self.temperature)
probabilities = self.compute_probabilities(logits)

# Replicate for `self.k` times to find the best token in top-k
# candidates.
Expand Down
Loading

0 comments on commit fc4f0f5

Please sign in to comment.