diff --git a/smdebug/tensorflow/keras.py b/smdebug/tensorflow/keras.py index 30ba008c9..f0426c1ed 100644 --- a/smdebug/tensorflow/keras.py +++ b/smdebug/tensorflow/keras.py @@ -33,6 +33,7 @@ is_keras_optimizer, is_tf_version_2_3_x, is_tf_version_2x, + supported_tf_variables, ) @@ -241,7 +242,7 @@ def _create_tensors_for_matching_collections( tensor_refs = [] for coll in colls_with_tensor: if not tensor_refs: - if isinstance(tensor, tf.Variable): + if isinstance(tensor, supported_tf_variables()): tensor_refs.append( coll.add_variable(tensor, export_name=export_name, mode=mode) ) @@ -469,7 +470,7 @@ def save_gradients_from_logs(self, gradients): if isinstance(v, tf.Tensor): # Tensor.name is meaningless with eager execution layer_name = str(v.numpy(), "utf-8") - elif isinstance(v, tf.Variable): + elif isinstance(v, supported_tf_variables()): layer_name = v.name elif isinstance(v, bytes): layer_name = str(v, "utf-8") @@ -785,7 +786,7 @@ def _save_layer_values(self, logs): if isinstance(layer_name, tf.Tensor): # Tensor.name is meaningless with eager execution layer_name = str(layer_name.numpy(), "utf-8") - elif isinstance(layer_name, tf.Variable): + elif isinstance(layer_name, supported_tf_variables()): layer_name = layer_name.name elif isinstance(layer_name, bytes): layer_name = str(layer_name, "utf-8") @@ -996,7 +997,12 @@ def run(*args, **kwargs): if ( (not grads or not vars) or (not isinstance(grads, list) or not isinstance(vars, list)) - or (not ((isinstance(vars[0], tf.Variable)) and hasattr(vars[0], "numpy"))) + or ( + not ( + (isinstance(vars[0], supported_tf_variables())) + and hasattr(vars[0], "numpy") + ) + ) or (not ((isinstance(grads[0], tf.Tensor)) and hasattr(grads[0], "numpy"))) ): return grads diff --git a/smdebug/tensorflow/tensor_ref.py b/smdebug/tensorflow/tensor_ref.py index 930cdcab6..d77e3c495 100644 --- a/smdebug/tensorflow/tensor_ref.py +++ b/smdebug/tensorflow/tensor_ref.py @@ -9,13 +9,13 @@ from smdebug.core.logger import get_logger # Local -from .utils import is_tf_version_2x +from .utils import is_tf_version_2x, supported_tf_variables logger = get_logger() def get_tf_names(arg): - if isinstance(arg, tf.Variable): + if isinstance(arg, supported_tf_variables()): tf_names = [arg.name] elif isinstance(arg, tf.Tensor): tf_names = [arg.name] @@ -101,7 +101,11 @@ def from_variable(cls, variable, export_name=None, mode=None, original_tensor=No # for mirrored variable value this will be the mirrored variable original_tensor = variable - if is_tf_version_2x() and tf.executing_eagerly() and isinstance(variable, tf.Variable): + if ( + is_tf_version_2x() + and tf.executing_eagerly() + and isinstance(variable, supported_tf_variables()) + ): # In TF 2.X eager mode, TF throws an error if you try to access a tensor's name. # We need to pass it in as a variable, not a tensor, to maintain the name. tf_obj = variable diff --git a/smdebug/tensorflow/utils.py b/smdebug/tensorflow/utils.py index 863a5f862..552a72fff 100644 --- a/smdebug/tensorflow/utils.py +++ b/smdebug/tensorflow/utils.py @@ -13,6 +13,21 @@ from smdebug.core.modes import ModeKeys +def does_tf_support_mixed_precision_training(): + # The Keras mixed precision API is first available in TensorFlow 2.1.0 + # See: https://www.tensorflow.org/guide/mixed_precision + return version.parse(tf.__version__) >= version.parse("2.1.0") + + +def supported_tf_variables(): + if does_tf_support_mixed_precision_training(): + from tensorflow.python.keras.mixed_precision.experimental import autocast_variable + + return tf.Variable, autocast_variable.AutoCastVariable + else: + return tf.Variable + + class ModelOutput: LABELS = "smdebug_y" PREDICTIONS = "smdebug_y_pred" diff --git a/tests/tensorflow2/test_mixed_precision_training.py b/tests/tensorflow2/test_mixed_precision_training.py new file mode 100644 index 000000000..ed6289f9c --- /dev/null +++ b/tests/tensorflow2/test_mixed_precision_training.py @@ -0,0 +1,69 @@ +# Third Party +import pytest +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import layers + +# First Party +import smdebug.tensorflow as smd +from smdebug.tensorflow.utils import does_tf_support_mixed_precision_training +from smdebug.trials import create_trial + +# Test Reference: https://github.com/tensorflow/docs/blob/master/site/en/guide/mixed_precision.ipynb + + +@pytest.mark.skipif( + does_tf_support_mixed_precision_training() is False, + reason="The Keras mixed precision API is first available in TensorFlow 2.1.0", +) +def test_mixed_precision_training(out_dir): + + from tensorflow.keras.mixed_precision import experimental as mixed_precision + + hook = smd.KerasHook(out_dir=out_dir, save_all=True) + policy = mixed_precision.Policy("mixed_float16") + mixed_precision.set_policy(policy) + + inputs = keras.Input(shape=(784,), name="digits") + if tf.config.list_physical_devices("GPU"): + # The model will run with 4096 units on a GPU + num_units = 4096 + else: + # Use fewer units on CPUs so the model finishes in a reasonable amount of time + # The model will run with 64 units on a CPU + num_units = 64 + dense1 = layers.Dense(num_units, activation="relu", name="dense_1") + x = dense1(inputs) + dense2 = layers.Dense(num_units, activation="relu", name="dense_2") + x = dense2(x) + + # CORRECT: softmax and model output are float32 + x = layers.Dense(10, name="dense_logits")(x) + outputs = layers.Activation("softmax", dtype="float32", name="predictions")(x) + + # The linear activation is an identity function. So this simply casts 'outputs' + # to float32. In this particular case, 'outputs' is already float32 so this is a + # no-op. + outputs = layers.Activation("linear", dtype="float32")(outputs) + + model = keras.Model(inputs=inputs, outputs=outputs) + model.compile( + loss="sparse_categorical_crossentropy", + optimizer=keras.optimizers.RMSprop(), + metrics=["accuracy"], + ) + + (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + x_train = x_train.reshape(60000, 784).astype("float32") / 255 + x_test = x_test.reshape(10000, 784).astype("float32") / 255 + + initial_weights = model.get_weights() + + hooks = [hook] + history = model.fit( + x_train, y_train, batch_size=8192, epochs=5, callbacks=hooks, validation_split=0.2 + ) + test_scores = model.evaluate(x_test, y_test, verbose=2) + + trial = create_trial(out_dir) + assert len(trial.tensor_names()) == 30