Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions smdebug/tensorflow/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,11 +556,13 @@ def _save_layer_input_and_outputs(self):
if self._is_collection_being_saved_for_step(CollectionKeys.LAYERS)
else set()
)
if hasattr(tensor, "numpy"):
self._save_tensor_to_file(export_name, tensor.numpy(), input_collection)
else:
t = tensor[0] if isinstance(tensor, list) and len(tensor) else tensor
if hasattr(t, "numpy") is False:
self.logger.warning("cannot save layer values during forward pass with tf.function")
continue
else:
self._save_tensor_to_file(export_name, tensor, input_collection)

# Save Output
tensor = self.saved_layers[layer_name].layer_output
export_name = get_export_name_for_keras(layer_name, tensor_type="output", tensor=tensor)
Expand All @@ -570,8 +572,11 @@ def _save_layer_input_and_outputs(self):
if self._is_collection_being_saved_for_step(CollectionKeys.LAYERS)
else set()
)
if hasattr(tensor, "numpy"):
self._save_tensor_to_file(export_name, tensor.numpy(), output_collection)
t = tensor[0] if isinstance(tensor, list) and len(tensor) else tensor
if hasattr(t, "numpy") is False:
self.logger.warning("cannot save layer values during forward pass with tf.function")
else:
self._save_tensor_to_file(export_name, tensor, output_collection)

def _save_tensors_post_step(self, batch, logs):
# some tensors available as value from within hook are saved here
Expand Down
38 changes: 38 additions & 0 deletions tests/tensorflow2/test_concat_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Third Party
import numpy as np
from tensorflow.keras.layers import Concatenate, Dense
from tensorflow.python.keras.models import Model

# First Party
import smdebug.tensorflow as smd
from smdebug.trials import create_trial


class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.con = Concatenate()
self.dense = Dense(10, activation="relu")

def call(self, x):
x = self.con([x, x])
return self.dense(x)


def test_multiple_inputs(out_dir):
my_model = MyModel()
hook = smd.KerasHook(
out_dir, save_all=True, save_config=smd.SaveConfig(save_steps=[0], save_interval=1)
)

hook.register_model(my_model)
x_train = np.random.random((1000, 20))
y_train = np.random.random((1000, 1))
my_model.compile(optimizer="Adam", loss="mse", run_eagerly=True)
my_model.fit(x_train, y_train, epochs=1, steps_per_epoch=1, callbacks=[hook])

trial = create_trial(path=out_dir)
tnames = sorted(trial.tensor_names(collection=smd.CollectionKeys.LAYERS))
assert "concatenate" in tnames[0]
assert len(trial.tensor(tnames[0]).value(0)) == 2
assert trial.tensor(tnames[0]).shape(0) == (2, 1000, 20)