-
Notifications
You must be signed in to change notification settings - Fork 83
Support TF 2.3 Tests #312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support TF 2.3 Tests #312
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
# Third Party | ||
import pytest | ||
import tensorflow.compat.v2 as tf | ||
from tests.tensorflow2.utils import is_tf_2_3 | ||
from tests.utils import SagemakerSimulator | ||
|
||
# First Party | ||
|
@@ -50,28 +51,40 @@ def helper_test_keras_v2(script_mode: bool = False, eager_mode: bool = True): | |
""" Test the default ZCC behavior of saving losses and metrics in eager and non-eager modes.""" | ||
smd.del_hook() | ||
tf.keras.backend.clear_session() | ||
if not eager_mode: | ||
if not eager_mode and is_tf_2_3() is False: | ||
# v1 training APIs are currently not supported | ||
# in ZCC mode with smdebug 0.9 and AWS TF 2.3.0 | ||
tf.compat.v1.disable_eager_execution() | ||
enable_tb = False if tf.__version__ == "2.0.2" else True | ||
with SagemakerSimulator(enable_tb=enable_tb) as sim: | ||
model = get_keras_model_v2() | ||
(x_train, y_train), (x_test, y_test) = get_keras_data() | ||
x_train, x_test = x_train / 255, x_test / 255 | ||
run_eagerly = None | ||
if is_tf_2_3(): | ||
# Test eager and non eager mode for v2 | ||
run_eagerly = eager_mode | ||
|
||
opt = tf.keras.optimizers.RMSprop() | ||
if script_mode: | ||
hook = smd.KerasHook(out_dir=sim.out_dir, export_tensorboard=True) | ||
opt = hook.wrap_optimizer(opt) | ||
model.compile( | ||
loss="sparse_categorical_crossentropy", optimizer=opt, metrics=["accuracy"] | ||
loss="sparse_categorical_crossentropy", | ||
optimizer=opt, | ||
metrics=["accuracy"], | ||
run_eagerly=run_eagerly, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn't this setting required in the else part below on L84? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Copied the changes to L84 |
||
) | ||
history = model.fit( | ||
x_train, y_train, batch_size=64, epochs=2, validation_split=0.2, callbacks=[hook] | ||
) | ||
test_scores = model.evaluate(x_test, y_test, verbose=2, callbacks=[hook]) | ||
else: | ||
model.compile( | ||
loss="sparse_categorical_crossentropy", optimizer=opt, metrics=["accuracy"] | ||
loss="sparse_categorical_crossentropy", | ||
optimizer=opt, | ||
metrics=["accuracy"], | ||
run_eagerly=run_eagerly, | ||
) | ||
history = model.fit(x_train, y_train, batch_size=64, epochs=2, validation_split=0.2) | ||
test_scores = model.evaluate(x_test, y_test, verbose=2) | ||
|
@@ -101,7 +114,9 @@ def helper_test_keras_v2_json_config( | |
""" Tests ZCC with custom hook configs """ | ||
smd.del_hook() | ||
tf.keras.backend.clear_session() | ||
if not eager_mode: | ||
if not eager_mode and is_tf_2_3() is False: | ||
# v1 training APIs are currently not supported | ||
# in ZCC mode with smdebug 0.9 and AWS TF 2.3.0 | ||
tf.compat.v1.disable_eager_execution() | ||
enable_tb = False if tf.__version__ == "2.0.2" else True | ||
with SagemakerSimulator(json_file_contents=json_file_contents, enable_tb=enable_tb) as sim: | ||
|
@@ -110,19 +125,29 @@ def helper_test_keras_v2_json_config( | |
x_train, x_test = x_train / 255, x_test / 255 | ||
|
||
opt = tf.keras.optimizers.RMSprop() | ||
run_eagerly = None | ||
if is_tf_2_3(): | ||
# Test eager and non eager mode for v2 | ||
run_eagerly = eager_mode | ||
if script_mode: | ||
hook = smd.KerasHook.create_from_json_file() | ||
opt = hook.wrap_optimizer(opt) | ||
model.compile( | ||
loss="sparse_categorical_crossentropy", optimizer=opt, metrics=["accuracy"] | ||
loss="sparse_categorical_crossentropy", | ||
optimizer=opt, | ||
metrics=["accuracy"], | ||
run_eagerly=run_eagerly, | ||
) | ||
history = model.fit( | ||
x_train, y_train, batch_size=64, epochs=2, validation_split=0.2, callbacks=[hook] | ||
) | ||
test_scores = model.evaluate(x_test, y_test, verbose=2, callbacks=[hook]) | ||
else: | ||
model.compile( | ||
loss="sparse_categorical_crossentropy", optimizer=opt, metrics=["accuracy"] | ||
loss="sparse_categorical_crossentropy", | ||
optimizer=opt, | ||
metrics=["accuracy"], | ||
run_eagerly=run_eagerly, | ||
) | ||
history = model.fit(x_train, y_train, epochs=2, batch_size=64, validation_split=0.2) | ||
test_scores = model.evaluate(x_test, y_test, verbose=2) | ||
|
@@ -134,7 +159,9 @@ def helper_test_keras_v2_json_config( | |
trial = smd.create_trial(path=sim.out_dir) | ||
assert len(trial.steps()) > 0, "Nothing saved at any step." | ||
assert len(trial.tensor_names()) > 0, "Tensors were not saved." | ||
if not eager_mode: | ||
if not eager_mode and is_tf_2_3() is False: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what about when its not eager and its TF 2.3? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If |
||
# Gradients are currently not saved in ZCC mode with AWS TF 2.3.0 | ||
# and smdebug 0.9 | ||
assert len(trial.tensor_names(collection="gradients")) > 0 | ||
assert len(trial.tensor_names(collection="weights")) > 0 | ||
assert len(trial.tensor_names(collection="losses")) > 0 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this required in the other unit tests as well or something unique to ZCC?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't required in other unit tests because they run in script mode.
zero_code_change tests need changes to AWS TF, which currently supports basic ZCC functionality