Skip to content

Commit

Permalink
refactor wf and model controller
Browse files Browse the repository at this point in the history
  • Loading branch information
SYangster committed Apr 6, 2024
1 parent 61954d0 commit 518a3f1
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 199 deletions.
49 changes: 0 additions & 49 deletions nvflare/apis/wf_controller_spec.py

This file was deleted.

3 changes: 3 additions & 0 deletions nvflare/app_common/abstract/fl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class FLModelConst:
OPTIMIZER_PARAMS = "optimizer_params"
METRICS = "metrics"
CURRENT_ROUND = "current_round"
START_ROUND = "start_round"
TOTAL_ROUNDS = "total_rounds"
META = "meta"

Expand All @@ -45,6 +46,7 @@ def __init__(
params: Any = None,
optimizer_params: Any = None,
metrics: Optional[Dict] = None,
start_round: Optional[int] = 0,
current_round: Optional[int] = None,
total_rounds: Optional[int] = None,
meta: Optional[Dict] = None,
Expand Down Expand Up @@ -79,6 +81,7 @@ def __init__(
self.params = params
self.optimizer_params = optimizer_params
self.metrics = metrics
self.start_round = start_round
self.current_round = current_round
self.total_rounds = total_rounds

Expand Down
6 changes: 6 additions & 0 deletions nvflare/app_common/utils/fl_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def to_shareable(fl_model: FLModel) -> Shareable:
dxo.meta.update(meta)

shareable = dxo.to_shareable()
if fl_model.start_round is not None:
shareable.set_header(AppConstants.START_ROUND, fl_model.start_round)
if fl_model.current_round is not None:
shareable.set_header(AppConstants.CURRENT_ROUND, fl_model.current_round)
if fl_model.total_rounds is not None:
Expand Down Expand Up @@ -120,6 +122,7 @@ def from_shareable(shareable: Shareable, fl_ctx: Optional[FLContext] = None) ->
if MetaKey.INITIAL_METRICS in meta:
metrics = meta[MetaKey.INITIAL_METRICS]

start_round = shareable.get_header(AppConstants.START_ROUND, None)
current_round = shareable.get_header(AppConstants.CURRENT_ROUND, None)
total_rounds = shareable.get_header(AppConstants.NUM_ROUNDS, None)
validate_type = shareable.get_header(AppConstants.VALIDATE_TYPE, None)
Expand All @@ -138,6 +141,7 @@ def from_shareable(shareable: Shareable, fl_ctx: Optional[FLContext] = None) ->
params_type=params_type,
params=params,
metrics=metrics,
start_round=start_round,
current_round=current_round,
total_rounds=total_rounds,
meta=meta,
Expand Down Expand Up @@ -168,6 +172,7 @@ def from_dxo(dxo: DXO) -> FLModel:
params_type = dxo.data.get(FLModelConst.PARAMS_TYPE, None)
metrics = dxo.data.get(FLModelConst.METRICS, None)
optimizer_params = dxo.data.get(FLModelConst.OPTIMIZER_PARAMS, None)
start_round = dxo.data.get(FLModelConst.START_ROUND, None)
current_round = dxo.data.get(FLModelConst.CURRENT_ROUND, None)
total_rounds = dxo.data.get(FLModelConst.TOTAL_ROUNDS, None)
meta = dxo.data.get(FLModelConst.META, None)
Expand All @@ -177,6 +182,7 @@ def from_dxo(dxo: DXO) -> FLModel:
params_type=params_type,
metrics=metrics,
optimizer_params=optimizer_params,
start_round=start_round,
current_round=current_round,
total_rounds=total_rounds,
meta=meta,
Expand Down
51 changes: 39 additions & 12 deletions nvflare/app_common/workflows/base_fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,24 @@ class BaseFedAvg(WFController):
- def run(self)
"""

def sample_clients(self, min_clients):
def __init__(
self,
*args,
min_clients: int = 1000,
num_rounds: int = 5,
start_round: int = 0,
persist_every_n_rounds: int = 1,
**kwargs,
):
super().__init__(*args, **kwargs)

self.min_clients = min_clients
self.num_rounds = num_rounds
self.start_round = start_round
self.persist_every_n_rounds = persist_every_n_rounds
self.current_round = None

def sample_clients(self, num_clients):
"""Called by the `run` routine to get a list of available clients.
Args:
Expand All @@ -55,15 +72,16 @@ def sample_clients(self, min_clients):
Returns: list of clients.
"""
self._min_clients = min_clients

clients = self.engine.get_clients()
if len(clients) < self._min_clients:
self._min_clients = len(clients)

if self._min_clients < len(clients):
if num_clients <= len(clients):
random.shuffle(clients)
clients = clients[0 : self._min_clients]
clients = clients[0:num_clients]
else:
self.info(
f"num_clients ({num_clients}) is greater than the number of available clients. Returning all clients."
)

return clients

Expand All @@ -85,15 +103,15 @@ def _aggregate_fn(results: List[FLModel]) -> FLModel:
data=_result.params,
weight=_result.meta.get(FLMetaKey.NUM_STEPS_CURRENT_ROUND, 1.0),
contributor_name=_result.meta.get("client_name", AppConstants.CLIENT_UNKNOWN),
contribution_round=_result.meta.get("current_round", None),
contribution_round=_result.current_round,
)

aggregated_dict = aggregation_helper.get_result()

aggr_result = FLModel(
params=aggregated_dict,
params_type=results[0].params_type,
meta={"nr_aggregated": len(results), "current_round": results[0].meta["current_round"]},
meta={"nr_aggregated": len(results), "current_round": results[0].current_round},
)
return aggr_result

Expand All @@ -114,7 +132,7 @@ def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel:
if not aggregate_fn:
aggregate_fn = self._aggregate_fn

self.info(f"aggregating {len(results)} update(s) at round {self._current_round}")
self.info(f"aggregating {len(results)} update(s) at round {self.current_round}")
try:
aggr_result = aggregate_fn(results)
except Exception as e:
Expand All @@ -130,21 +148,30 @@ def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel:

return aggr_result

def update_model(self, aggr_result):
def update_model(self, model, aggr_result):
"""Called by the `run` routine to update the current global model (self.model) given the aggregated result.
Args:
model: FLModel to be updated.
aggr_result: aggregated FLModel.
Returns: None.
"""
self.event(AppEventType.BEFORE_SHAREABLE_TO_LEARNABLE)

self.model = FLModelUtils.update_model(self.model, aggr_result)
model = FLModelUtils.update_model(model, aggr_result)

# persistor uses Learnable format to save model
ml = make_model_learnable(weights=self.model.params, meta_props=self.model.meta)
ml = make_model_learnable(weights=model.params, meta_props=model.meta)
self.fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, ml, private=True, sticky=True)

self.event(AppEventType.AFTER_SHAREABLE_TO_LEARNABLE)

return model

def save_model(self, model: FLModel):
if (
self.persist_every_n_rounds != 0 and (self.current_round + 1) % self.persist_every_n_rounds == 0
) or self.current_round == self.num_rounds - 1:
super().save_model(model)
17 changes: 11 additions & 6 deletions nvflare/app_common/workflows/fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,24 @@ class FedAvg(BaseFedAvg):
def run(self) -> None:
self.info("Start FedAvg.")

for self._current_round in range(self._num_rounds):
self.info(f"Round {self._current_round} started.")
model = self.load_model()
model.start_round = self.start_round
model.total_rounds = self.num_rounds

clients = self.sample_clients(self._min_clients)
for self.current_round in range(self.start_round, self.start_round + self.num_rounds):
self.info(f"Round {self.current_round} started.")

results = self.send_model(targets=clients, data=self.model)
clients = self.sample_clients(self.min_clients)

model.current_round = self.current_round
results = self.send_model_and_wait(targets=clients, data=model)

aggregate_results = self.aggregate(
results, aggregate_fn=None
) # if no `aggregate_fn` provided, default `WeightedAggregationHelper` is used

self.update_model(aggregate_results)
model = self.update_model(model, aggregate_results)

self.save_model()
self.save_model(model)

self.info("Finished FedAvg.")
Loading

0 comments on commit 518a3f1

Please sign in to comment.