diff --git a/src/sagemaker/modules/local_core/local_container.py b/src/sagemaker/modules/local_core/local_container.py index 5424f4f865..448330092d 100644 --- a/src/sagemaker/modules/local_core/local_container.py +++ b/src/sagemaker/modules/local_core/local_container.py @@ -108,6 +108,8 @@ class _LocalContainer(BaseModel): container_entrypoint: Optional[List[str]] container_arguments: Optional[List[str]] + _temporary_folders: List[str] = [] + def model_post_init(self, __context: Any): """Post init method to perform custom validation and set default values.""" self.hosts = [f"algo-{i}" for i in range(1, self.instance_count + 1)] @@ -201,6 +203,13 @@ def train( # Print our Job Complete line logger.info("Local training job completed, output artifacts saved to %s", artifacts) + + shutil.rmtree(os.path.join(self.container_root, "input")) + shutil.rmtree(os.path.join(self.container_root, "shared")) + for host in self.hosts: + shutil.rmtree(os.path.join(self.container_root, host)) + for folder in self._temporary_folders: + shutil.rmtree(os.path.join(self.container_root, folder)) return artifacts def retrieve_artifacts( @@ -540,6 +549,7 @@ def _get_data_source_local_path(self, data_source: DataSource): uri = data_source.s3_data_source.s3_uri parsed_uri = urlparse(uri) local_dir = TemporaryDirectory(prefix=os.path.join(self.container_root + "/")).name + self._temporary_folders.append(local_dir) download_folder(parsed_uri.netloc, parsed_uri.path, local_dir, self.sagemaker_session) return local_dir else: diff --git a/tests/integ/sagemaker/modules/train/test_local_model_trainer.py b/tests/integ/sagemaker/modules/train/test_local_model_trainer.py index adb5f85f3e..7947b2fc87 100644 --- a/tests/integ/sagemaker/modules/train/test_local_model_trainer.py +++ b/tests/integ/sagemaker/modules/train/test_local_model_trainer.py @@ -92,10 +92,7 @@ def test_single_container_local_mode_local_data(modules_sagemaker_session): "compressed_artifacts", "artifacts", "model", - "shared", - "input", "output", - "algo-1", ] for directory in directories: @@ -149,14 +146,16 @@ def test_single_container_local_mode_s3_data(modules_sagemaker_session): assert os.path.exists(os.path.join(CWD, "compressed_artifacts/model.tar.gz")) finally: subprocess.run(["docker", "compose", "down", "-v"]) + + assert not os.path.exists(os.path.join(CWD, "shared")) + assert not os.path.exists(os.path.join(CWD, "input")) + assert not os.path.exists(os.path.join(CWD, "algo-1")) + directories = [ "compressed_artifacts", "artifacts", "model", - "shared", - "input", "output", - "algo-1", ] for directory in directories: @@ -204,20 +203,20 @@ def test_multi_container_local_mode(modules_sagemaker_session): model_trainer.train() assert os.path.exists(os.path.join(CWD, "compressed_artifacts/model.tar.gz")) - assert os.path.exists(os.path.join(CWD, "algo-1")) - assert os.path.exists(os.path.join(CWD, "algo-2")) finally: subprocess.run(["docker", "compose", "down", "-v"]) + + assert not os.path.exists(os.path.join(CWD, "shared")) + assert not os.path.exists(os.path.join(CWD, "input")) + assert not os.path.exists(os.path.join(CWD, "algo-1")) + assert not os.path.exists(os.path.join(CWD, "algo-2")) + directories = [ "compressed_artifacts", "artifacts", "model", - "shared", - "input", "output", - "algo-1", - "algo-2", ] for directory in directories: