diff --git a/pyproject.toml b/pyproject.toml index 688c598c5a35..ecd8d584a4ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,8 @@ exclude = [ ] [tool.poetry.scripts] +flower-driver-api = "flwr.server:run_driver_api" +flower-fleet-api = "flwr.server:run_fleet_api" flower-server = "flwr.server:run_server" flower-client = "flwr.client:run_client" diff --git a/src/py/flwr/common/telemetry.py b/src/py/flwr/common/telemetry.py index 541615eecd16..43038721ce81 100644 --- a/src/py/flwr/common/telemetry.py +++ b/src/py/flwr/common/telemetry.py @@ -127,7 +127,15 @@ def _generate_next_value_(name: str, start: int, count: int, last_values: List[A START_SERVER_ENTER = auto() START_SERVER_LEAVE = auto() - # New Server + # Driver API + RUN_DRIVER_API_ENTER = auto() + RUN_DRIVER_API_LEAVE = auto() + + # Fleet API + RUN_FLEET_API_ENTER = auto() + RUN_FLEET_API_LEAVE = auto() + + # Driver API and Fleet API RUN_SERVER_ENTER = auto() RUN_SERVER_LEAVE = auto() diff --git a/src/py/flwr/server/__init__.py b/src/py/flwr/server/__init__.py index abcc03c1b745..b419b22d2fa8 100644 --- a/src/py/flwr/server/__init__.py +++ b/src/py/flwr/server/__init__.py @@ -16,6 +16,8 @@ from .app import ServerConfig as ServerConfig +from .app import run_driver_api as run_driver_api +from .app import run_fleet_api as run_fleet_api from .app import run_server as run_server from .app import start_server as start_server from .client_manager import ClientManager as ClientManager @@ -25,10 +27,12 @@ __all__ = [ "ClientManager", - "ServerConfig", "History", + "run_driver_api", + "run_fleet_api", "run_server", "Server", + "ServerConfig", "SimpleClientManager", "start_server", ] diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 648694269169..e25e7fc0ea94 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -210,6 +210,57 @@ def _fl( return hist +def run_driver_api() -> None: + """Run Flower server (Driver API).""" + + log(INFO, "Starting Flower server (Driver API)") + event(EventType.RUN_DRIVER_API_ENTER) + args = _parse_args_driver() + + # Init state + state = InMemoryState() + + # Start server + grpc_server: grpc.Server = _run_driver_api_grpc( + address=args.driver_api_address, + state=state, + ) + + # Graceful shutdown + _register_exit_handlers( + grpc_servers=[grpc_server], + event_type=EventType.RUN_DRIVER_API_LEAVE, + ) + + # Block + grpc_server.wait_for_termination() + + +def run_fleet_api() -> None: + """Run Flower server (Fleet API).""" + + log(INFO, "Starting Flower server (Fleet API)") + event(EventType.RUN_FLEET_API_ENTER) + args = _parse_args_fleet() + + # Init state + state = InMemoryState() + + # Start server + grpc_server: grpc.Server = _run_fleet_api_grpc_bidi( + address=args.fleet_api_address, + state=state, + ) + + _register_exit_handlers( + grpc_servers=[grpc_server], + event_type=EventType.RUN_FLEET_API_LEAVE, + ) + + # Block + grpc_server.wait_for_termination() + + def run_server() -> None: """Run Flower server (Driver API and Fleet API).""" @@ -338,6 +389,28 @@ def _run_fleet_api_grpc_bidi( return fleet_grpc_server +def _parse_args_driver() -> argparse.Namespace: + """Parse command line arguments for Driver API.""" + parser = argparse.ArgumentParser( + description="Start Flower server (Driver API)", + ) + + _add_arg_driver_api_address(parser=parser) + + return parser.parse_args() + + +def _parse_args_fleet() -> argparse.Namespace: + """Parse command line arguments for Fleet API.""" + parser = argparse.ArgumentParser( + description="Start Flower server (Fleet API)", + ) + + _add_arg_fleet_api_address(parser=parser) + + return parser.parse_args() + + def _parse_args() -> argparse.Namespace: """Parse command line arguments for both Driver API and Fleet API.""" parser = argparse.ArgumentParser(