diff --git a/tests/constants.py b/tests/constants.py new file mode 100644 index 000000000..b7b95f3bf --- /dev/null +++ b/tests/constants.py @@ -0,0 +1 @@ +TEST_DATASET_S3_PATH = "s3://smdebug-testing/datasets/" diff --git a/tests/tensorflow/test_keras_to_estimator.py b/tests/tensorflow/test_keras_to_estimator.py index d171990bc..0122b34b9 100644 --- a/tests/tensorflow/test_keras_to_estimator.py +++ b/tests/tensorflow/test_keras_to_estimator.py @@ -1,6 +1,7 @@ # Third Party import tensorflow as tf import tensorflow_datasets as tfds +from tests.constants import TEST_DATASET_S3_PATH # First Party from smdebug.tensorflow import EstimatorHook, modes @@ -17,7 +18,7 @@ def test_keras_to_estimator(out_dir): def input_fn(): split = tfds.Split.TRAIN - dataset = tfds.load("iris", split=split, as_supervised=True) + dataset = tfds.load("iris", data_dir=TEST_DATASET_S3_PATH, split=split, as_supervised=True) dataset = dataset.map(lambda features, labels: ({"dense_input": features}, labels)) dataset = dataset.batch(32).repeat() return dataset diff --git a/tests/tensorflow2/test_keras.py b/tests/tensorflow2/test_keras.py index 3359826fb..d7fc42759 100644 --- a/tests/tensorflow2/test_keras.py +++ b/tests/tensorflow2/test_keras.py @@ -13,6 +13,7 @@ import pytest import tensorflow.compat.v2 as tf import tensorflow_datasets as tfds +from tests.constants import TEST_DATASET_S3_PATH from tests.tensorflow2.utils import is_tf_2_2 from tests.tensorflow.utils import create_trial_fast_refresh @@ -749,7 +750,7 @@ def test_keras_to_estimator(out_dir, tf_eager_mode): def input_fn(): split = tfds.Split.TRAIN - dataset = tfds.load("iris", split=split, as_supervised=True) + dataset = tfds.load("iris", data_dir=TEST_DATASET_S3_PATH, split=split, as_supervised=True) dataset = dataset.map(lambda features, labels: ({"dense_input": features}, labels)) dataset = dataset.batch(32).repeat() return dataset diff --git a/tests/zero_code_change/test_tensorflow_integration.py b/tests/zero_code_change/test_tensorflow_integration.py index d002a94e7..95a1df5f9 100644 --- a/tests/zero_code_change/test_tensorflow_integration.py +++ b/tests/zero_code_change/test_tensorflow_integration.py @@ -21,6 +21,7 @@ import pytest import tensorflow.compat.v1 as tf import tensorflow_datasets as tfds +from tests.constants import TEST_DATASET_S3_PATH from tests.tensorflow.hooks.test_mirrored_strategy import test_basic from tests.tensorflow.keras.test_keras_mirrored import test_tf_keras from tests.zero_code_change.tf_utils import ( @@ -421,7 +422,9 @@ def test_keras_to_estimator(script_mode): def input_fn(): split = tfds.Split.TRAIN - dataset = tfds.load("iris", split=split, as_supervised=True) + dataset = tfds.load( + "iris", data_dir=TEST_DATASET_S3_PATH, split=split, as_supervised=True + ) dataset = dataset.map(lambda features, labels: ({"dense_input": features}, labels)) dataset = dataset.batch(32).repeat() return dataset