Skip to content

Commit e29065b

Browse files
authored
Fixes: SMSimulator fix, listing files local should ignore tmp files (#137)
* SM simulator is failing in exit, just making removal of files optional * Fixing list of files in case of temp files to look for json|csv|tfevents . Was problematic becuase in local run it was picking up temp files written by debugger
1 parent 42cec6b commit e29065b

File tree

2 files changed

+32
-12
lines changed

2 files changed

+32
-12
lines changed

smdebug/core/index_reader.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,11 +351,17 @@ def _is_event_file_present(self, file):
351351

352352
def list_index_files(self):
353353
index_dirname = IndexFileLocationUtils.get_index_path(self.path)
354-
index_files = list_files_in_directory(index_dirname)
354+
# index files are json files or csv files ending with string ".csv" or ".json"
355+
index_files_regex = "(.+)\.(json|csv)$"
356+
index_files = list_files_in_directory(index_dirname, file_regex=index_files_regex)
355357
return sorted(index_files)
356358

357359
def list_event_files(self, start_after_key=None):
358-
event_files = list_files_in_directory(get_path_to_events_directory(self.path))
360+
# event files are ending with string ".tfevents"
361+
event_file_regex = "(.+)\.(tfevents)$"
362+
event_files = list_files_in_directory(
363+
get_path_to_events_directory(self.path), file_regex=event_file_regex
364+
)
359365
event_files.sort()
360366
start_after_index = bisect_left(event_files, start_after_key)
361367
return event_files[start_after_index:]

smdebug/core/utils.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,14 @@ def is_s3(path):
7272
return False, None, None
7373

7474

75-
def list_files_in_directory(directory):
75+
def list_files_in_directory(directory, file_regex=None):
7676
files = []
7777
for root, dir_name, filename in os.walk(directory):
7878
for f in filename:
79-
files.append(os.path.join(root, f))
79+
if file_regex is None:
80+
files.append(os.path.join(root, f))
81+
elif re.match(file_regex, f):
82+
files.append(os.path.join(root, f))
8083
return files
8184

8285

@@ -85,8 +88,7 @@ def list_collection_files_in_directory(directory):
8588
import re
8689

8790
collections_file_regex = re.compile(".*_?collections.json")
88-
files = [f for f in os.listdir(collections_directory) if re.match(collections_file_regex, f)]
89-
return files
91+
return list_files_in_directory(collections_directory, file_regex=collections_file_regex)
9092

9193

9294
def serialize_tf_device(device: str) -> str:
@@ -219,6 +221,11 @@ def get_tb_worker():
219221
return f"{os.getpid()}_{socket.gethostname()}"
220222

221223

224+
def remove_file_if_exists(file_path):
225+
if os.path.exists(file_path):
226+
os.remove(file_path)
227+
228+
222229
class SagemakerSimulator(object):
223230
"""
224231
Creates an environment variable pointing to a JSON config file, and creates the config file.
@@ -233,16 +240,19 @@ def __init__(
233240
tensorboard_dir="/tmp/tensorboard",
234241
training_job_name="sm_job",
235242
json_file_contents="{}",
243+
cleanup=True,
236244
):
237245
self.out_dir = DEFAULT_SAGEMAKER_OUTDIR
238246
self.json_config_path = json_config_path
239247
self.tb_json_config_path = DEFAULT_SAGEMAKER_TENSORBOARD_PATH
240248
self.tensorboard_dir = tensorboard_dir
241249
self.training_job_name = training_job_name
242250
self.json_file_contents = json_file_contents
251+
self.cleanup = cleanup
243252

244253
def __enter__(self):
245-
shutil.rmtree(self.out_dir, ignore_errors=True)
254+
if self.cleanup is True:
255+
shutil.rmtree(self.out_dir, ignore_errors=True)
246256
shutil.rmtree(self.json_config_path, ignore_errors=True)
247257
tb_parent_dir = str(Path(self.tb_json_config_path).parent)
248258
shutil.rmtree(tb_parent_dir, ignore_errors=True)
@@ -269,11 +279,15 @@ def __enter__(self):
269279
def __exit__(self, *args):
270280
# Throws errors when the writers try to close.
271281
# shutil.rmtree(self.out_dir, ignore_errors=True)
272-
os.remove(self.json_config_path)
273-
os.remove(self.tb_json_config_path)
274-
del os.environ[CONFIG_FILE_PATH_ENV_STR]
275-
del os.environ["TRAINING_JOB_NAME"]
276-
del os.environ[TENSORBOARD_CONFIG_FILE_PATH_ENV_STR]
282+
if self.cleanup is True:
283+
remove_file_if_exists(self.json_config_path)
284+
remove_file_if_exists(self.tb_json_config_path)
285+
if CONFIG_FILE_PATH_ENV_STR in os.environ:
286+
del os.environ[CONFIG_FILE_PATH_ENV_STR]
287+
if "TRAINING_JOB_NAME" in os.environ:
288+
del os.environ["TRAINING_JOB_NAME"]
289+
if TENSORBOARD_CONFIG_FILE_PATH_ENV_STR in os.environ:
290+
del os.environ[TENSORBOARD_CONFIG_FILE_PATH_ENV_STR]
277291

278292

279293
class ScriptSimulator(object):

0 commit comments

Comments
 (0)