diff --git a/e2e/bare-https/driver.py b/e2e/bare-https/driver.py index dd7c9eab724..f7bfeb613f6 100644 --- a/e2e/bare-https/driver.py +++ b/e2e/bare-https/driver.py @@ -3,7 +3,7 @@ # Start Flower server -hist = fl.server.driver.start_driver( +hist = fl.server.start_driver( server_address="127.0.0.1:9091", config=fl.server.ServerConfig(num_rounds=3), root_certificates=Path("certificates/ca.crt").read_bytes(), diff --git a/e2e/bare/driver.py b/e2e/bare/driver.py index d428fe757aa..defc2ad5621 100644 --- a/e2e/bare/driver.py +++ b/e2e/bare/driver.py @@ -2,7 +2,7 @@ # Start Flower server -hist = fl.server.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/fastai/driver.py b/e2e/fastai/driver.py index b7b1c41ff5a..cc452ea523c 100644 --- a/e2e/fastai/driver.py +++ b/e2e/fastai/driver.py @@ -1,6 +1,6 @@ import flwr as fl -hist = fl.server.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/jax/driver.py b/e2e/jax/driver.py index b7b1c41ff5a..cc452ea523c 100644 --- a/e2e/jax/driver.py +++ b/e2e/jax/driver.py @@ -1,6 +1,6 @@ import flwr as fl -hist = fl.server.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/opacus/driver.py b/e2e/opacus/driver.py index 5bc40800c33..75acd9ccea2 100644 --- a/e2e/opacus/driver.py +++ b/e2e/opacus/driver.py @@ -1,6 +1,6 @@ import flwr as fl -hist = fl.server.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/pandas/driver.py b/e2e/pandas/driver.py index 78120fc946f..f5dc74c9f3f 100644 --- a/e2e/pandas/driver.py +++ b/e2e/pandas/driver.py @@ -3,7 +3,7 @@ from strategy import FedAnalytics # Start Flower server -hist = fl.server.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=1), strategy=FedAnalytics(), diff --git a/e2e/pytorch-lightning/driver.py b/e2e/pytorch-lightning/driver.py index b7b1c41ff5a..cc452ea523c 100644 --- a/e2e/pytorch-lightning/driver.py +++ b/e2e/pytorch-lightning/driver.py @@ -1,6 +1,6 @@ import flwr as fl -hist = fl.server.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/pytorch/driver.py b/e2e/pytorch/driver.py index 9f9b076ee75..2ea4de69a62 100644 --- a/e2e/pytorch/driver.py +++ b/e2e/pytorch/driver.py @@ -18,7 +18,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) # Start Flower server -hist = fl.server.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), strategy=strategy, diff --git a/e2e/scikit-learn/driver.py b/e2e/scikit-learn/driver.py index e7ce124e5ea..29051d02c6b 100644 --- a/e2e/scikit-learn/driver.py +++ b/e2e/scikit-learn/driver.py @@ -36,7 +36,7 @@ def evaluate(server_round, parameters: fl.common.NDArrays, config): evaluate_fn=get_evaluate_fn(model), on_fit_config_fn=fit_round, ) - hist = fl.server.driver.start_driver( + hist = fl.server.start_driver( server_address="0.0.0.0:9091", strategy=strategy, config=fl.server.ServerConfig(num_rounds=3), diff --git a/e2e/tabnet/driver.py b/e2e/tabnet/driver.py index b7b1c41ff5a..cc452ea523c 100644 --- a/e2e/tabnet/driver.py +++ b/e2e/tabnet/driver.py @@ -1,6 +1,6 @@ import flwr as fl -hist = fl.server.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), ) diff --git a/e2e/tensorflow/driver.py b/e2e/tensorflow/driver.py index 9f9b076ee75..2ea4de69a62 100644 --- a/e2e/tensorflow/driver.py +++ b/e2e/tensorflow/driver.py @@ -18,7 +18,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) # Start Flower server -hist = fl.server.driver.start_driver( +hist = fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), strategy=strategy, diff --git a/examples/quickstart-cpp/driver.py b/examples/quickstart-cpp/driver.py index 3b3036f7e92..f19cf0e9bd9 100644 --- a/examples/quickstart-cpp/driver.py +++ b/examples/quickstart-cpp/driver.py @@ -3,7 +3,7 @@ # Start Flower server for three rounds of federated learning if __name__ == "__main__": - fl.server.driver.start_driver( + fl.server.start_driver( server_address="0.0.0.0:9091", config=fl.server.ServerConfig(num_rounds=3), strategy=FedAvgCpp(), diff --git a/src/py/flwr/server/__init__.py b/src/py/flwr/server/__init__.py index b0f95f90381..09372e25886 100644 --- a/src/py/flwr/server/__init__.py +++ b/src/py/flwr/server/__init__.py @@ -18,12 +18,13 @@ from . import driver, strategy from .app import run_driver_api as run_driver_api from .app import run_fleet_api as run_fleet_api -from .app import run_server_app as run_server_app from .app import run_superlink as run_superlink from .app import start_server as start_server from .client_manager import ClientManager as ClientManager from .client_manager import SimpleClientManager as SimpleClientManager +from .compat import start_driver as start_driver from .history import History as History +from .run_serverapp import run_server_app as run_server_app from .server import Server as Server from .server_config import ServerConfig as ServerConfig from .serverapp import ServerApp as ServerApp @@ -40,6 +41,7 @@ "ServerApp", "ServerConfig", "SimpleClientManager", + "start_driver", "start_server", "strategy", ] diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 66adcbdb6b8..dbbf63b0fe5 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -19,7 +19,7 @@ import importlib.util import sys import threading -from logging import DEBUG, ERROR, INFO, WARN +from logging import ERROR, INFO, WARN from os.path import isfile from pathlib import Path from signal import SIGINT, SIGTERM, signal @@ -47,7 +47,6 @@ from .history import History from .server import Server from .server_config import ServerConfig -from .serverapp import ServerApp, load_server_app from .strategy import FedAvg, Strategy from .superlink.driver.driver_servicer import DriverServicer from .superlink.fleet.grpc_bidi.grpc_server import ( @@ -65,72 +64,6 @@ DATABASE = ":flwr-in-memory-state:" -def run_server_app() -> None: - """Run Flower server app.""" - event(EventType.RUN_SERVER_APP_ENTER) - - args = _parse_args_run_server_app().parse_args() - - # Obtain certificates - if args.insecure: - if args.root_certificates is not None: - sys.exit( - "Conflicting options: The '--insecure' flag disables HTTPS, " - "but '--root-certificates' was also specified. Please remove " - "the '--root-certificates' option when running in insecure mode, " - "or omit '--insecure' to use HTTPS." - ) - log( - WARN, - "Option `--insecure` was set. " - "Starting insecure HTTP client connected to %s.", - args.server, - ) - root_certificates = None - else: - # Load the certificates if provided, or load the system certificates - cert_path = args.root_certificates - if cert_path is None: - root_certificates = None - else: - root_certificates = Path(cert_path).read_bytes() - log( - DEBUG, - "Starting secure HTTPS client connected to %s " - "with the following certificates: %s.", - args.server, - cert_path, - ) - - log( - DEBUG, - "Flower will load ServerApp `%s`", - getattr(args, "server-app"), - ) - - log( - DEBUG, - "root_certificates: `%s`", - root_certificates, - ) - - log(WARN, "Not implemented: run_server_app") - - server_app_dir = args.dir - if server_app_dir is not None: - sys.path.insert(0, server_app_dir) - - def _load() -> ServerApp: - server_app: ServerApp = load_server_app(getattr(args, "server-app")) - return server_app - - server_app = _load() - - log(DEBUG, "server_app: `%s`", server_app) - - event(EventType.RUN_SERVER_APP_LEAVE) - - def start_server( # pylint: disable=too-many-arguments,too-many-locals *, server_address: str = ADDRESS_FLEET_API_GRPC_BIDI, @@ -816,42 +749,3 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None: type=int, default=1, ) - - -def _parse_args_run_server_app() -> argparse.ArgumentParser: - """Parse flower-server-app command line arguments.""" - parser = argparse.ArgumentParser( - description="Start a Flower server app", - ) - - parser.add_argument( - "server-app", - help="For example: `server:app` or `project.package.module:wrapper.app`", - ) - parser.add_argument( - "--insecure", - action="store_true", - help="Run the server app without HTTPS. By default, the app runs with " - "HTTPS enabled. Use this flag only if you understand the risks.", - ) - parser.add_argument( - "--root-certificates", - metavar="ROOT_CERT", - type=str, - help="Specifies the path to the PEM-encoded root certificate file for " - "establishing secure HTTPS connections.", - ) - parser.add_argument( - "--server", - default="0.0.0.0:9092", - help="Server address", - ) - parser.add_argument( - "--dir", - default="", - help="Add specified directory to the PYTHONPATH and load Flower " - "app from there." - " Default: current working directory.", - ) - - return parser diff --git a/src/py/flwr/server/compat/__init__.py b/src/py/flwr/server/compat/__init__.py new file mode 100644 index 00000000000..3a0c2b4e83a --- /dev/null +++ b/src/py/flwr/server/compat/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower ServerApp compatibility package.""" + + +from .app import start_driver as start_driver + +__all__ = [ + "start_driver", +] diff --git a/src/py/flwr/server/driver/app.py b/src/py/flwr/server/compat/app.py similarity index 97% rename from src/py/flwr/server/driver/app.py rename to src/py/flwr/server/compat/app.py index b47454b7b4b..06debb858c3 100644 --- a/src/py/flwr/server/driver/app.py +++ b/src/py/flwr/server/compat/app.py @@ -33,8 +33,8 @@ from flwr.server.server_config import ServerConfig from flwr.server.strategy import Strategy +from ..driver.grpc_driver import GrpcDriver from .driver_client_proxy import DriverClientProxy -from .grpc_driver import GrpcDriver DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091" @@ -111,10 +111,11 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals # Create the Driver if isinstance(root_certificates, str): root_certificates = Path(root_certificates).read_bytes() - driver = GrpcDriver( + grpc_driver = GrpcDriver( driver_service_address=address, root_certificates=root_certificates ) - driver.connect() + + grpc_driver.connect() lock = threading.Lock() # Initialize the Driver API server and config @@ -134,7 +135,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals thread = threading.Thread( target=update_client_manager, args=( - driver, + grpc_driver, initialized_server.client_manager(), lock, ), @@ -149,7 +150,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals # Stop the Driver API server and the thread with lock: - driver.disconnect() + grpc_driver.disconnect() thread.join() event(EventType.START_SERVER_LEAVE) diff --git a/src/py/flwr/server/driver/app_test.py b/src/py/flwr/server/compat/app_test.py similarity index 98% rename from src/py/flwr/server/driver/app_test.py rename to src/py/flwr/server/compat/app_test.py index 03f49080787..5f8f04ff2a0 100644 --- a/src/py/flwr/server/driver/app_test.py +++ b/src/py/flwr/server/compat/app_test.py @@ -26,7 +26,8 @@ ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.server.client_manager import SimpleClientManager -from flwr.server.driver.app import update_client_manager + +from .app import update_client_manager class TestClientManagerWithDriver(unittest.TestCase): diff --git a/src/py/flwr/server/driver/driver_client_proxy.py b/src/py/flwr/server/compat/driver_client_proxy.py similarity index 99% rename from src/py/flwr/server/driver/driver_client_proxy.py rename to src/py/flwr/server/compat/driver_client_proxy.py index 8ea288dbb50..1dc992106f6 100644 --- a/src/py/flwr/server/driver/driver_client_proxy.py +++ b/src/py/flwr/server/compat/driver_client_proxy.py @@ -31,7 +31,7 @@ from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 from flwr.server.client_proxy import ClientProxy -from .grpc_driver import GrpcDriver +from ..driver.grpc_driver import GrpcDriver SLEEP_TIME = 1 diff --git a/src/py/flwr/server/driver/driver_client_proxy_test.py b/src/py/flwr/server/compat/driver_client_proxy_test.py similarity index 99% rename from src/py/flwr/server/driver/driver_client_proxy_test.py rename to src/py/flwr/server/compat/driver_client_proxy_test.py index aa60448bd72..de6566622b7 100644 --- a/src/py/flwr/server/driver/driver_client_proxy_test.py +++ b/src/py/flwr/server/compat/driver_client_proxy_test.py @@ -44,7 +44,8 @@ Status, ) from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 -from flwr.server.driver.driver_client_proxy import DriverClientProxy + +from .driver_client_proxy import DriverClientProxy MESSAGE_PARAMETERS = Parameters(tensors=[b"abc"], tensor_type="np") diff --git a/src/py/flwr/server/driver/__init__.py b/src/py/flwr/server/driver/__init__.py index 1c3b09cc334..b61f6eebf6a 100644 --- a/src/py/flwr/server/driver/__init__.py +++ b/src/py/flwr/server/driver/__init__.py @@ -15,12 +15,10 @@ """Flower driver SDK.""" -from .app import start_driver from .driver import Driver from .grpc_driver import GrpcDriver __all__ = [ "Driver", "GrpcDriver", - "start_driver", ] diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py new file mode 100644 index 00000000000..35fffcf2d7b --- /dev/null +++ b/src/py/flwr/server/run_serverapp.py @@ -0,0 +1,131 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Run ServerApp.""" + + +import argparse +import sys +from logging import DEBUG, WARN +from pathlib import Path + +from flwr.common import EventType, event +from flwr.common.logger import log + +from .serverapp import ServerApp, load_server_app + + +def run_server_app() -> None: + """Run Flower server app.""" + event(EventType.RUN_SERVER_APP_ENTER) + + args = _parse_args_run_server_app().parse_args() + + # Obtain certificates + if args.insecure: + if args.root_certificates is not None: + sys.exit( + "Conflicting options: The '--insecure' flag disables HTTPS, " + "but '--root-certificates' was also specified. Please remove " + "the '--root-certificates' option when running in insecure mode, " + "or omit '--insecure' to use HTTPS." + ) + log( + WARN, + "Option `--insecure` was set. " + "Starting insecure HTTP client connected to %s.", + args.server, + ) + root_certificates = None + else: + # Load the certificates if provided, or load the system certificates + cert_path = args.root_certificates + if cert_path is None: + root_certificates = None + else: + root_certificates = Path(cert_path).read_bytes() + log( + DEBUG, + "Starting secure HTTPS client connected to %s " + "with the following certificates: %s.", + args.server, + cert_path, + ) + + log( + DEBUG, + "Flower will load ServerApp `%s`", + getattr(args, "server-app"), + ) + + log( + DEBUG, + "root_certificates: `%s`", + root_certificates, + ) + + log(WARN, "Not implemented: run_server_app") + + server_app_dir = args.dir + if server_app_dir is not None: + sys.path.insert(0, server_app_dir) + + def _load() -> ServerApp: + server_app: ServerApp = load_server_app(getattr(args, "server-app")) + return server_app + + server_app = _load() + + log(DEBUG, "server_app: `%s`", server_app) + + event(EventType.RUN_SERVER_APP_LEAVE) + + +def _parse_args_run_server_app() -> argparse.ArgumentParser: + """Parse flower-server-app command line arguments.""" + parser = argparse.ArgumentParser( + description="Start a Flower server app", + ) + + parser.add_argument( + "server-app", + help="For example: `server:app` or `project.package.module:wrapper.app`", + ) + parser.add_argument( + "--insecure", + action="store_true", + help="Run the server app without HTTPS. By default, the app runs with " + "HTTPS enabled. Use this flag only if you understand the risks.", + ) + parser.add_argument( + "--root-certificates", + metavar="ROOT_CERT", + type=str, + help="Specifies the path to the PEM-encoded root certificate file for " + "establishing secure HTTPS connections.", + ) + parser.add_argument( + "--server", + default="0.0.0.0:9092", + help="Server address", + ) + parser.add_argument( + "--dir", + default="", + help="Add specified directory to the PYTHONPATH and load Flower " + "app from there." + " Default: current working directory.", + ) + + return parser