Skip to content

Commit

Permalink
Add a mixed precision test and fix mixed precision errors for layers
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Sep 13, 2023
1 parent afe0432 commit 44b1005
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 9 deletions.
4 changes: 4 additions & 0 deletions keras_nlp/layers/modeling/cached_multi_head_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.backend import config
from keras_nlp.backend import ops
from keras_nlp.layers.modeling.cached_multi_head_attention import (
CachedMultiHeadAttention,
Expand All @@ -34,6 +35,9 @@ def test_layer_behaviors(self):
expected_output_shape=(2, 4, 6),
expected_num_trainable_weights=8,
expected_num_non_trainable_variables=1,
# tf.keras does not handle mixed precision correctly when not set
# globally.
run_mixed_precision_check=config.multi_backend(),
)

def test_cache_call_is_correct(self):
Expand Down
15 changes: 13 additions & 2 deletions keras_nlp/layers/modeling/f_net_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,13 @@ def build(self, inputs_shape):
# Layer Norm layers.
self._mixing_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="mixing_layer_norm",
)
self._mixing_layer_norm.build(inputs_shape)
self._output_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="output_layer_norm",
)
self._output_layer_norm.build(inputs_shape)
Expand All @@ -112,19 +114,25 @@ def build(self, inputs_shape):
activation=self.activation,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
dtype=self.dtype_policy,
name="intermediate_dense",
)
self._intermediate_dense.build(inputs_shape)
self._output_dense = keras.layers.Dense(
feature_size,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
dtype=self.dtype_policy,
name="output_dense",
)
self._output_dense.build(
self._intermediate_dense.compute_output_shape(inputs_shape)
)
self._output_dropout = keras.layers.Dropout(rate=self.dropout)
self._output_dropout = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
name="output_dropout",
)
self.built = True

def call(self, inputs):
Expand All @@ -140,9 +148,12 @@ def call(self, inputs):

def fourier_transform(input):
# Apply FFT on the input and take the real part.
input_dtype = input.dtype
# FFT transforms do not support float16.
input = ops.cast(input, "float32")
real_in, imaginary_in = (input, ops.zeros_like(input))
real_out, _ = ops.fft2((real_in, imaginary_in))
return real_out
return ops.cast(real_out, input_dtype)

def add_and_norm(input1, input2, norm_layer):
return norm_layer(input1 + input2)
Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/layers/modeling/masked_lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,12 @@ def build(self, inputs_shape, mask_positions_shape=None):
activation=self.intermediate_activation,
kernel_initializer=self.kernel_initializer,
bias_initializer=self.bias_initializer,
dtype=self.dtype_policy,
name="intermediate_dense",
)
self._intermediate_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="intermediate_layer_norm",
)
# The gather length does not affect any of our built variables, so
Expand Down Expand Up @@ -185,6 +187,7 @@ def call(self, inputs, mask_positions):
outputs = self.token_embedding(x, reverse=True)
else:
outputs = ops.matmul(x, self._kernel)
outputs = ops.cast(outputs, self.compute_dtype)
outputs = outputs + self._bias

# Apply a final activation.
Expand Down
2 changes: 2 additions & 0 deletions keras_nlp/layers/modeling/token_and_position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,13 @@ def __init__(
self.embeddings_initializer
),
mask_zero=mask_zero,
dtype=self.dtype_policy,
name="token_embedding",
)
self.position_embedding = PositionEmbedding(
sequence_length=sequence_length,
initializer=clone_initializer(self.embeddings_initializer),
dtype=self.dtype_policy,
name="position_embedding",
)
self.supports_masking = self.token_embedding.supports_masking
Expand Down
22 changes: 17 additions & 5 deletions keras_nlp/layers/modeling/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,14 @@ def __init__(
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
normalize_first=False,
name=None,
**kwargs,
):
# Work around for model saving, we need to ensure our model is built
# immediately after restoring from config.
decoder_sequence_shape = kwargs.pop("decoder_sequence_shape", None)
encoder_sequence_shape = kwargs.pop("encoder_sequence_shape", None)

super().__init__(name=name, **kwargs)
super().__init__(**kwargs)
self.intermediate_dim = intermediate_dim
self.num_heads = num_heads
self.dropout = dropout
Expand Down Expand Up @@ -160,6 +159,7 @@ def build(
dropout=self.dropout,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
dtype=self.dtype_policy,
name="self_attention",
)
if hasattr(self._self_attention_layer, "_build_from_signature"):
Expand All @@ -174,11 +174,14 @@ def build(
)
self._self_attention_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="self_attention_layer_norm",
)
self._self_attention_layer_norm.build(decoder_sequence_shape)
self._self_attention_dropout = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
name="self_attention_dropout",
)

# Cross attention layers are optional.
Expand All @@ -191,6 +194,7 @@ def build(
dropout=self.dropout,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
dtype=self.dtype_policy,
name="cross_attention",
)
if hasattr(self._cross_attention_layer, "_build_from_signature"):
Expand All @@ -205,11 +209,14 @@ def build(
)
self._cross_attention_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="cross_attention_layer_norm",
)
self._cross_attention_layer_norm.build(encoder_sequence_shape)
self._cross_attention_dropout = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
name="cross_attention_dropout",
)

# Feedforward layers.
Expand All @@ -218,25 +225,30 @@ def build(
activation=self.activation,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
name="intermediate_dense",
dtype=self.dtype_policy,
name="feedforward_intermediate_dense",
)
self._feedforward_intermediate_dense.build(decoder_sequence_shape)
self._feedforward_output_dense = keras.layers.Dense(
hidden_dim,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
name="output_dense",
dtype=self.dtype_policy,
name="feedforward_output_dense",
)
intermediate_shape = list(decoder_sequence_shape)
intermediate_shape[-1] = self.intermediate_dim
self._feedforward_output_dense.build(tuple(intermediate_shape))
self._feedforward_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
name="output_layer_norm",
dtype=self.dtype_policy,
name="feedforward_layer_norm",
)
self._feedforward_layer_norm.build(decoder_sequence_shape)
self._feedforward_dropout = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
name="feedforward_dropout",
)
# Create layers based on input shape.
self.built = True
Expand Down
12 changes: 10 additions & 2 deletions keras_nlp/layers/modeling/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,9 @@ def __init__(
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
normalize_first=False,
name=None,
**kwargs,
):
super().__init__(name=name, **kwargs)
super().__init__(**kwargs)
self.intermediate_dim = intermediate_dim
self.num_heads = num_heads
self.dropout = dropout
Expand Down Expand Up @@ -125,6 +124,7 @@ def build(self, inputs_shape):
dropout=self.dropout,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
dtype=self.dtype_policy,
name="self_attention_layer",
)
if hasattr(self._self_attention_layer, "_build_from_signature"):
Expand All @@ -139,38 +139,46 @@ def build(self, inputs_shape):
)
self._self_attention_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="self_attention_layer_norm",
)
self._self_attention_layer_norm.build(inputs_shape)
self._self_attention_dropout = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
name="self_attention_dropout",
)

# Feedforward layers.
self._feedforward_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="feedforward_layer_norm",
)
self._feedforward_layer_norm.build(inputs_shape)
self._feedforward_intermediate_dense = keras.layers.Dense(
self.intermediate_dim,
activation=self.activation,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
dtype=self.dtype_policy,
name="feedforward_intermediate_dense",
)
self._feedforward_intermediate_dense.build(inputs_shape)
self._feedforward_output_dense = keras.layers.Dense(
hidden_dim,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
dtype=self.dtype_policy,
name="feedforward_output_dense",
)
intermediate_shape = list(inputs_shape)
intermediate_shape[-1] = self.intermediate_dim
self._feedforward_output_dense.build(tuple(intermediate_shape))
self._feedforward_dropout = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
name="feedforward_dropout",
)
self.built = True

Expand Down
22 changes: 22 additions & 0 deletions keras_nlp/tests/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def run_layer_test(
expected_num_non_trainable_weights=0,
expected_num_non_trainable_variables=0,
run_training_check=True,
run_mixed_precision_check=True,
):
# Serialization test.
layer = layer_cls(**init_kwargs)
Expand Down Expand Up @@ -168,6 +169,27 @@ def call(self, x):
if run_training_check:
run_training_step(layer, input_data, output_data)

# Never test mixed precision on torch CPU. Torch lacks support.
if run_mixed_precision_check and config.backend() == "torch":
import torch

run_mixed_precision_check = torch.cuda.is_available()

if run_mixed_precision_check:
layer = layer_cls(**{**init_kwargs, "dtype": "mixed_float16"})
if isinstance(input_data, dict):
output_data = layer(**input_data)
else:
output_data = layer(input_data)
for tensor in tree.flatten(output_data):
dtype = standardize_dtype(tensor.dtype)
if "float" in dtype:
self.assertEqual(dtype, "float16")
for weight in layer.weights:
dtype = standardize_dtype(weight.dtype)
if "float" in dtype:
self.assertEqual(dtype, "float32")

def run_class_serialization_test(self, instance):
# get_config roundtrip
cls = instance.__class__
Expand Down

0 comments on commit 44b1005

Please sign in to comment.