Skip to content

Commit

Permalink
Added the current-round info the fl_ctx for BaseModelController (#2916)
Browse files Browse the repository at this point in the history
* Added the current-round info the fl_ctx for BaseModelController.

* reformat.

* codestyle fix.

* Moved the self.set_fl_context(data) call to broadcast_model().

* Change broadcast_model() to must send a FLModel, not None.

* Changed the BaseModelController broadcast_model data default value, and a warning message to debug.

* refactoried.

* Updated docstring.

---------

Co-authored-by: Chester Chen <512707+chesterxgchen@users.noreply.github.com>
Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <yuantingh@nvidia.com>
  • Loading branch information
3 people authored Sep 6, 2024
1 parent d62b901 commit 5d71c2c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
24 changes: 13 additions & 11 deletions nvflare/app_common/workflows/base_model_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def _build_shareable(self, data: FLModel = None) -> Shareable:

def broadcast_model(
self,
data,
task_name: str = AppConstants.TASK_TRAIN,
data: FLModel = None,
targets: Union[List[Client], List[str], None] = None,
min_responses: int = None,
timeout: int = 0,
Expand All @@ -112,8 +112,8 @@ def broadcast_model(
"""Send a task with data to a list of targets.
Args:
data: FLModel to be sent to clients. It must be a FLModel object. It will raise an exception if None.
task_name (str, optional): name of the task. Defaults to "train".
data (FLModel, optional): FLModel to be sent to clients. If no data is given, send empty FLModel.
targets (List[str], optional): the list of target client names or None (all clients). Defaults to None.
min_responses (int, optional): the minimum number of responses expected. If None, must receive responses from
all clients that the task has been sent to. Defaults to None.
Expand All @@ -128,9 +128,9 @@ def broadcast_model(
"""

if not isinstance(task_name, str):
raise TypeError("task_name must be a string but got {}".format(type(task_name)))
if data and not isinstance(data, FLModel):
raise TypeError("data must be a FLModel or None but got {}".format(type(data)))
raise TypeError(f"task_name must be a string but got {type(task_name)}")
if not isinstance(data, FLModel):
raise TypeError(f"data must be a FLModel but got {type(data)}")
if min_responses is None:
min_responses = 0 # this is internally used by controller's broadcast to represent all targets
check_non_negative_int("min_responses", min_responses)
Expand All @@ -139,9 +139,7 @@ def broadcast_model(
if not blocking and not isinstance(callback, Callable):
raise TypeError("callback must be defined if blocking is False, but got {}".format(type(callback)))

if not data:
self.warning("data is None. Sending empty FLModel.")
data = FLModel(params_type=ParamsType.FULL, params={})
self.set_fl_context(data)

task = self._prepare_task(data=data, task_name=task_name, timeout=timeout, callback=callback)

Expand Down Expand Up @@ -224,9 +222,6 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None:
result_model.meta["props"] = client_task.task.props[AppConstants.META_DATA]
result_model.meta["client_name"] = client_name

if result_model.current_round is not None:
self.fl_ctx.set_prop(AppConstants.CURRENT_ROUND, result_model.current_round, private=True, sticky=True)

self.event(AppEventType.BEFORE_CONTRIBUTION_ACCEPT)
self._accept_train_result(client_name=client_name, result=result, fl_ctx=fl_ctx)
self.event(AppEventType.AFTER_CONTRIBUTION_ACCEPT)
Expand Down Expand Up @@ -379,6 +374,13 @@ def sample_clients(self, num_clients: int = None) -> List[str]:

return clients

def set_fl_context(self, data: FLModel):
"""Set up the fl_ctx information based on the passed in FLModel data."""
if data and data.current_round is not None:
self.fl_ctx.set_prop(AppConstants.CURRENT_ROUND, data.current_round, private=True, sticky=True)
else:
self.debug("The FLModel data does not contain the current_round information.")

def get_component(self, component_id: str):
return self.engine.get_component(component_id)

Expand Down
2 changes: 1 addition & 1 deletion research/fed-bpt/src/global_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def run(self) -> None:

clients = self.sample_clients(self.num_clients)

global_model = FLModel(params={"global_es": global_es})
global_model = FLModel(params={"global_es": global_es}, current_round=self.current_round)
results = self.send_model_and_wait(targets=clients, data=global_model)

# get solutions from clients
Expand Down

0 comments on commit 5d71c2c

Please sign in to comment.