diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index 28a66e136639..1f7e5a9f5b9b 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -34,7 +34,6 @@ from flwr.common.constant import DRIVER_API_DEFAULT_ADDRESS from flwr.common.logger import log, update_console_handler, warn_deprecated_feature from flwr.common.object_ref import load_app -from flwr.common.typing import UserConfig from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611 from flwr.proto.run_pb2 import ( # pylint: disable=E0611 CreateRunRequest, @@ -46,13 +45,14 @@ from .server_app import LoadServerAppError, ServerApp +# pylint: disable-next=too-many-arguments,too-many-positional-arguments def run( driver: Driver, + context: Context, server_app_dir: str, - server_app_run_config: UserConfig, server_app_attr: Optional[str] = None, loaded_server_app: Optional[ServerApp] = None, -) -> None: +) -> Context: """Run ServerApp with a given Driver.""" if not (server_app_attr is None) ^ (loaded_server_app is None): raise ValueError( @@ -78,15 +78,11 @@ def _load() -> ServerApp: server_app = _load() - # Initialize Context - context = Context( - node_id=0, node_config={}, state=RecordSet(), run_config=server_app_run_config - ) - # Call ServerApp server_app(driver=driver, context=context) log(DEBUG, "ServerApp finished running.") + return context # pylint: disable-next=too-many-branches,too-many-statements,too-many-locals @@ -225,11 +221,19 @@ def run_server_app() -> None: root_certificates, ) + # Initialize Context + context = Context( + node_id=0, + node_config={}, + state=RecordSet(), + run_config=server_app_run_config, + ) + # Run the ServerApp with the Driver run( driver=driver, + context=context, server_app_dir=app_path, - server_app_run_config=server_app_run_config, server_app_attr=server_app_attr, ) diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 15ff6bf7d206..29834342554b 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -29,7 +29,7 @@ from flwr.cli.config_utils import load_and_validate from flwr.client import ClientApp -from flwr.common import EventType, event, log, now +from flwr.common import Context, EventType, RecordSet, event, log, now from flwr.common.config import get_fused_config_from_dir, parse_config_args from flwr.common.constant import RUN_ID_NUM_BYTES, Status from flwr.common.logger import ( @@ -40,7 +40,7 @@ ) from flwr.common.typing import Run, RunStatus, UserConfig from flwr.server.driver import Driver, InMemoryDriver -from flwr.server.run_serverapp import run as run_server_app +from flwr.server.run_serverapp import run as _run from flwr.server.server_app import ServerApp from flwr.server.superlink.fleet import vce from flwr.server.superlink.fleet.vce.backend.backend import BackendConfig @@ -333,11 +333,19 @@ def server_th_with_start_checks( log(INFO, "Enabling GPU growth for Tensorflow on the server thread.") enable_gpu_growth() + # Initialize Context + context = Context( + node_id=0, + node_config={}, + state=RecordSet(), + run_config=_server_app_run_config, + ) + # Run ServerApp - run_server_app( + _run( driver=_driver, + context=context, server_app_dir=_server_app_dir, - server_app_run_config=_server_app_run_config, server_app_attr=_server_app_attr, loaded_server_app=_server_app, )