From e32943915c5f9308f4287ebb74e9dea7494ad568 Mon Sep 17 00:00:00 2001 From: Yan Cheng Date: Wed, 20 Mar 2024 16:38:35 -0400 Subject: [PATCH 1/3] support av ipc agent --- nvflare/apis/dxo.py | 3 +- nvflare/apis/fl_constant.py | 1 + nvflare/app_common/app_defined/__init__.py | 13 + nvflare/app_common/app_defined/aggregator.py | 75 +++ .../app_common/app_defined/component_base.py | 87 ++++ .../app_common/app_defined/model_persistor.py | 57 +++ .../app_defined/shareable_generator.py | 94 ++++ .../decomposers/common_decomposers.py | 50 +- .../decomposers/numpy_decomposers.py | 79 +++ nvflare/app_common/executors/ipc_exchanger.py | 425 +++++++++++++++++ nvflare/client/__init__.py | 1 + nvflare/client/ipc/__init__.py | 13 + nvflare/client/ipc/defs.py | 106 +++++ nvflare/client/ipc/ipc_agent.py | 448 ++++++++++++++++++ nvflare/fuel/f3/cellnet/core_cell.py | 9 +- 15 files changed, 1408 insertions(+), 53 deletions(-) create mode 100644 nvflare/app_common/app_defined/__init__.py create mode 100644 nvflare/app_common/app_defined/aggregator.py create mode 100644 nvflare/app_common/app_defined/component_base.py create mode 100644 nvflare/app_common/app_defined/model_persistor.py create mode 100644 nvflare/app_common/app_defined/shareable_generator.py create mode 100644 nvflare/app_common/decomposers/numpy_decomposers.py create mode 100644 nvflare/app_common/executors/ipc_exchanger.py create mode 100644 nvflare/client/ipc/__init__.py create mode 100644 nvflare/client/ipc/defs.py create mode 100644 nvflare/client/ipc/ipc_agent.py diff --git a/nvflare/apis/dxo.py b/nvflare/apis/dxo.py index 90b9cc5819..f03c4bff78 100644 --- a/nvflare/apis/dxo.py +++ b/nvflare/apis/dxo.py @@ -29,6 +29,7 @@ class DataKind(object): COLLECTION = "COLLECTION" # Dict or List of DXO objects STATISTICS = "STATISTICS" PSI = "PSI" + APP_DEFINED = "APP_DEFINED" # data format is app defined class MetaKey(FLMetaKey): @@ -128,7 +129,7 @@ def validate(self) -> str: if self.data is None: return "missing data" - if not isinstance(self.data, dict): + if self.data_kind != DataKind.APP_DEFINED and not isinstance(self.data, dict): return "invalid data: expect dict but got {}".format(type(self.data)) if self.meta is not None and not isinstance(self.meta, dict): diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index 325050aac3..4da4e7ab71 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -41,6 +41,7 @@ class ReturnCode(object): UNSAFE_JOB = "UNSAFE_JOB" SERVER_NOT_READY = "SERVER_NOT_READY" SERVICE_UNAVAILABLE = "SERVICE_UNAVAILABLE" + EARLY_TERMINATION = "EARLY_TERMINATION" class MachineStatus(Enum): diff --git a/nvflare/app_common/app_defined/__init__.py b/nvflare/app_common/app_defined/__init__.py new file mode 100644 index 0000000000..4fc50543f1 --- /dev/null +++ b/nvflare/app_common/app_defined/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/app_common/app_defined/aggregator.py b/nvflare/app_common/app_defined/aggregator.py new file mode 100644 index 0000000000..3f33f42709 --- /dev/null +++ b/nvflare/app_common/app_defined/aggregator.py @@ -0,0 +1,75 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Any + +from nvflare.apis.dxo import DXO, DataKind, from_shareable +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable +from nvflare.app_common.abstract.aggregator import Aggregator +from nvflare.app_common.abstract.model import ModelLearnableKey +from nvflare.app_common.app_constant import AppConstants +from nvflare.app_common.app_event_type import AppEventType + +from .component_base import ComponentBase + + +class AppDefinedAggregator(Aggregator, ComponentBase, ABC): + def __init__(self): + Aggregator.__init__(self) + ComponentBase.__init__(self) + self.current_round = None + self.base_model_obj = None + + def handle_event(self, event_type, fl_ctx: FLContext): + if event_type == AppEventType.ROUND_STARTED: + self.fl_ctx = fl_ctx + self.current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND) + base_model_learnable = fl_ctx.get_prop(AppConstants.GLOBAL_MODEL) + if isinstance(base_model_learnable, dict): + self.base_model_obj = base_model_learnable.get(ModelLearnableKey.WEIGHTS) + self.reset() + + @abstractmethod + def reset(self): + pass + + @abstractmethod + def processing_training_result(self, client_name: str, trained_weights: Any, trained_meta: dict) -> bool: + pass + + @abstractmethod + def aggregate_training_result(self) -> (Any, dict): + pass + + def accept(self, shareable: Shareable, fl_ctx: FLContext) -> bool: + dxo = from_shareable(shareable) + trained_weights = dxo.data + trained_meta = dxo.meta + self.fl_ctx = fl_ctx + peer_ctx = fl_ctx.get_peer_context() + client_name = peer_ctx.get_identity_name() + return self.processing_training_result(client_name, trained_weights, trained_meta) + + def aggregate(self, fl_ctx: FLContext) -> Shareable: + self.fl_ctx = fl_ctx + aggregated_result, aggregated_meta = self.aggregate_training_result() + dxo = DXO( + data_kind=DataKind.APP_DEFINED, + data=aggregated_result, + meta=aggregated_meta, + ) + self.debug(f"learnable_to_shareable: {dxo.data}") + return dxo.to_shareable() diff --git a/nvflare/app_common/app_defined/component_base.py b/nvflare/app_common/app_defined/component_base.py new file mode 100644 index 0000000000..1c4b32e285 --- /dev/null +++ b/nvflare/app_common/app_defined/component_base.py @@ -0,0 +1,87 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.apis.fl_component import FLComponent + + +class ComponentBase(FLComponent): + def __init__(self): + FLComponent.__init__(self) + self.fl_ctx = None + + def debug(self, msg: str): + """Convenience method for logging an DEBUG message with contextual info + + Args: + msg: the message to be logged + + Returns: + + """ + self.log_debug(self.fl_ctx, msg) + + def info(self, msg: str): + """Convenience method for logging an INFO message with contextual info + + Args: + msg: the message to be logged + + Returns: + + """ + self.log_info(self.fl_ctx, msg) + + def error(self, msg: str): + """Convenience method for logging an ERROR message with contextual info + + Args: + msg: the message to be logged + + Returns: + + """ + self.log_error(self.fl_ctx, msg) + + def warning(self, msg: str): + """Convenience method for logging a WARNING message with contextual info + + Args: + msg: the message to be logged + + Returns: + + """ + self.log_warning(self.fl_ctx, msg) + + def exception(self, msg: str): + """Convenience method for logging an EXCEPTION message with contextual info + + Args: + msg: the message to be logged + + Returns: + + """ + self.log_exception(self.fl_ctx, msg) + + def critical(self, msg: str): + """Convenience method for logging a CRITICAL message with contextual info + + Args: + msg: the message to be logged + + Returns: + + """ + self.log_critical(self.fl_ctx, msg) diff --git a/nvflare/app_common/app_defined/model_persistor.py b/nvflare/app_common/app_defined/model_persistor.py new file mode 100644 index 0000000000..62e747ab35 --- /dev/null +++ b/nvflare/app_common/app_defined/model_persistor.py @@ -0,0 +1,57 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Any + +from nvflare.apis.fl_context import FLContext +from nvflare.app_common.abstract.model import ModelLearnable, ModelLearnableKey, make_model_learnable +from nvflare.app_common.abstract.model_persistor import ModelPersistor + +from .component_base import ComponentBase + + +class AppDefinedModelPersistor(ModelPersistor, ComponentBase, ABC): + def __init__(self): + ModelPersistor.__init__(self) + ComponentBase.__init__(self) + + @abstractmethod + def read_model(self) -> Any: + """Load model object. + + Returns: a model object + """ + pass + + @abstractmethod + def write_model(self, model_obj: Any): + """Save the model object + + Args: + model_obj: the model object to be saved + + Returns: None + + """ + pass + + def load_model(self, fl_ctx: FLContext) -> ModelLearnable: + self.fl_ctx = fl_ctx + model = self.read_model() + return make_model_learnable(weights=model, meta_props={}) + + def save_model(self, learnable: ModelLearnable, fl_ctx: FLContext): + self.fl_ctx = fl_ctx + self.write_model(learnable.get(ModelLearnableKey.WEIGHTS)) diff --git a/nvflare/app_common/app_defined/shareable_generator.py b/nvflare/app_common/app_defined/shareable_generator.py new file mode 100644 index 0000000000..e3d4d2907a --- /dev/null +++ b/nvflare/app_common/app_defined/shareable_generator.py @@ -0,0 +1,94 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Any + +from nvflare.apis.dxo import DXO, DataKind, from_shareable +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable +from nvflare.app_common.abstract.learnable import Learnable +from nvflare.app_common.abstract.model import ModelLearnable, ModelLearnableKey, make_model_learnable +from nvflare.app_common.abstract.shareable_generator import ShareableGenerator +from nvflare.app_common.app_constant import AppConstants + +from .component_base import ComponentBase + + +class AppDefinedShareableGenerator(ShareableGenerator, ComponentBase, ABC): + def __init__(self): + ShareableGenerator.__init__(self) + ComponentBase.__init__(self) + self.current_round = None + + @abstractmethod + def model_to_trainable(self, model_obj: Any) -> (Any, dict): + """Convert the model weights and meta to a format that can be sent to clients to do training + + Args: + model_obj: model object + + Returns: a tuple of (weights, meta) + + The returned weights and meta will be for training and serializable + """ + pass + + @abstractmethod + def apply_weights_to_model(self, model_obj: Any, weights: Any, meta: dict) -> Any: + """Apply trained weights and meta to the base model + + Args: + model_obj: base model object that weights will be applied to + weights: trained weights + meta: trained meta + + Returns: the updated model object + + """ + pass + + def learnable_to_shareable(self, learnable: Learnable, fl_ctx: FLContext) -> Shareable: + self.fl_ctx = fl_ctx + self.current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND) + self.debug(f"{learnable=}") + base_model_obj = learnable.get(ModelLearnableKey.WEIGHTS) + trainable_weights, trainable_meta = self.model_to_trainable(base_model_obj) + self.debug(f"trainable weights: {trainable_weights}") + dxo = DXO( + data_kind=DataKind.APP_DEFINED, + data=trainable_weights, + meta=trainable_meta, + ) + self.debug(f"learnable_to_shareable: {dxo.data}") + return dxo.to_shareable() + + def shareable_to_learnable(self, shareable: Shareable, fl_ctx: FLContext) -> Learnable: + self.fl_ctx = fl_ctx + self.current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND) + base_model_learnable = fl_ctx.get_prop(AppConstants.GLOBAL_MODEL) + + if not base_model_learnable: + self.system_panic(reason="No global base model!", fl_ctx=fl_ctx) + return base_model_learnable + + if not isinstance(base_model_learnable, ModelLearnable): + raise ValueError(f"expect global model to be ModelLearnable but got {type(base_model_learnable)}") + base_model_obj = base_model_learnable.get(ModelLearnableKey.WEIGHTS) + + dxo = from_shareable(shareable) + trained_weights = dxo.data + trained_meta = dxo.meta + model_obj = self.apply_weights_to_model(model_obj=base_model_obj, weights=trained_weights, meta=trained_meta) + return make_model_learnable(model_obj, {}) diff --git a/nvflare/app_common/decomposers/common_decomposers.py b/nvflare/app_common/decomposers/common_decomposers.py index b30a8a13ec..1b9f7514f9 100644 --- a/nvflare/app_common/decomposers/common_decomposers.py +++ b/nvflare/app_common/decomposers/common_decomposers.py @@ -14,19 +14,15 @@ """Decomposers for types from app_common and Machine Learning libraries.""" import os -from abc import ABC -from io import BytesIO from typing import Any -import numpy as np - from nvflare.app_common.abstract.fl_model import FLModel from nvflare.app_common.abstract.learnable import Learnable from nvflare.app_common.abstract.model import ModelLearnable from nvflare.app_common.widgets.event_recorder import _CtxPropReq, _EventReq, _EventStats from nvflare.fuel.utils import fobs from nvflare.fuel.utils.fobs.datum import DatumManager -from nvflare.fuel.utils.fobs.decomposer import Decomposer, DictDecomposer, Externalizer, Internalizer +from nvflare.fuel.utils.fobs.decomposer import DictDecomposer, Externalizer, Internalizer class FLModelDecomposer(fobs.Decomposer): @@ -60,50 +56,6 @@ def recompose(self, data: tuple, manager: DatumManager = None) -> FLModel: ) -class NumpyScalarDecomposer(fobs.Decomposer, ABC): - """Decomposer base class for all numpy types with item method.""" - - def decompose(self, target: Any, manager: DatumManager = None) -> Any: - return target.item() - - def recompose(self, data: Any, manager: DatumManager = None) -> np.ndarray: - return self.supported_type()(data) - - -class Float64ScalarDecomposer(NumpyScalarDecomposer): - def supported_type(self): - return np.float64 - - -class Float32ScalarDecomposer(NumpyScalarDecomposer): - def supported_type(self): - return np.float32 - - -class Int64ScalarDecomposer(NumpyScalarDecomposer): - def supported_type(self): - return np.int64 - - -class Int32ScalarDecomposer(NumpyScalarDecomposer): - def supported_type(self): - return np.int32 - - -class NumpyArrayDecomposer(Decomposer): - def supported_type(self): - return np.ndarray - - def decompose(self, target: np.ndarray, manager: DatumManager = None) -> Any: - stream = BytesIO() - np.save(stream, target, allow_pickle=False) - return stream.getvalue() - - def recompose(self, data: Any, manager: DatumManager = None) -> np.ndarray: - stream = BytesIO(data) - return np.load(stream, allow_pickle=False) - - def register(): if register.registered: return diff --git a/nvflare/app_common/decomposers/numpy_decomposers.py b/nvflare/app_common/decomposers/numpy_decomposers.py new file mode 100644 index 0000000000..1c0e5092ff --- /dev/null +++ b/nvflare/app_common/decomposers/numpy_decomposers.py @@ -0,0 +1,79 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Decomposers for types from app_common and Machine Learning libraries.""" +import os +from abc import ABC +from io import BytesIO +from typing import Any + +import numpy as np + +from nvflare.fuel.utils import fobs +from nvflare.fuel.utils.fobs.datum import DatumManager + + +class NumpyScalarDecomposer(fobs.Decomposer, ABC): + """Decomposer base class for all numpy types with item method.""" + + def decompose(self, target: Any, manager: DatumManager = None) -> Any: + return target.item() + + def recompose(self, data: Any, manager: DatumManager = None) -> np.ndarray: + return self.supported_type()(data) + + +class Float64ScalarDecomposer(NumpyScalarDecomposer): + def supported_type(self): + return np.float64 + + +class Float32ScalarDecomposer(NumpyScalarDecomposer): + def supported_type(self): + return np.float32 + + +class Int64ScalarDecomposer(NumpyScalarDecomposer): + def supported_type(self): + return np.int64 + + +class Int32ScalarDecomposer(NumpyScalarDecomposer): + def supported_type(self): + return np.int32 + + +class NumpyArrayDecomposer(fobs.Decomposer): + def supported_type(self): + return np.ndarray + + def decompose(self, target: np.ndarray, manager: DatumManager = None) -> Any: + stream = BytesIO() + np.save(stream, target, allow_pickle=False) + return stream.getvalue() + + def recompose(self, data: Any, manager: DatumManager = None) -> np.ndarray: + stream = BytesIO(data) + return np.load(stream, allow_pickle=False) + + +def register(): + if register.registered: + return + + fobs.register_folder(os.path.dirname(__file__), __package__) + + register.registered = True + + +register.registered = False diff --git a/nvflare/app_common/executors/ipc_exchanger.py b/nvflare/app_common/executors/ipc_exchanger.py new file mode 100644 index 0000000000..6389ec448e --- /dev/null +++ b/nvflare/app_common/executors/ipc_exchanger.py @@ -0,0 +1,425 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import time +from typing import Union + +from nvflare.apis.dxo import DXO, DataKind, from_shareable +from nvflare.apis.event_type import EventType +from nvflare.apis.executor import Executor +from nvflare.apis.fl_constant import FLContextKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import ReturnCode, Shareable, make_reply +from nvflare.apis.signal import Signal +from nvflare.app_common.app_constant import AppConstants +from nvflare.client.ipc import defs +from nvflare.fuel.f3.cellnet.cell import Cell, Message, MessageHeaderKey +from nvflare.fuel.f3.cellnet.cell import ReturnCode as CellReturnCode +from nvflare.fuel.f3.cellnet.utils import make_reply as make_cell_reply +from nvflare.security.logging import secure_format_traceback + +_SHORT_SLEEP_TIME = 0.2 + + +class _TaskContext: + def __init__(self, task_name: str, task_id: str, fl_ctx: FLContext): + self.task_id = task_id + self.task_name = task_name + self.fl_ctx = fl_ctx + self.send_rc = None + self.result_rc = None + self.result_error = None + self.result = None + self.result_received_time = None + self.result_waiter = threading.Event() + + def __str__(self): + return f"'{self.task_name} {self.task_id}'" + + +class IPCExchanger(Executor): + def __init__( + self, + send_task_timeout=5.0, + resend_task_interval=2.0, + agent_connection_timeout=60.0, + agent_heartbeat_timeout=None, + agent_heartbeat_interval=5.0, + agent_ack_timeout=5.0, + agent_id=None, + ): + """Constructor of IPCExchanger + + Args: + send_task_timeout: when sending task to Agent, how long to wait for response + resend_task_interval: when failed to send task to agent, how often to resend + agent_heartbeat_timeout: time allowed to miss heartbeat ack from agent before stopping + agent_connection_timeout: time allowed to miss heartbeat ack from agent for considering it disconnected + agent_heartbeat_interval: how often to send heartbeats to the agent + agent_ack_timeout: how long to wait for agent ack (for heartbeat and bye messages) + agent_id: the unique ID of the agent. If not specified, will get it from job's meta + """ + Executor.__init__(self) + self.flare_agent_fqcn = None + self.agent_ack_timeout = agent_ack_timeout + self.agent_heartbeat_interval = agent_heartbeat_interval + self.agent_heartbeat_timeout = agent_heartbeat_timeout + self.agent_connection_timeout = agent_connection_timeout + self.send_task_timeout = send_task_timeout + self.resend_task_interval = resend_task_interval + self.agent_id = agent_id + self.last_agent_ack_time = time.time() + self.engine = None + self.cell = None + self.is_done = False + self.is_connected = False + self.task_ctx = None + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + self.engine = fl_ctx.get_engine() + self.cell = self.engine.get_cell() + + self.cell.register_request_cb( + channel=defs.CHANNEL, + topic=defs.TOPIC_SUBMIT_RESULT, + cb=self._receive_result, + ) + + # get meta + if not self.agent_id: + agent_id = None + meta = fl_ctx.get_prop(FLContextKey.JOB_META) + if isinstance(meta, dict): + agent_id = meta.get(defs.JOB_META_KEY_AGENT_ID) + + if not agent_id: + self.system_panic(reason=f"missing {defs.JOB_META_KEY_AGENT_ID} from job meta", fl_ctx=fl_ctx) + return + + if not isinstance(agent_id, str): + self.system_panic( + reason=f"invalid {defs.JOB_META_KEY_AGENT_ID} from job meta: {agent_id}. " + f"must be str but got {type(agent_id)}", + fl_ctx=fl_ctx, + ) + return + + self.agent_id = agent_id + + client_name = fl_ctx.get_identity_name() + self.flare_agent_fqcn = defs.agent_site_fqcn(client_name, self.agent_id) + self.log_info(fl_ctx, f"Flare Agent FQCN: {self.flare_agent_fqcn}") + t = threading.Thread(target=self._monitor, daemon=True) + t.start() + elif event_type == EventType.END_RUN: + self.is_done = True + self._say_goodbye() + + def _say_goodbye(self): + # say goodbye to agent + self.logger.info(f"job done - say goodbye to {self.flare_agent_fqcn}") + reply = self.cell.send_request( + channel=defs.CHANNEL, + topic=defs.TOPIC_BYE, + target=self.flare_agent_fqcn, + request=Message(), + optional=True, + timeout=self.agent_ack_timeout, + ) + if reply: + rc = reply.get_header(MessageHeaderKey.RETURN_CODE) + if rc != CellReturnCode.OK: + self.logger.warning(f"return code from agent {self.flare_agent_fqcn} for bye: {rc}") + + def _monitor(self): + # try to connect the flare agent + self.logger.info(f"waiting for flare agent {self.flare_agent_fqcn} ...") + assert isinstance(self.cell, Cell) + + last_hb_time = 0 + while True: + if self.is_done: + return + + if time.time() - last_hb_time > self.agent_heartbeat_interval: + reply = self.cell.send_request( + channel=defs.CHANNEL, + topic=defs.TOPIC_HEARTBEAT, + target=self.flare_agent_fqcn, + request=Message(), + timeout=self.agent_ack_timeout, + optional=True, + ) + last_hb_time = time.time() + rc = reply.get_header(MessageHeaderKey.RETURN_CODE) + if rc == CellReturnCode.OK: + self.last_agent_ack_time = last_hb_time + if not self.is_connected: + self.logger.info(f"agent {self.flare_agent_fqcn} connected") + self.is_connected = True + else: + since_last_ack = last_hb_time - self.last_agent_ack_time + if since_last_ack > self.agent_connection_timeout: + if self.is_connected: + self.logger.warning( + f"agent {self.flare_agent_fqcn} disconnected: " + f"no heartbeat for {self.agent_connection_timeout} secs" + ) + self.is_connected = False + + if self.agent_heartbeat_timeout and since_last_ack > self.agent_heartbeat_timeout: + self.is_done = True + with self.engine.new_context() as fl_ctx: + self.system_panic( + f"agent {self.flare_agent_fqcn} is dead: " + f"no heartbeat for {self.agent_heartbeat_timeout} secs", + fl_ctx=fl_ctx, + ) + return + + # sleep only small amount of time, so we can check other conditions frequently + time.sleep(_SHORT_SLEEP_TIME) + + def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + task_id = shareable.get_header(key=FLContextKey.TASK_ID) + current_task = self.task_ctx + if current_task: + # still working on previous task! + self.log_error(fl_ctx, f"got new task {task_name=} {task_id=} while still working on {current_task}") + return make_reply(ReturnCode.BAD_REQUEST_DATA) + + # wait for flare agent + while True: + if self.is_done or abort_signal.triggered: + return make_reply(ReturnCode.TASK_ABORTED) + + if self.is_connected: + break + else: + time.sleep(_SHORT_SLEEP_TIME) + + self.task_ctx = _TaskContext(task_name, task_id, fl_ctx) + result = self._do_execute(task_name, shareable, fl_ctx, abort_signal) + self.task_ctx = None + return result + + def _send_task(self, task_ctx: _TaskContext, msg, abort_signal): + # keep sending until done + fl_ctx = task_ctx.fl_ctx + task_name = task_ctx.task_name + task_id = task_ctx.task_id + task_ctx.send_rc = ReturnCode.OK + last_send_time = 0 + + while True: + if self.is_done or abort_signal.triggered: + self.log_info(fl_ctx, "task aborted - ask agent to abort the task") + + # it's possible that the agent may have already received the task + # we ask it to abort the task. + self._ask_agent_to_abort_task(task_name, task_id) + task_ctx.send_rc = ReturnCode.TASK_ABORTED + return + + if task_ctx.result_received_time: + # the result has been received + # this could happen only when we thought the previous send didn't succeed, but it actually did! + return + + if self.is_connected and time.time() - last_send_time > self.resend_task_interval: + self.log_info(fl_ctx, f"try to send task to {self.flare_agent_fqcn}") + start = time.time() + reply = self.cell.send_request( + channel=defs.CHANNEL, + topic=defs.TOPIC_GET_TASK, + request=msg, + target=self.flare_agent_fqcn, + timeout=self.send_task_timeout, + ) + last_send_time = time.time() + rc = reply.get_header(MessageHeaderKey.RETURN_CODE) + if rc == CellReturnCode.OK: + self.log_info(fl_ctx, f"Sent task to {self.flare_agent_fqcn} in {time.time() - start} secs") + return + elif rc == CellReturnCode.INVALID_REQUEST: + self.log_error(fl_ctx, f"Task rejected by {self.flare_agent_fqcn}: {rc}") + task_ctx.send_rc = ReturnCode.BAD_REQUEST_DATA + return + else: + self.log_error( + fl_ctx, + f"Failed to send task to {self.flare_agent_fqcn}: {rc}. " + "Will retry in {self.resend_task_interval} secs", + ) + time.sleep(_SHORT_SLEEP_TIME) + + def _do_execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + try: + dxo = from_shareable(shareable) + except: + self.log_error(fl_ctx, f"Unable to extract dxo from shareable: {secure_format_traceback()}") + return make_reply(ReturnCode.BAD_TASK_DATA) + + # send to flare agent + is_app_defined = False + task_ctx = self.task_ctx + task_id = task_ctx.task_id + data = dxo.data + if dxo.data_kind == DataKind.APP_DEFINED: + is_app_defined = True + if not data: + data = {} + meta = dxo.meta + if not meta: + meta = {} + + current_round = shareable.get_header(AppConstants.CURRENT_ROUND, None) + total_rounds = shareable.get_header(AppConstants.NUM_ROUNDS, None) + + meta[defs.MetaKey.DATA_KIND] = dxo.data_kind + if current_round is not None: + meta[defs.MetaKey.CURRENT_ROUND] = current_round + if total_rounds is not None: + meta[defs.MetaKey.TOTAL_ROUND] = total_rounds + + msg = Message( + headers={ + defs.MsgHeader.TASK_ID: task_id, + defs.MsgHeader.TASK_NAME: task_name, + }, + payload={defs.PayloadKey.DATA: data, defs.PayloadKey.META: meta}, + ) + + # keep sending until done + self._send_task(task_ctx, msg, abort_signal) + if task_ctx.send_rc != ReturnCode.OK: + # send_task failed + return make_reply(task_ctx.send_rc) + + # wait for result + self.log_info(fl_ctx, f"Waiting for result from {self.flare_agent_fqcn}") + while True: + if task_ctx.result_waiter.wait(timeout=_SHORT_SLEEP_TIME): + # result available + break + else: + # timed out - check other conditions + if self.is_done or abort_signal.triggered: + self.log_info(fl_ctx, "task is aborted") + + # notify the agent + self._ask_agent_to_abort_task(task_name, task_id) + self.task_ctx = None + return make_reply(ReturnCode.TASK_ABORTED) + + # convert the result + if task_ctx.result_rc not in [defs.RC.OK, defs.RC.EARLY_TERMINATION]: + return make_reply(task_ctx.result_rc) + + result = task_ctx.result + meta = result.get(defs.PayloadKey.META) + + data = result.get(defs.PayloadKey.DATA) + if is_app_defined: + data_kind = DataKind.APP_DEFINED + else: + data_kind = meta.get(defs.MetaKey.DATA_KIND, DataKind.WEIGHTS) + + dxo = DXO( + data_kind=data_kind, + data=data, + meta=meta, + ) + s = dxo.to_shareable() + s.set_return_code(task_ctx.result_rc) + return s + + def _ask_agent_to_abort_task(self, task_name, task_id): + msg = Message( + headers={ + defs.MsgHeader.TASK_ID: task_id, + defs.MsgHeader.TASK_NAME: task_name, + } + ) + + self.cell.fire_and_forget( + channel=defs.CHANNEL, + topic=defs.TOPIC_ABORT, + message=msg, + targets=[self.flare_agent_fqcn], + optional=True, + ) + + @staticmethod + def _finish_result(task_ctx: _TaskContext, result_rc="", result=None, result_is_valid=True): + task_ctx.result_rc = result_rc + task_ctx.result = result + task_ctx.result_received_time = time.time() + task_ctx.result_waiter.set() + if result_is_valid: + return make_cell_reply(CellReturnCode.OK) + else: + return make_cell_reply(CellReturnCode.INVALID_REQUEST) + + def _receive_result(self, request: Message) -> Union[None, Message]: + sender = request.get_header(MessageHeaderKey.ORIGIN) + task_id = request.get_header(defs.MsgHeader.TASK_ID) + + # When the agent is restarted, it sends a result to us, in case we are waiting for the result + # of the current task. In this case, the task_id is empty. + task_ctx = self.task_ctx + if not task_ctx: + # we are not waiting for any result + if not task_id: + # this was sent by the agent when it's started or restarted - just ignore + return make_cell_reply(CellReturnCode.OK) + + self.logger.error(f"received result from {sender} for task {task_id} while not waiting for result!") + return make_cell_reply(CellReturnCode.INVALID_REQUEST) + + # the agent could send us valid result after restarted + fl_ctx = task_ctx.fl_ctx + if task_id and task_id != task_ctx.task_id: + self.log_error(fl_ctx, f"received task id {task_id} != expected {task_ctx.task_id}") + return make_cell_reply(CellReturnCode.INVALID_REQUEST) + + if task_ctx.result_received_time: + # already received - this is a dup + self.log_info(fl_ctx, f"received duplicate result from {sender}") + return make_cell_reply(CellReturnCode.OK) + + payload = request.payload + if not isinstance(payload, dict): + self.log_error(fl_ctx, f"bad result from {sender}: expect dict but got {type(payload)}") + return self._finish_result(task_ctx, result_is_valid=False, result_rc=ReturnCode.EXECUTION_EXCEPTION) + + data = payload.get(defs.PayloadKey.DATA) + if data is None: + self.log_error(fl_ctx, f"bad result from {sender}: missing {defs.PayloadKey.DATA}") + return self._finish_result(task_ctx, result_is_valid=False, result_rc=ReturnCode.EXECUTION_EXCEPTION) + + meta = payload.get(defs.PayloadKey.META) + if meta is None: + self.log_error(fl_ctx, f"bad result from {sender}: missing {defs.PayloadKey.META}") + return self._finish_result(task_ctx, result_is_valid=False, result_rc=ReturnCode.EXECUTION_EXCEPTION) + + self.log_info(fl_ctx, f"received result from {sender}") + return self._finish_result( + task_ctx, + result_is_valid=True, + result_rc=request.get_header(defs.MsgHeader.RC, defs.RC.OK), + result=payload, + ) diff --git a/nvflare/client/__init__.py b/nvflare/client/__init__.py index 0e0583816a..8d66896223 100644 --- a/nvflare/client/__init__.py +++ b/nvflare/client/__init__.py @@ -33,3 +33,4 @@ from .api import system_info as system_info from .decorator import evaluate as evaluate from .decorator import train as train +from .ipc.ipc_agent import IPCAgent diff --git a/nvflare/client/ipc/__init__.py b/nvflare/client/ipc/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/nvflare/client/ipc/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/client/ipc/defs.py b/nvflare/client/ipc/defs.py new file mode 100644 index 0000000000..17b5c4234e --- /dev/null +++ b/nvflare/client/ipc/defs.py @@ -0,0 +1,106 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod + +from nvflare.apis.fl_constant import ReturnCode as RC +from nvflare.fuel.f3.cellnet.fqcn import FQCN + +CHANNEL = "flare_agent" + +TOPIC_GET_TASK = "get_task" +TOPIC_SUBMIT_RESULT = "submit_result" +TOPIC_HEARTBEAT = "heartbeat" +TOPIC_HELLO = "hello" +TOPIC_BYE = "bye" +TOPIC_ABORT = "abort" + +JOB_META_KEY_AGENT_ID = "agent_id" + + +class MsgHeader: + + TASK_ID = "task_id" + TASK_NAME = "task_name" + RC = "rc" + + +class PayloadKey: + DATA = "data" + META = "meta" + + +class MetaKey: + CURRENT_ROUND = "current_round" + TOTAL_ROUND = "total_round" + DATA_KIND = "data_kind" + NUM_STEPS_CURRENT_ROUND = "NUM_STEPS_CURRENT_ROUND" + PROCESSED_ALGORITHM = "PROCESSED_ALGORITHM" + PROCESSED_KEYS = "PROCESSED_KEYS" + INITIAL_METRICS = "initial_metrics" + FILTER_HISTORY = "filter_history" + + +class Task: + + NEW = 0 + FETCHED = 1 + PROCESSED = 2 + + def __init__(self, task_name: str, task_id: str, meta: dict, data): + self.task_name = task_name + self.task_id = task_id + self.meta = meta + self.data = data + self.status = Task.NEW + self.last_send_result_time = None + self.aborted = False + self.already_received = False + + def __str__(self): + return f"'{self.task_name} {self.task_id}'" + + +class TaskResult: + def __init__(self, meta: dict, data, return_code=RC.OK): + if not meta: + meta = {} + + if not isinstance(meta, dict): + raise TypeError(f"meta must be dict but got {type(meta)}") + + if not data: + data = {} + + if not isinstance(return_code, str): + raise TypeError(f"return_code must be str but got {type(return_code)}") + + self.return_code = return_code + self.meta = meta + self.data = data + + +class AgentClosed(Exception): + pass + + +class CallStateError(Exception): + pass + + +def agent_site_fqcn(site_name: str, agent_id: str): + # add the "-" prefix to the agent_id to make a child of the site + # this prefix will make the agent site's FQCN < the CJ's FQCN + # this is necessary to enable ad-hoc connections between CJ and agent, where CJ listens + # with ad-hoc connection, the cell with greater FQCN listens. + return FQCN.join([site_name, f"-{agent_id}"]) diff --git a/nvflare/client/ipc/ipc_agent.py b/nvflare/client/ipc/ipc_agent.py new file mode 100644 index 0000000000..e7b104cc96 --- /dev/null +++ b/nvflare/client/ipc/ipc_agent.py @@ -0,0 +1,448 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import threading +import time +import traceback +from typing import Union + +from nvflare.app_common.decomposers import numpy_decomposers +from nvflare.client.ipc import defs +from nvflare.fuel.f3.cellnet.cell import Cell, Message +from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode +from nvflare.fuel.f3.cellnet.net_agent import NetAgent +from nvflare.fuel.f3.cellnet.utils import make_reply +from nvflare.fuel.f3.drivers.driver_params import DriverParams +from nvflare.fuel.utils.config_service import ConfigService + +_SSL_ROOT_CERT = "rootCA.pem" +_SHORT_SLEEP_TIME = 0.2 + + +class IPCAgent: + def __init__( + self, + flare_site_url: str, + flare_site_name: str, + agent_id: str, + workspace_dir: str, + secure_mode=False, + submit_result_timeout=30.0, + flare_site_connection_timeout=60.0, + flare_site_heartbeat_timeout=None, + resend_result_interval=2.0, + ): + """Constructor of Flare Agent. The agent is responsible for communicating with the Flare Client Job cell (CJ) + to get task and to submit task result. + + Args: + flare_site_url: the URL to the client parent cell (CP) + flare_site_name: the CJ's site name (client name) + agent_id: the unique ID of the agent + workspace_dir: directory that contains startup folder and comm_config.json + secure_mode: whether the connection is in secure mode or not + submit_result_timeout: when submitting task result, how long to wait for response from the CJ + flare_site_heartbeat_timeout: time for missing heartbeats from CJ before considering it dead + flare_site_connection_timeout: time for missing heartbeats from CJ before considering it disconnected + """ + ConfigService.initialize(section_files={}, config_path=[workspace_dir]) + + self.logger = logging.getLogger(self.__class__.__name__) + self.cell_name = defs.agent_site_fqcn(flare_site_name, agent_id) + self.workspace_dir = workspace_dir + self.secure_mode = secure_mode + self.flare_site_url = flare_site_url + self.submit_result_timeout = submit_result_timeout + self.flare_site_heartbeat_timeout = flare_site_heartbeat_timeout + self.flare_site_connection_timeout = flare_site_connection_timeout + self.resend_result_interval = resend_result_interval + self.num_results_submitted = 0 + self.current_task = None + self.pending_task = None + self.task_lock = threading.Lock() + self.last_msg_time = time.time() # last time to get msg from flare site + self.peer_fqcn = None + self.is_done = False + self.is_started = False # has the agent been started? + self.is_stopped = False # has the agent been stopped? + self.is_connected = False # is the agent connected to the flare site? + self.credentials = {} # security credentials for secure connection + + if secure_mode: + root_cert_path = ConfigService.find_file(_SSL_ROOT_CERT) + if not root_cert_path: + raise ValueError(f"cannot find {_SSL_ROOT_CERT} from config path {workspace_dir}") + + self.credentials = { + DriverParams.CA_CERT.value: root_cert_path, + } + + self.cell = Cell( + fqcn=self.cell_name, + root_url="", + parent_url=self.flare_site_url, + secure=self.secure_mode, + credentials=self.credentials, + create_internal_listener=False, + ) + self.net_agent = NetAgent(self.cell) + + self.cell.register_request_cb(channel=defs.CHANNEL, topic=defs.TOPIC_GET_TASK, cb=self._receive_task) + self.logger.info(f"registered task CB for {defs.CHANNEL} {defs.TOPIC_GET_TASK}") + self.cell.register_request_cb(channel=defs.CHANNEL, topic=defs.TOPIC_HEARTBEAT, cb=self._handle_heartbeat) + self.cell.register_request_cb(channel=defs.CHANNEL, topic=defs.TOPIC_BYE, cb=self._handle_bye) + self.cell.register_request_cb(channel=defs.CHANNEL, topic=defs.TOPIC_ABORT, cb=self._handle_abort_task) + self.cell.add_incoming_request_filter( + channel=defs.CHANNEL, + topic="*", + cb=self._msg_received, + ) + numpy_decomposers.register() + + def start(self): + """Start the agent. This method must be called to enable CJ/Agent communication. + + Returns: None + + """ + if self.is_started: + self.logger.warning("the agent is already started") + return + + if self.is_stopped: + raise defs.CallStateError("cannot start the agent since it is already stopped") + + self.is_started = True + self.logger.info(f"starting agent {self.cell_name} ...") + self.cell.start() + t = threading.Thread(target=self._monitor, daemon=True) + t.start() + + def stop(self): + """Stop the agent. After this is called, there will be no more communications between CJ and agent. + + Returns: None + + """ + if not self.is_started: + self.logger.warning("cannot stop the agent since it is not started") + return + + if self.is_stopped: + self.logger.warning("agent is already stopped") + return + + self.is_stopped = True + self.cell.stop() + self.net_agent.close() + + def _monitor(self): + while True: + since_last_msg = time.time() - self.last_msg_time + if since_last_msg > self.flare_site_connection_timeout: + if self.is_connected: + self.logger.error( + "flare site disconnected since no message received " + f"for {self.flare_site_connection_timeout} seconds" + ) + self.is_connected = False + + if self.flare_site_heartbeat_timeout and since_last_msg > self.flare_site_heartbeat_timeout: + self.logger.error( + f"flare site is dead since no message received for {self.flare_site_heartbeat_timeout} seconds" + ) + self.is_done = True + return + + time.sleep(_SHORT_SLEEP_TIME) + + def _handle_bye(self, request: Message) -> Union[None, Message]: + peer = request.get_header(MessageHeaderKey.ORIGIN) + self.logger.info(f"got goodbye from {peer}") + self.is_done = True + return make_reply(ReturnCode.OK) + + def _msg_received(self, request: Message): + peer = request.get_header(MessageHeaderKey.ORIGIN) + if self.peer_fqcn and self.peer_fqcn != peer: + # this could happen when a new job is started for the same training + self.logger.warning(f"got peer FQCN '{peer}' while expecting '{self.peer_fqcn}'") + + self.peer_fqcn = peer + self.last_msg_time = time.time() + if not self.is_connected: + self.is_connected = True + self.logger.info(f"connected to flare site {peer}") + + def _handle_heartbeat(self, request: Message) -> Union[None, Message]: + peer = request.get_header(MessageHeaderKey.ORIGIN) + self.logger.debug(f"got heartbeat from {peer}") + return make_reply(ReturnCode.OK) + + def _handle_abort_task(self, request: Message) -> Union[None, Message]: + peer = request.get_header(MessageHeaderKey.ORIGIN) + task_id = request.get_header(defs.MsgHeader.TASK_ID) + task_name = request.get_header(defs.MsgHeader.TASK_NAME) + self.logger.warning(f"received from {peer} to abort {task_name=} {task_id=}") + with self.task_lock: + if self.current_task and task_id == self.current_task.task_id: + self.current_task.aborted = True + elif self.pending_task and task_id == self.pending_task.task_id: + self.pending_task = None + return make_reply(ReturnCode.OK) + + def _receive_task(self, request: Message) -> Union[None, Message]: + with self.task_lock: + return self._do_receive_task(request) + + def _create_task(self, request: Message): + peer = request.get_header(MessageHeaderKey.ORIGIN) + task_id = request.get_header(defs.MsgHeader.TASK_ID) + task_name = request.get_header(defs.MsgHeader.TASK_NAME) + self.logger.info(f"received task from {peer}: {task_name=} {task_id=}") + + task_data = request.payload + if not isinstance(task_data, dict): + self.logger.error(f"bad task data from {peer}: expect dict but got {type(task_data)}") + return None + + data = task_data.get(defs.PayloadKey.DATA) + if not data: + self.logger.error(f"bad task data from {peer}: missing {defs.PayloadKey.DATA}") + return None + + meta = task_data.get(defs.PayloadKey.META) + if not meta: + self.logger.error(f"bad task data from {peer}: missing {defs.PayloadKey.META}") + return None + + return defs.Task(task_name, task_id, meta, data) + + def _do_receive_task(self, request: Message) -> Union[None, Message]: + peer = request.get_header(MessageHeaderKey.ORIGIN) + task_id = request.get_header(defs.MsgHeader.TASK_ID) + task_name = request.get_header(defs.MsgHeader.TASK_NAME) + + # create a new task + new_task = self._create_task(request) + if not new_task: + return make_reply(ReturnCode.INVALID_REQUEST) + + if self.pending_task: + assert isinstance(self.pending_task, defs.Task) + if task_id == self.pending_task.task_id: + return make_reply(ReturnCode.OK) + else: + # this could happen when the CJ is restarted + self.logger.warning(f"got new task from {peer} while already having a pending task!") + + # replace the pending task + self.pending_task = new_task + return make_reply(ReturnCode.OK) + + current_task = self.current_task + if current_task: + assert isinstance(current_task, defs.Task) + if task_id == current_task.task_id: + self.logger.info(f"received duplicate task {task_id} from {peer}") + return make_reply(ReturnCode.OK) + + if current_task.last_send_result_time: + # we already tried to send result back + # assume that the flare site has received + # we set the flag so the sending process will end quickly + # in the meanwhile we ask flare site to retry later + current_task.already_received = True + else: + # error - one task at a time + self.logger.warning( + f"got task {task_name} {task_id} from {peer} " + f"while still working on {current_task.task_name} {current_task.task_id}" + ) + + # this could happen when CJ is restarted while we are processing current task + # we set the current_task to be aborted. App should check this flag frequently to abort processing + current_task.aborted = True + + # treat the new task as pending task - it will become current after the current_task is submitted + self.pending_task = new_task + return make_reply(ReturnCode.OK) + else: + # no current task + self.current_task = new_task + return make_reply(ReturnCode.OK) + + def get_task(self, timeout=None): + """Get a task from FLARE. This is a blocking call. + + If timeout is specified, this call is blocked only for the specified amount of time. + If timeout is not specified, this call is blocked forever until a task is received or agent is closed. + + Args: + timeout: amount of time to block + + Returns: None if no task is available during before timeout; or a Task object if task is available. + Raises: + AgentClosed exception if the agent is closed before timeout. + CallStateError exception if the call is not made properly. + + Note: the application must make the call only when it is just started or after a previous task's result + has been submitted. + + """ + if timeout is not None: + if not isinstance(timeout, (int, float)): + raise TypeError(f"timeout must be (int, float) but got {type(timeout)}") + if timeout <= 0: + raise ValueError(f"timeout must > 0, but got {timeout}") + + start = time.time() + while True: + if self.is_done or self.is_stopped: + self.logger.info("no more tasks - agent closed") + raise defs.AgentClosed("flare agent is closed") + + with self.task_lock: + current_task = self.current_task + if current_task: + assert isinstance(current_task, defs.Task) + if current_task.aborted: + pass + elif current_task.status == defs.Task.NEW: + current_task.status = defs.Task.FETCHED + return current_task + else: + raise defs.CallStateError( + f"application called get_task while the current task is in status {current_task.status}" + ) + if timeout and time.time() - start > timeout: + # no task available before timeout + self.logger.info(f"get_task timeout after {timeout} seconds") + return None + time.sleep(_SHORT_SLEEP_TIME) + + def submit_result(self, result: defs.TaskResult) -> bool: + """Submit the result of the current task. + This is a blocking call. The agent will try to send the result to flare site until it is successfully sent or + the task is aborted or the agent is closed. + + Args: + result: result to be submitted + + Returns: whether the result is submitted successfully + Raises: the CallStateError exception if the submit_result call is not made properly. + + Notes: the application must only make this call after the received task is processed. The call can only be + made a single time regardless whether the submission is successful. + + """ + try: + result_submitted = self._do_submit_result(result) + except Exception as ex: + self.logger.error(f"exception encountered: {ex}") + result_submitted = False + + with self.task_lock: + self.current_task = None + if self.pending_task: + # a new task is waiting for the current task to finish + self.current_task = self.pending_task + self.pending_task = None + return result_submitted + + def _do_submit_result(self, result: defs.TaskResult) -> bool: + if not isinstance(result, defs.TaskResult): + raise TypeError(f"result must be TaskResult but got {type(result)}") + + with self.task_lock: + current_task = self.current_task + if current_task: + if current_task.aborted: + return False + if current_task.status != defs.Task.FETCHED: + raise defs.CallStateError( + f"submit_result is called while current task is in status {current_task.status}" + ) + current_task.status = defs.Task.PROCESSED + elif self.num_results_submitted > 0: + self.logger.error("submit_result is called but there is no current task!") + return False + else: + # if the agent is restarted, it may pick up from previous checkpoint and continue training. + # then it can send the result after finish training. + pass + self.num_results_submitted += 1 + try: + return self._send_result(current_task, result) + except: + self.logger.error(f"exception submitting result to {current_task.sender}") + traceback.print_exc() + return False + + def _send_result(self, current_task: defs.Task, result: defs.TaskResult): + meta = result.meta + rc = result.return_code + data = result.data + + msg = Message( + headers={ + defs.MsgHeader.TASK_NAME: current_task.task_name if current_task else "", + defs.MsgHeader.TASK_ID: current_task.task_id if current_task else "", + defs.MsgHeader.RC: rc, + }, + payload={ + defs.PayloadKey.META: meta, + defs.PayloadKey.DATA: data, + }, + ) + + last_send_time = 0 + while True: + if self.is_done or self.is_stopped: + self.logger.error(f"quit submitting result for task {current_task} since agent is closed") + raise defs.AgentClosed("agent is stopped") + + if current_task and current_task.already_received: + if not current_task.last_send_result_time: + self.logger.warning(f"task {current_task} was marked already_received but has been sent!") + return True + + if current_task and current_task.aborted: + self.logger.error(f"quit submitting result for task {current_task} since it is aborted") + return False + + if self.is_connected and time.time() - last_send_time > self.resend_result_interval: + self.logger.info(f"sending result to {self.peer_fqcn} for task {current_task}") + if current_task: + current_task.last_send_result_time = time.time() + reply = self.cell.send_request( + channel=defs.CHANNEL, + topic=defs.TOPIC_SUBMIT_RESULT, + target=self.peer_fqcn, + request=msg, + timeout=self.submit_result_timeout, + ) + last_send_time = time.time() + if reply: + rc = reply.get_header(MessageHeaderKey.RETURN_CODE) + peer = reply.get_header(MessageHeaderKey.ORIGIN) + if rc == ReturnCode.OK: + return True + elif rc == ReturnCode.INVALID_REQUEST: + self.logger.error(f"received return code from {peer}: {rc}") + return False + else: + self.logger.info(f"failed to send to {self.peer_fqcn}: {rc} - will retry") + time.sleep(_SHORT_SLEEP_TIME) diff --git a/nvflare/fuel/f3/cellnet/core_cell.py b/nvflare/fuel/f3/cellnet/core_cell.py index 9ca89a85e6..048ce526b1 100644 --- a/nvflare/fuel/f3/cellnet/core_cell.py +++ b/nvflare/fuel/f3/cellnet/core_cell.py @@ -334,10 +334,13 @@ def __init__( self.secure = secure self.logger.debug(f"{self.my_info.fqcn}: max_msg_size={self.max_msg_size}") - if not root_url: - raise ValueError(f"{self.my_info.fqcn}: root_url not provided") + if not root_url and not parent_url: + raise ValueError(f"{self.my_info.fqcn}: neither root_url nor parent_url is provided") if self.my_info.is_root and self.my_info.is_on_server: + if not root_url: + raise ValueError(f"{self.my_info.fqcn}: root_url is required for server-side cells but not provided") + if isinstance(root_url, list): for url in root_url: if not _validate_url(url): @@ -346,7 +349,7 @@ def __init__( if not _validate_url(root_url): raise ValueError(f"{self.my_info.fqcn}: invalid Root URL '{root_url}'") root_url = [root_url] - else: + elif root_url: if isinstance(root_url, list): # multiple urls are available - randomly pick one root_url = random.choice(root_url) From 3372783870f27260222f251582242d89b876f64d Mon Sep 17 00:00:00 2001 From: Yan Cheng Date: Thu, 21 Mar 2024 16:37:00 -0400 Subject: [PATCH 2/3] removed unused import --- nvflare/client/ipc/defs.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/nvflare/client/ipc/defs.py b/nvflare/client/ipc/defs.py index 17b5c4234e..9d8cf3eeff 100644 --- a/nvflare/client/ipc/defs.py +++ b/nvflare/client/ipc/defs.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC, abstractmethod - from nvflare.apis.fl_constant import ReturnCode as RC from nvflare.fuel.f3.cellnet.fqcn import FQCN From 3af8cbf35e422eb3fae56b6e5ccd812f795f0b36 Mon Sep 17 00:00:00 2001 From: Yan Cheng Date: Thu, 4 Apr 2024 17:38:53 -0400 Subject: [PATCH 3/3] address PR comments --- nvflare/app_common/app_defined/__init__.py | 2 +- .../app_common/app_defined/shareable_generator.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/nvflare/app_common/app_defined/__init__.py b/nvflare/app_common/app_defined/__init__.py index 4fc50543f1..d9155f923f 100644 --- a/nvflare/app_common/app_defined/__init__.py +++ b/nvflare/app_common/app_defined/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/nvflare/app_common/app_defined/shareable_generator.py b/nvflare/app_common/app_defined/shareable_generator.py index e3d4d2907a..d4b461bfcf 100644 --- a/nvflare/app_common/app_defined/shareable_generator.py +++ b/nvflare/app_common/app_defined/shareable_generator.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -46,12 +46,12 @@ def model_to_trainable(self, model_obj: Any) -> (Any, dict): pass @abstractmethod - def apply_weights_to_model(self, model_obj: Any, weights: Any, meta: dict) -> Any: - """Apply trained weights and meta to the base model + def update_model(self, model_obj: Any, training_result: Any, meta: dict) -> Any: + """Update model with training result and meta Args: - model_obj: base model object that weights will be applied to - weights: trained weights + model_obj: base model object to be updated + training_result: training result to be applied to the model object meta: trained meta Returns: the updated model object @@ -88,7 +88,7 @@ def shareable_to_learnable(self, shareable: Shareable, fl_ctx: FLContext) -> Lea base_model_obj = base_model_learnable.get(ModelLearnableKey.WEIGHTS) dxo = from_shareable(shareable) - trained_weights = dxo.data + training_result = dxo.data trained_meta = dxo.meta - model_obj = self.apply_weights_to_model(model_obj=base_model_obj, weights=trained_weights, meta=trained_meta) + model_obj = self.update_model(model_obj=base_model_obj, training_result=training_result, meta=trained_meta) return make_model_learnable(model_obj, {})