diff --git a/tests/zero_code_change/test_tensorflow2_integration.py b/tests/zero_code_change/test_tensorflow2_integration.py index cb3cc7ddb..703a9abcf 100644 --- a/tests/zero_code_change/test_tensorflow2_integration.py +++ b/tests/zero_code_change/test_tensorflow2_integration.py @@ -55,7 +55,9 @@ def helper_test_keras_v2(script_mode: bool = False, eager_mode: bool = True): # v1 training APIs are currently not supported # in ZCC mode with smdebug 0.9 and AWS TF 2.3.0 tf.compat.v1.disable_eager_execution() - enable_tb = False if tf.__version__ == "2.0.2" else True + + # Performance regression in the _make_histogram fn + enable_tb = False if tf.__version__ == "2.0.2" or is_tf_2_3() else True with SagemakerSimulator(enable_tb=enable_tb) as sim: model = get_keras_model_v2() (x_train, y_train), (x_test, y_test) = get_keras_data() @@ -118,7 +120,10 @@ def helper_test_keras_v2_json_config( # v1 training APIs are currently not supported # in ZCC mode with smdebug 0.9 and AWS TF 2.3.0 tf.compat.v1.disable_eager_execution() - enable_tb = False if tf.__version__ == "2.0.2" else True + + # Performance regression in the _make_histogram fn + enable_tb = False if tf.__version__ == "2.0.2" or is_tf_2_3() else True + with SagemakerSimulator(json_file_contents=json_file_contents, enable_tb=enable_tb) as sim: model = get_keras_model_v2() (x_train, y_train), (x_test, y_test) = get_keras_data()