Skip to content

Commit

Permalink
add params_transfer_type to ScriptRunner (#2922)
Browse files Browse the repository at this point in the history
  • Loading branch information
SYangster authored Sep 6, 2024
1 parent 355f02c commit f79d832
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion nvflare/job_config/script_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from nvflare.app_common.executors.client_api_launcher_executor import ClientAPILauncherExecutor
from nvflare.app_common.executors.in_process_client_api_executor import InProcessClientAPIExecutor
from nvflare.client.config import ExchangeFormat
from nvflare.client.config import ExchangeFormat, TransferType
from nvflare.fuel.utils.import_utils import optional_import


Expand All @@ -35,6 +35,7 @@ def __init__(
launch_external_process: bool = False,
command: str = "python3 -u",
framework: FrameworkType = FrameworkType.PYTORCH,
params_transfer_type: str = TransferType.FULL,
):
"""ScriptRunner is used with FedJob API to run or launch a script.
Expand All @@ -47,12 +48,15 @@ def __init__(
launch_external_process (bool): Whether to launch the script in external process. Defaults to False.
command (str): If launch_external_process=True, command to run script (preprended to script). Defaults to "python3".
framework (str): Framework type to connfigure converter and params exchange formats. Defaults to FrameworkType.PYTORCH.
params_transfer_type (str): How to transfer the parameters. FULL means the whole model parameters are sent.
DIFF means that only the difference is sent. Defaults to TransferType.FULL.
"""
self._script = script
self._script_args = script_args
self._command = command
self._launch_external_process = launch_external_process
self._framework = framework
self._params_transfer_type = params_transfer_type

self._params_exchange_format = None

Expand Down Expand Up @@ -116,6 +120,7 @@ def add_to_fed_job(self, job, ctx, **kwargs):
pipe_id=pipe_id,
launcher_id=launcher_id,
params_exchange_format=self._params_exchange_format,
params_transfer_type=self._params_transfer_type,
heartbeat_timeout=0,
)
job.add_executor(executor, tasks=tasks, ctx=ctx)
Expand Down Expand Up @@ -148,6 +153,7 @@ def add_to_fed_job(self, job, ctx, **kwargs):
task_script_path=self._script,
task_script_args=self._script_args,
params_exchange_format=self._params_exchange_format,
params_transfer_type=self._params_transfer_type,
)
job.add_executor(executor, tasks=tasks, ctx=ctx)

Expand Down

0 comments on commit f79d832

Please sign in to comment.