Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the current-round info the fl_ctx for BaseModelController #2916

Merged
merged 12 commits into from
Sep 6, 2024
24 changes: 13 additions & 11 deletions nvflare/app_common/workflows/base_model_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def _build_shareable(self, data: FLModel = None) -> Shareable:

def broadcast_model(
self,
data,
task_name: str = AppConstants.TASK_TRAIN,
data: FLModel = None,
targets: Union[List[Client], List[str], None] = None,
min_responses: int = None,
timeout: int = 0,
Expand All @@ -112,8 +112,8 @@ def broadcast_model(
"""Send a task with data to a list of targets.

Args:
data: FLModel to be sent to clients. It must be a FLModel object. It will raise an exception if None.
task_name (str, optional): name of the task. Defaults to "train".
data (FLModel, optional): FLModel to be sent to clients. If no data is given, send empty FLModel.
targets (List[str], optional): the list of target client names or None (all clients). Defaults to None.
min_responses (int, optional): the minimum number of responses expected. If None, must receive responses from
all clients that the task has been sent to. Defaults to None.
Expand All @@ -128,9 +128,9 @@ def broadcast_model(
"""

if not isinstance(task_name, str):
raise TypeError("task_name must be a string but got {}".format(type(task_name)))
if data and not isinstance(data, FLModel):
raise TypeError("data must be a FLModel or None but got {}".format(type(data)))
raise TypeError(f"task_name must be a string but got {type(task_name)}")
if not isinstance(data, FLModel):
raise TypeError(f"data must be a FLModel but got {type(data)}")
if min_responses is None:
min_responses = 0 # this is internally used by controller's broadcast to represent all targets
check_non_negative_int("min_responses", min_responses)
Expand All @@ -139,9 +139,7 @@ def broadcast_model(
if not blocking and not isinstance(callback, Callable):
raise TypeError("callback must be defined if blocking is False, but got {}".format(type(callback)))

if not data:
yhwen marked this conversation as resolved.
Show resolved Hide resolved
self.warning("data is None. Sending empty FLModel.")
data = FLModel(params_type=ParamsType.FULL, params={})
yhwen marked this conversation as resolved.
Show resolved Hide resolved
self.set_fl_context(data)

task = self._prepare_task(data=data, task_name=task_name, timeout=timeout, callback=callback)

Expand Down Expand Up @@ -224,9 +222,6 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None:
result_model.meta["props"] = client_task.task.props[AppConstants.META_DATA]
result_model.meta["client_name"] = client_name

if result_model.current_round is not None:
self.fl_ctx.set_prop(AppConstants.CURRENT_ROUND, result_model.current_round, private=True, sticky=True)

self.event(AppEventType.BEFORE_CONTRIBUTION_ACCEPT)
self._accept_train_result(client_name=client_name, result=result, fl_ctx=fl_ctx)
self.event(AppEventType.AFTER_CONTRIBUTION_ACCEPT)
Expand Down Expand Up @@ -379,6 +374,13 @@ def sample_clients(self, num_clients: int = None) -> List[str]:

return clients

def set_fl_context(self, data: FLModel):
"""Set up the fl_ctx information based on the passed in FLModel data."""
if data and data.current_round is not None:
self.fl_ctx.set_prop(AppConstants.CURRENT_ROUND, data.current_round, private=True, sticky=True)
else:
yhwen marked this conversation as resolved.
Show resolved Hide resolved
self.debug("The FLModel data does not contain the current_round information.")

def get_component(self, component_id: str):
return self.engine.get_component(component_id)

Expand Down
2 changes: 1 addition & 1 deletion research/fed-bpt/src/global_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def run(self) -> None:

clients = self.sample_clients(self.num_clients)

global_model = FLModel(params={"global_es": global_es})
global_model = FLModel(params={"global_es": global_es}, current_round=self.current_round)
yhwen marked this conversation as resolved.
Show resolved Hide resolved
results = self.send_model_and_wait(targets=clients, data=global_model)

# get solutions from clients
Expand Down
Loading