diff --git a/examples/advanced/job_config/hello-pt/hello_pt_job.py b/examples/advanced/job_config/hello-pt/hello_pt_job.py index d95a962a3b..9f40041dac 100644 --- a/examples/advanced/job_config/hello-pt/hello_pt_job.py +++ b/examples/advanced/job_config/hello-pt/hello_pt_job.py @@ -109,12 +109,12 @@ def _create_server_app(self): def export_job(self, job_root): self.job.generate_job_config(job_root) - def simulator_run(self, job_root, workspace): - self.job.simulator_run(job_root, workspace, threads=2) + def simulator_run(self, workspace): + self.job.simulator_run(workspace, threads=2) if __name__ == "__main__": job = HelloPTJob() # job.export_job("/tmp/nvflare/jobs") - job.simulator_run("/tmp/nvflare/jobs", "/tmp/nvflare/simulator_workspace") + job.simulator_run("/tmp/nvflare/simulator_workspace") diff --git a/nvflare/job_config/fed_job_config.py b/nvflare/job_config/fed_job_config.py index 793e40e0c2..dc440a4941 100644 --- a/nvflare/job_config/fed_job_config.py +++ b/nvflare/job_config/fed_job_config.py @@ -17,6 +17,7 @@ import os import shutil from enum import Enum +from tempfile import TemporaryDirectory from typing import Dict from nvflare import SimulatorRunner @@ -119,18 +120,19 @@ def generate_job_config(self, job_root): self._generate_meta(job_dir) - def simulator_run(self, job_root, workspace, clients=None, n_clients=None, threads=None, gpu=None): - self.generate_job_config(job_root) - - simulator = SimulatorRunner( - job_folder=os.path.join(job_root, self.job_name), - workspace=workspace, - clients=clients, - n_clients=n_clients, - threads=threads, - gpu=gpu, - ) - simulator.run() + def simulator_run(self, workspace, clients=None, n_clients=None, threads=None, gpu=None): + with TemporaryDirectory() as job_root: + self.generate_job_config(job_root) + + simulator = SimulatorRunner( + job_folder=os.path.join(job_root, self.job_name), + workspace=workspace, + clients=clients, + n_clients=n_clients, + threads=threads, + gpu=gpu, + ) + simulator.run() def _get_server_app(self, config_dir, custom_dir, fed_app): server_app = {"format_version": 2, "workflows": []}