Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the current-round info the fl_ctx for BaseModelController #2916

Merged
merged 12 commits into from
Sep 6, 2024
20 changes: 12 additions & 8 deletions nvflare/app_common/workflows/base_model_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,10 @@ def broadcast_model(

if not isinstance(task_name, str):
raise TypeError("task_name must be a string but got {}".format(type(task_name)))
if data and not isinstance(data, FLModel):
raise TypeError("data must be a FLModel or None but got {}".format(type(data)))
if not data:
yhwen marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError("data must be a FLModel but got None")
if not isinstance(data, FLModel):
raise TypeError("data must be a FLModel but got {}".format(type(data)))
if min_responses is None:
min_responses = 0 # this is internally used by controller's broadcast to represent all targets
check_non_negative_int("min_responses", min_responses)
Expand All @@ -139,9 +141,7 @@ def broadcast_model(
if not blocking and not isinstance(callback, Callable):
raise TypeError("callback must be defined if blocking is False, but got {}".format(type(callback)))

if not data:
yhwen marked this conversation as resolved.
Show resolved Hide resolved
self.warning("data is None. Sending empty FLModel.")
data = FLModel(params_type=ParamsType.FULL, params={})
yhwen marked this conversation as resolved.
Show resolved Hide resolved
self.set_fl_context(data)

task = self._prepare_task(data=data, task_name=task_name, timeout=timeout, callback=callback)

Expand Down Expand Up @@ -224,9 +224,6 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None:
result_model.meta["props"] = client_task.task.props[AppConstants.META_DATA]
result_model.meta["client_name"] = client_name

if result_model.current_round is not None:
self.fl_ctx.set_prop(AppConstants.CURRENT_ROUND, result_model.current_round, private=True, sticky=True)

self.event(AppEventType.BEFORE_CONTRIBUTION_ACCEPT)
self._accept_train_result(client_name=client_name, result=result, fl_ctx=fl_ctx)
self.event(AppEventType.AFTER_CONTRIBUTION_ACCEPT)
Expand Down Expand Up @@ -379,6 +376,13 @@ def sample_clients(self, num_clients: int = None) -> List[str]:

return clients

def set_fl_context(self, data: FLModel):
"""Set up the fl_ctx information based on the passed in FLModel data."""
if data and data.current_round is not None:
self.fl_ctx.set_prop(AppConstants.CURRENT_ROUND, data.current_round, private=True, sticky=True)
else:
yhwen marked this conversation as resolved.
Show resolved Hide resolved
self.warning("The FLModel data does not contain the current_round information.")

YuanTingHsieh marked this conversation as resolved.
Show resolved Hide resolved
def get_component(self, component_id: str):
return self.engine.get_component(component_id)

Expand Down
2 changes: 1 addition & 1 deletion research/fed-bpt/src/global_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def run(self) -> None:

clients = self.sample_clients(self.num_clients)

global_model = FLModel(params={"global_es": global_es})
global_model = FLModel(params={"global_es": global_es}, current_round=self.current_round)
yhwen marked this conversation as resolved.
Show resolved Hide resolved
results = self.send_model_and_wait(targets=clients, data=global_model)

# get solutions from clients
Expand Down
Loading