Skip to content

Commit 8401dbb

Browse files
committed
revert update
1 parent 4cacbc0 commit 8401dbb

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

smdebug/tensorflow/keras.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -964,12 +964,16 @@ def _save_layer_values(self, logs):
964964
layer_input_tensor_name_with_idx, layer_input_tensor, collections_to_write
965965
)
966966
layer_output_tensor_name = get_export_name_for_keras(str(layer_name), "output")
967-
if not isinstance(layer_output, list):
968-
layer_output = [layer_output]
969-
for idx, l_output in enumerate(layer_output):
970-
layer_output_tensor_name_with_idx = f"{layer_output_tensor_name}_{idx}"
967+
if isinstance(layer_output, list):
968+
for idx, l_output in enumerate(layer_output):
969+
layer_output_tensor_name_with_idx = f"{layer_output_tensor_name}_{idx}"
970+
self._save_tensor_to_file(
971+
layer_output_tensor_name_with_idx, l_output, collections_to_write
972+
)
973+
else:
974+
layer_output_tensor_name_with_idx = f"{layer_output_tensor_name}_{0}"
971975
self._save_tensor_to_file(
972-
layer_output_tensor_name_with_idx, l_output, collections_to_write
976+
layer_output_tensor_name_with_idx, layer_output, collections_to_write
973977
)
974978

975979
def _write_optimizer_variables(self):

0 commit comments

Comments
 (0)