From 5d71c2c736888ebc59b77a1b189c756dfc496875 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Fri, 6 Sep 2024 19:26:33 -0400 Subject: [PATCH] Added the current-round info the fl_ctx for BaseModelController (#2916) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Added the current-round info the fl_ctx for BaseModelController. * reformat. * codestyle fix. * Moved the self.set_fl_context(data) call to broadcast_model(). * Change broadcast_model() to must send a FLModel, not None. * Changed the BaseModelController broadcast_model data default value, and a warning message to debug. * refactoried. * Updated docstring. --------- Co-authored-by: Chester Chen <512707+chesterxgchen@users.noreply.github.com> Co-authored-by: Yuan-Ting Hsieh (謝沅廷) --- .../workflows/base_model_controller.py | 24 ++++++++++--------- research/fed-bpt/src/global_es.py | 2 +- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/nvflare/app_common/workflows/base_model_controller.py b/nvflare/app_common/workflows/base_model_controller.py index f40ff6f9a9..4c1b1333d0 100644 --- a/nvflare/app_common/workflows/base_model_controller.py +++ b/nvflare/app_common/workflows/base_model_controller.py @@ -100,8 +100,8 @@ def _build_shareable(self, data: FLModel = None) -> Shareable: def broadcast_model( self, + data, task_name: str = AppConstants.TASK_TRAIN, - data: FLModel = None, targets: Union[List[Client], List[str], None] = None, min_responses: int = None, timeout: int = 0, @@ -112,8 +112,8 @@ def broadcast_model( """Send a task with data to a list of targets. Args: + data: FLModel to be sent to clients. It must be a FLModel object. It will raise an exception if None. task_name (str, optional): name of the task. Defaults to "train". - data (FLModel, optional): FLModel to be sent to clients. If no data is given, send empty FLModel. targets (List[str], optional): the list of target client names or None (all clients). Defaults to None. min_responses (int, optional): the minimum number of responses expected. If None, must receive responses from all clients that the task has been sent to. Defaults to None. @@ -128,9 +128,9 @@ def broadcast_model( """ if not isinstance(task_name, str): - raise TypeError("task_name must be a string but got {}".format(type(task_name))) - if data and not isinstance(data, FLModel): - raise TypeError("data must be a FLModel or None but got {}".format(type(data))) + raise TypeError(f"task_name must be a string but got {type(task_name)}") + if not isinstance(data, FLModel): + raise TypeError(f"data must be a FLModel but got {type(data)}") if min_responses is None: min_responses = 0 # this is internally used by controller's broadcast to represent all targets check_non_negative_int("min_responses", min_responses) @@ -139,9 +139,7 @@ def broadcast_model( if not blocking and not isinstance(callback, Callable): raise TypeError("callback must be defined if blocking is False, but got {}".format(type(callback))) - if not data: - self.warning("data is None. Sending empty FLModel.") - data = FLModel(params_type=ParamsType.FULL, params={}) + self.set_fl_context(data) task = self._prepare_task(data=data, task_name=task_name, timeout=timeout, callback=callback) @@ -224,9 +222,6 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None: result_model.meta["props"] = client_task.task.props[AppConstants.META_DATA] result_model.meta["client_name"] = client_name - if result_model.current_round is not None: - self.fl_ctx.set_prop(AppConstants.CURRENT_ROUND, result_model.current_round, private=True, sticky=True) - self.event(AppEventType.BEFORE_CONTRIBUTION_ACCEPT) self._accept_train_result(client_name=client_name, result=result, fl_ctx=fl_ctx) self.event(AppEventType.AFTER_CONTRIBUTION_ACCEPT) @@ -379,6 +374,13 @@ def sample_clients(self, num_clients: int = None) -> List[str]: return clients + def set_fl_context(self, data: FLModel): + """Set up the fl_ctx information based on the passed in FLModel data.""" + if data and data.current_round is not None: + self.fl_ctx.set_prop(AppConstants.CURRENT_ROUND, data.current_round, private=True, sticky=True) + else: + self.debug("The FLModel data does not contain the current_round information.") + def get_component(self, component_id: str): return self.engine.get_component(component_id) diff --git a/research/fed-bpt/src/global_es.py b/research/fed-bpt/src/global_es.py index 34c4e8ba3d..2002cd7b34 100644 --- a/research/fed-bpt/src/global_es.py +++ b/research/fed-bpt/src/global_es.py @@ -95,7 +95,7 @@ def run(self) -> None: clients = self.sample_clients(self.num_clients) - global_model = FLModel(params={"global_es": global_es}) + global_model = FLModel(params={"global_es": global_es}, current_round=self.current_round) results = self.send_model_and_wait(targets=clients, data=global_model) # get solutions from clients