Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor WFController and ModelController #2475

Merged
merged 2 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 0 additions & 49 deletions nvflare/apis/wf_controller_spec.py

This file was deleted.

3 changes: 3 additions & 0 deletions nvflare/app_common/abstract/fl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class FLModelConst:
OPTIMIZER_PARAMS = "optimizer_params"
METRICS = "metrics"
CURRENT_ROUND = "current_round"
START_ROUND = "start_round"
TOTAL_ROUNDS = "total_rounds"
META = "meta"

Expand All @@ -45,6 +46,7 @@ def __init__(
params: Any = None,
optimizer_params: Any = None,
metrics: Optional[Dict] = None,
start_round: Optional[int] = 0,
current_round: Optional[int] = None,
total_rounds: Optional[int] = None,
meta: Optional[Dict] = None,
Expand Down Expand Up @@ -79,6 +81,7 @@ def __init__(
self.params = params
self.optimizer_params = optimizer_params
self.metrics = metrics
self.start_round = start_round
self.current_round = current_round
self.total_rounds = total_rounds

Expand Down
4 changes: 3 additions & 1 deletion nvflare/app_common/decomposers/common_decomposers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,22 @@ def decompose(self, b: FLModel, manager: DatumManager = None) -> Any:
externalizer.externalize(b.params),
externalizer.externalize(b.optimizer_params),
externalizer.externalize(b.metrics),
b.start_round,
b.current_round,
b.total_rounds,
externalizer.externalize(b.meta),
)

def recompose(self, data: tuple, manager: DatumManager = None) -> FLModel:
assert isinstance(data, tuple)
pt, params, opt_params, metrics, cr, tr, meta = data
pt, params, opt_params, metrics, sr, cr, tr, meta = data
internalizer = Internalizer(manager)
return FLModel(
params_type=pt,
params=internalizer.internalize(params),
optimizer_params=internalizer.internalize(opt_params),
metrics=internalizer.internalize(metrics),
start_round=sr,
current_round=cr,
total_rounds=tr,
meta=internalizer.internalize(meta),
Expand Down
2 changes: 1 addition & 1 deletion nvflare/app_common/executors/task_exchanger.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort
return make_reply(ReturnCode.EXECUTION_EXCEPTION)

current_round = shareable.get_header(AppConstants.CURRENT_ROUND)
if current_round:
if current_round is not None:
result.set_header(AppConstants.CURRENT_ROUND, current_round)

if not self.check_output_shareable(task_name, result, fl_ctx):
Expand Down
6 changes: 6 additions & 0 deletions nvflare/app_common/utils/fl_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def to_shareable(fl_model: FLModel) -> Shareable:
dxo.meta.update(meta)

shareable = dxo.to_shareable()
if fl_model.start_round is not None:
shareable.set_header(AppConstants.START_ROUND, fl_model.start_round)
if fl_model.current_round is not None:
shareable.set_header(AppConstants.CURRENT_ROUND, fl_model.current_round)
if fl_model.total_rounds is not None:
Expand Down Expand Up @@ -120,6 +122,7 @@ def from_shareable(shareable: Shareable, fl_ctx: Optional[FLContext] = None) ->
if MetaKey.INITIAL_METRICS in meta:
metrics = meta[MetaKey.INITIAL_METRICS]

start_round = shareable.get_header(AppConstants.START_ROUND, None)
current_round = shareable.get_header(AppConstants.CURRENT_ROUND, None)
total_rounds = shareable.get_header(AppConstants.NUM_ROUNDS, None)
validate_type = shareable.get_header(AppConstants.VALIDATE_TYPE, None)
Expand All @@ -138,6 +141,7 @@ def from_shareable(shareable: Shareable, fl_ctx: Optional[FLContext] = None) ->
params_type=params_type,
params=params,
metrics=metrics,
start_round=start_round,
current_round=current_round,
total_rounds=total_rounds,
meta=meta,
Expand Down Expand Up @@ -168,6 +172,7 @@ def from_dxo(dxo: DXO) -> FLModel:
params_type = dxo.data.get(FLModelConst.PARAMS_TYPE, None)
metrics = dxo.data.get(FLModelConst.METRICS, None)
optimizer_params = dxo.data.get(FLModelConst.OPTIMIZER_PARAMS, None)
start_round = dxo.data.get(FLModelConst.START_ROUND, None)
current_round = dxo.data.get(FLModelConst.CURRENT_ROUND, None)
total_rounds = dxo.data.get(FLModelConst.TOTAL_ROUNDS, None)
meta = dxo.data.get(FLModelConst.META, None)
Expand All @@ -177,6 +182,7 @@ def from_dxo(dxo: DXO) -> FLModel:
params_type=params_type,
metrics=metrics,
optimizer_params=optimizer_params,
start_round=start_round,
current_round=current_round,
total_rounds=total_rounds,
meta=meta,
Expand Down
91 changes: 64 additions & 27 deletions nvflare/app_common/workflows/base_fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,52 @@


class BaseFedAvg(WFController):
"""The base controller for FedAvg Workflow. *Note*: This class is based on the `WFController`.
def __init__(
self,
*args,
min_clients: int = 1000,
num_rounds: int = 5,
start_round: int = 0,
persist_every_n_rounds: int = 1,
**kwargs,
):
"""The base controller for FedAvg Workflow. *Note*: This class is based on the `WFController`.

Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629).

A model persistor can be configured via the `persistor_id` argument of the `WFController`.
The model persistor is used to load the initial global model which is sent to a list of clients.
Each client sends it's updated weights after local training which is aggregated.
Next, the global model is updated.
The model_persistor will also save the model after training.

Provides the default implementations for the follow routines:
- def sample_clients(self, min_clients)
- def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel
- def update_model(self, aggr_result)

The `run` routine needs to be implemented by the derived class:

- def run(self)

Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629).
The model persistor (persistor_id) is used to load the initial global model which is sent to a list of clients.
Each client sends it's updated weights after local training which is aggregated.
Next, the global model is updated.
The model_persistor also saves the model after training.

Provides the default implementations for the follow routines:
- def sample_clients(self, min_clients)
- def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel
- def update_model(self, aggr_result)

The `run` routine needs to be implemented by the derived class:
Args:
min_clients (int, optional): The minimum number of clients responses before
Workflow starts to wait for `wait_time_after_min_received`. Note that the workflow will move forward
when all available clients have responded regardless of this value. Defaults to 1000.
num_rounds (int, optional): The total number of training rounds. Defaults to 5.
start_round (int, optional): The starting round number.
persist_every_n_rounds (int, optional): persist the global model every n rounds. Defaults to 1.
If n is 0 then no persist.
"""
super().__init__(*args, **kwargs)

- def run(self)
"""
self.min_clients = min_clients
self.num_rounds = num_rounds
self.start_round = start_round
self.persist_every_n_rounds = persist_every_n_rounds
self.current_round = None

def sample_clients(self, min_clients):
def sample_clients(self, num_clients):
"""Called by the `run` routine to get a list of available clients.

Args:
Expand All @@ -55,15 +82,16 @@ def sample_clients(self, min_clients):
Returns: list of clients.

"""
self._min_clients = min_clients

clients = self.engine.get_clients()
if len(clients) < self._min_clients:
self._min_clients = len(clients)

if self._min_clients < len(clients):
if num_clients <= len(clients):
random.shuffle(clients)
clients = clients[0 : self._min_clients]
clients = clients[0:num_clients]
else:
self.info(
f"num_clients ({num_clients}) is greater than the number of available clients. Returning all clients."
)

return clients

Expand All @@ -85,15 +113,15 @@ def _aggregate_fn(results: List[FLModel]) -> FLModel:
data=_result.params,
weight=_result.meta.get(FLMetaKey.NUM_STEPS_CURRENT_ROUND, 1.0),
contributor_name=_result.meta.get("client_name", AppConstants.CLIENT_UNKNOWN),
contribution_round=_result.meta.get("current_round", None),
contribution_round=_result.current_round,
)

aggregated_dict = aggregation_helper.get_result()

aggr_result = FLModel(
params=aggregated_dict,
params_type=results[0].params_type,
meta={"nr_aggregated": len(results), "current_round": results[0].meta["current_round"]},
meta={"nr_aggregated": len(results), "current_round": results[0].current_round},
)
return aggr_result

Expand All @@ -114,7 +142,7 @@ def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel:
if not aggregate_fn:
aggregate_fn = self._aggregate_fn

self.info(f"aggregating {len(results)} update(s) at round {self._current_round}")
self.info(f"aggregating {len(results)} update(s) at round {self.current_round}")
try:
aggr_result = aggregate_fn(results)
except Exception as e:
Expand All @@ -130,21 +158,30 @@ def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel:

return aggr_result

def update_model(self, aggr_result):
def update_model(self, model, aggr_result):
"""Called by the `run` routine to update the current global model (self.model) given the aggregated result.

Args:
model: FLModel to be updated.
aggr_result: aggregated FLModel.

Returns: None.

"""
self.event(AppEventType.BEFORE_SHAREABLE_TO_LEARNABLE)

self.model = FLModelUtils.update_model(self.model, aggr_result)
model = FLModelUtils.update_model(model, aggr_result)

# persistor uses Learnable format to save model
ml = make_model_learnable(weights=self.model.params, meta_props=self.model.meta)
ml = make_model_learnable(weights=model.params, meta_props=model.meta)
self.fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, ml, private=True, sticky=True)

self.event(AppEventType.AFTER_SHAREABLE_TO_LEARNABLE)

return model

def save_model(self, model: FLModel):
if (
self.persist_every_n_rounds != 0 and (self.current_round + 1) % self.persist_every_n_rounds == 0
) or self.current_round == self.num_rounds - 1:
super().save_model(model)
20 changes: 13 additions & 7 deletions nvflare/app_common/workflows/fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


class FedAvg(BaseFedAvg):
"""Controller for FedAvg Workflow. *Note*: This class is based on the experimental `ModelController`.
"""Controller for FedAvg Workflow. *Note*: This class is based on the `WFController`.
Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629).

Provides the implementations for the `run` routine, controlling the main workflow:
Expand All @@ -29,6 +29,7 @@ class FedAvg(BaseFedAvg):
Workflow starts to wait for `wait_time_after_min_received`. Note that the workflow will move forward
when all available clients have responded regardless of this value. Defaults to 1000.
num_rounds (int, optional): The total number of training rounds. Defaults to 5.
start_round (int, optional): The starting round number.
persistor_id (str, optional): ID of the persistor component. Defaults to "persistor".
ignore_result_error (bool, optional): whether this controller can proceed if client result has errors.
Defaults to False.
Expand All @@ -43,19 +44,24 @@ class FedAvg(BaseFedAvg):
def run(self) -> None:
self.info("Start FedAvg.")

for self._current_round in range(self._num_rounds):
self.info(f"Round {self._current_round} started.")
model = self.load_model()
model.start_round = self.start_round
model.total_rounds = self.num_rounds

clients = self.sample_clients(self._min_clients)
for self.current_round in range(self.start_round, self.start_round + self.num_rounds):
self.info(f"Round {self.current_round} started.")

results = self.send_model(targets=clients, data=self.model)
clients = self.sample_clients(self.min_clients)

model.current_round = self.current_round
results = self.send_model_and_wait(targets=clients, data=model)

aggregate_results = self.aggregate(
results, aggregate_fn=None
) # if no `aggregate_fn` provided, default `WeightedAggregationHelper` is used

self.update_model(aggregate_results)
model = self.update_model(model, aggregate_results)

self.save_model()
self.save_model(model)

self.info("Finished FedAvg.")
Loading
Loading