diff --git a/nvflare/apis/wf_controller_spec.py b/nvflare/apis/wf_controller_spec.py deleted file mode 100644 index d5f735da5b..0000000000 --- a/nvflare/apis/wf_controller_spec.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod -from typing import Callable, List, Union - - -class WFControllerSpec(ABC): - @abstractmethod - def run(self): - """Main `run` routine for the controller workflow.""" - raise NotImplementedError - - def send_model( - self, - task_name: str, - data: any, - targets: Union[List[str], None], - timeout: int, - wait_time_after_min_received: int, - blocking: bool, - callback: Callable, - ) -> List: - """Send a task with data to a list of targets. - - Args: - task_name (str): name of the task. - data (any): data to be sent to clients. - targets (List[str]): the list of target client names. - timeout (int): time to wait for clients to perform task. - wait_time_after_min_received (int): time to wait after minimum number of clients responses have been received. - blocking (bool): whether to block to wait for task result. - callback (Callable[any]): callback when a result is received, only called when blocking=False. - - Returns: - List[any] if blocking=True else None - """ - raise NotImplementedError diff --git a/nvflare/app_common/abstract/fl_model.py b/nvflare/app_common/abstract/fl_model.py index a82ae57c2c..5b115df150 100644 --- a/nvflare/app_common/abstract/fl_model.py +++ b/nvflare/app_common/abstract/fl_model.py @@ -30,6 +30,7 @@ class FLModelConst: OPTIMIZER_PARAMS = "optimizer_params" METRICS = "metrics" CURRENT_ROUND = "current_round" + START_ROUND = "start_round" TOTAL_ROUNDS = "total_rounds" META = "meta" @@ -45,6 +46,7 @@ def __init__( params: Any = None, optimizer_params: Any = None, metrics: Optional[Dict] = None, + start_round: Optional[int] = 0, current_round: Optional[int] = None, total_rounds: Optional[int] = None, meta: Optional[Dict] = None, @@ -79,6 +81,7 @@ def __init__( self.params = params self.optimizer_params = optimizer_params self.metrics = metrics + self.start_round = start_round self.current_round = current_round self.total_rounds = total_rounds diff --git a/nvflare/app_common/decomposers/common_decomposers.py b/nvflare/app_common/decomposers/common_decomposers.py index 1b9f7514f9..c479d89d22 100644 --- a/nvflare/app_common/decomposers/common_decomposers.py +++ b/nvflare/app_common/decomposers/common_decomposers.py @@ -36,6 +36,7 @@ def decompose(self, b: FLModel, manager: DatumManager = None) -> Any: externalizer.externalize(b.params), externalizer.externalize(b.optimizer_params), externalizer.externalize(b.metrics), + b.start_round, b.current_round, b.total_rounds, externalizer.externalize(b.meta), @@ -43,13 +44,14 @@ def decompose(self, b: FLModel, manager: DatumManager = None) -> Any: def recompose(self, data: tuple, manager: DatumManager = None) -> FLModel: assert isinstance(data, tuple) - pt, params, opt_params, metrics, cr, tr, meta = data + pt, params, opt_params, metrics, sr, cr, tr, meta = data internalizer = Internalizer(manager) return FLModel( params_type=pt, params=internalizer.internalize(params), optimizer_params=internalizer.internalize(opt_params), metrics=internalizer.internalize(metrics), + start_round=sr, current_round=cr, total_rounds=tr, meta=internalizer.internalize(meta), diff --git a/nvflare/app_common/executors/task_exchanger.py b/nvflare/app_common/executors/task_exchanger.py index 77a7b19bb9..e9053629a7 100644 --- a/nvflare/app_common/executors/task_exchanger.py +++ b/nvflare/app_common/executors/task_exchanger.py @@ -207,7 +207,7 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort return make_reply(ReturnCode.EXECUTION_EXCEPTION) current_round = shareable.get_header(AppConstants.CURRENT_ROUND) - if current_round: + if current_round is not None: result.set_header(AppConstants.CURRENT_ROUND, current_round) if not self.check_output_shareable(task_name, result, fl_ctx): diff --git a/nvflare/app_common/utils/fl_model_utils.py b/nvflare/app_common/utils/fl_model_utils.py index 2d84daa14f..ac7201e414 100644 --- a/nvflare/app_common/utils/fl_model_utils.py +++ b/nvflare/app_common/utils/fl_model_utils.py @@ -73,6 +73,8 @@ def to_shareable(fl_model: FLModel) -> Shareable: dxo.meta.update(meta) shareable = dxo.to_shareable() + if fl_model.start_round is not None: + shareable.set_header(AppConstants.START_ROUND, fl_model.start_round) if fl_model.current_round is not None: shareable.set_header(AppConstants.CURRENT_ROUND, fl_model.current_round) if fl_model.total_rounds is not None: @@ -120,6 +122,7 @@ def from_shareable(shareable: Shareable, fl_ctx: Optional[FLContext] = None) -> if MetaKey.INITIAL_METRICS in meta: metrics = meta[MetaKey.INITIAL_METRICS] + start_round = shareable.get_header(AppConstants.START_ROUND, None) current_round = shareable.get_header(AppConstants.CURRENT_ROUND, None) total_rounds = shareable.get_header(AppConstants.NUM_ROUNDS, None) validate_type = shareable.get_header(AppConstants.VALIDATE_TYPE, None) @@ -138,6 +141,7 @@ def from_shareable(shareable: Shareable, fl_ctx: Optional[FLContext] = None) -> params_type=params_type, params=params, metrics=metrics, + start_round=start_round, current_round=current_round, total_rounds=total_rounds, meta=meta, @@ -168,6 +172,7 @@ def from_dxo(dxo: DXO) -> FLModel: params_type = dxo.data.get(FLModelConst.PARAMS_TYPE, None) metrics = dxo.data.get(FLModelConst.METRICS, None) optimizer_params = dxo.data.get(FLModelConst.OPTIMIZER_PARAMS, None) + start_round = dxo.data.get(FLModelConst.START_ROUND, None) current_round = dxo.data.get(FLModelConst.CURRENT_ROUND, None) total_rounds = dxo.data.get(FLModelConst.TOTAL_ROUNDS, None) meta = dxo.data.get(FLModelConst.META, None) @@ -177,6 +182,7 @@ def from_dxo(dxo: DXO) -> FLModel: params_type=params_type, metrics=metrics, optimizer_params=optimizer_params, + start_round=start_round, current_round=current_round, total_rounds=total_rounds, meta=meta, diff --git a/nvflare/app_common/workflows/base_fedavg.py b/nvflare/app_common/workflows/base_fedavg.py index fc0e9c7abf..745d0ff840 100644 --- a/nvflare/app_common/workflows/base_fedavg.py +++ b/nvflare/app_common/workflows/base_fedavg.py @@ -28,25 +28,50 @@ class BaseFedAvg(WFController): - """The base controller for FedAvg Workflow. *Note*: This class is based on the `WFController`. + def __init__( + self, + *args, + min_clients: int = 1000, + num_rounds: int = 5, + start_round: int = 0, + persist_every_n_rounds: int = 1, + **kwargs, + ): + """The base controller for FedAvg Workflow. *Note*: This class is based on the `WFController`. + + Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629). + The model persistor (persistor_id) is used to load the initial global model which is sent to a list of clients. + Each client sends it's updated weights after local training which is aggregated. + Next, the global model is updated. + The model_persistor also saves the model after training. + + Provides the default implementations for the follow routines: + - def sample_clients(self, min_clients) + - def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel + - def update_model(self, aggr_result) + + The `run` routine needs to be implemented by the derived class: + + - def run(self) - Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629). - The model persistor (persistor_id) is used to load the initial global model which is sent to a list of clients. - Each client sends it's updated weights after local training which is aggregated. - Next, the global model is updated. - The model_persistor also saves the model after training. - - Provides the default implementations for the follow routines: - - def sample_clients(self, min_clients) - - def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel - - def update_model(self, aggr_result) - - The `run` routine needs to be implemented by the derived class: + Args: + min_clients (int, optional): The minimum number of clients responses before + Workflow starts to wait for `wait_time_after_min_received`. Note that the workflow will move forward + when all available clients have responded regardless of this value. Defaults to 1000. + num_rounds (int, optional): The total number of training rounds. Defaults to 5. + start_round (int, optional): The starting round number. + persist_every_n_rounds (int, optional): persist the global model every n rounds. Defaults to 1. + If n is 0 then no persist. + """ + super().__init__(*args, **kwargs) - - def run(self) - """ + self.min_clients = min_clients + self.num_rounds = num_rounds + self.start_round = start_round + self.persist_every_n_rounds = persist_every_n_rounds + self.current_round = None - def sample_clients(self, min_clients): + def sample_clients(self, num_clients): """Called by the `run` routine to get a list of available clients. Args: @@ -55,15 +80,16 @@ def sample_clients(self, min_clients): Returns: list of clients. """ - self._min_clients = min_clients clients = self.engine.get_clients() - if len(clients) < self._min_clients: - self._min_clients = len(clients) - if self._min_clients < len(clients): + if num_clients <= len(clients): random.shuffle(clients) - clients = clients[0 : self._min_clients] + clients = clients[0:num_clients] + else: + self.info( + f"num_clients ({num_clients}) is greater than the number of available clients. Returning all clients." + ) return clients @@ -85,7 +111,7 @@ def _aggregate_fn(results: List[FLModel]) -> FLModel: data=_result.params, weight=_result.meta.get(FLMetaKey.NUM_STEPS_CURRENT_ROUND, 1.0), contributor_name=_result.meta.get("client_name", AppConstants.CLIENT_UNKNOWN), - contribution_round=_result.meta.get("current_round", None), + contribution_round=_result.current_round, ) aggregated_dict = aggregation_helper.get_result() @@ -93,7 +119,7 @@ def _aggregate_fn(results: List[FLModel]) -> FLModel: aggr_result = FLModel( params=aggregated_dict, params_type=results[0].params_type, - meta={"nr_aggregated": len(results), "current_round": results[0].meta["current_round"]}, + meta={"nr_aggregated": len(results), "current_round": results[0].current_round}, ) return aggr_result @@ -114,7 +140,7 @@ def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel: if not aggregate_fn: aggregate_fn = self._aggregate_fn - self.info(f"aggregating {len(results)} update(s) at round {self._current_round}") + self.info(f"aggregating {len(results)} update(s) at round {self.current_round}") try: aggr_result = aggregate_fn(results) except Exception as e: @@ -130,10 +156,11 @@ def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel: return aggr_result - def update_model(self, aggr_result): + def update_model(self, model, aggr_result): """Called by the `run` routine to update the current global model (self.model) given the aggregated result. Args: + model: FLModel to be updated. aggr_result: aggregated FLModel. Returns: None. @@ -141,10 +168,18 @@ def update_model(self, aggr_result): """ self.event(AppEventType.BEFORE_SHAREABLE_TO_LEARNABLE) - self.model = FLModelUtils.update_model(self.model, aggr_result) + model = FLModelUtils.update_model(model, aggr_result) # persistor uses Learnable format to save model - ml = make_model_learnable(weights=self.model.params, meta_props=self.model.meta) + ml = make_model_learnable(weights=model.params, meta_props=model.meta) self.fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, ml, private=True, sticky=True) self.event(AppEventType.AFTER_SHAREABLE_TO_LEARNABLE) + + return model + + def save_model(self, model: FLModel): + if ( + self.persist_every_n_rounds != 0 and (self.current_round + 1) % self.persist_every_n_rounds == 0 + ) or self.current_round == self.num_rounds - 1: + super().save_model(model) diff --git a/nvflare/app_common/workflows/fedavg.py b/nvflare/app_common/workflows/fedavg.py index 07df9962e3..61edca34fc 100644 --- a/nvflare/app_common/workflows/fedavg.py +++ b/nvflare/app_common/workflows/fedavg.py @@ -16,7 +16,7 @@ class FedAvg(BaseFedAvg): - """Controller for FedAvg Workflow. *Note*: This class is based on the experimental `ModelController`. + """Controller for FedAvg Workflow. *Note*: This class is based on the `WFController`. Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629). Provides the implementations for the `run` routine, controlling the main workflow: @@ -29,6 +29,7 @@ class FedAvg(BaseFedAvg): Workflow starts to wait for `wait_time_after_min_received`. Note that the workflow will move forward when all available clients have responded regardless of this value. Defaults to 1000. num_rounds (int, optional): The total number of training rounds. Defaults to 5. + start_round (int, optional): The starting round number. persistor_id (str, optional): ID of the persistor component. Defaults to "persistor". ignore_result_error (bool, optional): whether this controller can proceed if client result has errors. Defaults to False. @@ -43,19 +44,24 @@ class FedAvg(BaseFedAvg): def run(self) -> None: self.info("Start FedAvg.") - for self._current_round in range(self._num_rounds): - self.info(f"Round {self._current_round} started.") + model = self.load_model() + model.start_round = self.start_round + model.total_rounds = self.num_rounds - clients = self.sample_clients(self._min_clients) + for self.current_round in range(self.start_round, self.start_round + self.num_rounds): + self.info(f"Round {self.current_round} started.") - results = self.send_model(targets=clients, data=self.model) + clients = self.sample_clients(self.min_clients) + + model.current_round = self.current_round + results = self.send_model_and_wait(targets=clients, data=model) aggregate_results = self.aggregate( results, aggregate_fn=None ) # if no `aggregate_fn` provided, default `WeightedAggregationHelper` is used - self.update_model(aggregate_results) + model = self.update_model(model, aggregate_results) - self.save_model() + self.save_model(model) self.info("Finished FedAvg.") diff --git a/nvflare/app_common/workflows/model_controller.py b/nvflare/app_common/workflows/model_controller.py index 087209202e..f891ef8ee0 100644 --- a/nvflare/app_common/workflows/model_controller.py +++ b/nvflare/app_common/workflows/model_controller.py @@ -30,29 +30,21 @@ from nvflare.app_common.app_event_type import AppEventType from nvflare.app_common.utils.fl_component_wrapper import FLComponentWrapper from nvflare.app_common.utils.fl_model_utils import FLModelUtils -from nvflare.fuel.utils.validation_utils import check_non_negative_int, check_positive_int, check_str +from nvflare.fuel.utils.validation_utils import check_non_negative_int, check_str from nvflare.security.logging import secure_format_exception -from nvflare.widgets.info_collector import GroupInfoCollector, InfoCollector class ModelController(Controller, FLComponentWrapper, ABC): def __init__( self, - min_clients: int = 1000, - num_rounds: int = 5, persistor_id="", ignore_result_error: bool = False, allow_empty_global_weights: bool = False, task_check_period: float = 0.5, - persist_every_n_rounds: int = 1, ): """FLModel based controller. Args: - min_clients (int, optional): The minimum number of clients responses before - Workflow starts to wait for `wait_time_after_min_received`. Note that the workflow will move forward - when all available clients have responded regardless of this value. Defaults to 1000. - num_rounds (int, optional): The total number of training rounds. Defaults to 5. persistor_id (str, optional): ID of the persistor component. Defaults to "persistor". ignore_result_error (bool, optional): whether this controller can proceed if client result has errors. Defaults to False. @@ -60,99 +52,47 @@ def __init__( empty global weights at first round, such that clients start training from scratch without any global info. Defaults to False. task_check_period (float, optional): interval for checking status of tasks. Defaults to 0.5. - persist_every_n_rounds (int, optional): persist the global model every n rounds. Defaults to 1. - If n is 0 then no persist. """ super().__init__(task_check_period=task_check_period) # Check arguments - check_positive_int("min_clients", min_clients) - check_non_negative_int("num_rounds", num_rounds) - check_non_negative_int("persist_every_n_rounds", persist_every_n_rounds) check_str("persistor_id", persistor_id) if not isinstance(task_check_period, (int, float)): raise TypeError(f"task_check_period must be an int or float but got {type(task_check_period)}") elif task_check_period <= 0: raise ValueError("task_check_period must be greater than 0.") self._task_check_period = task_check_period - self.persistor_id = persistor_id - self.persistor = None + self._persistor_id = persistor_id + self._persistor = None # config data - self._min_clients = min_clients - self._num_rounds = num_rounds - self._persist_every_n_rounds = persist_every_n_rounds - self.ignore_result_error = ignore_result_error - self.allow_empty_global_weights = allow_empty_global_weights - - # workflow phases: init, train, validate - self._phase = AppConstants.PHASE_INIT - self._current_round = None + self._ignore_result_error = ignore_result_error + self._allow_empty_global_weights = allow_empty_global_weights # model related - self.model = None self._results = [] def start_controller(self, fl_ctx: FLContext) -> None: self.fl_ctx = fl_ctx self.info("Initializing ModelController workflow.") - if self.persistor_id: - self.persistor = self._engine.get_component(self.persistor_id) - if not isinstance(self.persistor, LearnablePersistor): + if self._persistor_id: + self._persistor = self._engine.get_component(self._persistor_id) + if not isinstance(self._persistor, LearnablePersistor): self.panic( - f"Model Persistor {self.persistor_id} must be a LearnablePersistor type object, " - f"but got {type(self.persistor)}" - ) - return - - # initialize global model - if self.persistor: - global_weights = self.persistor.load(self.fl_ctx) - - if not isinstance(global_weights, ModelLearnable): - self.panic( - f"Expected global weights to be of type `ModelLearnable` but received {type(global_weights)}" + f"Model Persistor {self._persistor_id} must be a LearnablePersistor type object, " + f"but got {type(self._persistor)}" ) return - if global_weights.is_empty(): - if not self.allow_empty_global_weights: - # if empty not allowed, further check whether it is available from fl_ctx - global_weights = self.fl_ctx.get_prop(AppConstants.GLOBAL_MODEL) - - if not global_weights.is_empty(): - self.model = FLModel( - params_type=ParamsType.FULL, - params=global_weights[ModelLearnableKey.WEIGHTS], - meta=global_weights[ModelLearnableKey.META], - ) - elif self.allow_empty_global_weights: - self.model = FLModel(params_type=ParamsType.FULL, params={}) - else: - self.panic( - f"Neither `persistor` {self.persistor_id} or `fl_ctx` returned a global model! If this was intended, set `self.allow_empty_global_weights` to `True`." - ) - return - else: - self.model = FLModel(params_type=ParamsType.FULL, params={}) - - # persistor uses Learnable format to save model - ml = make_model_learnable(weights=self.model.params, meta_props=self.model.meta) - self.fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, ml, private=True, sticky=True) - self.event(AppEventType.INITIAL_MODEL_LOADED) - self.engine = self.fl_ctx.get_engine() FLComponentWrapper.initialize(self) def _build_shareable(self, data: FLModel = None) -> Shareable: - if not data: # if no data is given, send self.model - data = self.model - data_shareable: Shareable = FLModelUtils.to_shareable(data) - data_shareable.set_header(AppConstants.CURRENT_ROUND, self._current_round) - data_shareable.set_header(AppConstants.NUM_ROUNDS, self._num_rounds) - data_shareable.add_cookie(AppConstants.CONTRIBUTION_ROUND, self._current_round) + data_shareable.add_cookie( + AppConstants.CONTRIBUTION_ROUND, data_shareable.get_header(AppConstants.CURRENT_ROUND) + ) return data_shareable @@ -170,13 +110,13 @@ def send_model( Args: task_name (str, optional): name of the task. Defaults to "train". - data (FLModel, optional): FLModel to be sent to clients. If no data is given, send `self.model`. + data (FLModel, optional): FLModel to be sent to clients. If no data is given, send empty FLModel. targets (List[str], optional): the list of target client names or None (all clients). Defaults to None. timeout (int, optional): time to wait for clients to perform task. Defaults to 0, i.e., never time out. wait_time_after_min_received (int, optional): time to wait after minimum number of clients responses has been received. Defaults to 10. - blocking (bool, optional): whether to block to wait for task result. - callback (Callable[[FLModel], None], optional): callback when a result is received, only called when blocking=False. + blocking (bool, optional): whether to block to wait for task result. Defaults to True. + callback (Callable[[FLModel], None], optional): callback when a result is received, only called when blocking=False. Defaults to None. Returns: List[FLModel] if blocking=True else None @@ -191,6 +131,10 @@ def send_model( if not blocking and not isinstance(callback, Callable): raise TypeError("callback must be defined if blocking is False, but got {}".format(type(callback))) + if not data: + self.warning("data is None. Sending empty FLModel.") + data = FLModel(params_type=ParamsType.FULL, params={}) + task = self._prepare_task(data=data, task_name=task_name, timeout=timeout, callback=callback) if targets: @@ -268,18 +212,16 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None: result = client_task.result client_name = client_task.client.name - self.fl_ctx.set_prop(AppConstants.CURRENT_ROUND, self._current_round, private=True, sticky=True) + # Turn result into FLModel + result_model = FLModelUtils.from_shareable(result) + result_model.meta["client_name"] = client_name + + self.fl_ctx.set_prop(AppConstants.CURRENT_ROUND, result_model.current_round, private=True, sticky=True) self.event(AppEventType.BEFORE_CONTRIBUTION_ACCEPT) self._accept_train_result(client_name=client_name, result=result, fl_ctx=fl_ctx) self.event(AppEventType.AFTER_CONTRIBUTION_ACCEPT) - # Turn result into FLModel - result_model = FLModelUtils.from_shareable(result) - result_model.meta["client_name"] = client_name - result_model.meta["current_round"] = self._current_round - result_model.meta["total_rounds"] = self._num_rounds - callback = client_task.task.get_prop(AppConstants.TASK_PROP_CALLBACK) if callback: try: @@ -297,7 +239,7 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None: def process_result_of_unknown_task( self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext ) -> None: - if self._phase == AppConstants.PHASE_TRAIN and task_name == task_name: + if task_name == AppConstants.TASK_TRAIN: self._accept_train_result(client_name=client.name, result=result, fl_ctx=fl_ctx) self.info(f"Result of unknown task {task_name} sent to aggregator.") else: @@ -307,16 +249,18 @@ def _accept_train_result(self, client_name: str, result: Shareable, fl_ctx: FLCo self.fl_ctx = fl_ctx rc = result.get_return_code() + current_round = result.get_header(AppConstants.CURRENT_ROUND, None) + # Raise panic if bad peer context or execution exception. if rc and rc != ReturnCode.OK: - if self.ignore_result_error: + if self._ignore_result_error: self.warning( - f"Ignore the train result from {client_name} at round {self._current_round}. Train result error code: {rc}", + f"Ignore the train result from {client_name} at round {current_round}. Train result error code: {rc}", ) else: self.panic( f"Result from {client_name} is bad, error code: {rc}. " - f"{self.__class__.__name__} exiting at round {self._current_round}." + f"{self.__class__.__name__} exiting at round {current_round}." ) return @@ -332,50 +276,72 @@ def run(self): raise NotImplementedError def control_flow(self, abort_signal: Signal, fl_ctx: FLContext) -> None: - self._phase = AppConstants.PHASE_TRAIN - fl_ctx.set_prop(AppConstants.PHASE, self._phase, private=True, sticky=False) - fl_ctx.set_prop(AppConstants.NUM_ROUNDS, self._num_rounds, private=True, sticky=False) self.fl_ctx = fl_ctx self.abort_signal = abort_signal try: self.info("Beginning model controller run.") self.event(AppEventType.TRAINING_STARTED) - self._phase = AppConstants.PHASE_TRAIN self.run() - self._phase = AppConstants.PHASE_FINISHED except Exception as e: error_msg = f"Exception in model controller run: {secure_format_exception(e)}" self.exception(error_msg) self.panic(error_msg) - def save_model(self): - if self.persistor: - if ( - self._persist_every_n_rounds != 0 and (self._current_round + 1) % self._persist_every_n_rounds == 0 - ) or self._current_round == self._num_rounds - 1: - self.info("Start persist model on server.") - self.event(AppEventType.BEFORE_LEARNABLE_PERSIST) - # persistor uses Learnable format to save model - ml = make_model_learnable(weights=self.model.params, meta_props=self.model.meta) - self.persistor.save(ml, self.fl_ctx) - self.event(AppEventType.AFTER_LEARNABLE_PERSIST) - self.info("End persist model on server.") + def load_model(self): + # initialize global model + model = None + if self._persistor: + self.info("loading initial model from persistor") + global_weights = self._persistor.load(self.fl_ctx) + + if not isinstance(global_weights, ModelLearnable): + self.panic( + f"Expected global weights to be of type `ModelLearnable` but received {type(global_weights)}" + ) + return + + if global_weights.is_empty(): + if not self._allow_empty_global_weights: + # if empty not allowed, further check whether it is available from fl_ctx + global_weights = self.fl_ctx.get_prop(AppConstants.GLOBAL_MODEL) + + if not global_weights.is_empty(): + model = FLModel( + params_type=ParamsType.FULL, + params=global_weights[ModelLearnableKey.WEIGHTS], + meta=global_weights[ModelLearnableKey.META], + ) + elif self._allow_empty_global_weights: + model = FLModel(params_type=ParamsType.FULL, params={}) + else: + self.panic( + f"Neither `persistor` {self._persistor_id} or `fl_ctx` returned a global model! If this was intended, set `self._allow_empty_global_weights` to `True`." + ) + return + else: + self.info("persistor not configured, creating empty initial FLModel") + model = FLModel(params_type=ParamsType.FULL, params={}) + + # persistor uses Learnable format to save model + ml = make_model_learnable(weights=model.params, meta_props=model.meta) + self.fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, ml, private=True, sticky=True) + self.event(AppEventType.INITIAL_MODEL_LOADED) + + return model + + def save_model(self, model): + if self._persistor: + self.info("Start persist model on server.") + self.event(AppEventType.BEFORE_LEARNABLE_PERSIST) + # persistor uses Learnable format to save model + ml = make_model_learnable(weights=model.params, meta_props=model.meta) + self._persistor.save(ml, self.fl_ctx) + self.event(AppEventType.AFTER_LEARNABLE_PERSIST) + self.info("End persist model on server.") + else: + self.error("persistor not configured, model will not be saved") def stop_controller(self, fl_ctx: FLContext): - self._phase = AppConstants.PHASE_FINISHED self.fl_ctx = fl_ctx self.finalize() - - def handle_event(self, event_type: str, fl_ctx: FLContext): - super().handle_event(event_type, fl_ctx) - if event_type == InfoCollector.EVENT_TYPE_GET_STATS: - collector = fl_ctx.get_prop(InfoCollector.CTX_KEY_STATS_COLLECTOR, None) - if collector: - if not isinstance(collector, GroupInfoCollector): - raise TypeError("collector must be GroupInfoCollector but got {}".format(type(collector))) - - collector.add_info( - group_name=self._name, - info={"phase": self._phase, "current_round": self._current_round, "num_rounds": self._num_rounds}, - ) diff --git a/nvflare/app_common/workflows/scaffold.py b/nvflare/app_common/workflows/scaffold.py index 7d5f4a1996..116ab9ca9b 100644 --- a/nvflare/app_common/workflows/scaffold.py +++ b/nvflare/app_common/workflows/scaffold.py @@ -21,7 +21,6 @@ from nvflare.app_common.abstract.fl_model import FLModel from nvflare.app_common.aggregators.weighted_aggregation_helper import WeightedAggregationHelper from nvflare.app_common.app_constant import AlgorithmConstants, AppConstants -from nvflare.app_common.utils.fl_component_wrapper import FLComponentWrapper from .base_fedavg import BaseFedAvg @@ -51,8 +50,12 @@ class Scaffold(BaseFedAvg): If n is 0 then no persist. """ - def initialize(self): - FLComponentWrapper.initialize(self) + def initialize(self, fl_ctx): + super().initialize(fl_ctx) + self.model = self.load_model() + self.model.start_round = self.start_round + self.model.total_rounds = self.num_rounds + self._global_ctrl_weights = copy.deepcopy(self.model.params) # Initialize correction term with zeros for k in self._global_ctrl_weights.keys(): @@ -61,27 +64,28 @@ def initialize(self): def run(self) -> None: self.info("Start FedAvg.") - for self._current_round in range(self._num_rounds): - self.info(f"Round {self._current_round} started.") + for self.current_round in range(self.start_round, self.start_round + self.num_rounds): + self.info(f"Round {self.current_round} started.") + self.model.current_round = self.current_round - clients = self.sample_clients(self._min_clients) + clients = self.sample_clients(self.min_clients) # Add SCAFFOLD global control terms to global model meta global_model = self.model global_model.meta[AlgorithmConstants.SCAFFOLD_CTRL_GLOBAL] = self._global_ctrl_weights - results = self.send_model(targets=clients, data=global_model) + results = self.send_model_and_wait(targets=clients, data=global_model) aggregate_results = self.aggregate(results, aggregate_fn=scaffold_aggregate_fn) - self.update_model(aggregate_results) + self.model = self.update_model(self.model, aggregate_results) # update SCAFFOLD global controls ctr_diff = aggregate_results.meta[AlgorithmConstants.SCAFFOLD_CTRL_DIFF] for v_name, v_value in ctr_diff.items(): self._global_ctrl_weights[v_name] += v_value - self.save_model() + self.save_model(self.model) self.info("Finished FedAvg.") @@ -96,13 +100,13 @@ def scaffold_aggregate_fn(results: List[FLModel]) -> FLModel: data=_result.params, weight=_result.meta.get(FLMetaKey.NUM_STEPS_CURRENT_ROUND, 1.0), contributor_name=_result.meta.get("client_name", AppConstants.CLIENT_UNKNOWN), - contribution_round=_result.meta.get("current_round", None), + contribution_round=_result.current_round, ) crtl_aggregation_helper.add( data=_result.meta[AlgorithmConstants.SCAFFOLD_CTRL_DIFF], weight=_result.meta.get(FLMetaKey.NUM_STEPS_CURRENT_ROUND, 1.0), contributor_name=_result.meta.get("client_name", AppConstants.CLIENT_UNKNOWN), - contribution_round=_result.meta.get("current_round", None), + contribution_round=_result.current_round, ) aggregated_dict = aggregation_helper.get_result() @@ -113,7 +117,7 @@ def scaffold_aggregate_fn(results: List[FLModel]) -> FLModel: meta={ AlgorithmConstants.SCAFFOLD_CTRL_DIFF: crtl_aggregation_helper.get_result(), "nr_aggregated": len(results), - "current_round": results[0].meta["current_round"], + "current_round": results[0].current_round, }, ) diff --git a/nvflare/app_common/workflows/wf_controller.py b/nvflare/app_common/workflows/wf_controller.py index 51a148a4b0..76dd4ff482 100644 --- a/nvflare/app_common/workflows/wf_controller.py +++ b/nvflare/app_common/workflows/wf_controller.py @@ -15,19 +15,46 @@ from abc import ABC, abstractmethod from typing import Callable, List, Union -from nvflare.apis.wf_controller_spec import WFControllerSpec from nvflare.app_common.abstract.fl_model import FLModel from nvflare.app_common.workflows.model_controller import ModelController -class WFController(ModelController, WFControllerSpec, ABC): - """Workflow Controller for FLModel based ModelController.""" +class WFController(ModelController, ABC): + """Workflow Controller API for FLModel-based ModelController.""" @abstractmethod def run(self): """Main `run` routine for the controller workflow.""" raise NotImplementedError + def send_model_and_wait( + self, + task_name: str = "train", + data: FLModel = None, + targets: Union[List[str], None] = None, + timeout: int = 0, + wait_time_after_min_received: int = 10, + ) -> List[FLModel]: + """Send a task with data to targets and wait for results. + + Args: + task_name (str, optional): name of the task. Defaults to "train". + data (FLModel, optional): FLModel to be sent to clients. Defaults to None. + targets (List[str], optional): the list of target client names or None (all clients). Defaults to None. + timeout (int, optional): time to wait for clients to perform task. Defaults to 0 (never time out). + wait_time_after_min_received (int, optional): time to wait after minimum number of client responses have been received. Defaults to 10. + + Returns: + List[FLModel] + """ + return super().send_model( + task_name=task_name, + data=data, + targets=targets, + timeout=timeout, + wait_time_after_min_received=wait_time_after_min_received, + ) + def send_model( self, task_name: str = "train", @@ -35,10 +62,9 @@ def send_model( targets: Union[List[str], None] = None, timeout: int = 0, wait_time_after_min_received: int = 10, - blocking: bool = True, callback: Callable[[FLModel], None] = None, - ) -> Union[List[FLModel], None]: - """Send a task with data to targets. + ) -> None: + """Send a task with data to targets (non-blocking). Callback is called when a result is received. Args: task_name (str, optional): name of the task. Defaults to "train". @@ -46,18 +72,36 @@ def send_model( targets (List[str], optional): the list of target client names or None (all clients). Defaults to None. timeout (int, optional): time to wait for clients to perform task. Defaults to 0 (never time out). wait_time_after_min_received (int, optional): time to wait after minimum number of client responses have been received. Defaults to 10. - blocking (bool, optional): whether to block to wait for task result. Defaults to True. - callback (Callable[[FLModel], None], optional): callback when a result is received. Only called when blocking=False. Defaults to None. + callback (Callable[[FLModel], None], optional): callback when a result is received. Defaults to None. Returns: - List[FLModel] if blocking = True else None + None """ - return super().send_model( + super().send_model( task_name=task_name, data=data, targets=targets, timeout=timeout, wait_time_after_min_received=wait_time_after_min_received, - blocking=blocking, + blocking=False, callback=callback, ) + + def load_model(self): + """Load initial model from persistor. If persistor is not configured, returns empty FLModel. + + Returns: + FLModel + """ + return super().load_model() + + def save_model(self, model: FLModel): + """Saves model with persistor. If persistor is not configured, does not save. + + Args: + model (FLModel): model to save. + + Returns: + None + """ + super().save_model(model)