Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions smdebug/tensorflow/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
is_keras_optimizer,
is_tf_version_2_3_x,
is_tf_version_2x,
supported_tf_variables,
)


Expand Down Expand Up @@ -241,7 +242,7 @@ def _create_tensors_for_matching_collections(
tensor_refs = []
for coll in colls_with_tensor:
if not tensor_refs:
if isinstance(tensor, tf.Variable):
if isinstance(tensor, supported_tf_variables()):
tensor_refs.append(
coll.add_variable(tensor, export_name=export_name, mode=mode)
)
Expand Down Expand Up @@ -469,7 +470,7 @@ def save_gradients_from_logs(self, gradients):
if isinstance(v, tf.Tensor):
# Tensor.name is meaningless with eager execution
layer_name = str(v.numpy(), "utf-8")
elif isinstance(v, tf.Variable):
elif isinstance(v, supported_tf_variables()):
layer_name = v.name
elif isinstance(v, bytes):
layer_name = str(v, "utf-8")
Expand Down Expand Up @@ -785,7 +786,7 @@ def _save_layer_values(self, logs):
if isinstance(layer_name, tf.Tensor):
# Tensor.name is meaningless with eager execution
layer_name = str(layer_name.numpy(), "utf-8")
elif isinstance(layer_name, tf.Variable):
elif isinstance(layer_name, supported_tf_variables()):
layer_name = layer_name.name
elif isinstance(layer_name, bytes):
layer_name = str(layer_name, "utf-8")
Expand Down Expand Up @@ -996,7 +997,12 @@ def run(*args, **kwargs):
if (
(not grads or not vars)
or (not isinstance(grads, list) or not isinstance(vars, list))
or (not ((isinstance(vars[0], tf.Variable)) and hasattr(vars[0], "numpy")))
or (
not (
(isinstance(vars[0], supported_tf_variables()))
and hasattr(vars[0], "numpy")
)
)
or (not ((isinstance(grads[0], tf.Tensor)) and hasattr(grads[0], "numpy")))
):
return grads
Expand Down
10 changes: 7 additions & 3 deletions smdebug/tensorflow/tensor_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from smdebug.core.logger import get_logger

# Local
from .utils import is_tf_version_2x
from .utils import is_tf_version_2x, supported_tf_variables

logger = get_logger()


def get_tf_names(arg):
if isinstance(arg, tf.Variable):
if isinstance(arg, supported_tf_variables()):
tf_names = [arg.name]
elif isinstance(arg, tf.Tensor):
tf_names = [arg.name]
Expand Down Expand Up @@ -101,7 +101,11 @@ def from_variable(cls, variable, export_name=None, mode=None, original_tensor=No
# for mirrored variable value this will be the mirrored variable
original_tensor = variable

if is_tf_version_2x() and tf.executing_eagerly() and isinstance(variable, tf.Variable):
if (
is_tf_version_2x()
and tf.executing_eagerly()
and isinstance(variable, supported_tf_variables())
):
# In TF 2.X eager mode, TF throws an error if you try to access a tensor's name.
# We need to pass it in as a variable, not a tensor, to maintain the name.
tf_obj = variable
Expand Down
15 changes: 15 additions & 0 deletions smdebug/tensorflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,21 @@
from smdebug.core.modes import ModeKeys


def does_tf_support_mixed_precision_training():
# The Keras mixed precision API is first available in TensorFlow 2.1.0
# See: https://www.tensorflow.org/guide/mixed_precision
return version.parse(tf.__version__) >= version.parse("2.1.0")


def supported_tf_variables():
if does_tf_support_mixed_precision_training():
from tensorflow.python.keras.mixed_precision.experimental import autocast_variable

return tf.Variable, autocast_variable.AutoCastVariable
else:
return tf.Variable


class ModelOutput:
LABELS = "smdebug_y"
PREDICTIONS = "smdebug_y_pred"
Expand Down
69 changes: 69 additions & 0 deletions tests/tensorflow2/test_mixed_precision_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Third Party
import pytest
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# First Party
import smdebug.tensorflow as smd
from smdebug.tensorflow.utils import does_tf_support_mixed_precision_training
from smdebug.trials import create_trial

# Test Reference: https://github.com/tensorflow/docs/blob/master/site/en/guide/mixed_precision.ipynb


@pytest.mark.skipif(
does_tf_support_mixed_precision_training() is False,
reason="The Keras mixed precision API is first available in TensorFlow 2.1.0",
)
def test_mixed_precision_training(out_dir):

from tensorflow.keras.mixed_precision import experimental as mixed_precision

hook = smd.KerasHook(out_dir=out_dir, save_all=True)
policy = mixed_precision.Policy("mixed_float16")
mixed_precision.set_policy(policy)

inputs = keras.Input(shape=(784,), name="digits")
if tf.config.list_physical_devices("GPU"):
# The model will run with 4096 units on a GPU
num_units = 4096
else:
# Use fewer units on CPUs so the model finishes in a reasonable amount of time
# The model will run with 64 units on a CPU
num_units = 64
dense1 = layers.Dense(num_units, activation="relu", name="dense_1")
x = dense1(inputs)
dense2 = layers.Dense(num_units, activation="relu", name="dense_2")
x = dense2(x)

# CORRECT: softmax and model output are float32
x = layers.Dense(10, name="dense_logits")(x)
outputs = layers.Activation("softmax", dtype="float32", name="predictions")(x)

# The linear activation is an identity function. So this simply casts 'outputs'
# to float32. In this particular case, 'outputs' is already float32 so this is a
# no-op.
outputs = layers.Activation("linear", dtype="float32")(outputs)

model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(
loss="sparse_categorical_crossentropy",
optimizer=keras.optimizers.RMSprop(),
metrics=["accuracy"],
)

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype("float32") / 255
x_test = x_test.reshape(10000, 784).astype("float32") / 255

initial_weights = model.get_weights()

hooks = [hook]
history = model.fit(
x_train, y_train, batch_size=8192, epochs=5, callbacks=hooks, validation_split=0.2
)
test_scores = model.evaluate(x_test, y_test, verbose=2)

trial = create_trial(out_dir)
assert len(trial.tensor_names()) == 30