Skip to content

Commit

Permalink
Fix heartbeat timeout config (#2878)
Browse files Browse the repository at this point in the history
* fix heartbeat timeout config

* use TaskExchanger variable
  • Loading branch information
SYangster authored Aug 30, 2024
1 parent a8245d2 commit f52be3b
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def prepare_config_for_launch(self, fl_ctx: FLContext):
ConfigKey.CLASS_NAME: pipe_export_class,
ConfigKey.ARG: pipe_export_args,
},
ConfigKey.HEARTBEAT_TIMEOUT: self.heartbeat_timeout,
}

config_data = {
Expand Down
9 changes: 7 additions & 2 deletions nvflare/app_common/executors/task_exchanger.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
from nvflare.fuel.utils.constants import PipeChannelName
from nvflare.fuel.utils.pipe.pipe import Message, Pipe
from nvflare.fuel.utils.pipe.pipe_handler import PipeHandler, Topic
from nvflare.fuel.utils.validation_utils import check_non_negative_int, check_positive_number, check_str
from nvflare.fuel.utils.validation_utils import (
check_non_negative_int,
check_non_negative_number,
check_positive_number,
check_str,
)
from nvflare.security.logging import secure_format_exception


Expand Down Expand Up @@ -70,7 +75,7 @@ def __init__(
check_positive_number("read_interval", read_interval)
check_positive_number("heartbeat_interval", heartbeat_interval)
if heartbeat_timeout is not None:
check_positive_number("heartbeat_timeout", heartbeat_timeout)
check_non_negative_number("heartbeat_timeout", heartbeat_timeout)
check_positive_number("resend_interval", resend_interval)
if max_resends is not None:
check_non_negative_int("max_resends", max_resends)
Expand Down
1 change: 1 addition & 0 deletions nvflare/app_common/widgets/metric_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,6 @@ def export(self, export_mode: str) -> Tuple[str, dict]:
ConfigKey.CLASS_NAME: pipe_export_class,
ConfigKey.ARG: pipe_export_args,
},
ConfigKey.HEARTBEAT_TIMEOUT: self._heartbeat_timeout,
}
return ConfigKey.METRICS_EXCHANGE, config_dict
7 changes: 7 additions & 0 deletions nvflare/client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ def get_eval_task(self):
def get_submit_model_task(self):
return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.SUBMIT_MODEL_TASK_NAME, "")

def get_heartbeat_timeout(self):
# TODO decouple task and metric heartbeat timeouts
return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(
ConfigKey.HEARTBEAT_TIMEOUT,
self.config.get(ConfigKey.METRICS_EXCHANGE, {}).get(ConfigKey.HEARTBEAT_TIMEOUT, 60),
)

def to_json(self, config_file: str):
with open(config_file, "w") as f:
json.dump(self.config, f, indent=2)
Expand Down
2 changes: 1 addition & 1 deletion nvflare/client/ex_process/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def init(self, rank: Optional[str] = None):
task_channel_name=task_channel_name,
metric_pipe=metric_pipe,
metric_channel_name=metric_channel_name,
heartbeat_timeout=client_config.config.get(ConfigKey.HEARTBEAT_TIMEOUT, 60),
heartbeat_timeout=client_config.get_heartbeat_timeout(),
)
flare_agent.start()

Expand Down
4 changes: 3 additions & 1 deletion nvflare/job_config/script_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
script: str,
script_args: str = "",
launch_external_process: bool = False,
command: str = "python3",
command: str = "python3 -u",
framework: FrameworkType = FrameworkType.PYTORCH,
):
"""ScriptRunner is used with FedJob API to run or launch a script.
Expand Down Expand Up @@ -116,6 +116,7 @@ def add_to_fed_job(self, job, ctx, **kwargs):
pipe_id=pipe_id,
launcher_id=launcher_id,
params_exchange_format=self._params_exchange_format,
heartbeat_timeout=0,
)
job.add_executor(executor, tasks=tasks, ctx=ctx)

Expand All @@ -133,6 +134,7 @@ def add_to_fed_job(self, job, ctx, **kwargs):
component = MetricRelay(
pipe_id=metric_pipe_id,
event_type="fed.analytix_log_stats",
heartbeat_timeout=0,
)
metric_relay_id = job.add_component("metric_relay", component, ctx)
comp_ids["metric_relay_id"] = metric_relay_id
Expand Down

0 comments on commit f52be3b

Please sign in to comment.