diff --git a/smdebug/tensorflow/keras.py b/smdebug/tensorflow/keras.py index d34b645be..ccea3ba47 100644 --- a/smdebug/tensorflow/keras.py +++ b/smdebug/tensorflow/keras.py @@ -454,7 +454,7 @@ def save_smdebug_logs(self, logs): elif key == SMDEBUG_LAYER_OUTPUTS_KEY: layer_outputs = logs[key] self.save_layer_outputs(layer_outputs) - self.save_layer_inputs(logs[ModelInput.X], layer_outputs) + self.save_layer_inputs(logs[ModelInput.INPUTS], layer_outputs) # Save Model Inputs elif key in ModelInputs: export_name = get_model_input_export_name()