@@ -72,11 +72,14 @@ def is_s3(path):
7272 return False , None , None
7373
7474
75- def list_files_in_directory (directory ):
75+ def list_files_in_directory (directory , file_regex = None ):
7676 files = []
7777 for root , dir_name , filename in os .walk (directory ):
7878 for f in filename :
79- files .append (os .path .join (root , f ))
79+ if file_regex is None :
80+ files .append (os .path .join (root , f ))
81+ elif re .match (file_regex , f ):
82+ files .append (os .path .join (root , f ))
8083 return files
8184
8285
@@ -85,8 +88,7 @@ def list_collection_files_in_directory(directory):
8588 import re
8689
8790 collections_file_regex = re .compile (".*_?collections.json" )
88- files = [f for f in os .listdir (collections_directory ) if re .match (collections_file_regex , f )]
89- return files
91+ return list_files_in_directory (collections_directory , file_regex = collections_file_regex )
9092
9193
9294def serialize_tf_device (device : str ) -> str :
@@ -219,6 +221,11 @@ def get_tb_worker():
219221 return f"{ os .getpid ()} _{ socket .gethostname ()} "
220222
221223
224+ def remove_file_if_exists (file_path ):
225+ if os .path .exists (file_path ):
226+ os .remove (file_path )
227+
228+
222229class SagemakerSimulator (object ):
223230 """
224231 Creates an environment variable pointing to a JSON config file, and creates the config file.
@@ -233,16 +240,19 @@ def __init__(
233240 tensorboard_dir = "/tmp/tensorboard" ,
234241 training_job_name = "sm_job" ,
235242 json_file_contents = "{}" ,
243+ cleanup = True ,
236244 ):
237245 self .out_dir = DEFAULT_SAGEMAKER_OUTDIR
238246 self .json_config_path = json_config_path
239247 self .tb_json_config_path = DEFAULT_SAGEMAKER_TENSORBOARD_PATH
240248 self .tensorboard_dir = tensorboard_dir
241249 self .training_job_name = training_job_name
242250 self .json_file_contents = json_file_contents
251+ self .cleanup = cleanup
243252
244253 def __enter__ (self ):
245- shutil .rmtree (self .out_dir , ignore_errors = True )
254+ if self .cleanup is True :
255+ shutil .rmtree (self .out_dir , ignore_errors = True )
246256 shutil .rmtree (self .json_config_path , ignore_errors = True )
247257 tb_parent_dir = str (Path (self .tb_json_config_path ).parent )
248258 shutil .rmtree (tb_parent_dir , ignore_errors = True )
@@ -269,11 +279,15 @@ def __enter__(self):
269279 def __exit__ (self , * args ):
270280 # Throws errors when the writers try to close.
271281 # shutil.rmtree(self.out_dir, ignore_errors=True)
272- os .remove (self .json_config_path )
273- os .remove (self .tb_json_config_path )
274- del os .environ [CONFIG_FILE_PATH_ENV_STR ]
275- del os .environ ["TRAINING_JOB_NAME" ]
276- del os .environ [TENSORBOARD_CONFIG_FILE_PATH_ENV_STR ]
282+ if self .cleanup is True :
283+ remove_file_if_exists (self .json_config_path )
284+ remove_file_if_exists (self .tb_json_config_path )
285+ if CONFIG_FILE_PATH_ENV_STR in os .environ :
286+ del os .environ [CONFIG_FILE_PATH_ENV_STR ]
287+ if "TRAINING_JOB_NAME" in os .environ :
288+ del os .environ ["TRAINING_JOB_NAME" ]
289+ if TENSORBOARD_CONFIG_FILE_PATH_ENV_STR in os .environ :
290+ del os .environ [TENSORBOARD_CONFIG_FILE_PATH_ENV_STR ]
277291
278292
279293class ScriptSimulator (object ):
0 commit comments