diff --git a/smdebug/tensorflow/keras.py b/smdebug/tensorflow/keras.py index b9855a29b..67d325d87 100644 --- a/smdebug/tensorflow/keras.py +++ b/smdebug/tensorflow/keras.py @@ -119,8 +119,7 @@ def register_model(self, model): # It attaches a hook to every layer of the model to capture # layer values self.model = model - if self.tape is not None: - self._wrap_model_with_input_output_saver() + self._wrap_model_with_input_output_saver() self.has_registered_model = True def _get_matching_collections( @@ -527,8 +526,7 @@ def _save_metrics(self, batch, logs, force_save=False): self._save_for_tensor(key, logs[key], check_before_write=False) def _save_layer_input_and_outputs(self): - # Run only for GradTape - if self.tape is None: + if is_tf_version_2x() is False: return for layer_name in self.saved_layers: # Save Input @@ -542,8 +540,8 @@ def _save_layer_input_and_outputs(self): if hasattr(tensor, "numpy"): self._save_tensor_to_file(export_name, tensor.numpy(), input_collection) else: - self.logger.warn("cannot save layer values during forward pass with tf.function") - return + self.logger.warning("cannot save layer values during forward pass with tf.function") + continue # Save Output tensor = self.saved_layers[layer_name].layer_output export_name = get_export_name_for_keras(layer_name, tensor_type="output", tensor=tensor) @@ -562,6 +560,7 @@ def _save_tensors_post_step(self, batch, logs): self._save_metrics(batch, logs) self.save_smdebug_logs(logs) self._save_custom_tensors_post_step() + self._save_layer_input_and_outputs() if is_tf_version_2x() and tf.executing_eagerly(): for tensor_ref in self.tensor_refs_to_save_this_step: diff --git a/tests/tensorflow2/test_model_subclassing.py b/tests/tensorflow2/test_model_subclassing.py new file mode 100644 index 000000000..d90b931e2 --- /dev/null +++ b/tests/tensorflow2/test_model_subclassing.py @@ -0,0 +1,75 @@ +# Third Party +import tensorflow as tf +from tensorflow.keras.layers import BatchNormalization, Conv2D, Dense, Flatten +from tensorflow.keras.models import Model + +# First Party +import smdebug.tensorflow as smd + + +class MyModel(Model): + def __init__(self): + super().__init__() + self.conv1 = Conv2D( + 32, 3, activation="relu", kernel_initializer=tf.keras.initializers.GlorotNormal(seed=12) + ) + self.conv0 = Conv2D( + 32, 3, activation="relu", kernel_initializer=tf.keras.initializers.GlorotNormal(seed=12) + ) + self.flatten = Flatten() + self.d1 = Dense( + 128, activation="relu", kernel_initializer=tf.keras.initializers.GlorotNormal(seed=192) + ) + self.d2 = Dense(10, kernel_initializer=tf.keras.initializers.GlorotNormal(seed=126)) + self.bn = BatchNormalization() + + def first(self, x): + with tf.name_scope("first"): + tf.print("mymodel.first") + x = self.conv1(x) + # x = self.bn(x) + return self.flatten(x) + + def second(self, x): + with tf.name_scope("second"): + x = self.d1(x) + return self.d2(x) + + def call(self, x, training=None): + x = self.first(x) + return self.second(x) + + +def test_subclassed_model(out_dir): + # Download and load MNIST dataset. + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data("MNIST-data") + x_train, x_test = x_train / 255.0, x_test / 255.0 + + # Add a channels dimension + x_train = x_train[..., tf.newaxis] + x_test = x_test[..., tf.newaxis] + + # Create an instance of the model + model = MyModel() + + train_ds = ( + tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000, seed=123).batch(2) + ) + + MyModel.hook = smd.KerasHook( + out_dir, + save_all=True, + save_config=smd.SaveConfig(save_steps=[x for x in range(10)], save_interval=1), + ) + + MyModel.hook.register_model(model) + model.compile(optimizer="Adam", loss="mse", run_eagerly=True) + model.fit(train_ds, epochs=1, steps_per_epoch=10, callbacks=[MyModel.hook]) + + trial = smd.create_trial(out_dir) + assert len(trial.tensor_names(collection=smd.CollectionKeys.LAYERS)) == 8 + + assert trial.tensor_names(collection=smd.CollectionKeys.INPUTS) == ["model_input"] + assert trial.tensor_names(collection=smd.CollectionKeys.OUTPUTS) == ["labels", "predictions"] + assert trial.tensor_names(collection=smd.CollectionKeys.LOSSES) == ["loss"] + assert len(trial.tensor_names(collection=smd.CollectionKeys.GRADIENTS)) == 6