Skip to content

Commit

Permalink
Improve dead client handling (#2506)
Browse files Browse the repository at this point in the history
* dev

* test dead client cmd

* added more info for dead client tracing

* remove unused imports

* fix unit test

* fix test case

* address PR comments

---------

Co-authored-by: Sean Yang <seany314@gmail.com>
  • Loading branch information
yanchengnv and SYangster authored Apr 18, 2024
1 parent f2fd48a commit f948b6e
Show file tree
Hide file tree
Showing 13 changed files with 326 additions and 172 deletions.
11 changes: 11 additions & 0 deletions nvflare/apis/controller_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,3 +542,14 @@ def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_
fl_ctx: the FL context
"""
pass

def get_client_disconnect_time(self, client_name):
"""Get the time that the client is deemed disconnected.
Args:
client_name: the name of the client
Returns: time at which the client was deemed disconnected; or None if the client is not disconnected.
"""
return None
3 changes: 2 additions & 1 deletion nvflare/apis/event_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ class EventType(object):
JOB_COMPLETED = "_job_completed"
JOB_ABORTED = "_job_aborted"
JOB_CANCELLED = "_job_cancelled"
JOB_DEAD = "_job_dead"
CLIENT_DISCONNECTED = "_client_disconnected"
CLIENT_RECONNECTED = "_client_reconnected"

BEFORE_PULL_TASK = "_before_pull_task"
AFTER_PULL_TASK = "_after_pull_task"
Expand Down
10 changes: 9 additions & 1 deletion nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ class FLContextKey(object):
CLIENT_TOKEN = "__client_token"
AUTHORIZATION_RESULT = "_authorization_result"
AUTHORIZATION_REASON = "_authorization_reason"
DEAD_JOB_CLIENT_NAME = "_dead_job_client_name"
DISCONNECTED_CLIENT_NAME = "_disconnected_client_name"
RECONNECTED_CLIENT_NAME = "_reconnected_client_name"

CLIENT_REGISTER_DATA = "_client_register_data"
SECURITY_ITEMS = "_security_items"
Expand Down Expand Up @@ -263,6 +264,7 @@ class ServerCommandKey(object):
CLIENTS = "clients"
COLLECTOR = "collector"
TURN_TO_COLD = "__turn_to_cold__"
REASON = "reason"


class FedEventHeader(object):
Expand Down Expand Up @@ -464,6 +466,12 @@ class ConfigVarName:
# client and server: query interval for reliable message
RM_QUERY_INTERVAL = "rm_query_interval"

# server: wait this long since client death report before treating the client as dead/disconnected
DEAD_CLIENT_GRACE_PERIOD = "dead_client_grace_period"

# server: wait this long since job schedule time before starting to check dead/disconnected clients
DEAD_CLIENT_CHECK_LEAD_TIME = "dead_client_check_lead_time"


class SystemVarName:
"""
Expand Down
13 changes: 13 additions & 0 deletions nvflare/apis/impl/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,16 @@ def cancel_task(

def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None):
self.communicator.cancel_all_tasks(completion_status, fl_ctx)

def get_client_disconnect_time(self, client_name):
"""Get the time when the client is deemed disconnected.
Args:
client_name: the name of the client
Returns: time at which the client was deemed disconnected; or None if the client is not disconnected.
"""
if not self.communicator:
return None
return self.communicator.get_client_disconnect_time(client_name)
313 changes: 189 additions & 124 deletions nvflare/apis/impl/wf_comm_server.py

Large diffs are not rendered by default.

31 changes: 31 additions & 0 deletions nvflare/apis/wf_comm_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,37 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul
"""
raise NotImplementedError

def get_client_disconnect_time(self, client_name):
"""Get the time that the client is deemed disconnected.
Args:
client_name: the name of the client
Returns: time at which the client was deemed disconnected; or None if the client is not disconnected.
"""
raise NotImplementedError

def process_dead_client_report(self, client_name: str, fl_ctx: FLContext):
"""Called by the Engine to process dead client report.
Args:
client_name: name of the client that dead report is received
fl_ctx: the FLContext
"""
raise NotImplementedError

def client_is_active(self, client_name: str, reason: str, fl_ctx: FLContext):
"""Called by the Engine to notify us that the client is active .
Args:
client_name: name of the client that is active
reason: why client is considered active
fl_ctx: the FLContext
"""
raise NotImplementedError

def process_task_check(self, task_id: str, fl_ctx: FLContext):
"""Called by the Engine to check whether a specified task still exists.
Args:
Expand Down
24 changes: 9 additions & 15 deletions nvflare/app_common/workflows/cyclic_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import FLContextKey, ReturnCode
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import Shareable
Expand Down Expand Up @@ -141,20 +140,24 @@ def start_controller(self, fl_ctx: FLContext):
self._last_client = None

def _get_relay_orders(self, fl_ctx: FLContext) -> Union[List[Client], None]:
if len(self._participating_clients) <= 1:
self.system_panic(f"Not enough client sites ({len(self._participating_clients)}).", fl_ctx)
active_clients_map = {}
for t in self._participating_clients:
if not self.get_client_disconnect_time(t.name):
active_clients_map[t.name] = t

if len(active_clients_map) <= 1:
self.system_panic(f"Not enough active client sites ({len(active_clients_map)}).", fl_ctx)
return None

if isinstance(self._order, list):
targets = []
active_clients_map = {t.name: t for t in self._participating_clients}
for c_name in self._order:
if c_name not in active_clients_map:
self.system_panic(f"Required client site ({c_name}) is not in active clients.", fl_ctx)
return None
targets.append(active_clients_map[c_name])
else:
targets = list(self._participating_clients)
targets = list(active_clients_map.values())
if self._order == RelayOrder.RANDOM or self._order == RelayOrder.RANDOM_WITHOUT_SAME_IN_A_ROW:
random.shuffle(targets)
if self._order == RelayOrder.RANDOM_WITHOUT_SAME_IN_A_ROW and self._last_client == targets[0]:
Expand Down Expand Up @@ -310,12 +313,3 @@ def restore(self, state_data: dict, fl_ctx: FLContext):
self._start_round = self._current_round
finally:
pass

def handle_event(self, event_type, fl_ctx):
if event_type == EventType.JOB_DEAD:
client_name = fl_ctx.get_prop(FLContextKey.DEAD_JOB_CLIENT_NAME)
new_client_list = []
for client in self._participating_clients:
if client_name != client.name:
new_client_list.append(client)
self._participating_clients = new_client_list
2 changes: 1 addition & 1 deletion nvflare/private/fed/client/client_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(

self.task_check_timeout = self.get_positive_float_var(ConfigVarName.TASK_CHECK_TIMEOUT, 5.0)
self.task_check_interval = self.get_positive_float_var(ConfigVarName.TASK_CHECK_INTERVAL, 5.0)
self.job_heartbeat_interval = self.get_positive_float_var(ConfigVarName.JOB_HEARTBEAT_INTERVAL, 30.0)
self.job_heartbeat_interval = self.get_positive_float_var(ConfigVarName.JOB_HEARTBEAT_INTERVAL, 10.0)
self.get_task_timeout = self.get_positive_float_var(ConfigVarName.GET_TASK_TIMEOUT, None)
self.submit_task_result_timeout = self.get_positive_float_var(ConfigVarName.SUBMIT_TASK_RESULT_TIMEOUT, None)
self._register_aux_message_handlers(engine)
Expand Down
35 changes: 10 additions & 25 deletions nvflare/private/fed/server/fed_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
)
from nvflare.apis.fl_context import FLContext
from nvflare.apis.fl_exception import NotAuthenticated
from nvflare.apis.shareable import Shareable
from nvflare.apis.workspace import Workspace
from nvflare.fuel.common.exit_codes import ProcessExitCode
from nvflare.fuel.f3.cellnet.cell import Cell
Expand Down Expand Up @@ -140,13 +139,8 @@ def close(self):
self.lock.release()
except RuntimeError:
self.logger.info("canceling sync locks")
try:
# if self.cell:
# self.cell.stop()
pass
finally:
self.logger.info("server off")
return 0
self.logger.info("server off")
return 0

def deploy(self, args, grpc_args=None, secure_train=False):
"""Start a grpc server and listening the designated port."""
Expand Down Expand Up @@ -624,26 +618,17 @@ def _sync_client_jobs(self, request, client_token):
# this is a dict: token => nvflare.apis.client.Client
client = participating_clients.get(client_token, None)
if client:
self._notify_dead_job(client, job_id)
self._notify_dead_job(client, job_id, "missing job on client")

return jobs_need_abort

def _notify_dead_job(self, client, job_id: str):
def _notify_dead_job(self, client, job_id: str, reason: str):
try:
with self.engine.lock:
shareable = Shareable()
shareable.set_header(ServerCommandKey.FL_CLIENT, client.name)
fqcn = FQCN.join([FQCN.ROOT_SERVER, job_id])
request = new_cell_message({}, shareable)
self.cell.fire_and_forget(
targets=fqcn,
channel=CellChannel.SERVER_COMMAND,
topic=ServerCommandNames.HANDLE_DEAD_JOB,
message=request,
optional=True,
)
except Exception:
self.logger.info("Could not connect to server runner process")
self.engine.notify_dead_job(job_id, client.name, reason)
except Exception as ex:
self.logger.info(
f"Failed to notify_dead_job to runner process of job {job_id}: {secure_format_exception(ex)}"
)

def notify_dead_client(self, client):
"""Called to do further processing of the dead client
Expand All @@ -662,7 +647,7 @@ def notify_dead_client(self, client):
assert isinstance(process_info, dict)
participating_clients = process_info.get(RunProcessKey.PARTICIPANTS, None)
if participating_clients and client.token in participating_clients:
self._notify_dead_job(client, job_id)
self._notify_dead_job(client, job_id, "client dead")

def start_run(self, job_id, run_root, conf, args, snapshot):
# Create the FL Engine
Expand Down
2 changes: 2 additions & 0 deletions nvflare/private/fed/server/server_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ def process(self, data: Shareable, fl_ctx: FLContext):
"""
client_name = data.get_header(ServerCommandKey.FL_CLIENT)
reason = data.get_header(ServerCommandKey.REASON)
self.logger.warning(f"received dead job notification: {reason=}")
server_runner = fl_ctx.get_prop(FLContextKey.RUNNER)
if server_runner:
server_runner.handle_dead_job(client_name, fl_ctx)
Expand Down
17 changes: 15 additions & 2 deletions nvflare/private/fed/server/server_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,13 +577,26 @@ def update_job_run_status(self):
data = {"execution_error": execution_error}
job_id = fl_ctx.get_job_id()
request = new_cell_message({CellMessageHeaderKeys.JOB_ID: job_id}, data)
return_data = self.server.cell.fire_and_forget(
self.server.cell.fire_and_forget(
targets=FQCN.ROOT_SERVER,
channel=CellChannel.SERVER_PARENT_LISTENER,
topic=ServerCommandNames.UPDATE_RUN_STATUS,
message=request,
)

def notify_dead_job(self, job_id: str, client_name: str, reason: str):
shareable = Shareable()
shareable.set_header(ServerCommandKey.FL_CLIENT, client_name)
shareable.set_header(ServerCommandKey.REASON, reason)
self.send_command_to_child_runner_process(
job_id=job_id,
command_name=ServerCommandNames.HANDLE_DEAD_JOB,
command_data=shareable,
timeout=0.0,
optional=True,
)
self.logger.warning(f"notified SJ of dead-job: {job_id=}; {client_name=}; {reason=}")

def send_command_to_child_runner_process(
self, job_id: str, command_name: str, command_data, timeout=5.0, optional=False
):
Expand All @@ -595,7 +608,7 @@ def send_command_to_child_runner_process(
targets=fqcn,
channel=CellChannel.SERVER_COMMAND,
topic=command_name,
request=request,
message=request,
optional=optional,
)
return None
Expand Down
19 changes: 16 additions & 3 deletions nvflare/private/fed/server/server_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def _register_aux_message_handler(self, engine):

def _handle_sync_runner(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable:
# simply ack
self._report_client_active("syncRunner", fl_ctx)
return make_reply(ReturnCode.OK)

def _execute_run(self):
Expand Down Expand Up @@ -278,6 +279,8 @@ def process_task_request(self, client: Client, fl_ctx: FLContext) -> (str, str,
self.log_debug(fl_ctx, "invalid task request: no peer context - asked client to try again later")
return self._task_try_again()

self._report_client_active("getTask", fl_ctx)

peer_job_id = peer_ctx.get_job_id()
if not peer_job_id or peer_job_id != self.job_id:
# the client is in a different RUN
Expand Down Expand Up @@ -383,9 +386,8 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
if self.current_wf is None:
return

fl_ctx.set_prop(FLContextKey.DEAD_JOB_CLIENT_NAME, client_name)
self.log_debug(fl_ctx, "firing event EventType.JOB_DEAD")
self.fire_event(EventType.JOB_DEAD, fl_ctx)
if self.current_wf.controller:
self.current_wf.controller.communicator.process_dead_client_report(client_name, fl_ctx)

except Exception as e:
self.log_exception(
Expand All @@ -408,6 +410,7 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul
fl_ctx: FLContext
"""
self.log_info(fl_ctx, f"got result from client {client.name} for task: name={task_name}, id={task_id}")
self._report_client_active("submitTaskResult", fl_ctx)

if not isinstance(result, Shareable):
self.log_error(fl_ctx, "invalid result submission: must be Shareable but got {}".format(type(result)))
Expand Down Expand Up @@ -503,11 +506,21 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul
"Error processing client result by {}: {}".format(self.current_wf.id, secure_format_exception(e)),
)

def _report_client_active(self, reason: str, fl_ctx: FLContext):
with self.wf_lock:
if self.current_wf and self.current_wf.controller:
peer_ctx = fl_ctx.get_peer_context()
assert isinstance(peer_ctx, FLContext)
client_name = peer_ctx.get_identity_name()
self.current_wf.controller.communicator.client_is_active(client_name, reason, fl_ctx)

def _handle_job_heartbeat(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable:
self.log_debug(fl_ctx, "received client job_heartbeat")
self._report_client_active("jobHeartbeat", fl_ctx)
return make_reply(ReturnCode.OK)

def _handle_task_check(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable:
self._report_client_active("taskCheck", fl_ctx)
task_id = request.get_header(ReservedHeaderKey.TASK_ID)
if not task_id:
self.log_error(fl_ctx, f"missing {ReservedHeaderKey.TASK_ID} in task_check request")
Expand Down
18 changes: 18 additions & 0 deletions nvflare/private/fed/server/sys_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ def get_spec(self):
authz_func=self.authorize_client_operation,
visible=True,
),
CommandSpec(
name="dead",
description="send dead client msg to SJ",
usage="dead <client-name>",
handler_func=self.dead_client,
authz_func=self.must_be_project_admin,
visible=False,
),
],
)

Expand Down Expand Up @@ -175,3 +183,13 @@ def report_env(self, conn: Connection, args: List[str]):
table = conn.append_table(["Sites", "Env"], name=MetaKey.CLIENTS)
for k, v in site_resources.items():
table.add_row([str(k), str(v)], meta=v)

def dead_client(self, conn: Connection, args: List[str]):
if len(args) != 3:
conn.append_error(f"Usage: {args[0]} client_name job_id")
return
client_name = args[1]
job_id = args[2]
engine = conn.app_ctx
engine.notify_dead_job(job_id, client_name, f"AdminCommand: {args[0]}")
conn.append_string(f"called notify_dead_job for client {client_name=} {job_id=}")

0 comments on commit f948b6e

Please sign in to comment.