Skip to content

Commit

Permalink
Refactor WFController and ModelController (#2475)
Browse files Browse the repository at this point in the history
* refactor wf and model controller

* clarify persisor_id
  • Loading branch information
SYangster authored Apr 8, 2024
1 parent 61954d0 commit d6827bc
Show file tree
Hide file tree
Showing 10 changed files with 256 additions and 226 deletions.
49 changes: 0 additions & 49 deletions nvflare/apis/wf_controller_spec.py

This file was deleted.

3 changes: 3 additions & 0 deletions nvflare/app_common/abstract/fl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class FLModelConst:
OPTIMIZER_PARAMS = "optimizer_params"
METRICS = "metrics"
CURRENT_ROUND = "current_round"
START_ROUND = "start_round"
TOTAL_ROUNDS = "total_rounds"
META = "meta"

Expand All @@ -45,6 +46,7 @@ def __init__(
params: Any = None,
optimizer_params: Any = None,
metrics: Optional[Dict] = None,
start_round: Optional[int] = 0,
current_round: Optional[int] = None,
total_rounds: Optional[int] = None,
meta: Optional[Dict] = None,
Expand Down Expand Up @@ -79,6 +81,7 @@ def __init__(
self.params = params
self.optimizer_params = optimizer_params
self.metrics = metrics
self.start_round = start_round
self.current_round = current_round
self.total_rounds = total_rounds

Expand Down
4 changes: 3 additions & 1 deletion nvflare/app_common/decomposers/common_decomposers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,22 @@ def decompose(self, b: FLModel, manager: DatumManager = None) -> Any:
externalizer.externalize(b.params),
externalizer.externalize(b.optimizer_params),
externalizer.externalize(b.metrics),
b.start_round,
b.current_round,
b.total_rounds,
externalizer.externalize(b.meta),
)

def recompose(self, data: tuple, manager: DatumManager = None) -> FLModel:
assert isinstance(data, tuple)
pt, params, opt_params, metrics, cr, tr, meta = data
pt, params, opt_params, metrics, sr, cr, tr, meta = data
internalizer = Internalizer(manager)
return FLModel(
params_type=pt,
params=internalizer.internalize(params),
optimizer_params=internalizer.internalize(opt_params),
metrics=internalizer.internalize(metrics),
start_round=sr,
current_round=cr,
total_rounds=tr,
meta=internalizer.internalize(meta),
Expand Down
2 changes: 1 addition & 1 deletion nvflare/app_common/executors/task_exchanger.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort
return make_reply(ReturnCode.EXECUTION_EXCEPTION)

current_round = shareable.get_header(AppConstants.CURRENT_ROUND)
if current_round:
if current_round is not None:
result.set_header(AppConstants.CURRENT_ROUND, current_round)

if not self.check_output_shareable(task_name, result, fl_ctx):
Expand Down
6 changes: 6 additions & 0 deletions nvflare/app_common/utils/fl_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def to_shareable(fl_model: FLModel) -> Shareable:
dxo.meta.update(meta)

shareable = dxo.to_shareable()
if fl_model.start_round is not None:
shareable.set_header(AppConstants.START_ROUND, fl_model.start_round)
if fl_model.current_round is not None:
shareable.set_header(AppConstants.CURRENT_ROUND, fl_model.current_round)
if fl_model.total_rounds is not None:
Expand Down Expand Up @@ -120,6 +122,7 @@ def from_shareable(shareable: Shareable, fl_ctx: Optional[FLContext] = None) ->
if MetaKey.INITIAL_METRICS in meta:
metrics = meta[MetaKey.INITIAL_METRICS]

start_round = shareable.get_header(AppConstants.START_ROUND, None)
current_round = shareable.get_header(AppConstants.CURRENT_ROUND, None)
total_rounds = shareable.get_header(AppConstants.NUM_ROUNDS, None)
validate_type = shareable.get_header(AppConstants.VALIDATE_TYPE, None)
Expand All @@ -138,6 +141,7 @@ def from_shareable(shareable: Shareable, fl_ctx: Optional[FLContext] = None) ->
params_type=params_type,
params=params,
metrics=metrics,
start_round=start_round,
current_round=current_round,
total_rounds=total_rounds,
meta=meta,
Expand Down Expand Up @@ -168,6 +172,7 @@ def from_dxo(dxo: DXO) -> FLModel:
params_type = dxo.data.get(FLModelConst.PARAMS_TYPE, None)
metrics = dxo.data.get(FLModelConst.METRICS, None)
optimizer_params = dxo.data.get(FLModelConst.OPTIMIZER_PARAMS, None)
start_round = dxo.data.get(FLModelConst.START_ROUND, None)
current_round = dxo.data.get(FLModelConst.CURRENT_ROUND, None)
total_rounds = dxo.data.get(FLModelConst.TOTAL_ROUNDS, None)
meta = dxo.data.get(FLModelConst.META, None)
Expand All @@ -177,6 +182,7 @@ def from_dxo(dxo: DXO) -> FLModel:
params_type=params_type,
metrics=metrics,
optimizer_params=optimizer_params,
start_round=start_round,
current_round=current_round,
total_rounds=total_rounds,
meta=meta,
Expand Down
91 changes: 64 additions & 27 deletions nvflare/app_common/workflows/base_fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,52 @@


class BaseFedAvg(WFController):
"""The base controller for FedAvg Workflow. *Note*: This class is based on the `WFController`.
def __init__(
self,
*args,
min_clients: int = 1000,
num_rounds: int = 5,
start_round: int = 0,
persist_every_n_rounds: int = 1,
**kwargs,
):
"""The base controller for FedAvg Workflow. *Note*: This class is based on the `WFController`.
Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629).
A model persistor can be configured via the `persistor_id` argument of the `WFController`.
The model persistor is used to load the initial global model which is sent to a list of clients.
Each client sends it's updated weights after local training which is aggregated.
Next, the global model is updated.
The model_persistor will also save the model after training.
Provides the default implementations for the follow routines:
- def sample_clients(self, min_clients)
- def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel
- def update_model(self, aggr_result)
The `run` routine needs to be implemented by the derived class:
- def run(self)
Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629).
The model persistor (persistor_id) is used to load the initial global model which is sent to a list of clients.
Each client sends it's updated weights after local training which is aggregated.
Next, the global model is updated.
The model_persistor also saves the model after training.
Provides the default implementations for the follow routines:
- def sample_clients(self, min_clients)
- def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel
- def update_model(self, aggr_result)
The `run` routine needs to be implemented by the derived class:
Args:
min_clients (int, optional): The minimum number of clients responses before
Workflow starts to wait for `wait_time_after_min_received`. Note that the workflow will move forward
when all available clients have responded regardless of this value. Defaults to 1000.
num_rounds (int, optional): The total number of training rounds. Defaults to 5.
start_round (int, optional): The starting round number.
persist_every_n_rounds (int, optional): persist the global model every n rounds. Defaults to 1.
If n is 0 then no persist.
"""
super().__init__(*args, **kwargs)

- def run(self)
"""
self.min_clients = min_clients
self.num_rounds = num_rounds
self.start_round = start_round
self.persist_every_n_rounds = persist_every_n_rounds
self.current_round = None

def sample_clients(self, min_clients):
def sample_clients(self, num_clients):
"""Called by the `run` routine to get a list of available clients.
Args:
Expand All @@ -55,15 +82,16 @@ def sample_clients(self, min_clients):
Returns: list of clients.
"""
self._min_clients = min_clients

clients = self.engine.get_clients()
if len(clients) < self._min_clients:
self._min_clients = len(clients)

if self._min_clients < len(clients):
if num_clients <= len(clients):
random.shuffle(clients)
clients = clients[0 : self._min_clients]
clients = clients[0:num_clients]
else:
self.info(
f"num_clients ({num_clients}) is greater than the number of available clients. Returning all clients."
)

return clients

Expand All @@ -85,15 +113,15 @@ def _aggregate_fn(results: List[FLModel]) -> FLModel:
data=_result.params,
weight=_result.meta.get(FLMetaKey.NUM_STEPS_CURRENT_ROUND, 1.0),
contributor_name=_result.meta.get("client_name", AppConstants.CLIENT_UNKNOWN),
contribution_round=_result.meta.get("current_round", None),
contribution_round=_result.current_round,
)

aggregated_dict = aggregation_helper.get_result()

aggr_result = FLModel(
params=aggregated_dict,
params_type=results[0].params_type,
meta={"nr_aggregated": len(results), "current_round": results[0].meta["current_round"]},
meta={"nr_aggregated": len(results), "current_round": results[0].current_round},
)
return aggr_result

Expand All @@ -114,7 +142,7 @@ def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel:
if not aggregate_fn:
aggregate_fn = self._aggregate_fn

self.info(f"aggregating {len(results)} update(s) at round {self._current_round}")
self.info(f"aggregating {len(results)} update(s) at round {self.current_round}")
try:
aggr_result = aggregate_fn(results)
except Exception as e:
Expand All @@ -130,21 +158,30 @@ def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel:

return aggr_result

def update_model(self, aggr_result):
def update_model(self, model, aggr_result):
"""Called by the `run` routine to update the current global model (self.model) given the aggregated result.
Args:
model: FLModel to be updated.
aggr_result: aggregated FLModel.
Returns: None.
"""
self.event(AppEventType.BEFORE_SHAREABLE_TO_LEARNABLE)

self.model = FLModelUtils.update_model(self.model, aggr_result)
model = FLModelUtils.update_model(model, aggr_result)

# persistor uses Learnable format to save model
ml = make_model_learnable(weights=self.model.params, meta_props=self.model.meta)
ml = make_model_learnable(weights=model.params, meta_props=model.meta)
self.fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, ml, private=True, sticky=True)

self.event(AppEventType.AFTER_SHAREABLE_TO_LEARNABLE)

return model

def save_model(self, model: FLModel):
if (
self.persist_every_n_rounds != 0 and (self.current_round + 1) % self.persist_every_n_rounds == 0
) or self.current_round == self.num_rounds - 1:
super().save_model(model)
20 changes: 13 additions & 7 deletions nvflare/app_common/workflows/fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


class FedAvg(BaseFedAvg):
"""Controller for FedAvg Workflow. *Note*: This class is based on the experimental `ModelController`.
"""Controller for FedAvg Workflow. *Note*: This class is based on the `WFController`.
Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629).
Provides the implementations for the `run` routine, controlling the main workflow:
Expand All @@ -29,6 +29,7 @@ class FedAvg(BaseFedAvg):
Workflow starts to wait for `wait_time_after_min_received`. Note that the workflow will move forward
when all available clients have responded regardless of this value. Defaults to 1000.
num_rounds (int, optional): The total number of training rounds. Defaults to 5.
start_round (int, optional): The starting round number.
persistor_id (str, optional): ID of the persistor component. Defaults to "persistor".
ignore_result_error (bool, optional): whether this controller can proceed if client result has errors.
Defaults to False.
Expand All @@ -43,19 +44,24 @@ class FedAvg(BaseFedAvg):
def run(self) -> None:
self.info("Start FedAvg.")

for self._current_round in range(self._num_rounds):
self.info(f"Round {self._current_round} started.")
model = self.load_model()
model.start_round = self.start_round
model.total_rounds = self.num_rounds

clients = self.sample_clients(self._min_clients)
for self.current_round in range(self.start_round, self.start_round + self.num_rounds):
self.info(f"Round {self.current_round} started.")

results = self.send_model(targets=clients, data=self.model)
clients = self.sample_clients(self.min_clients)

model.current_round = self.current_round
results = self.send_model_and_wait(targets=clients, data=model)

aggregate_results = self.aggregate(
results, aggregate_fn=None
) # if no `aggregate_fn` provided, default `WeightedAggregationHelper` is used

self.update_model(aggregate_results)
model = self.update_model(model, aggregate_results)

self.save_model()
self.save_model(model)

self.info("Finished FedAvg.")
Loading

0 comments on commit d6827bc

Please sign in to comment.