diff --git a/smdebug/tensorflow/utils.py b/smdebug/tensorflow/utils.py index 65d40db65..34a713629 100644 --- a/smdebug/tensorflow/utils.py +++ b/smdebug/tensorflow/utils.py @@ -329,7 +329,7 @@ def __init__(self): self.layer_input = None self.layer_output = None - def __call__(self, callable_inputs, *args, **kwargs) -> None: + def __call__(self, inputs, *args, **kwargs) -> None: self.layer_input = kwargs["layer_input"] self.layer_output = kwargs["layer_output"] @@ -337,11 +337,11 @@ def __call__(self, callable_inputs, *args, **kwargs) -> None: def get_layer_call_fn(layer: tf.keras.layers.Layer) -> Callable[[tf.Tensor], tf.Tensor]: old_call_fn = layer.call - def call(callable_inputs, *args, **kwargs) -> tf.Tensor: - layer_input = callable_inputs - layer_output = old_call_fn(callable_inputs) + def call(inputs, *args, **kwargs) -> tf.Tensor: + layer_input = inputs + layer_output = old_call_fn(inputs) for hook in layer._hooks: - hook_result = hook(callable_inputs, layer_input=layer_input, layer_output=layer_output) + hook_result = hook(inputs, layer_input=layer_input, layer_output=layer_output) if hook_result is not None: layer_output = hook_result return layer_output diff --git a/tests/tensorflow2/test_keras.py b/tests/tensorflow2/test_keras.py index 4323b5284..089cc2d9e 100644 --- a/tests/tensorflow2/test_keras.py +++ b/tests/tensorflow2/test_keras.py @@ -103,6 +103,8 @@ def helper_keras_fit( elif step == "predict": model.predict(x_test[:100], callbacks=hooks, verbose=0) + model.save(trial_dir, save_format="tf") + hook.close() @@ -180,6 +182,7 @@ def helper_keras_gradtape( ) train_acc_metric.reset_states() + model.save(trial_dir, save_format="tf") hook.close()