Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions tests/zero_code_change/test_tensorflow2_gradtape_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# Third Party
import pytest
import tensorflow.compat.v2 as tf
from tests.tensorflow2.utils import is_tf_2_2

# First Party
import smdebug.tensorflow as smd
Expand Down Expand Up @@ -80,7 +81,7 @@ def helper_test_keras_v2_gradienttape(script_mode: bool = False, json_file_conte
assert len(trial.tensor_names()) > 0, "Tensors were not saved."
assert len(trial.tensor_names(collection="losses")) > 0
else:
# ZCC doesn't support yet (as of smdebug v0.7.2)
# ZCC support added from smdebug v0.8.0)
for epoch in range(n_epochs):
print("Epoch %d/%d" % (epoch + 1, n_epochs))
for data, labels in dataset:
Expand All @@ -97,7 +98,16 @@ def helper_test_keras_v2_gradienttape(script_mode: bool = False, json_file_conte
print(log)
train_acc_metric.reset_states()
hook = smd.get_hook()
assert not hook
if not is_tf_2_2():
assert not hook # only supported on TF 2.2 and greater
return
assert hook
hook.close()
# Check that hook created and tensors saved
trial = smd.create_trial(path=sim.out_dir)
assert len(trial.steps()) > 0, "Nothing saved at any step."
assert len(trial.tensor_names()) > 0, "Tensors were not saved."
assert len(trial.tensor_names(collection="losses")) > 0


@pytest.mark.parametrize("script_mode", [False])
Expand Down