diff --git a/nvflare/apis/controller_spec.py b/nvflare/apis/controller_spec.py index f0019da206..2f18f95623 100644 --- a/nvflare/apis/controller_spec.py +++ b/nvflare/apis/controller_spec.py @@ -542,3 +542,14 @@ def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_ fl_ctx: the FL context """ pass + + def get_client_disconnect_time(self, client_name): + """Get the time that the client is deemed disconnected. + + Args: + client_name: the name of the client + + Returns: time at which the client was deemed disconnected; or None if the client is not disconnected. + + """ + return None diff --git a/nvflare/apis/event_type.py b/nvflare/apis/event_type.py index e2448da113..5bdbe50bc8 100644 --- a/nvflare/apis/event_type.py +++ b/nvflare/apis/event_type.py @@ -34,7 +34,8 @@ class EventType(object): JOB_COMPLETED = "_job_completed" JOB_ABORTED = "_job_aborted" JOB_CANCELLED = "_job_cancelled" - JOB_DEAD = "_job_dead" + CLIENT_DISCONNECTED = "_client_disconnected" + CLIENT_RECONNECTED = "_client_reconnected" BEFORE_PULL_TASK = "_before_pull_task" AFTER_PULL_TASK = "_after_pull_task" diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index 4752eb4a29..71ae9c0c45 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -159,7 +159,8 @@ class FLContextKey(object): CLIENT_TOKEN = "__client_token" AUTHORIZATION_RESULT = "_authorization_result" AUTHORIZATION_REASON = "_authorization_reason" - DEAD_JOB_CLIENT_NAME = "_dead_job_client_name" + DISCONNECTED_CLIENT_NAME = "_disconnected_client_name" + RECONNECTED_CLIENT_NAME = "_reconnected_client_name" CLIENT_REGISTER_DATA = "_client_register_data" SECURITY_ITEMS = "_security_items" @@ -263,6 +264,7 @@ class ServerCommandKey(object): CLIENTS = "clients" COLLECTOR = "collector" TURN_TO_COLD = "__turn_to_cold__" + REASON = "reason" class FedEventHeader(object): @@ -464,6 +466,12 @@ class ConfigVarName: # client and server: query interval for reliable message RM_QUERY_INTERVAL = "rm_query_interval" + # server: wait this long since client death report before treating the client as dead/disconnected + DEAD_CLIENT_GRACE_PERIOD = "dead_client_grace_period" + + # server: wait this long since job schedule time before starting to check dead/disconnected clients + DEAD_CLIENT_CHECK_LEAD_TIME = "dead_client_check_lead_time" + class SystemVarName: """ diff --git a/nvflare/apis/impl/controller.py b/nvflare/apis/impl/controller.py index 6091ac23f1..924512f77b 100644 --- a/nvflare/apis/impl/controller.py +++ b/nvflare/apis/impl/controller.py @@ -145,3 +145,16 @@ def cancel_task( def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None): self.communicator.cancel_all_tasks(completion_status, fl_ctx) + + def get_client_disconnect_time(self, client_name): + """Get the time when the client is deemed disconnected. + + Args: + client_name: the name of the client + + Returns: time at which the client was deemed disconnected; or None if the client is not disconnected. + + """ + if not self.communicator: + return None + return self.communicator.get_client_disconnect_time(client_name) diff --git a/nvflare/apis/impl/wf_comm_server.py b/nvflare/apis/impl/wf_comm_server.py index 06b5d13457..679efe643d 100644 --- a/nvflare/apis/impl/wf_comm_server.py +++ b/nvflare/apis/impl/wf_comm_server.py @@ -20,7 +20,7 @@ from nvflare.apis.controller_spec import ClientTask, SendOrder, Task, TaskCompletionStatus from nvflare.apis.event_type import EventType from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_constant import FLContextKey +from nvflare.apis.fl_constant import ConfigVarName, FLContextKey, SystemConfigs from nvflare.apis.fl_context import FLContext from nvflare.apis.job_def import job_from_meta from nvflare.apis.shareable import ReservedHeaderKey, Shareable, make_copy @@ -40,12 +40,6 @@ _TASK_KEY_MANAGER = "___mgr" _TASK_KEY_DONE = "___done" -# wait this long since client death report before treating the client as dead -_CONFIG_VAR_DEAD_CLIENT_GRACE_PERIOD = "dead_client_grace_period" - -# wait this long since job schedule time before starting to check dead clients -_CONFIG_VAR_DEAD_CLIENT_CHECK_LEAD_TIME = "dead_client_check_lead_time" - def _check_positive_int(name, value): if not isinstance(value, int): @@ -79,6 +73,12 @@ def _get_client_task(target, task: Task): return None +class _DeadClientStatus: + def __init__(self): + self.report_time = time.time() + self.disconnect_time = None + + class WFCommServer(FLComponent, WFCommSpec): def __init__(self, task_check_period=0.2): """Manage life cycles of tasks and their destinations. @@ -93,9 +93,10 @@ def __init__(self, task_check_period=0.2): self._client_task_map = {} # client_task_id => client_task self._all_done = False self._task_lock = Lock() - self._task_monitor = threading.Thread(target=self._monitor_tasks, args=()) + self._task_monitor = threading.Thread(target=self._monitor_tasks, args=(), daemon=True) self._task_check_period = task_check_period - self._dead_client_reports = {} # clients that reported the job is dead on it: name => report time + self._dead_client_grace = 60.0 + self._dead_clients = {} # clients reported dead: name => _DeadClientStatus self._dead_clients_lock = Lock() # need lock since dead_clients can be modified from different threads # make sure check_tasks, process_task_request, process_submission does not interfere with each other self._controller_lock = Lock() @@ -112,13 +113,16 @@ def initialize_run(self, fl_ctx: FLContext): """ engine = fl_ctx.get_engine() if not engine: - self.system_panic(f"Engine not found. {self.__class__.__name__} exiting.", fl_ctx) + self.system_panic(f"Engine not found. {self.name} exiting.", fl_ctx) return self._engine = engine + self._dead_client_grace = ConfigService.get_float_var( + name=ConfigVarName.DEAD_CLIENT_GRACE_PERIOD, conf=SystemConfigs.APPLICATION_CONF, default=60.0 + ) self._task_monitor.start() - def _try_again(self) -> Tuple[str, str, Shareable]: + def _try_again(self) -> Tuple[str, str, Optional[Shareable]]: # TODO: how to tell client no shareable available now? return "", "", None @@ -135,7 +139,7 @@ def _set_stats(self, fl_ctx: FLContext): "collector must be an instance of GroupInfoCollector, but got {}".format(type(collector)) ) collector.add_info( - group_name=self.controller._name, + group_name=self.name, info={ "tasks": {t.name: [ct.client.name for ct in t.client_tasks] for t in self._tasks}, }, @@ -150,12 +154,15 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): """ if event_type == InfoCollector.EVENT_TYPE_GET_STATS: self._set_stats(fl_ctx) - elif event_type == EventType.JOB_DEAD: - client_name = fl_ctx.get_prop(FLContextKey.DEAD_JOB_CLIENT_NAME) - with self._dead_clients_lock: - self.log_info(fl_ctx, f"received dead job report from client {client_name}") - if not self._dead_client_reports.get(client_name): - self._dead_client_reports[client_name] = time.time() + + def process_dead_client_report(self, client_name: str, fl_ctx: FLContext): + with self._dead_clients_lock: + self.log_warning(fl_ctx, f"received dead job report for client {client_name}") + if not self._dead_clients.get(client_name): + self.log_warning(fl_ctx, f"client {client_name} is placed on dead client watch list") + self._dead_clients[client_name] = _DeadClientStatus() + else: + self.log_warning(fl_ctx, f"discarded dead client report {client_name=}: already on watch list") def process_task_request(self, client: Client, fl_ctx: FLContext) -> Tuple[str, str, Shareable]: """Called by runner when a client asks for a task. @@ -183,9 +190,6 @@ def _do_process_task_request(self, client: Client, fl_ctx: FLContext) -> Tuple[s if not isinstance(client, Client): raise TypeError("client must be an instance of Client, but got {}".format(type(client))) - with self._dead_clients_lock: - self._dead_client_reports.pop(client.name, None) - if not isinstance(fl_ctx, FLContext): raise TypeError("fl_ctx must be an instance of FLContext, but got {}".format(type(fl_ctx))) @@ -371,13 +375,6 @@ def _do_process_submission( if not isinstance(client, Client): raise TypeError("client must be an instance of Client, but got {}".format(type(client))) - # reset the dead job report! - # note that due to potential race conditions, a client may fail to include the job id in its - # heartbeat (since the job hasn't started at the time of heartbeat report), but then includes - # the job ID later. - with self._dead_clients_lock: - self._dead_client_reports.pop(client.name, None) - if not isinstance(fl_ctx, FLContext): raise TypeError("fl_ctx must be an instance of FLContext, but got {}".format(type(fl_ctx))) if not isinstance(result, Shareable): @@ -490,18 +487,23 @@ def broadcast( ): """Schedule a broadcast task. This is a non-blocking call. - The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients. + The task is scheduled into a task list. + Clients can request tasks and controller will dispatch the task to eligible clients. Args: task (Task): the task to be scheduled fl_ctx (FLContext): FLContext associated with this task - targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. - min_responses (int, optional): the condition to mark this task as completed because enough clients respond with submission. Defaults to 1. - wait_time_after_min_received (int, optional): a grace period for late clients to contribute their submission. 0 means no grace period. + targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names + or None (all clients). Defaults to None. + min_responses (int, optional): the condition to mark this task as completed because enough clients + respond with submission. Defaults to 1. + wait_time_after_min_received (int, optional): a grace period for late clients to contribute their + submission. 0 means no grace period. Submission of late clients in the grace period are still collected as valid submission. Defaults to 0. Raises: - ValueError: min_responses is greater than the length of targets since this condition will make the task, if allowed to be scheduled, never exit. + ValueError: min_responses is greater than the length of targets since this condition will make the task, + if allowed to be scheduled, never exit. """ _check_inputs(task=task, fl_ctx=fl_ctx, targets=targets) _check_positive_int("min_responses", min_responses) @@ -527,16 +529,21 @@ def broadcast_and_wait( ): """Schedule a broadcast task. This is a blocking call. - The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients. + The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task + to eligible clients. Args: task (Task): the task to be scheduled fl_ctx (FLContext): FLContext associated with this task - targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. - min_responses (int, optional): the condition to mark this task as completed because enough clients respond with submission. Defaults to 1. - wait_time_after_min_received (int, optional): a grace period for late clients to contribute their submission. 0 means no grace period. + targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names + or None (all clients). Defaults to None. + min_responses (int, optional): the condition to mark this task as completed because enough clients + respond with submission. Defaults to 1. + wait_time_after_min_received (int, optional): a grace period for late clients to contribute their + submission. 0 means no grace period. Submission of late clients in the grace period are still collected as valid submission. Defaults to 0. - abort_signal (Optional[Signal], optional): as this is a blocking call, this abort_signal informs this method to return. Defaults to None. + abort_signal (Optional[Signal], optional): as this is a blocking call, this abort_signal informs + this method to return. Defaults to None. """ self.broadcast( task=task, @@ -550,13 +557,16 @@ def broadcast_and_wait( def broadcast_forever(self, task: Task, fl_ctx: FLContext, targets: Union[List[Client], List[str], None] = None): """Schedule a broadcast task. This is a non-blocking call. - The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients. + The task is scheduled into a task list. Clients can request tasks and controller will dispatch + the task to eligible clients. + This broadcast will not end. Args: task (Task): the task to be scheduled fl_ctx (FLContext): FLContext associated with this task - targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. + targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names + or None (all clients). Defaults to None. """ _check_inputs(task=task, fl_ctx=fl_ctx, targets=targets) manager = BcastForeverTaskManager() @@ -572,14 +582,18 @@ def send( ): """Schedule a single task to targets. This is a non-blocking call. - The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients based on the send_order. + The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task + to eligible clients based on the send_order. Args: task (Task): the task to be scheduled fl_ctx (FLContext): FLContext associated with this task - targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. - send_order (SendOrder, optional): the order for clients to become eligible. SEQUENTIAL means the order in targets is enforced. ANY means - clients in targets and haven't received task are eligible for task. Defaults to SendOrder.SEQUENTIAL. + targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names + or None (all clients). Defaults to None. + send_order (SendOrder, optional): the order for clients to become eligible. + SEQUENTIAL means the order in targets is enforced. + ANY means clients in targets and haven't received task are eligible for task. + Defaults to SendOrder.SEQUENTIAL. task_assignment_timeout (int, optional): how long to wait for one client to pick the task. Defaults to 0. Raises: @@ -625,16 +639,21 @@ def send_and_wait( ): """Schedule a single task to targets. This is a blocking call. - The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients based on the send_order. + The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task + to eligible clients based on the send_order. Args: task (Task): the task to be scheduled fl_ctx (FLContext): FLContext associated with this task - targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. - send_order (SendOrder, optional): the order for clients to become eligible. SEQUENTIAL means the order in targets is enforced. ANY means - clients in targets and haven't received task are eligible for task. Defaults to SendOrder.SEQUENTIAL. + targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names + or None (all clients). Defaults to None. + send_order (SendOrder, optional): the order for clients to become eligible. + SEQUENTIAL means the order in targets is enforced. + ANY means clients in targets and haven't received task are eligible for task. + Defaults to SendOrder.SEQUENTIAL. task_assignment_timeout (int, optional): how long to wait for one client to pick the task. Defaults to 0. - abort_signal (Optional[Signal], optional): as this is a blocking call, this abort_signal informs this method to return. Defaults to None. + abort_signal (Optional[Signal], optional): as this is a blocking call, this abort_signal informs this + method to return. Defaults to None. """ self.send( @@ -663,11 +682,13 @@ def cancel_task( note:: - We only mark the task as completed and leave it to the task monitor to clean up. This is to avoid potential deadlock of task_lock. + We only mark the task as completed and leave it to the task monitor to clean up. + This is to avoid potential deadlock of task_lock. Args: task (Task): the task to be cancelled - completion_status (str, optional): the completion status for this cancellation. Defaults to TaskCompletionStatus.CANCELLED. + completion_status (str, optional): the completion status for this cancellation. + Defaults to TaskCompletionStatus.CANCELLED. fl_ctx (Optional[FLContext], optional): FLContext associated with this cancellation. Defaults to None. """ task.completion_status = completion_status @@ -676,7 +697,8 @@ def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_ """Cancel all standing tasks in this controller. Args: - completion_status (str, optional): the completion status for this cancellation. Defaults to TaskCompletionStatus.CANCELLED. + completion_status (str, optional): the completion status for this cancellation. + Defaults to TaskCompletionStatus.CANCELLED. fl_ctx (Optional[FLContext], optional): FLContext associated with this cancellation. Defaults to None. """ with self._task_lock: @@ -695,11 +717,6 @@ def finalize_run(self, fl_ctx: FLContext): """ self.cancel_all_tasks() # unconditionally cancel all tasks self._all_done = True - try: - if self._task_monitor.is_alive(): - self._task_monitor.join() - except RuntimeError: - self.log_debug(fl_ctx, "unable to join monitor thread (not started?)") def relay( self, @@ -713,18 +730,23 @@ def relay( ): """Schedule a single task to targets in one-after-another style. This is a non-blocking call. - The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients based on the send_order. + The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task + to eligible clients based on the send_order. Args: task (Task): the task to be scheduled fl_ctx (FLContext): FLContext associated with this task - targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. + targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names + or None (all clients). Defaults to None. send_order (SendOrder, optional): the order for clients to become eligible. SEQUENTIAL means the order in targets is enforced. - ANY means any clients that are inside the targets and haven't received the task are eligible. Defaults to SendOrder.SEQUENTIAL. + ANY means any clients that are inside the targets and haven't received the task are eligible. + Defaults to SendOrder.SEQUENTIAL. task_assignment_timeout (int, optional): how long to wait for one client to pick the task. Defaults to 0. - task_result_timeout (int, optional): how long to wait for current working client to reply its result. Defaults to 0. - dynamic_targets (bool, optional): allow clients not in targets to join at the end of targets list. Defaults to True. + task_result_timeout (int, optional): how long to wait for current working client to reply its result. + Defaults to 0. + dynamic_targets (bool, optional): allow clients not in targets to join at the end of targets list. + Defaults to True. Raises: ValueError: when task_assignment_timeout is greater than task's timeout @@ -792,18 +814,25 @@ def relay_and_wait( ): """Schedule a single task to targets in one-after-another style. This is a blocking call. - The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients based on the send_order. + The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task + to eligible clients based on the send_order. Args: task (Task): the task to be scheduled fl_ctx (FLContext): FLContext associated with this task - targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names or None (all clients). Defaults to None. - send_order (SendOrder, optional): the order for clients to become eligible. SEQUENTIAL means the order in targets is enforced. ANY means - clients in targets and haven't received task are eligible for task. Defaults to SendOrder.SEQUENTIAL. + targets (Union[List[Client], List[str], None], optional): the list of eligible clients or client names + or None (all clients). Defaults to None. + send_order (SendOrder, optional): the order for clients to become eligible. + SEQUENTIAL means the order in targets is enforced. + ANY means clients in targets and haven't received task are eligible for task. + Defaults to SendOrder.SEQUENTIAL. task_assignment_timeout (int, optional): how long to wait for one client to pick the task. Defaults to 0. - task_result_timeout (int, optional): how long to wait for current working client to reply its result. Defaults to 0. - dynamic_targets (bool, optional): allow clients not in targets to join at the end of targets list. Defaults to True. - abort_signal (Optional[Signal], optional): as this is a blocking call, this abort_signal informs this method to return. Defaults to None. + task_result_timeout (int, optional): how long to wait for current working client to reply its result. + Defaults to 0. + dynamic_targets (bool, optional): allow clients not in targets to join at the end of targets list. + Defaults to True. + abort_signal (Optional[Signal], optional): as this is a blocking call, this abort_signal informs + this method to return. Defaults to None. """ self.relay( task=task, @@ -816,8 +845,33 @@ def relay_and_wait( ) self.wait_for_task(task, abort_signal) + def _check_dead_clients(self): + if not self._dead_clients: + return + + now = time.time() + with self._dead_clients_lock: + for client_name, status in self._dead_clients.items(): + if status.disconnect_time: + # already disconnected + continue + + if now - status.report_time < self._dead_client_grace: + # this report is still fresh - consider the client to be still alive + continue + + # consider client disconnected + status.disconnect_time = now + self.logger.error(f"Client {client_name} is deemed disconnected!") + with self._engine.new_context() as fl_ctx: + fl_ctx.set_prop(FLContextKey.DISCONNECTED_CLIENT_NAME, client_name) + self.fire_event(EventType.CLIENT_DISCONNECTED, fl_ctx) + def _monitor_tasks(self): while not self._all_done: + # determine clients are still active or not + self._check_dead_clients() + should_abort_job = self._job_policy_violated() if not should_abort_job: self.check_tasks() @@ -907,29 +961,30 @@ def _get_task_dead_clients(self, task: Task): See whether the task is only waiting for response from a dead client """ now = time.time() - lead_time = ConfigService.get_float_var(name=_CONFIG_VAR_DEAD_CLIENT_CHECK_LEAD_TIME, default=30.0) + lead_time = ConfigService.get_float_var( + name=ConfigVarName.DEAD_CLIENT_CHECK_LEAD_TIME, conf=SystemConfigs.APPLICATION_CONF, default=30.0 + ) if now - task.schedule_time < lead_time: # due to potential race conditions, we'll wait for at least 1 minute after the task # is started before checking dead clients. return None dead_clients = [] - with self._dead_clients_lock: - for target in task.targets: - ct = _get_client_task(target, task) - if ct is not None and ct.result_received_time: - # response has been received from this client - continue - - # either we have not sent the task to this client or we have not received response - # is the client already dead? - if self._client_still_alive(target): - # this client is still alive - # we let the task continue its course since we still have live clients - return None - else: - # this client is dead - remember it - dead_clients.append(target) + for target in task.targets: + ct = _get_client_task(target, task) + if ct is not None and ct.result_received_time: + # response has been received from this client + continue + + # either we have not sent the task to this client or we have not received response + # is the client already dead? + if self.get_client_disconnect_time(target): + # this client is dead - remember it + dead_clients.append(target) + else: + # this client is still alive + # we let the task continue its course since we still have live clients + return None return dead_clients @@ -964,47 +1019,57 @@ def _job_policy_violated(self): with self._engine.new_context() as fl_ctx: clients = self._engine.get_clients() - with self._dead_clients_lock: - alive_clients = [] - dead_clients = [] - - for client in clients: - if self._client_still_alive(client.name): - alive_clients.append(client.name) - else: - dead_clients.append(client.name) - - if not dead_clients: - return False - - if not alive_clients: - self.log_error(fl_ctx, f"All clients are dead: {dead_clients}") - return True + alive_clients = [] + dead_clients = [] - job_meta = fl_ctx.get_prop(FLContextKey.JOB_META) - job = job_from_meta(job_meta) - if len(alive_clients) < job.min_sites: - self.log_error(fl_ctx, f"Alive clients {len(alive_clients)} < required min {job.min_sites}") + for client in clients: + if self.get_client_disconnect_time(client.name): + dead_clients.append(client.name) + else: + alive_clients.append(client.name) + + if not dead_clients: + return False + + if not alive_clients: + self.log_error(fl_ctx, f"All clients are dead: {dead_clients}") + return True + + job_meta = fl_ctx.get_prop(FLContextKey.JOB_META) + job = job_from_meta(job_meta) + if len(alive_clients) < job.min_sites: + self.log_error(fl_ctx, f"Alive clients {len(alive_clients)} < required min {job.min_sites}") + return True + + # check required clients: + if dead_clients and job.required_sites: + dead_required_clients = [c for c in dead_clients if c in job.required_sites] + if dead_required_clients: + self.log_error(fl_ctx, f"Required client(s) dead: {dead_required_clients}") return True - - # check required clients: - if dead_clients and job.required_sites: - dead_required_clients = [c for c in dead_clients if c in job.required_sites] - if dead_required_clients: - self.log_error(fl_ctx, f"Required client(s) dead: {dead_required_clients}") - return True return False - def _client_still_alive(self, client_name): - now = time.time() - report_time = self._dead_client_reports.get(client_name, None) - grace_period = ConfigService.get_float_var(name=_CONFIG_VAR_DEAD_CLIENT_GRACE_PERIOD, default=30.0) + def client_is_active(self, client_name: str, reason: str, fl_ctx: FLContext): + with self._dead_clients_lock: + self.log_debug(fl_ctx, f"client {client_name} is active: {reason}") + if client_name in self._dead_clients: + self.log_info(fl_ctx, f"Client {client_name} is removed from watch list: {reason}") + status = self._dead_clients.pop(client_name) + if status.disconnect_time: + self.log_info(fl_ctx, f"Client {client_name} is reconnected") + fl_ctx.set_prop(FLContextKey.RECONNECTED_CLIENT_NAME, client_name) + self.fire_event(EventType.CLIENT_RECONNECTED, fl_ctx) + + def get_client_disconnect_time(self, client_name: str): + """Get the time that the client was deemed disconnected - if not report_time: - # this client is still alive - return True - elif now - report_time < grace_period: - # this report is still fresh - consider the client to be still alive - return True + Args: + client_name: name of the client - return False + Returns: time at which the client was deemed disconnected; or None if the client is not disconnected + + """ + status = self._dead_clients.get(client_name) + if status: + return status.disconnect_time + return None diff --git a/nvflare/apis/wf_comm_spec.py b/nvflare/apis/wf_comm_spec.py index a32c34948a..0ac504f9a1 100644 --- a/nvflare/apis/wf_comm_spec.py +++ b/nvflare/apis/wf_comm_spec.py @@ -284,6 +284,37 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul """ raise NotImplementedError + def get_client_disconnect_time(self, client_name): + """Get the time that the client is deemed disconnected. + + Args: + client_name: the name of the client + + Returns: time at which the client was deemed disconnected; or None if the client is not disconnected. + + """ + raise NotImplementedError + + def process_dead_client_report(self, client_name: str, fl_ctx: FLContext): + """Called by the Engine to process dead client report. + + Args: + client_name: name of the client that dead report is received + fl_ctx: the FLContext + + """ + raise NotImplementedError + + def client_is_active(self, client_name: str, reason: str, fl_ctx: FLContext): + """Called by the Engine to notify us that the client is active . + + Args: + client_name: name of the client that is active + reason: why client is considered active + fl_ctx: the FLContext + """ + raise NotImplementedError + def process_task_check(self, task_id: str, fl_ctx: FLContext): """Called by the Engine to check whether a specified task still exists. Args: diff --git a/nvflare/app_common/workflows/cyclic_ctl.py b/nvflare/app_common/workflows/cyclic_ctl.py index 442aaa89be..034a1a7d11 100644 --- a/nvflare/app_common/workflows/cyclic_ctl.py +++ b/nvflare/app_common/workflows/cyclic_ctl.py @@ -18,8 +18,7 @@ from nvflare.apis.client import Client from nvflare.apis.controller_spec import ClientTask, Task -from nvflare.apis.event_type import EventType -from nvflare.apis.fl_constant import FLContextKey, ReturnCode +from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext from nvflare.apis.impl.controller import Controller from nvflare.apis.shareable import Shareable @@ -141,20 +140,24 @@ def start_controller(self, fl_ctx: FLContext): self._last_client = None def _get_relay_orders(self, fl_ctx: FLContext) -> Union[List[Client], None]: - if len(self._participating_clients) <= 1: - self.system_panic(f"Not enough client sites ({len(self._participating_clients)}).", fl_ctx) + active_clients_map = {} + for t in self._participating_clients: + if not self.get_client_disconnect_time(t.name): + active_clients_map[t.name] = t + + if len(active_clients_map) <= 1: + self.system_panic(f"Not enough active client sites ({len(active_clients_map)}).", fl_ctx) return None if isinstance(self._order, list): targets = [] - active_clients_map = {t.name: t for t in self._participating_clients} for c_name in self._order: if c_name not in active_clients_map: self.system_panic(f"Required client site ({c_name}) is not in active clients.", fl_ctx) return None targets.append(active_clients_map[c_name]) else: - targets = list(self._participating_clients) + targets = list(active_clients_map.values()) if self._order == RelayOrder.RANDOM or self._order == RelayOrder.RANDOM_WITHOUT_SAME_IN_A_ROW: random.shuffle(targets) if self._order == RelayOrder.RANDOM_WITHOUT_SAME_IN_A_ROW and self._last_client == targets[0]: @@ -310,12 +313,3 @@ def restore(self, state_data: dict, fl_ctx: FLContext): self._start_round = self._current_round finally: pass - - def handle_event(self, event_type, fl_ctx): - if event_type == EventType.JOB_DEAD: - client_name = fl_ctx.get_prop(FLContextKey.DEAD_JOB_CLIENT_NAME) - new_client_list = [] - for client in self._participating_clients: - if client_name != client.name: - new_client_list.append(client) - self._participating_clients = new_client_list diff --git a/nvflare/private/fed/client/client_runner.py b/nvflare/private/fed/client/client_runner.py index b57f2f68fd..365aefef3f 100644 --- a/nvflare/private/fed/client/client_runner.py +++ b/nvflare/private/fed/client/client_runner.py @@ -145,7 +145,7 @@ def __init__( self.task_check_timeout = self.get_positive_float_var(ConfigVarName.TASK_CHECK_TIMEOUT, 5.0) self.task_check_interval = self.get_positive_float_var(ConfigVarName.TASK_CHECK_INTERVAL, 5.0) - self.job_heartbeat_interval = self.get_positive_float_var(ConfigVarName.JOB_HEARTBEAT_INTERVAL, 30.0) + self.job_heartbeat_interval = self.get_positive_float_var(ConfigVarName.JOB_HEARTBEAT_INTERVAL, 10.0) self.get_task_timeout = self.get_positive_float_var(ConfigVarName.GET_TASK_TIMEOUT, None) self.submit_task_result_timeout = self.get_positive_float_var(ConfigVarName.SUBMIT_TASK_RESULT_TIMEOUT, None) self._register_aux_message_handlers(engine) diff --git a/nvflare/private/fed/server/fed_server.py b/nvflare/private/fed/server/fed_server.py index ae50e4c3a6..2f6334b470 100644 --- a/nvflare/private/fed/server/fed_server.py +++ b/nvflare/private/fed/server/fed_server.py @@ -36,7 +36,6 @@ ) from nvflare.apis.fl_context import FLContext from nvflare.apis.fl_exception import NotAuthenticated -from nvflare.apis.shareable import Shareable from nvflare.apis.workspace import Workspace from nvflare.fuel.common.exit_codes import ProcessExitCode from nvflare.fuel.f3.cellnet.cell import Cell @@ -140,13 +139,8 @@ def close(self): self.lock.release() except RuntimeError: self.logger.info("canceling sync locks") - try: - # if self.cell: - # self.cell.stop() - pass - finally: - self.logger.info("server off") - return 0 + self.logger.info("server off") + return 0 def deploy(self, args, grpc_args=None, secure_train=False): """Start a grpc server and listening the designated port.""" @@ -624,26 +618,17 @@ def _sync_client_jobs(self, request, client_token): # this is a dict: token => nvflare.apis.client.Client client = participating_clients.get(client_token, None) if client: - self._notify_dead_job(client, job_id) + self._notify_dead_job(client, job_id, "missing job on client") return jobs_need_abort - def _notify_dead_job(self, client, job_id: str): + def _notify_dead_job(self, client, job_id: str, reason: str): try: - with self.engine.lock: - shareable = Shareable() - shareable.set_header(ServerCommandKey.FL_CLIENT, client.name) - fqcn = FQCN.join([FQCN.ROOT_SERVER, job_id]) - request = new_cell_message({}, shareable) - self.cell.fire_and_forget( - targets=fqcn, - channel=CellChannel.SERVER_COMMAND, - topic=ServerCommandNames.HANDLE_DEAD_JOB, - message=request, - optional=True, - ) - except Exception: - self.logger.info("Could not connect to server runner process") + self.engine.notify_dead_job(job_id, client.name, reason) + except Exception as ex: + self.logger.info( + f"Failed to notify_dead_job to runner process of job {job_id}: {secure_format_exception(ex)}" + ) def notify_dead_client(self, client): """Called to do further processing of the dead client @@ -662,7 +647,7 @@ def notify_dead_client(self, client): assert isinstance(process_info, dict) participating_clients = process_info.get(RunProcessKey.PARTICIPANTS, None) if participating_clients and client.token in participating_clients: - self._notify_dead_job(client, job_id) + self._notify_dead_job(client, job_id, "client dead") def start_run(self, job_id, run_root, conf, args, snapshot): # Create the FL Engine diff --git a/nvflare/private/fed/server/server_commands.py b/nvflare/private/fed/server/server_commands.py index 37ca856ef8..a4456bc5ce 100644 --- a/nvflare/private/fed/server/server_commands.py +++ b/nvflare/private/fed/server/server_commands.py @@ -263,6 +263,8 @@ def process(self, data: Shareable, fl_ctx: FLContext): """ client_name = data.get_header(ServerCommandKey.FL_CLIENT) + reason = data.get_header(ServerCommandKey.REASON) + self.logger.warning(f"received dead job notification: {reason=}") server_runner = fl_ctx.get_prop(FLContextKey.RUNNER) if server_runner: server_runner.handle_dead_job(client_name, fl_ctx) diff --git a/nvflare/private/fed/server/server_engine.py b/nvflare/private/fed/server/server_engine.py index cc0c3a0484..60f28bb758 100644 --- a/nvflare/private/fed/server/server_engine.py +++ b/nvflare/private/fed/server/server_engine.py @@ -577,13 +577,26 @@ def update_job_run_status(self): data = {"execution_error": execution_error} job_id = fl_ctx.get_job_id() request = new_cell_message({CellMessageHeaderKeys.JOB_ID: job_id}, data) - return_data = self.server.cell.fire_and_forget( + self.server.cell.fire_and_forget( targets=FQCN.ROOT_SERVER, channel=CellChannel.SERVER_PARENT_LISTENER, topic=ServerCommandNames.UPDATE_RUN_STATUS, message=request, ) + def notify_dead_job(self, job_id: str, client_name: str, reason: str): + shareable = Shareable() + shareable.set_header(ServerCommandKey.FL_CLIENT, client_name) + shareable.set_header(ServerCommandKey.REASON, reason) + self.send_command_to_child_runner_process( + job_id=job_id, + command_name=ServerCommandNames.HANDLE_DEAD_JOB, + command_data=shareable, + timeout=0.0, + optional=True, + ) + self.logger.warning(f"notified SJ of dead-job: {job_id=}; {client_name=}; {reason=}") + def send_command_to_child_runner_process( self, job_id: str, command_name: str, command_data, timeout=5.0, optional=False ): @@ -595,7 +608,7 @@ def send_command_to_child_runner_process( targets=fqcn, channel=CellChannel.SERVER_COMMAND, topic=command_name, - request=request, + message=request, optional=optional, ) return None diff --git a/nvflare/private/fed/server/server_runner.py b/nvflare/private/fed/server/server_runner.py index 9fbceb8880..4831cacfe8 100644 --- a/nvflare/private/fed/server/server_runner.py +++ b/nvflare/private/fed/server/server_runner.py @@ -116,6 +116,7 @@ def _register_aux_message_handler(self, engine): def _handle_sync_runner(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: # simply ack + self._report_client_active("syncRunner", fl_ctx) return make_reply(ReturnCode.OK) def _execute_run(self): @@ -278,6 +279,8 @@ def process_task_request(self, client: Client, fl_ctx: FLContext) -> (str, str, self.log_debug(fl_ctx, "invalid task request: no peer context - asked client to try again later") return self._task_try_again() + self._report_client_active("getTask", fl_ctx) + peer_job_id = peer_ctx.get_job_id() if not peer_job_id or peer_job_id != self.job_id: # the client is in a different RUN @@ -383,9 +386,8 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext): if self.current_wf is None: return - fl_ctx.set_prop(FLContextKey.DEAD_JOB_CLIENT_NAME, client_name) - self.log_debug(fl_ctx, "firing event EventType.JOB_DEAD") - self.fire_event(EventType.JOB_DEAD, fl_ctx) + if self.current_wf.controller: + self.current_wf.controller.communicator.process_dead_client_report(client_name, fl_ctx) except Exception as e: self.log_exception( @@ -408,6 +410,7 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul fl_ctx: FLContext """ self.log_info(fl_ctx, f"got result from client {client.name} for task: name={task_name}, id={task_id}") + self._report_client_active("submitTaskResult", fl_ctx) if not isinstance(result, Shareable): self.log_error(fl_ctx, "invalid result submission: must be Shareable but got {}".format(type(result))) @@ -503,11 +506,21 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul "Error processing client result by {}: {}".format(self.current_wf.id, secure_format_exception(e)), ) + def _report_client_active(self, reason: str, fl_ctx: FLContext): + with self.wf_lock: + if self.current_wf and self.current_wf.controller: + peer_ctx = fl_ctx.get_peer_context() + assert isinstance(peer_ctx, FLContext) + client_name = peer_ctx.get_identity_name() + self.current_wf.controller.communicator.client_is_active(client_name, reason, fl_ctx) + def _handle_job_heartbeat(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: self.log_debug(fl_ctx, "received client job_heartbeat") + self._report_client_active("jobHeartbeat", fl_ctx) return make_reply(ReturnCode.OK) def _handle_task_check(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: + self._report_client_active("taskCheck", fl_ctx) task_id = request.get_header(ReservedHeaderKey.TASK_ID) if not task_id: self.log_error(fl_ctx, f"missing {ReservedHeaderKey.TASK_ID} in task_check request") diff --git a/nvflare/private/fed/server/sys_cmd.py b/nvflare/private/fed/server/sys_cmd.py index c684e08073..6fb89266ae 100644 --- a/nvflare/private/fed/server/sys_cmd.py +++ b/nvflare/private/fed/server/sys_cmd.py @@ -76,6 +76,14 @@ def get_spec(self): authz_func=self.authorize_client_operation, visible=True, ), + CommandSpec( + name="dead", + description="send dead client msg to SJ", + usage="dead ", + handler_func=self.dead_client, + authz_func=self.must_be_project_admin, + visible=False, + ), ], ) @@ -175,3 +183,13 @@ def report_env(self, conn: Connection, args: List[str]): table = conn.append_table(["Sites", "Env"], name=MetaKey.CLIENTS) for k, v in site_resources.items(): table.add_row([str(k), str(v)], meta=v) + + def dead_client(self, conn: Connection, args: List[str]): + if len(args) != 3: + conn.append_error(f"Usage: {args[0]} client_name job_id") + return + client_name = args[1] + job_id = args[2] + engine = conn.app_ctx + engine.notify_dead_job(job_id, client_name, f"AdminCommand: {args[0]}") + conn.append_string(f"called notify_dead_job for client {client_name=} {job_id=}")