Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions smdebug/tensorflow/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tensorflow.compat.v1 as tf
from tensorflow.python.distribute import values
from tensorflow.python.framework.indexed_slices import IndexedSlices
from tensorflow.python.util import nest

# First Party
from smdebug.core.modes import ModeKeys
Expand Down Expand Up @@ -497,6 +498,10 @@ def save_smdebug_logs(self, logs):
else set()
)
for t_name, t_value in tensors_to_save:
if isinstance(t_value, dict):
# flatten the inputs and labels
# since we cannot convert dicts into numpy
t_value = nest.flatten(t_value)
self._save_tensor_to_file(t_name, t_value, collections_to_write)

def _save_metrics(self, batch, logs, force_save=False):
Expand Down
42 changes: 42 additions & 0 deletions tests/tensorflow2/test_support_dicts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Third Party
import numpy as np
import tensorflow as tf

# First Party
import smdebug.tensorflow as smd
from smdebug.core.collection import CollectionKeys
from smdebug.trials import create_trial


def get_data():
images = np.zeros((64, 224))
labels = np.zeros((64, 5))
inputs = {"Image_input": images}
outputs = {"output-softmax": labels}
return inputs, outputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have some test cases for valid and invalid dicts?
empty dicts,
either inputs/output empty (is this possible?)
are the values always valid? - is there a chance they cant be converted to numpy?

is this applicable only to image inputs?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are dicts the only other valid format? can inputs be nested lists or something similar?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this used only with keras fit or can it be used with gradtape too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The flatten call made by smdebug is only after tf runs the same function on the feed_dict.

Both fit and gradtape execute the same call stack

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have some test cases for valid and invalid dicts?
empty dicts,
either inputs/output empty (is this possible?)
are the values always valid? - is there a chance they cant be converted to numpy?

is this applicable only to image inputs?

model.fit breaks if the customer passes an empty or invalid feed_dict.



def create_hook(trial_dir):
hook = smd.KerasHook(trial_dir, save_all=True)
return hook


def create_model():
input_layer = tf.keras.layers.Input(name="Image_input", shape=(224), dtype="float32")
model = tf.keras.layers.Dense(5)(input_layer)
model = tf.keras.layers.Activation("softmax", name="output-softmax")(model)
model = tf.keras.models.Model(inputs=input_layer, outputs=[model])
return model


def test_support_dicts(out_dir):
model = create_model()
optimizer = tf.keras.optimizers.Adadelta(lr=1.0, rho=0.95, epsilon=None, decay=0.0)
model.compile(loss="categorical_crossentropy", optimizer=optimizer)
inputs, labels = get_data()
smdebug_hook = create_hook(out_dir)
model.fit(inputs, labels, batch_size=16, epochs=10, callbacks=[smdebug_hook])
model.save(out_dir, save_format="tf")
trial = create_trial(out_dir)
assert trial.tensor_names(collection=CollectionKeys.INPUTS) == ["model_input"]
assert trial.tensor_names(collection=CollectionKeys.OUTPUTS) == ["labels", "predictions"]