Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
TEST_DATASET_S3_PATH = "s3://smdebug-testing/datasets/"
3 changes: 2 additions & 1 deletion tests/tensorflow/test_keras_to_estimator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Third Party
import tensorflow as tf
import tensorflow_datasets as tfds
from tests.constants import TEST_DATASET_S3_PATH

# First Party
from smdebug.tensorflow import EstimatorHook, modes
Expand All @@ -17,7 +18,7 @@ def test_keras_to_estimator(out_dir):

def input_fn():
split = tfds.Split.TRAIN
dataset = tfds.load("iris", split=split, as_supervised=True)
dataset = tfds.load("iris", data_dir=TEST_DATASET_S3_PATH, split=split, as_supervised=True)
dataset = dataset.map(lambda features, labels: ({"dense_input": features}, labels))
dataset = dataset.batch(32).repeat()
return dataset
Expand Down
3 changes: 2 additions & 1 deletion tests/tensorflow2/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pytest
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from tests.constants import TEST_DATASET_S3_PATH
from tests.tensorflow2.utils import is_tf_2_2
from tests.tensorflow.utils import create_trial_fast_refresh

Expand Down Expand Up @@ -749,7 +750,7 @@ def test_keras_to_estimator(out_dir, tf_eager_mode):

def input_fn():
split = tfds.Split.TRAIN
dataset = tfds.load("iris", split=split, as_supervised=True)
dataset = tfds.load("iris", data_dir=TEST_DATASET_S3_PATH, split=split, as_supervised=True)
dataset = dataset.map(lambda features, labels: ({"dense_input": features}, labels))
dataset = dataset.batch(32).repeat()
return dataset
Expand Down
5 changes: 4 additions & 1 deletion tests/zero_code_change/test_tensorflow_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytest
import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds
from tests.constants import TEST_DATASET_S3_PATH
from tests.tensorflow.hooks.test_mirrored_strategy import test_basic
from tests.tensorflow.keras.test_keras_mirrored import test_tf_keras
from tests.zero_code_change.tf_utils import (
Expand Down Expand Up @@ -421,7 +422,9 @@ def test_keras_to_estimator(script_mode):

def input_fn():
split = tfds.Split.TRAIN
dataset = tfds.load("iris", split=split, as_supervised=True)
dataset = tfds.load(
"iris", data_dir=TEST_DATASET_S3_PATH, split=split, as_supervised=True
)
dataset = dataset.map(lambda features, labels: ({"dense_input": features}, labels))
dataset = dataset.batch(32).repeat()
return dataset
Expand Down