From 5f7211d833cf6d8d627d74417a4e9d3c0b7a65c8 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Thu, 5 Sep 2024 14:38:31 -0400 Subject: [PATCH 1/8] Added the current-round info the fl_ctx for BaseModelController. --- .../app_common/workflows/base_model_controller.py | 12 +++++++++--- nvflare/app_common/workflows/model_controller.py | 2 ++ research/fed-bpt/src/global_es.py | 2 +- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/nvflare/app_common/workflows/base_model_controller.py b/nvflare/app_common/workflows/base_model_controller.py index f40ff6f9a9..388f74edda 100644 --- a/nvflare/app_common/workflows/base_model_controller.py +++ b/nvflare/app_common/workflows/base_model_controller.py @@ -224,9 +224,6 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None: result_model.meta["props"] = client_task.task.props[AppConstants.META_DATA] result_model.meta["client_name"] = client_name - if result_model.current_round is not None: - self.fl_ctx.set_prop(AppConstants.CURRENT_ROUND, result_model.current_round, private=True, sticky=True) - self.event(AppEventType.BEFORE_CONTRIBUTION_ACCEPT) self._accept_train_result(client_name=client_name, result=result, fl_ctx=fl_ctx) self.event(AppEventType.AFTER_CONTRIBUTION_ACCEPT) @@ -379,6 +376,15 @@ def sample_clients(self, num_clients: int = None) -> List[str]: return clients + def set_fl_context(self, data: FLModel): + """Set up the fl_ctx information based on the passed in FLModel data. + + """ + if data and data.current_round is not None: + self.fl_ctx.set_prop(AppConstants.CURRENT_ROUND, data.current_round, private=True, sticky=True) + else: + self.warning(f"The FLModel data does not contain the current_round information.") + def get_component(self, component_id: str): return self.engine.get_component(component_id) diff --git a/nvflare/app_common/workflows/model_controller.py b/nvflare/app_common/workflows/model_controller.py index 61fb7b9739..6c25bb4248 100644 --- a/nvflare/app_common/workflows/model_controller.py +++ b/nvflare/app_common/workflows/model_controller.py @@ -63,6 +63,7 @@ def send_model_and_wait( Returns: List[FLModel] """ + self.set_fl_context(data) return super().broadcast_model( task_name=task_name, data=data, @@ -94,6 +95,7 @@ def send_model( Returns: None """ + self.set_fl_context(data) super().broadcast_model( task_name=task_name, data=data, diff --git a/research/fed-bpt/src/global_es.py b/research/fed-bpt/src/global_es.py index 34c4e8ba3d..2002cd7b34 100644 --- a/research/fed-bpt/src/global_es.py +++ b/research/fed-bpt/src/global_es.py @@ -95,7 +95,7 @@ def run(self) -> None: clients = self.sample_clients(self.num_clients) - global_model = FLModel(params={"global_es": global_es}) + global_model = FLModel(params={"global_es": global_es}, current_round=self.current_round) results = self.send_model_and_wait(targets=clients, data=global_model) # get solutions from clients From 990e7aef9ac6aa2f2d4267feeaa065906126fc44 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Thu, 5 Sep 2024 14:49:40 -0400 Subject: [PATCH 2/8] reformat. --- nvflare/app_common/workflows/base_model_controller.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/nvflare/app_common/workflows/base_model_controller.py b/nvflare/app_common/workflows/base_model_controller.py index 388f74edda..d7adcbedea 100644 --- a/nvflare/app_common/workflows/base_model_controller.py +++ b/nvflare/app_common/workflows/base_model_controller.py @@ -377,9 +377,7 @@ def sample_clients(self, num_clients: int = None) -> List[str]: return clients def set_fl_context(self, data: FLModel): - """Set up the fl_ctx information based on the passed in FLModel data. - - """ + """Set up the fl_ctx information based on the passed in FLModel data.""" if data and data.current_round is not None: self.fl_ctx.set_prop(AppConstants.CURRENT_ROUND, data.current_round, private=True, sticky=True) else: From 25343fb8fbfc1af660131e61b2ce520de0b92516 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Thu, 5 Sep 2024 14:59:04 -0400 Subject: [PATCH 3/8] codestyle fix. --- nvflare/app_common/workflows/base_model_controller.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nvflare/app_common/workflows/base_model_controller.py b/nvflare/app_common/workflows/base_model_controller.py index d7adcbedea..4093c5dcc1 100644 --- a/nvflare/app_common/workflows/base_model_controller.py +++ b/nvflare/app_common/workflows/base_model_controller.py @@ -381,7 +381,7 @@ def set_fl_context(self, data: FLModel): if data and data.current_round is not None: self.fl_ctx.set_prop(AppConstants.CURRENT_ROUND, data.current_round, private=True, sticky=True) else: - self.warning(f"The FLModel data does not contain the current_round information.") + self.warning("The FLModel data does not contain the current_round information.") def get_component(self, component_id: str): return self.engine.get_component(component_id) From 750bef367a5adbfbdf6d73ac526ce0d095f2fc0b Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Thu, 5 Sep 2024 21:52:37 -0400 Subject: [PATCH 4/8] Moved the self.set_fl_context(data) call to broadcast_model(). --- nvflare/app_common/workflows/base_model_controller.py | 9 ++++++--- nvflare/app_common/workflows/model_controller.py | 2 -- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/nvflare/app_common/workflows/base_model_controller.py b/nvflare/app_common/workflows/base_model_controller.py index 4093c5dcc1..6b2fb0f259 100644 --- a/nvflare/app_common/workflows/base_model_controller.py +++ b/nvflare/app_common/workflows/base_model_controller.py @@ -139,9 +139,12 @@ def broadcast_model( if not blocking and not isinstance(callback, Callable): raise TypeError("callback must be defined if blocking is False, but got {}".format(type(callback))) - if not data: - self.warning("data is None. Sending empty FLModel.") - data = FLModel(params_type=ParamsType.FULL, params={}) + # This will never happen. We don't need to have this logic to create a dummy FLModel object. + # if not data: + # self.warning("data is None. Sending empty FLModel.") + # data = FLModel(params_type=ParamsType.FULL, params={}) + + self.set_fl_context(data) task = self._prepare_task(data=data, task_name=task_name, timeout=timeout, callback=callback) diff --git a/nvflare/app_common/workflows/model_controller.py b/nvflare/app_common/workflows/model_controller.py index 6c25bb4248..61fb7b9739 100644 --- a/nvflare/app_common/workflows/model_controller.py +++ b/nvflare/app_common/workflows/model_controller.py @@ -63,7 +63,6 @@ def send_model_and_wait( Returns: List[FLModel] """ - self.set_fl_context(data) return super().broadcast_model( task_name=task_name, data=data, @@ -95,7 +94,6 @@ def send_model( Returns: None """ - self.set_fl_context(data) super().broadcast_model( task_name=task_name, data=data, From ff83bcf9183ab26f10a656c3a04fa8e0573e80a4 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Fri, 6 Sep 2024 09:43:05 -0400 Subject: [PATCH 5/8] Change broadcast_model() to must send a FLModel, not None. --- nvflare/app_common/workflows/base_model_controller.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/nvflare/app_common/workflows/base_model_controller.py b/nvflare/app_common/workflows/base_model_controller.py index 6b2fb0f259..f2314ff18e 100644 --- a/nvflare/app_common/workflows/base_model_controller.py +++ b/nvflare/app_common/workflows/base_model_controller.py @@ -129,8 +129,10 @@ def broadcast_model( if not isinstance(task_name, str): raise TypeError("task_name must be a string but got {}".format(type(task_name))) - if data and not isinstance(data, FLModel): - raise TypeError("data must be a FLModel or None but got {}".format(type(data))) + if not data: + raise TypeError("data must be a FLModel but got None") + if not isinstance(data, FLModel): + raise TypeError("data must be a FLModel but got {}".format(type(data))) if min_responses is None: min_responses = 0 # this is internally used by controller's broadcast to represent all targets check_non_negative_int("min_responses", min_responses) @@ -139,11 +141,6 @@ def broadcast_model( if not blocking and not isinstance(callback, Callable): raise TypeError("callback must be defined if blocking is False, but got {}".format(type(callback))) - # This will never happen. We don't need to have this logic to create a dummy FLModel object. - # if not data: - # self.warning("data is None. Sending empty FLModel.") - # data = FLModel(params_type=ParamsType.FULL, params={}) - self.set_fl_context(data) task = self._prepare_task(data=data, task_name=task_name, timeout=timeout, callback=callback) From b1d2ca7fa9ac9397fbd1037788db05bf3e55cf79 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Fri, 6 Sep 2024 17:09:35 -0400 Subject: [PATCH 6/8] Changed the BaseModelController broadcast_model data default value, and a warning message to debug. --- nvflare/app_common/workflows/base_model_controller.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nvflare/app_common/workflows/base_model_controller.py b/nvflare/app_common/workflows/base_model_controller.py index f2314ff18e..02d5f43e0d 100644 --- a/nvflare/app_common/workflows/base_model_controller.py +++ b/nvflare/app_common/workflows/base_model_controller.py @@ -100,8 +100,8 @@ def _build_shareable(self, data: FLModel = None) -> Shareable: def broadcast_model( self, + data, task_name: str = AppConstants.TASK_TRAIN, - data: FLModel = None, targets: Union[List[Client], List[str], None] = None, min_responses: int = None, timeout: int = 0, @@ -381,7 +381,7 @@ def set_fl_context(self, data: FLModel): if data and data.current_round is not None: self.fl_ctx.set_prop(AppConstants.CURRENT_ROUND, data.current_round, private=True, sticky=True) else: - self.warning("The FLModel data does not contain the current_round information.") + self.debug("The FLModel data does not contain the current_round information.") def get_component(self, component_id: str): return self.engine.get_component(component_id) From 95f75961ccf063ba23518f4d48790abe646e3c07 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Fri, 6 Sep 2024 17:20:49 -0400 Subject: [PATCH 7/8] refactoried. --- nvflare/app_common/workflows/base_model_controller.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/nvflare/app_common/workflows/base_model_controller.py b/nvflare/app_common/workflows/base_model_controller.py index 02d5f43e0d..0ef75eee04 100644 --- a/nvflare/app_common/workflows/base_model_controller.py +++ b/nvflare/app_common/workflows/base_model_controller.py @@ -128,11 +128,9 @@ def broadcast_model( """ if not isinstance(task_name, str): - raise TypeError("task_name must be a string but got {}".format(type(task_name))) - if not data: - raise TypeError("data must be a FLModel but got None") + raise TypeError(f"task_name must be a string but got {type(task_name)}") if not isinstance(data, FLModel): - raise TypeError("data must be a FLModel but got {}".format(type(data))) + raise TypeError(f"data must be a FLModel but got {type(data)}") if min_responses is None: min_responses = 0 # this is internally used by controller's broadcast to represent all targets check_non_negative_int("min_responses", min_responses) From ac502e94cecd8cb61fc9150bd6df3fbb19285b0e Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Fri, 6 Sep 2024 17:23:30 -0400 Subject: [PATCH 8/8] Updated docstring. --- nvflare/app_common/workflows/base_model_controller.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nvflare/app_common/workflows/base_model_controller.py b/nvflare/app_common/workflows/base_model_controller.py index 0ef75eee04..4c1b1333d0 100644 --- a/nvflare/app_common/workflows/base_model_controller.py +++ b/nvflare/app_common/workflows/base_model_controller.py @@ -112,8 +112,8 @@ def broadcast_model( """Send a task with data to a list of targets. Args: + data: FLModel to be sent to clients. It must be a FLModel object. It will raise an exception if None. task_name (str, optional): name of the task. Defaults to "train". - data (FLModel, optional): FLModel to be sent to clients. If no data is given, send empty FLModel. targets (List[str], optional): the list of target client names or None (all clients). Defaults to None. min_responses (int, optional): the minimum number of responses expected. If None, must receive responses from all clients that the task has been sent to. Defaults to None.