-
Notifications
You must be signed in to change notification settings - Fork 83
Support Inputs and Labels in the dict format #345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Third Party | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
# First Party | ||
import smdebug.tensorflow as smd | ||
from smdebug.core.collection import CollectionKeys | ||
from smdebug.trials import create_trial | ||
|
||
|
||
def get_data(): | ||
images = np.zeros((64, 224)) | ||
labels = np.zeros((64, 5)) | ||
inputs = {"Image_input": images} | ||
outputs = {"output-softmax": labels} | ||
return inputs, outputs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we have some test cases for valid and invalid dicts? is this applicable only to image inputs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are dicts the only other valid format? can inputs be nested lists or something similar? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this used only with keras fit or can it be used with gradtape too? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The flatten call made by smdebug is only after tf runs the same function on the feed_dict. Both fit and gradtape execute the same call stack There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
model.fit breaks if the customer passes an empty or invalid feed_dict. |
||
|
||
|
||
def create_hook(trial_dir): | ||
hook = smd.KerasHook(trial_dir, save_all=True) | ||
return hook | ||
|
||
|
||
def create_model(): | ||
input_layer = tf.keras.layers.Input(name="Image_input", shape=(224), dtype="float32") | ||
model = tf.keras.layers.Dense(5)(input_layer) | ||
model = tf.keras.layers.Activation("softmax", name="output-softmax")(model) | ||
model = tf.keras.models.Model(inputs=input_layer, outputs=[model]) | ||
return model | ||
|
||
|
||
def test_support_dicts(out_dir): | ||
model = create_model() | ||
optimizer = tf.keras.optimizers.Adadelta(lr=1.0, rho=0.95, epsilon=None, decay=0.0) | ||
model.compile(loss="categorical_crossentropy", optimizer=optimizer) | ||
inputs, labels = get_data() | ||
smdebug_hook = create_hook(out_dir) | ||
model.fit(inputs, labels, batch_size=16, epochs=10, callbacks=[smdebug_hook]) | ||
vandanavk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
model.save(out_dir, save_format="tf") | ||
trial = create_trial(out_dir) | ||
assert trial.tensor_names(collection=CollectionKeys.INPUTS) == ["model_input"] | ||
assert trial.tensor_names(collection=CollectionKeys.OUTPUTS) == ["labels", "predictions"] |
Uh oh!
There was an error while loading. Please reload this page.