Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# Third Party
import pytest
import tensorflow.compat.v2 as tf
from tests.tensorflow2.utils import is_tf_2_2, is_tf_2_3
from packaging import version
from tests.tensorflow2.utils import is_tf_2_2

# First Party
import smdebug.tensorflow as smd
Expand Down Expand Up @@ -104,8 +105,8 @@ def helper_test_keras_v2_gradienttape(
print(log)
train_acc_metric.reset_states()
hook = smd.get_hook()
if not (is_tf_2_2() or is_tf_2_3()):
assert not hook # only supported on TF 2.2 and greater
if version.parse(tf.__version__) < version.parse("2.1.2"):
assert not hook # only supported on TF 2.1.2 and greater
return
assert hook
hook.close()
Expand Down