Skip to content

Commit

Permalink
simplify job simulator_run to take only one workspace parameter. (#2528)
Browse files Browse the repository at this point in the history
  • Loading branch information
yhwen authored Apr 24, 2024
1 parent fa4d00f commit 724140e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 15 deletions.
6 changes: 3 additions & 3 deletions examples/advanced/job_config/hello-pt/hello_pt_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,12 @@ def _create_server_app(self):
def export_job(self, job_root):
self.job.generate_job_config(job_root)

def simulator_run(self, job_root, workspace):
self.job.simulator_run(job_root, workspace, threads=2)
def simulator_run(self, workspace):
self.job.simulator_run(workspace, threads=2)


if __name__ == "__main__":
job = HelloPTJob()

# job.export_job("/tmp/nvflare/jobs")
job.simulator_run("/tmp/nvflare/jobs", "/tmp/nvflare/simulator_workspace")
job.simulator_run("/tmp/nvflare/simulator_workspace")
26 changes: 14 additions & 12 deletions nvflare/job_config/fed_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import shutil
from enum import Enum
from tempfile import TemporaryDirectory
from typing import Dict

from nvflare import SimulatorRunner
Expand Down Expand Up @@ -119,18 +120,19 @@ def generate_job_config(self, job_root):

self._generate_meta(job_dir)

def simulator_run(self, job_root, workspace, clients=None, n_clients=None, threads=None, gpu=None):
self.generate_job_config(job_root)

simulator = SimulatorRunner(
job_folder=os.path.join(job_root, self.job_name),
workspace=workspace,
clients=clients,
n_clients=n_clients,
threads=threads,
gpu=gpu,
)
simulator.run()
def simulator_run(self, workspace, clients=None, n_clients=None, threads=None, gpu=None):
with TemporaryDirectory() as job_root:
self.generate_job_config(job_root)

simulator = SimulatorRunner(
job_folder=os.path.join(job_root, self.job_name),
workspace=workspace,
clients=clients,
n_clients=n_clients,
threads=threads,
gpu=gpu,
)
simulator.run()

def _get_server_app(self, config_dir, custom_dir, fed_app):
server_app = {"format_version": 2, "workflows": []}
Expand Down

0 comments on commit 724140e

Please sign in to comment.