diff --git a/smdebug/tensorflow/keras.py b/smdebug/tensorflow/keras.py index 99317c91f..090b1676f 100644 --- a/smdebug/tensorflow/keras.py +++ b/smdebug/tensorflow/keras.py @@ -777,6 +777,13 @@ def _save_layer_values(self, logs): for layer_name, layer_input, layer_output in logs: # Cast layer_name to str since it can also be of type bytes # when run with mirrored strategy + if isinstance(layer_name, tf.Tensor): + # Tensor.name is meaningless with eager execution + layer_name = str(layer_name.numpy(), "utf-8") + elif isinstance(layer_name, tf.Variable): + layer_name = layer_name.name + elif isinstance(layer_name, bytes): + layer_name = str(layer_name, "utf-8") if len(layer_input) == 1: # Layer Inputs are flattened and passed as a list into # the next layer. Unpacking it speeds up the _make_numpy fn. diff --git a/tests/tensorflow2/test_keras.py b/tests/tensorflow2/test_keras.py index 0c9da2d89..7228b7ec9 100644 --- a/tests/tensorflow2/test_keras.py +++ b/tests/tensorflow2/test_keras.py @@ -7,6 +7,7 @@ `python tests/tensorflow2/test_keras.py` from the main directory. """ # Standard Library +import re import time # Third Party @@ -29,6 +30,19 @@ from smdebug.tensorflow import ReductionConfig, SaveConfig +def get_model(): + model = tf.keras.models.Sequential( + [ + # WA for TF issue https://github.com/tensorflow/tensorflow/issues/36279 + tf.keras.layers.Flatten(input_shape=(28, 28, 1)), + tf.keras.layers.Dense(128, activation="relu"), + tf.keras.layers.Dropout(0.2), + tf.keras.layers.Dense(10, activation="softmax"), + ] + ) + return model + + def helper_keras_fit( trial_dir, save_all=False, @@ -48,15 +62,7 @@ def helper_keras_fit( (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255, x_test / 255 - model = tf.keras.models.Sequential( - [ - tf.keras.layers.Flatten(input_shape=(28, 28)), - tf.keras.layers.Dense(128, activation="relu"), - tf.keras.layers.Dropout(0.2), - tf.keras.layers.Dense(10, activation="softmax"), - ] - ) - + model = get_model() if hook is None: if save_config is None: save_config = SaveConfig(save_interval=3) @@ -124,17 +130,7 @@ def helper_keras_gradtape( (tf.cast(x_train[..., tf.newaxis] / 255, tf.float32), tf.cast(y_train, tf.int64)) ) dataset = dataset.shuffle(1000).batch(batch_size) - - model = tf.keras.models.Sequential( - [ - # WA for TF issue https://github.com/tensorflow/tensorflow/issues/36279 - tf.keras.layers.Flatten(input_shape=(28, 28, 1)), - tf.keras.layers.Dense(128, activation="relu"), - tf.keras.layers.Dropout(0.2), - tf.keras.layers.Dense(10, activation="softmax"), - ] - ) - + model = get_model() if hook is None: if save_config is None: save_config = SaveConfig(save_interval=3) @@ -186,6 +182,24 @@ def helper_keras_gradtape( hook.close() +@pytest.mark.skip_if_non_eager +@pytest.mark.slow +def test_layer_names_gradient_tape(out_dir): + hook = smd.KerasHook( + out_dir, + save_config=SaveConfig(save_interval=9), + include_collections=[CollectionKeys.LAYERS], + ) + helper_keras_gradtape(out_dir, hook=hook, save_config=SaveConfig(save_interval=9)) + + tr = create_trial_fast_refresh(out_dir) + tnames = tr.tensor_names(collection=CollectionKeys.LAYERS) + pattern = r"^(flatten|dense|dropout)(_\d+)?\/(inputs|outputs)" + for tname in tnames: + assert re.match(pattern=pattern, string=tname) is not None + + +@pytest.mark.skip_if_non_eager def test_keras_gradtape_shapes(out_dir): hook = smd.KerasHook( out_dir=out_dir, @@ -549,6 +563,28 @@ def test_include_regex(out_dir, tf_eager_mode): assert tr.tensor(tname).value(0) is not None +@pytest.mark.slow +def test_layer_names(out_dir, tf_eager_mode): + hook = smd.KerasHook( + out_dir, + save_config=SaveConfig(save_interval=9), + include_collections=[CollectionKeys.LAYERS], + ) + helper_keras_fit( + out_dir, + hook=hook, + save_config=SaveConfig(save_interval=9), + steps=["train"], + run_eagerly=tf_eager_mode, + ) + + tr = create_trial_fast_refresh(out_dir) + tnames = tr.tensor_names(collection=CollectionKeys.LAYERS) + pattern = r"^(flatten|dense|dropout)(_\d+)?\/(inputs|outputs)" + for tname in tnames: + assert re.match(pattern=pattern, string=tname) is not None + + @pytest.mark.skip_if_non_eager @pytest.mark.slow def test_clash_with_tb_callback(out_dir):