diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index a9e48affa0..8541954ddd 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -39,6 +39,7 @@ class ReturnCode(object): VALIDATE_TYPE_UNKNOWN = "VALIDATE_TYPE_UNKNOWN" EMPTY_RESULT = "EMPTY_RESULT" UNSAFE_JOB = "UNSAFE_JOB" + EARLY_TERMINATION = "EARLY_TERMINATION" SERVER_NOT_READY = "SERVER_NOT_READY" SERVICE_UNAVAILABLE = "SERVICE_UNAVAILABLE" diff --git a/nvflare/app_common/workflows/cyclic_ctl.py b/nvflare/app_common/workflows/cyclic_ctl.py index 7fbb59096c..442aaa89be 100644 --- a/nvflare/app_common/workflows/cyclic_ctl.py +++ b/nvflare/app_common/workflows/cyclic_ctl.py @@ -14,6 +14,7 @@ import gc import random +from typing import List, Union from nvflare.apis.client import Client from nvflare.apis.controller_spec import ClientTask, Task @@ -50,7 +51,8 @@ def __init__( task_check_period: float = 0.5, persist_every_n_rounds: int = 1, snapshot_every_n_rounds: int = 1, - order: str = RelayOrder.FIXED, + order: Union[str, List[str]] = RelayOrder.FIXED, + allow_early_termination=False, ): """A sample implementation to demonstrate how to use relay method for Cyclic Federated Learning. @@ -67,11 +69,13 @@ def __init__( If n is 0 then no persist. snapshot_every_n_rounds (int, optional): persist the server state every n rounds. Defaults to 1. If n is 0 then no persist. - order (str, optional): the order of relay. - If FIXED means the same order for every round. - If RANDOM means random order for every round. - If RANDOM_WITHOUT_SAME_IN_A_ROW means every round the order gets shuffled but a client will never be - run twice in a row (in different round). + order (Union[str, List[str]], optional): The order of relay. + - If a string is provided: + - "FIXED": Same order for every round. + - "RANDOM": Random order for every round. + - "RANDOM_WITHOUT_SAME_IN_A_ROW": Shuffled order, no repetition in consecutive rounds. + - If a list of strings is provided, it represents a custom order for relay. + allow_early_termination: whether to allow early workflow termination from clients Raises: TypeError: when any of input arguments does not have correct type @@ -90,13 +94,14 @@ def __init__( if not isinstance(task_name, str): raise TypeError("task_name must be a string but got {}".format(type(task_name))) - if order not in SUPPORTED_ORDERS: - raise ValueError(f"order must be in {SUPPORTED_ORDERS}") + if order not in SUPPORTED_ORDERS and not isinstance(order, list): + raise ValueError(f"order must be in {SUPPORTED_ORDERS} or a list") self._num_rounds = num_rounds self._start_round = 0 self._end_round = self._start_round + self._num_rounds self._current_round = 0 + self._is_done = False self._last_learnable = None self.persistor_id = persistor_id self.shareable_generator_id = shareable_generator_id @@ -109,6 +114,7 @@ def __init__( self._participating_clients = None self._last_client = None self._order = order + self._allow_early_termination = allow_early_termination def start_controller(self, fl_ctx: FLContext): self.log_debug(fl_ctx, "starting controller") @@ -129,46 +135,79 @@ def start_controller(self, fl_ctx: FLContext): fl_ctx.set_prop(AppConstants.NUM_ROUNDS, self._num_rounds, private=True, sticky=True) self.fire_event(AppEventType.INITIAL_MODEL_LOADED, fl_ctx) - self._participating_clients = self._engine.get_clients() + self._participating_clients: List[Client] = self._engine.get_clients() if len(self._participating_clients) <= 1: self.system_panic("Not enough client sites.", fl_ctx) self._last_client = None - def _get_relay_orders(self, fl_ctx: FLContext): - targets = list(self._participating_clients) - if len(targets) <= 1: - self.system_panic("Not enough client sites.", fl_ctx) - if self._order == RelayOrder.RANDOM: - random.shuffle(targets) - elif self._order == RelayOrder.RANDOM_WITHOUT_SAME_IN_A_ROW: - random.shuffle(targets) - if self._last_client == targets[0]: - targets = targets.append(targets.pop(0)) + def _get_relay_orders(self, fl_ctx: FLContext) -> Union[List[Client], None]: + if len(self._participating_clients) <= 1: + self.system_panic(f"Not enough client sites ({len(self._participating_clients)}).", fl_ctx) + return None + + if isinstance(self._order, list): + targets = [] + active_clients_map = {t.name: t for t in self._participating_clients} + for c_name in self._order: + if c_name not in active_clients_map: + self.system_panic(f"Required client site ({c_name}) is not in active clients.", fl_ctx) + return None + targets.append(active_clients_map[c_name]) + else: + targets = list(self._participating_clients) + if self._order == RelayOrder.RANDOM or self._order == RelayOrder.RANDOM_WITHOUT_SAME_IN_A_ROW: + random.shuffle(targets) + if self._order == RelayOrder.RANDOM_WITHOUT_SAME_IN_A_ROW and self._last_client == targets[0]: + targets.append(targets.pop(0)) self._last_client = targets[-1] return targets - def _process_result(self, client_task: ClientTask, fl_ctx: FLContext): - result = client_task.result - rc = result.get_return_code() - client_name = client_task.client.name - - # Raise errors if ReturnCode is not OK. - if rc and rc != ReturnCode.OK: - self.system_panic( - f"Result from {client_name} is bad, error code: {rc}. " - f"{self.__class__.__name__} exiting at round {self._current_round}.", - fl_ctx=fl_ctx, - ) - return False + def _stop_workflow(self, task: Task): + self.cancel_task(task) + self._is_done = True + def _process_result(self, client_task: ClientTask, fl_ctx: FLContext): # submitted shareable is stored in client_task.result # we need to update task.data with that shareable so the next target # will get the updated shareable task = client_task.task - # update the global learnable with the received result (shareable) - # e.g. the received result could be weight_diffs, the learnable could be full weights. - self._last_learnable = self.shareable_generator.shareable_to_learnable(client_task.result, fl_ctx) + result = client_task.result + if isinstance(result, Shareable): + # update the global learnable with the received result (shareable) + # e.g. the received result could be weight_diffs, the learnable could be full weights. + rc = result.get_return_code() + try: + self._last_learnable = self.shareable_generator.shareable_to_learnable(result, fl_ctx) + except Exception as ex: + if rc != ReturnCode.EARLY_TERMINATION: + self._stop_workflow(task) + self.log_error(fl_ctx, f"exception {secure_format_exception(ex)} from shareable_to_learnable") + return + else: + self.log_warning( + fl_ctx, + f"ignored {secure_format_exception(ex)} from shareable_to_learnable in early termination", + ) + + if rc == ReturnCode.EARLY_TERMINATION: + if self._allow_early_termination: + # the workflow is done + self._stop_workflow(task) + self.log_info(fl_ctx, f"Stopping workflow due to {rc} from client {client_task.client.name}") + return + else: + self.log_warning( + fl_ctx, + f"Ignored {rc} from client {client_task.client.name} because early termination is not allowed", + ) + else: + self._stop_workflow(task) + self.log_error( + fl_ctx, + f"Stopping workflow due to result from client {client_task.client.name} is not a Shareable", + ) + return # prepare task shareable data for next client task.data = self.shareable_generator.learnable_to_shareable(self._last_learnable, fl_ctx) @@ -183,6 +222,9 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): self.log_debug(fl_ctx, "Cyclic starting.") for self._current_round in range(self._start_round, self._end_round): + if self._is_done: + return + if abort_signal.triggered: return @@ -191,6 +233,8 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): # Task for one cyclic targets = self._get_relay_orders(fl_ctx) + if targets is None: + return targets_names = [t.name for t in targets] self.log_debug(fl_ctx, f"Relay on {targets_names}") diff --git a/tests/unit_test/app_common/workflow/cyclic_ctl_test.py b/tests/unit_test/app_common/workflow/cyclic_ctl_test.py new file mode 100644 index 0000000000..9f17e310e2 --- /dev/null +++ b/tests/unit_test/app_common/workflow/cyclic_ctl_test.py @@ -0,0 +1,134 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import uuid +from unittest.mock import Mock, patch + +import pytest + +from nvflare.apis.client import Client +from nvflare.apis.controller_spec import ClientTask, Task +from nvflare.apis.fl_constant import ReturnCode +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal +from nvflare.app_common.abstract.learnable import Learnable +from nvflare.app_common.workflows.cyclic_ctl import CyclicController, RelayOrder + +SITE_1_ID = uuid.uuid4() +SITE_2_ID = uuid.uuid4() +SITE_3_ID = uuid.uuid4() + +ORDER_TEST_CASES = [ + ( + RelayOrder.FIXED, + [Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)], + [Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)], + ), + ( + ["site-1", "site-2"], + [Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)], + [Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)], + ), + ( + ["site-2", "site-1"], + [Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)], + [Client("site-2", SITE_2_ID), Client("site-1", SITE_1_ID)], + ), + ( + ["site-2", "site-1", "site-3"], + [Client("site-3", SITE_3_ID), Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)], + [Client("site-2", SITE_2_ID), Client("site-1", SITE_1_ID), Client("site-3", SITE_3_ID)], + ), +] + + +def gen_shareable(is_early_termination: bool = False, is_not_shareable: bool = False): + if is_not_shareable: + return [1, 2, 3] + return_result = Shareable() + if is_early_termination: + return_result.set_return_code(ReturnCode.EARLY_TERMINATION) + return return_result + + +PROCESS_RESULT_TEST_CASES = [gen_shareable(is_early_termination=True), gen_shareable(is_not_shareable=True)] + + +class TestCyclicController: + @pytest.mark.parametrize("order,active_clients,expected_result", ORDER_TEST_CASES) + def test_get_relay_orders(self, order, active_clients, expected_result): + ctl = CyclicController(order=order) + ctx = FLContext() + ctl._participating_clients = active_clients + targets = ctl._get_relay_orders(ctx) + for c, e_c in zip(targets, expected_result): + assert c.name == e_c.name + assert c.token == e_c.token + + def test_control_flow_call_relay_and_wait(self): + + with patch("nvflare.app_common.workflows.cyclic_ctl.CyclicController.relay_and_wait") as mock_method: + ctl = CyclicController(persist_every_n_rounds=0, snapshot_every_n_rounds=0, num_rounds=1) + ctl.shareable_generator = Mock() + ctl._participating_clients = [ + Client("site-3", SITE_3_ID), + Client("site-1", SITE_1_ID), + Client("site-2", SITE_2_ID), + ] + + abort_signal = Signal() + fl_ctx = FLContext() + + with patch.object(ctl.shareable_generator, "learnable_to_shareable") as mock_method1, patch.object( + ctl.shareable_generator, "shareable_to_learnable" + ) as mock_method2: + mock_method1.return_value = Shareable() + mock_method2.return_value = Learnable() + + ctl.control_flow(abort_signal, fl_ctx) + + mock_method.assert_called_once() + + @pytest.mark.parametrize("return_result", PROCESS_RESULT_TEST_CASES) + def test_process_result(self, return_result): + ctl = CyclicController( + persist_every_n_rounds=0, snapshot_every_n_rounds=0, num_rounds=1, allow_early_termination=True + ) + ctl.shareable_generator = Mock() + ctl._participating_clients = [ + Client("site-3", SITE_3_ID), + Client("site-1", SITE_1_ID), + Client("site-2", SITE_2_ID), + ] + + fl_ctx = FLContext() + with patch.object(ctl, "cancel_task") as mock_method, patch.object( + ctl.shareable_generator, "learnable_to_shareable" + ) as mock_method1, patch.object(ctl.shareable_generator, "shareable_to_learnable") as mock_method2: + mock_method1.return_value = Shareable() + mock_method2.return_value = Learnable() + + client_task = ClientTask( + client=Mock(), + task=Task( + name="__test_task", + data=Shareable(), + ), + ) + client_task.result = return_result + ctl._process_result(client_task, fl_ctx) + mock_method.assert_called_once() + assert ctl._is_done is True