Skip to content
7 changes: 7 additions & 0 deletions smdebug/tensorflow/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,13 @@ def _save_layer_values(self, logs):
for layer_name, layer_input, layer_output in logs:
# Cast layer_name to str since it can also be of type bytes
# when run with mirrored strategy
if isinstance(layer_name, tf.Tensor):
# Tensor.name is meaningless with eager execution
layer_name = str(layer_name.numpy(), "utf-8")
elif isinstance(layer_name, tf.Variable):
layer_name = layer_name.name
elif isinstance(layer_name, bytes):
layer_name = str(layer_name, "utf-8")
if len(layer_input) == 1:
# Layer Inputs are flattened and passed as a list into
# the next layer. Unpacking it speeds up the _make_numpy fn.
Expand Down
76 changes: 56 additions & 20 deletions tests/tensorflow2/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
`python tests/tensorflow2/test_keras.py` from the main directory.
"""
# Standard Library
import re
import time

# Third Party
Expand All @@ -29,6 +30,19 @@
from smdebug.tensorflow import ReductionConfig, SaveConfig


def get_model():
model = tf.keras.models.Sequential(
[
# WA for TF issue https://github.com/tensorflow/tensorflow/issues/36279
tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation="softmax"),
]
)
return model


def helper_keras_fit(
trial_dir,
save_all=False,
Expand All @@ -48,15 +62,7 @@ def helper_keras_fit(
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255, x_test / 255

model = tf.keras.models.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation="softmax"),
]
)

model = get_model()
if hook is None:
if save_config is None:
save_config = SaveConfig(save_interval=3)
Expand Down Expand Up @@ -124,17 +130,7 @@ def helper_keras_gradtape(
(tf.cast(x_train[..., tf.newaxis] / 255, tf.float32), tf.cast(y_train, tf.int64))
)
dataset = dataset.shuffle(1000).batch(batch_size)

model = tf.keras.models.Sequential(
[
# WA for TF issue https://github.com/tensorflow/tensorflow/issues/36279
tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation="softmax"),
]
)

model = get_model()
if hook is None:
if save_config is None:
save_config = SaveConfig(save_interval=3)
Expand Down Expand Up @@ -186,6 +182,24 @@ def helper_keras_gradtape(
hook.close()


@pytest.mark.skip_if_non_eager
@pytest.mark.slow
def test_layer_names_gradient_tape(out_dir):
hook = smd.KerasHook(
out_dir,
save_config=SaveConfig(save_interval=9),
include_collections=[CollectionKeys.LAYERS],
)
helper_keras_gradtape(out_dir, hook=hook, save_config=SaveConfig(save_interval=9))

tr = create_trial_fast_refresh(out_dir)
tnames = tr.tensor_names(collection=CollectionKeys.LAYERS)
pattern = r"^(flatten|dense|dropout)(_\d+)?\/(inputs|outputs)"
for tname in tnames:
assert re.match(pattern=pattern, string=tname) is not None


@pytest.mark.skip_if_non_eager
def test_keras_gradtape_shapes(out_dir):
hook = smd.KerasHook(
out_dir=out_dir,
Expand Down Expand Up @@ -549,6 +563,28 @@ def test_include_regex(out_dir, tf_eager_mode):
assert tr.tensor(tname).value(0) is not None


@pytest.mark.slow
def test_layer_names(out_dir, tf_eager_mode):
hook = smd.KerasHook(
out_dir,
save_config=SaveConfig(save_interval=9),
include_collections=[CollectionKeys.LAYERS],
)
helper_keras_fit(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a gradtape test too?

out_dir,
hook=hook,
save_config=SaveConfig(save_interval=9),
steps=["train"],
run_eagerly=tf_eager_mode,
)

tr = create_trial_fast_refresh(out_dir)
tnames = tr.tensor_names(collection=CollectionKeys.LAYERS)
pattern = r"^(flatten|dense|dropout)(_\d+)?\/(inputs|outputs)"
for tname in tnames:
assert re.match(pattern=pattern, string=tname) is not None


@pytest.mark.skip_if_non_eager
@pytest.mark.slow
def test_clash_with_tb_callback(out_dir):
Expand Down