diff --git a/nvflare/apis/dxo.py b/nvflare/apis/dxo.py index 90b9cc5819..368248e722 100644 --- a/nvflare/apis/dxo.py +++ b/nvflare/apis/dxo.py @@ -29,6 +29,7 @@ class DataKind(object): COLLECTION = "COLLECTION" # Dict or List of DXO objects STATISTICS = "STATISTICS" PSI = "PSI" + RAW = "RAW" class MetaKey(FLMetaKey): diff --git a/nvflare/apis/wf_controller.py b/nvflare/apis/wf_controller.py new file mode 100644 index 0000000000..175ee6827f --- /dev/null +++ b/nvflare/apis/wf_controller.py @@ -0,0 +1,97 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Callable, List, Optional + +from nvflare.app_common import wf_comm + +from .fl_constant import ReturnCode + +ABORT_WHEN_IN_ERROR = { + ReturnCode.EXECUTION_EXCEPTION: True, + ReturnCode.TASK_UNKNOWN: True, + ReturnCode.EXECUTION_RESULT_ERROR: False, + ReturnCode.TASK_DATA_FILTER_ERROR: True, + ReturnCode.TASK_RESULT_FILTER_ERROR: True, +} + + +class WFController(ABC): + def __init__(self): + self.communicator = wf_comm.get_wf_comm_api() + + @abstractmethod + def run(self): + pass + + def broadcast_and_wait( + self, + task_name: str, + min_responses: int, + data: any, + meta: dict = None, + targets: Optional[List[str]] = None, + callback: Callable = None, + ): + return self.communicator.broadcast_and_wait(task_name, min_responses, data, meta, targets, callback) + + def send_and_wait( + self, + task_name: str, + min_responses: int, + data: any, + meta: dict = None, + targets: Optional[List[str]] = None, + send_order: str = "sequential", + callback: Callable = None, + ): + return self.communicator.send_and_wait(task_name, min_responses, data, meta, targets, send_order, callback) + + def relay_and_wait( + self, + task_name: str, + min_responses: int, + data: any, + meta: dict = None, + targets: Optional[List[str]] = None, + relay_order: str = "sequential", + callback: Callable = None, + ): + return self.communicator.relay_and_wait(task_name, min_responses, data, meta, targets, relay_order, callback) + + def broadcast(self, task_name: str, data: any, meta: dict = None, targets: Optional[List[str]] = None): + return self.communicator.broadcast(task_name, data, meta, targets) + + def send( + self, + task_name: str, + data: any, + meta: dict = None, + targets: Optional[str] = None, + send_order: str = "sequential", + ): + return self.communicator.send(task_name, data, meta, targets, send_order) + + def relay( + self, + task_name: str, + data: any, + meta: dict = None, + targets: Optional[List[str]] = None, + relay_order: str = "sequential", + ): + return self.communicator.send(task_name, data, meta, targets, relay_order) + + def get_site_names(self) -> List[str]: + return self.communicator.get_site_names() diff --git a/nvflare/app_common/abstract/fl_model.py b/nvflare/app_common/abstract/fl_model.py index a82ae57c2c..502e7a5361 100644 --- a/nvflare/app_common/abstract/fl_model.py +++ b/nvflare/app_common/abstract/fl_model.py @@ -45,6 +45,7 @@ def __init__( params: Any = None, optimizer_params: Any = None, metrics: Optional[Dict] = None, + start_round: int = 0, current_round: Optional[int] = None, total_rounds: Optional[int] = None, meta: Optional[Dict] = None, @@ -79,6 +80,7 @@ def __init__( self.params = params self.optimizer_params = optimizer_params self.metrics = metrics + self.start_round = start_round self.current_round = current_round self.total_rounds = total_rounds diff --git a/nvflare/app_common/utils/fl_model_utils.py b/nvflare/app_common/utils/fl_model_utils.py index 2d84daa14f..7358070303 100644 --- a/nvflare/app_common/utils/fl_model_utils.py +++ b/nvflare/app_common/utils/fl_model_utils.py @@ -201,6 +201,9 @@ def get_configs(model: FLModel) -> Optional[dict]: @staticmethod def update_model(model: FLModel, model_update: FLModel, replace_meta: bool = True) -> FLModel: + + model.metrics = model_update.metrics + if model.params_type != ParamsType.FULL: raise RuntimeError(f"params_type {model.params_type} of `model` not supported! Expected `ParamsType.FULL`.") @@ -209,8 +212,6 @@ def update_model(model: FLModel, model_update: FLModel, replace_meta: bool = Tru else: model.meta.update(model_update.meta) - model.metrics = model_update.metrics - if model_update.params_type == ParamsType.FULL: model.params = model_update.params elif model_update.params_type == ParamsType.DIFF: diff --git a/nvflare/app_common/utils/math_utils.py b/nvflare/app_common/utils/math_utils.py new file mode 100644 index 0000000000..3dde557a4c --- /dev/null +++ b/nvflare/app_common/utils/math_utils.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import operator +from typing import Callable, Optional, Tuple + +operator_mapping = { + ">=": operator.ge, + "<=": operator.le, + ">": operator.gt, + "<": operator.lt, + "=": operator.eq, +} + + +def parse_compare_criteria(compare_expr: Optional[str] = None) -> Tuple[str, float, Callable]: + """ + Parse the compare expression into individual component + compare expression is in the format of string literal : " " + such as + accuracy >= 0.5 + loss > 2.4 + Args: + compare_expr: string literal in the format of " " + + Returns: Tuple key, value, operator + + """ + tokens = compare_expr.split(" ") + if len(tokens) != 3: + raise ValueError( + f"Invalid early_stop_condition, expecting form of ' value' but got '{compare_expr}'" + ) + + key = tokens[0] + op = tokens[1] + target = tokens[2] + op_fn = operator_mapping.get(op, None) + if op_fn is None: + raise ValueError("Invalid operator symbol: expecting one of <=, =, >=, <, > ") + if not target: + raise ValueError("Invalid empty or None target value") + try: + target_value = float(target) + except Exception as e: + raise ValueError(f"expect a number, but get '{target}' in '{compare_expr}'") + + return key, target_value, op_fn diff --git a/nvflare/app_common/wf_comm/__init__.py b/nvflare/app_common/wf_comm/__init__.py new file mode 100644 index 0000000000..c511d3e87b --- /dev/null +++ b/nvflare/app_common/wf_comm/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.app_common.wf_comm.wf_comm_api_spec import WFCommAPISpec +from nvflare.fuel.data_event.data_bus import DataBus + +data_bus = DataBus() + + +def get_wf_comm_api() -> WFCommAPISpec: + return data_bus.get_data("wf_comm_api") diff --git a/nvflare/app_common/wf_comm/base_wf_communicator.py b/nvflare/app_common/wf_comm/base_wf_communicator.py new file mode 100644 index 0000000000..172e7b98b7 --- /dev/null +++ b/nvflare/app_common/wf_comm/base_wf_communicator.py @@ -0,0 +1,304 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC +from typing import Dict, List, Optional, Tuple + +from nvflare.apis.client import Client +from nvflare.apis.controller_spec import ClientTask, ControllerSpec, OperatorMethod, SendOrder, Task, TaskOperatorKey +from nvflare.apis.dxo import DXO, DataKind +from nvflare.apis.fl_component import FLComponent +from nvflare.apis.fl_constant import FLContextKey, ReturnCode +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable +from nvflare.apis.wf_controller import ABORT_WHEN_IN_ERROR +from nvflare.app_common.abstract.fl_model import FLModel +from nvflare.app_common.app_constant import AppConstants +from nvflare.app_common.app_event_type import AppEventType +from nvflare.app_common.utils.fl_model_utils import FLModelUtils +from nvflare.app_common.wf_comm.decomposer_register import DecomposerRegister +from nvflare.app_common.wf_comm.wf_comm_api import WFCommAPI +from nvflare.app_common.wf_comm.wf_comm_api_spec import ( + DATA, + MIN_RESPONSES, + RESULT, + SITE_NAMES, + STATUS, + TARGET_SITES, + TASK_NAME, +) +from nvflare.app_common.wf_comm.wf_communicator_spec import WFCommunicatorSpec +from nvflare.fuel.data_event.data_bus import DataBus +from nvflare.fuel.data_event.event_manager import EventManager +from nvflare.fuel.utils.class_utils import instantiate_class +from nvflare.fuel.utils.component_builder import ComponentBuilder +from nvflare.fuel.utils.fobs import fobs +from nvflare.fuel.utils.import_utils import optional_import +from nvflare.private.defs import CommConstants +from nvflare.security.logging import secure_format_traceback + + +class BaseWFCommunicator(FLComponent, WFCommunicatorSpec, ControllerSpec, ABC): + def __init__( + self, + task_timeout: int = 0, + result_pull_interval: float = 0.2, + ): + super().__init__() + self.wf_controller_fn_name = "run" + self.clients = None + self.task_timeout = task_timeout + + self.result_pull_interval = result_pull_interval + self.engine = None + self.fl_ctx = None + self.data_bus: Optional[DataBus] = None + self.event_manager: Optional[EventManager] = None + + def start_controller(self, fl_ctx: FLContext): + self.fl_ctx = fl_ctx + self.log_info(fl_ctx, "Initializing controller workflow.") + + self.data_bus = DataBus() + self.event_manager = EventManager(self.data_bus) + + self.engine = self.fl_ctx.get_engine() + self.register_decomposers() + + self.clients = self.engine.get_clients() + self.publish_comm_api(fl_ctx) + self.log_info(fl_ctx, "workflow controller started") + + def register_decomposers(self): + decomposer_register = self.engine.get_component("decomposer_register") + if decomposer_register: + if not isinstance(decomposer_register, DecomposerRegister): + raise ValueError( + f"decomposer_register component must be type of 'DecomposerRegister', got {type(decomposer_register)}." + ) + decomposer_register.register() + + def publish_comm_api(self, fl_ctx: FLContext): + comm_api = WFCommAPI(cid=fl_ctx.get_prop(FLContextKey.WORKFLOW, "")) + comm_api.meta.update({SITE_NAMES: self.get_site_names()}) + self.data_bus.put_data("wf_comm_api", comm_api) + + def start_workflow(self, abort_signal, fl_ctx): + try: + fl_ctx.set_prop("abort_signal", abort_signal) + func = getattr(self.get_controller(), self.wf_controller_fn_name) + func() + + except Exception as e: + error_msg = secure_format_traceback() + self.log_error(fl_ctx, error_msg) + self.system_panic(error_msg, fl_ctx=fl_ctx) + + def stop_controller(self, fl_ctx: FLContext): + pass + + def process_result_of_unknown_task( + self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext + ): + pass + + def broadcast_to_peers_and_wait(self, pay_load): + abort_signal = self.fl_ctx.get_prop("abort_signal") + current_round = self.prepare_round_info(self.fl_ctx, pay_load) + task, min_responses, targets = self.get_payload_task(pay_load) + + self.fl_ctx.set_prop("task_name", task.name) + + self.broadcast_and_wait( + task=task, + fl_ctx=self.fl_ctx, + targets=targets, + min_responses=min_responses, + wait_time_after_min_received=0, + abort_signal=abort_signal, + ) + self.fire_event(AppEventType.ROUND_DONE, self.fl_ctx) + self.log_info(self.fl_ctx, f"Round {current_round} finished.") + + def broadcast_to_peers(self, pay_load): + task, min_responses, targets = self.get_payload_task(pay_load) + self.broadcast( + task=task, fl_ctx=self.fl_ctx, targets=targets, min_responses=min_responses, wait_time_after_min_received=0 + ) + + def send_to_peers(self, pay_load, send_order: SendOrder = SendOrder.SEQUENTIAL): + task, _, targets = self.get_payload_task(pay_load) + self.send(task=task, fl_ctx=self.fl_ctx, targets=targets, send_order=send_order, task_assignment_timeout=0) + + def send_to_peers_and_wait(self, pay_load, send_order: SendOrder = SendOrder.SEQUENTIAL): + abort_signal = self.fl_ctx.get_prop("abort_signal") + task, _, targets = self.get_payload_task(pay_load) + self.send_and_wait( + task=task, + fl_ctx=self.fl_ctx, + targets=targets, + send_order=send_order, + task_assignment_timeout=0, + abort_signal=abort_signal, + ) + + def relay_to_peers_and_wait(self, pay_load, send_order: SendOrder = SendOrder.SEQUENTIAL): + abort_signal = self.fl_ctx.get_prop("abort_signal") + task, min_responses, targets = self.get_payload_task(pay_load) + self.relay_and_wait( + task=task, + fl_ctx=self.fl_ctx, + targets=targets, + send_order=send_order, + task_assignment_timeout=0, + task_result_timeout=0, + dynamic_targets=True, + abort_signal=abort_signal, + ) + + def relay_to_peers(self, pay_load, send_order: SendOrder = SendOrder.SEQUENTIAL): + task, min_responses, targets = self.get_payload_task(pay_load) + self.relay( + task=task, + fl_ctx=self.fl_ctx, + targets=targets, + send_order=send_order, + task_assignment_timeout=0, + task_result_timeout=0, + dynamic_targets=True, + ) + + def prepare_round_info(self, fl_ctx, pay_load): + current_round = pay_load.get(AppConstants.CURRENT_ROUND, 0) + start_round = pay_load.get(AppConstants.START_ROUND, 0) + num_rounds = pay_load.get(AppConstants.NUM_ROUNDS, 1) + + fl_ctx.set_prop(AppConstants.CURRENT_ROUND, current_round, private=True, sticky=True) + fl_ctx.set_prop(AppConstants.NUM_ROUNDS, num_rounds, private=True, sticky=True) + fl_ctx.set_prop(AppConstants.START_ROUND, start_round, private=True, sticky=True) + if current_round == start_round: + self.fire_event(AppEventType.ROUND_STARTED, fl_ctx) + return current_round + + def get_payload_task(self, pay_load) -> Tuple[Task, int, List[str]]: + min_responses = pay_load.get(MIN_RESPONSES) + current_round = pay_load.get(AppConstants.CURRENT_ROUND, 0) + start_round = pay_load.get(AppConstants.START_ROUND, 0) + num_rounds = pay_load.get(AppConstants.NUM_ROUNDS, 1) + targets = pay_load.get(TARGET_SITES, self.get_site_names()) + task_name = pay_load.get(TASK_NAME) + + data = pay_load.get(DATA, {}) + data_shareable = self.get_shareable(data) + data_shareable.set_header(AppConstants.START_ROUND, start_round) + data_shareable.set_header(AppConstants.CURRENT_ROUND, current_round) + data_shareable.set_header(AppConstants.NUM_ROUNDS, num_rounds) + data_shareable.add_cookie(AppConstants.CONTRIBUTION_ROUND, current_round) + + operator = { + TaskOperatorKey.OP_ID: task_name, + TaskOperatorKey.METHOD: OperatorMethod.BROADCAST, + TaskOperatorKey.TIMEOUT: self.task_timeout, + } + + task = Task( + name=task_name, + data=data_shareable, + operator=operator, + props={}, + timeout=self.task_timeout, + before_task_sent_cb=None, + result_received_cb=self._result_received_cb, + ) + + return task, min_responses, targets + + def get_shareable(self, data): + if isinstance(data, FLModel): + data_shareable: Shareable = FLModelUtils.to_shareable(data) + elif data is None: + data_shareable = Shareable() + else: + dxo = DXO(DataKind.RAW, data=data, meta={}) + data_shareable = dxo.to_shareable() + return data_shareable + + def _result_received_cb(self, client_task: ClientTask, fl_ctx: FLContext): + + self.log_info( + fl_ctx, f"\n{client_task.client.name} task:'{client_task.task.name}' result callback received.\n\n" + ) + + client_name = client_task.client.name + task_name = client_task.task.name + result = client_task.result + rc = result.get_return_code() + results: Dict[str, any] = {STATUS: rc} + + if rc == ReturnCode.OK: + self.log_info(fl_ctx, f"Received result entries from client:{client_name} for task {task_name}") + fl_model = FLModelUtils.from_shareable(result) + results[RESULT] = {client_name: fl_model} + payload = {task_name: results} + self.event_manager.fire_event(CommConstants.TASK_RESULT, payload) + else: + self.handle_client_errors(rc, client_task, fl_ctx) + + # Cleanup task result + client_task.result = None + + def get_site_names(self): + return [client.name for client in self.clients] + + def handle_client_errors(self, rc: str, client_task: ClientTask, fl_ctx: FLContext): + abort = ABORT_WHEN_IN_ERROR[rc] + if abort: + self.log_error(fl_ctx, f"error code = {rc}") + self.system_panic( + f"Failed in client-site for {client_task.client.name} during task {client_task.task.name}.", + fl_ctx=fl_ctx, + ) + self.log_error(fl_ctx, f"Execution failed for {client_task.client.name}") + else: + raise ValueError(f"Execution result is not received for {client_task.client.name}") + + def set_controller_config(self, controller_config: Dict): + if controller_config is None: + raise ValueError("controller_config is None") + + if not isinstance(controller_config, dict): + raise ValueError(f"controller_config should be Dict, found '{type(controller_config)}'") + + self.controller_config = controller_config + + def get_controller(self): + controller = None + if isinstance(self.controller_config, dict): + controller = ComponentBuilder().build_component(self.controller_config) + if controller is None: + raise ValueError("wf_controller should provided, but get None") + + return controller + + def register_serializers(self, serializer_class_paths: List[str] = None): + self.register_default_serializers() + if serializer_class_paths: + for class_path in serializer_class_paths: + fobs.register(instantiate_class(class_path, {})) + + def register_default_serializers(self): + torch, flag = optional_import("torch") + if flag: + from nvflare.app_opt.pt.decomposers import TensorDecomposer + + fobs.register(TensorDecomposer) diff --git a/nvflare/app_common/wf_comm/decomposer_register.py b/nvflare/app_common/wf_comm/decomposer_register.py new file mode 100644 index 0000000000..c6f8ebac94 --- /dev/null +++ b/nvflare/app_common/wf_comm/decomposer_register.py @@ -0,0 +1,29 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + +from nvflare.apis.fl_component import FLComponent +from nvflare.fuel.utils.class_utils import instantiate_class +from nvflare.fuel.utils.fobs import fobs + + +class DecomposerRegister(FLComponent): + def __init__(self, decomposers: List[str]): + super(DecomposerRegister, self).__init__() + self.decomposers = decomposers + + def register(self): + for class_path in self.decomposers: + d = instantiate_class(class_path, init_params=None) + fobs.register(d) diff --git a/nvflare/app_common/wf_comm/wf_comm_api.py b/nvflare/app_common/wf_comm/wf_comm_api.py new file mode 100644 index 0000000000..0b9ffaea5b --- /dev/null +++ b/nvflare/app_common/wf_comm/wf_comm_api.py @@ -0,0 +1,229 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import threading +from typing import Callable, Dict, List, Optional + +from nvflare.apis.controller_spec import SendOrder +from nvflare.apis.fl_constant import ReturnCode +from nvflare.app_common.abstract.fl_model import FLModel +from nvflare.app_common.app_constant import AppConstants +from nvflare.app_common.wf_comm.wf_comm_api_spec import ( + DATA, + MIN_RESPONSES, + RESP_MAX_WAIT_TIME, + RESULT, + SITE_NAMES, + STATUS, + TARGET_SITES, + TASK_NAME, + WFCommAPISpec, +) +from nvflare.fuel.data_event.data_bus import DataBus +from nvflare.fuel.data_event.event_manager import EventManager +from nvflare.private.defs import CommConstants + + +class WFCommAPI(WFCommAPISpec): + def __init__(self, cid: str = ""): + self.meta = {SITE_NAMES: []} + self.logger = logging.getLogger(self.__class__.__name__) + + self.task_results = {} + self.task_result_lock = threading.Lock() + + data_bus = DataBus() + data_bus.subscribe(topics=[CommConstants.TASK_RESULT], callback=self.result_callback) + + self.event_manager = EventManager(data_bus) + self.comm = data_bus.get_data(cid + CommConstants.COMMUNICATOR) + self._check_inputs() + + def get_site_names(self): + return self.meta.get(SITE_NAMES) + + def broadcast_and_wait( + self, + task_name: str, + min_responses: int, + data: any, + meta: dict = None, + targets: Optional[List[str]] = None, + callback: Callable = None, + ) -> Dict[str, Dict[str, FLModel]]: + + meta = {} if meta is None else meta + msg_payload = self._prepare_input_payload(task_name, data, meta, min_responses, targets) + self.register_callback(callback) + self.comm.broadcast_to_peers_and_wait(msg_payload) + + if callback is None: + return self._get_results(task_name) + + def register_callback(self, callback): + if callback: + self.event_manager.data_bus.subscribe([CommConstants.POST_PROCESS_RESULT], callback) + + def send_and_wait( + self, + task_name: str, + min_responses: int, + data: any, + meta: dict = None, + send_order: SendOrder = SendOrder.SEQUENTIAL, + targets: Optional[List[str]] = None, + callback: Callable = None, + ) -> Dict[str, Dict[str, FLModel]]: + meta = {} if meta is None else meta + msg_payload = self._prepare_input_payload(task_name, data, meta, min_responses, targets) + + if callback is not None: + self.register_callback(callback) + + self.comm.send_to_peers_and_wait(msg_payload, send_order) + + if callback is None: + return self._get_results(task_name) + + def relay_and_wait( + self, + task_name: str, + min_responses: int, + data: any, + meta: dict = None, + targets: Optional[List[str]] = None, + relay_order: str = "sequential", + callback: Callable = None, + ) -> Dict[str, Dict[str, FLModel]]: + + meta = {} if meta is None else meta + msg_payload = self._prepare_input_payload(task_name, data, meta, min_responses, targets) + + self.register_callback(callback) + + self.comm.relay_to_peers_and_wait(msg_payload, SendOrder(relay_order)) + + if callback is None: + return self._get_results(task_name) + + def broadcast(self, task_name: str, data: any, meta: dict = None, targets: Optional[List[str]] = None): + msg_payload = self._prepare_input_payload(task_name, data, meta, min_responses=0, targets=targets) + self.comm.broadcast_to_peers(pay_load=msg_payload) + + def send( + self, + task_name: str, + data: any, + meta: dict = None, + targets: Optional[List[str]] = None, + send_order: str = "sequential", + ): + msg_payload = self._prepare_input_payload(task_name, data, meta, min_responses=0, targets=targets) + self.comm.send_to_peers(pay_load=msg_payload, send_order=send_order) + + def relay( + self, + task_name: str, + data: any, + meta: dict = None, + targets: Optional[List[str]] = None, + send_order: str = "sequential", + ): + msg_payload = self._prepare_input_payload(task_name, data, meta, min_responses=0, targets=targets) + self.comm.relay_to_peers(msg_payload, send_order) + + def _process_one_result(self, site_result) -> Dict[str, FLModel]: + self._check_result(site_result) + rc = site_result.get(STATUS) + if rc == ReturnCode.OK: + result = site_result.get(RESULT, {}) + site_name, data = next(iter(result.items())) + task_result = {site_name: data} + else: + msg = f"task failed with '{rc}' status" + raise RuntimeError(msg) + + return task_result + + def _get_results(self, task_name) -> Dict[str, Dict[str, FLModel]]: + batch_result: Dict = {} + site_results = self.task_results.get(task_name) + if not site_results: + raise RuntimeError(f"not result for given task {task_name}") + + for i in range(len(site_results)): + item = site_results[i] + one_result = self._process_one_result(item) + task_result = batch_result.get(task_name, {}) + task_result.update(one_result) + batch_result[task_name] = task_result + + with self.task_result_lock: + self.task_results[task_name] = [] + + return batch_result + + def _check_result(self, site_result): + + if site_result is None: + raise RuntimeError("expecting site_result to be dictionary, but get None") + + if not isinstance(site_result, dict): + raise RuntimeError(f"expecting site_result to be dictionary, but get '{type(site_result)}', {site_result=}") + + keys = [RESULT, STATUS] + all_keys_present = all(key in site_result for key in keys) + if not all_keys_present: + raise RuntimeError(f"expecting all keys {keys} present in site_result") + + def _check_inputs(self): + if self.comm is None: + raise RuntimeError("missing Controller") + + def result_callback(self, topic, data, data_bus): + if topic == CommConstants.TASK_RESULT: + task, site_result = next(iter(data.items())) + # fire event with process data + one_result = self._process_one_result(site_result) + self.event_manager.fire_event(CommConstants.POST_PROCESS_RESULT, {task: one_result}) + site_task_results = self.task_results.get(task, []) + site_task_results.append(site_result) + self.task_results[task] = site_task_results + + def _prepare_input_payload(self, task_name, data, meta, min_responses, targets): + + if data and isinstance(data, FLModel): + start_round = data.start_round + current_round = data.current_round + num_rounds = data.total_rounds + else: + start_round = meta.get(AppConstants.START_ROUND, 0) + current_round = meta.get(AppConstants.CURRENT_ROUND, 0) + num_rounds = meta.get(AppConstants.NUM_ROUNDS, 1) + + resp_max_wait_time = meta.get(RESP_MAX_WAIT_TIME, 15) + + msg_payload = { + TASK_NAME: task_name, + MIN_RESPONSES: min_responses, + RESP_MAX_WAIT_TIME: resp_max_wait_time, + AppConstants.CURRENT_ROUND: current_round, + AppConstants.NUM_ROUNDS: num_rounds, + AppConstants.START_ROUND: start_round, + DATA: data, + TARGET_SITES: targets, + } + return msg_payload diff --git a/nvflare/app_common/wf_comm/wf_comm_api_spec.py b/nvflare/app_common/wf_comm/wf_comm_api_spec.py new file mode 100644 index 0000000000..f30f92e15d --- /dev/null +++ b/nvflare/app_common/wf_comm/wf_comm_api_spec.py @@ -0,0 +1,208 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Callable, Dict, List, Optional + +SITE_NAMES = "SITE_NAMES" +TASK_NAME = "TASK_NAME" + +MIN_RESPONSES = "min_responses" +RESP_MAX_WAIT_TIME = "resp_max_wait_time" + +STATUS = "status" +RESULT = "result" +DATA = "data" +TARGET_SITES = "target_sizes" + + +class WFCommAPISpec(ABC): + @abstractmethod + def broadcast_and_wait( + self, + task_name: str, + min_responses: int, + data: any, + meta: dict = None, + targets: Optional[List[str]] = None, + callback: Callable = None, + ) -> Dict[str, any]: + """Communication interface for the blocking version of the 'broadcast' method. + + First, the task is scheduled for broadcast (see the broadcast method); + It then waits until the task is completed. + + Args: + task_name: the name of the task to be sent. + min_responses: the min number of responses expected. If 0, must get responses from + all clients that the task has been sent to. + data: the data to be sent in the task. + meta: the meta to be sent in the task. + targets: list of destination clients. If None, all clients. + callback: callback to be registered. + + Returns: + result dict if callback is None + """ + pass + + @abstractmethod + def send_and_wait( + self, + task_name: str, + min_responses: int, + data: any, + meta: dict = None, + targets: Optional[List[str]] = None, + send_order: str = "sequential", + callback: Callable = None, + ) -> Dict[str, any]: + """Communication interface for the blocking version of the 'send' method. + + First, the task is scheduled for send (see the 'send' method); + It then waits until the task is completed and returns the task completion status and collected result. + + Args: + task_name: the name of the task to be sent. + min_responses: the min number of responses expected. If 0, must get responses from + all clients that the task has been sent to. + data: the data to be sent in the task. + meta: the meta to be sent in the task. + targets: list of destination clients. + send_order: order for choosing the next client. + callback: callback to be registered. + + Returns: + result dict if callback is None + """ + pass + + @abstractmethod + def relay_and_wait( + self, + task_name: str, + min_responses: int, + data: any, + meta: dict = None, + targets: Optional[List[str]] = None, + relay_order: str = "sequential", + callback: Callable = None, + ) -> Dict[str, any]: + """Communication interface to schedule a task to be done sequentially by the clients in the targets list. This is a non-blocking call. + + Args: + task_name: the name of the task to be sent. + min_responses: the min number of responses expected. If 0, must get responses from + all clients that the task has been sent to. + data: the data to be sent in the task. + meta: the meta to be sent in the task. + targets: list of destination clients. If None, all clients. + relay_order: order for choosing the next client. + callback: callback to be registered. + + Returns: + result dict if callback is None + """ + pass + + @abstractmethod + def broadcast(self, task_name: str, data: any, meta: dict = None, targets: Optional[List[str]] = None): + """Communication interface to schedule to broadcast the task to specified targets. + + This is a non-blocking call. + + The task is standing until one of the following conditions comes true: + - if timeout is specified (> 0), and the task has been standing for more than the specified time + - the controller has received the specified min_responses results for this task, and all target clients + are done. + - the controller has received the specified min_responses results for this task, and has waited + for wait_time_after_min_received. + + While the task is standing: + - Before sending the task to a client, the before_task_sent CB (if specified) is called; + - When a result is received from a client, the result_received CB (if specified) is called; + + After the task is done, the task_done CB (if specified) is called: + - If result_received CB is specified, the 'result' in the ClientTask of each + client is produced by the result_received CB; + - Otherwise, the 'result' contains the original result submitted by the clients; + + NOTE: if the targets is None, the actual broadcast target clients will be dynamic, because the clients + could join/disconnect at any moment. While the task is standing, any client that joins automatically + becomes a target for this broadcast. + + Args: + task_name: the name of the task to be sent. + data: the data to be sent in the task. + meta: the meta to be sent in the task. + targets: list of destination clients. If None, all clients. + """ + pass + + @abstractmethod + def send( + self, + task_name: str, + data: any, + meta: dict = None, + targets: Optional[str] = None, + send_order: str = "sequential", + ): + """Communication interface to schedule to send the task to a single target client. + + This is a non-blocking call. + + In ANY order, the target client is the first target that asks for task. + In SEQUENTIAL order, the controller will try its best to send the task to the first client + in the targets list. If can't, it will try the next target, and so on. + + NOTE: if the 'targets' is None, the actual target clients will be dynamic, because the clients + could join/disconnect at any moment. While the task is standing, any client that joins automatically + becomes a target for this task. + + If the send_order is SEQUENTIAL, the targets must be a non-empty list of client names. + + Args: + task_name: the name of the task to be sent. + data: the data to be sent in the task. + meta: the meta to be sent in the task. + targets: list of destination clients. If None, all clients. + send_order: order for choosing the next client. + """ + pass + + @abstractmethod + def relay( + self, + task_name: str, + data: any, + meta: dict = None, + targets: Optional[List[str]] = None, + relay_order: str = "sequential", + ): + """Communication interface to schedule a task to be done sequentially by the clients in the targets list. This is a non-blocking call. + + Args: + task_name: the name of the task to be sent. + data: the data to be sent in the task. + meta: the meta to be sent in the task. + targets: list of destination clients. + relay_order: order for choosing the next client. + """ + pass + + @abstractmethod + def get_site_names(self) -> List[str]: + """Get list of site names.""" + pass diff --git a/nvflare/app_common/wf_comm/wf_communicator.py b/nvflare/app_common/wf_comm/wf_communicator.py new file mode 100644 index 0000000000..ed6d15d103 --- /dev/null +++ b/nvflare/app_common/wf_comm/wf_communicator.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.apis.fl_context import FLContext +from nvflare.apis.impl.controller import Controller +from nvflare.apis.signal import Signal +from nvflare.app_common.wf_comm.base_wf_communicator import BaseWFCommunicator + + +class WFCommunicator(BaseWFCommunicator, Controller): + def __init__(self): + super().__init__() + self.register_serializers() + + def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): + self.start_workflow(abort_signal, fl_ctx) diff --git a/nvflare/app_common/wf_comm/wf_communicator_spec.py b/nvflare/app_common/wf_comm/wf_communicator_spec.py new file mode 100644 index 0000000000..a198b31e92 --- /dev/null +++ b/nvflare/app_common/wf_comm/wf_communicator_spec.py @@ -0,0 +1,81 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Dict, Optional + +from nvflare.apis.controller_spec import SendOrder + + +class WFCommunicatorSpec(ABC): + def __init__(self): + self.controller_config: Optional[Dict] = None + + @abstractmethod + def broadcast_to_peers_and_wait(self, pay_load: Dict): + """Convert pay_load and call Controller's 'broadcast_and_wait' method. + + Args: + pay_load: the name of the task to be sent. + """ + pass + + @abstractmethod + def broadcast_to_peers(self, pay_load: Dict): + """Convert pay_load and call Controller's 'broadcast' method. + + Args: + pay_load: the name of the task to be sent. + """ + pass + + @abstractmethod + def send_to_peers(self, pay_load: Dict, send_order: SendOrder = SendOrder.SEQUENTIAL): + """Convert pay_load and call Controller's 'send' method. + + Args: + pay_load: the name of the task to be sent. + send_order: order for choosing the next client. + """ + pass + + @abstractmethod + def send_to_peers_and_wait(self, pay_load: Dict, send_order: SendOrder = SendOrder.SEQUENTIAL): + """Convert pay_load and call Controller's 'send_and_wait' method. + + Args: + pay_load: the name of the task to be sent. + send_order: order for choosing the next client. + """ + pass + + @abstractmethod + def relay_to_peers_and_wait(self, pay_load: Dict, send_order: SendOrder = SendOrder.SEQUENTIAL): + """Convert pay_load and call Controller's 'relay_and_wait' method. + + Args: + pay_load: the name of the task to be sent. + send_order: order for choosing the next client. + """ + pass + + @abstractmethod + def relay_to_peers(self, pay_load: Dict, send_order: SendOrder = SendOrder.SEQUENTIAL): + """Convert pay_load and call Controller's 'relay' method. + + Args: + pay_load: the name of the task to be sent. + send_order: order for choosing the next client. + """ + pass diff --git a/nvflare/fuel/utils/component_builder.py b/nvflare/fuel/utils/component_builder.py index ca6fc48847..ac39a71ad6 100644 --- a/nvflare/fuel/utils/component_builder.py +++ b/nvflare/fuel/utils/component_builder.py @@ -67,23 +67,46 @@ def build_component(self, config_dict): return None class_args = config_dict.get("args", dict()) - for k, v in class_args.items(): - if isinstance(v, dict) and self.is_class_config(v): - # try to replace the arg with a component - try: - t = self.build_component(v) - class_args[k] = t - except Exception as e: - raise ValueError(f"failed to instantiate class: {secure_format_exception(e)} ") - - class_path = self.get_class_path(config_dict) + lazy_instantiate = config_dict.get("lazy_instantiate", False) + if not lazy_instantiate: + for k, v in class_args.items(): + if isinstance(v, dict) and self.is_class_config(v): + # try to replace the arg with a component + try: + t = self.build_component(v) + class_args[k] = t + except Exception as e: + raise ValueError(f"failed to instantiate class: {secure_format_exception(e)} ") + + class_path = None + if self.is_class_config(config_dict): + class_path = self.get_class_path(config_dict) # Handle the special case, if config pass in the class_attributes, use the user defined class attributes # parameters directly. - if "class_attributes" in class_args: - class_args = class_args["class_attributes"] + if class_path and not lazy_instantiate: + if "class_attributes" in class_args: + class_args = class_args["class_attributes"] - return instantiate_class(class_path, class_args) + return instantiate_class(class_path, class_args) + else: + comp_dict = {} + lazy_instantiate = config_dict.get("lazy_instantiate", False) + if not lazy_instantiate: + for k, v in config_dict.items(): + if isinstance(v, dict) and self.is_class_config(v): + # try to replace the arg with a component + try: + t = self.build_component(v) + comp_dict[k] = t + except Exception as e: + raise ValueError(f"failed to instantiate class: {secure_format_exception(e)} ") + else: + comp_dict[k] = v + else: + comp_dict = config_dict + + return comp_dict def get_class_path(self, config_dict): if "path" in config_dict.keys(): diff --git a/nvflare/private/defs.py b/nvflare/private/defs.py index 6ae5c7fd8c..9318866016 100644 --- a/nvflare/private/defs.py +++ b/nvflare/private/defs.py @@ -181,6 +181,13 @@ class CellMessageHeaderKeys: ABORT_JOBS = "abort_jobs" +class CommConstants(object): + COMMUNICATOR = "communicator" + CONTROLLER = "controller" + TASK_RESULT = "TASK_RESULT" + POST_PROCESS_RESULT = "POST_PROCESS_RESULT" + + class JobFailureMsgKey: JOB_ID = "job_id" diff --git a/nvflare/private/fed/server/server_json_config.py b/nvflare/private/fed/server/server_json_config.py index 5685a2e456..a97e982c28 100644 --- a/nvflare/private/fed/server/server_json_config.py +++ b/nvflare/private/fed/server/server_json_config.py @@ -17,9 +17,16 @@ from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import SystemConfigs, SystemVarName from nvflare.apis.responder import Responder +from nvflare.apis.wf_controller import WFController +from nvflare.app_common.wf_comm.wf_communicator import WFCommunicator +from nvflare.app_common.wf_comm.wf_communicator_spec import WFCommunicatorSpec +from nvflare.fuel.data_event.data_bus import DataBus from nvflare.fuel.utils.argument_utils import parse_vars +from nvflare.fuel.utils.class_utils import instantiate_class +from nvflare.fuel.utils.component_builder import ComponentBuilder from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.json_scanner import Node +from nvflare.private.defs import CommConstants from nvflare.private.fed_json_config import FedJsonConfigurator from nvflare.private.json_configer import ConfigContext, ConfigError @@ -30,15 +37,17 @@ class WorkFlow: - def __init__(self, id, responder: Responder): + def __init__(self, id, responder: Responder, wf_controller=None): """Workflow is a responder with ID. Args: id: identification - responder (Responder): A responder + responder (Responder): A responder or communicator + wf_controller: federated learning wf_controller. If None, the responder will implement the wf_controller """ self.id = id self.responder = responder + self.wf_controller = wf_controller class ServerJsonConfigurator(FedJsonConfigurator): @@ -124,15 +133,10 @@ def process_config_element(self, config_ctx: ConfigContext, node: Node): return if re.search(r"^workflows\.#[0-9]+$", path): - workflow = self.authorize_and_build_component(element, config_ctx, node) - if not isinstance(workflow, Responder): - raise ConfigError( - '"workflow" must be a Responder or Controller object, but got {}'.format(type(workflow)) - ) + element = self.enhance_workflow_config(element) + component = self.authorize_and_build_component(element, config_ctx, node) cid = element.get("id", None) - if not cid: - cid = type(workflow).__name__ if not isinstance(cid, str): raise ConfigError('"id" must be str but got {}'.format(type(cid))) @@ -143,10 +147,71 @@ def process_config_element(self, config_ctx: ConfigContext, node: Node): if cid in self.components: raise ConfigError('duplicate component id "{}"'.format(cid)) - self.workflows.append(WorkFlow(cid, workflow)) - self.components[cid] = workflow + responder = self.get_responder(component, cid) + + workflow = WorkFlow(cid, responder) + self.workflows.append(workflow) + self.components[cid] = responder return + def enhance_workflow_config(self, element: dict): + if CommConstants.CONTROLLER in element: + controller_config = element.get(CommConstants.CONTROLLER) + controller_config["lazy_instantiate"] = True + element[CommConstants.CONTROLLER] = controller_config + elif CommConstants.COMMUNICATOR in element: + wf_config = element.copy() + comm_config = wf_config.pop(CommConstants.COMMUNICATOR) + controller_config = wf_config + controller_config["lazy_instantiate"] = True + id = controller_config.pop("id") + element = {"id": id, CommConstants.COMMUNICATOR: comm_config, CommConstants.CONTROLLER: controller_config} + elif isinstance(instantiate_class(self.get_class_path(element), element.get("args", dict())), WFController): + controller_config = element.copy() + controller_config["lazy_instantiate"] = True + id = controller_config.pop("id") + element = {"id": id, CommConstants.CONTROLLER: controller_config} + + return element + + def get_responder(self, component, cid): + if isinstance(component, dict): + wf_config = component + communicator = wf_config.get(CommConstants.COMMUNICATOR) + if communicator is None: + communicator = WFCommunicator() + + if isinstance(communicator, WFCommunicatorSpec): + controller_config = wf_config.get(CommConstants.CONTROLLER) + controller_config["lazy_instantiate"] = False + communicator.set_controller_config(controller_config) + data_bus = DataBus() + data_bus.put_data(cid + CommConstants.COMMUNICATOR, communicator) + responder = communicator + else: + responder = component + + if not isinstance(responder, Responder): + raise ConfigError( + '"workflow" must be a Responder or Controller or has a Responder object, but got {}'.format( + type(responder) + ) + ) + return responder + + def get_wf_controller(self, wf_config): + wf_controller_comp = wf_config.get("wf_controller") + wf_controller_comp["lazy_instantiate"] = False + if isinstance(wf_controller_comp, dict): + wf_controller = ComponentBuilder().build_component(wf_controller_comp) + else: + wf_controller = wf_controller_comp + + if wf_controller is None: + raise ValueError("wf_controller should provided, but get None") + + return wf_controller + def _get_all_workflows_ids(self): ids = [] for t in self.workflows: diff --git a/tests/unit_test/app_common/utils/math_utils_test.py b/tests/unit_test/app_common/utils/math_utils_test.py new file mode 100644 index 0000000000..dc1ad768bc --- /dev/null +++ b/tests/unit_test/app_common/utils/math_utils_test.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import operator + +import pytest + +from nvflare.app_common.utils.math_utils import parse_compare_criteria + +TEST_CASES = [ + ("accuracy >= 50", ("accuracy", 50, operator.ge)), + ("accuracy <= 50", ("accuracy", 50, operator.le)), + ("accuracy > 50", ("accuracy", 50, operator.gt)), + ("accuracy < 50", ("accuracy", 50, operator.lt)), + ("accuracy = 50", ("accuracy", 50, operator.eq)), + ("loss < 0.1", ("loss", 0.1, operator.lt)), + ("50 >= 50", ("50", 50, operator.ge)), +] + +INVALID_TEST_CASES = [ + ("50 >= accuracy", None), + ("accuracy >== 50", None), + ("accuracy >= accuracy", None), + (50, None), +] + + +class TestMathUtils: + @pytest.mark.parametrize("compare_expr,compare_tuple", TEST_CASES + INVALID_TEST_CASES) + def test_parse_compare_criteria(self, compare_expr, compare_tuple): + if compare_tuple is None: + with pytest.raises(Exception): + result_tuple = parse_compare_criteria(compare_expr) + else: + result_tuple = parse_compare_criteria(compare_expr) + assert result_tuple == compare_tuple