diff --git a/tests/zero_code_change/test_tensorflow2_gradtape_integration.py b/tests/zero_code_change/test_tensorflow2_gradtape_integration.py index 414ec4587..e7cfcc42e 100644 --- a/tests/zero_code_change/test_tensorflow2_gradtape_integration.py +++ b/tests/zero_code_change/test_tensorflow2_gradtape_integration.py @@ -12,7 +12,8 @@ # Third Party import pytest import tensorflow.compat.v2 as tf -from tests.tensorflow2.utils import is_tf_2_2, is_tf_2_3 +from packaging import version +from tests.tensorflow2.utils import is_tf_2_2 # First Party import smdebug.tensorflow as smd @@ -104,8 +105,8 @@ def helper_test_keras_v2_gradienttape( print(log) train_acc_metric.reset_states() hook = smd.get_hook() - if not (is_tf_2_2() or is_tf_2_3()): - assert not hook # only supported on TF 2.2 and greater + if version.parse(tf.__version__) < version.parse("2.1.2"): + assert not hook # only supported on TF 2.1.2 and greater return assert hook hook.close()