Skip to content

Commit

Permalink
Create server-side compatibility package (#2957)
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes authored Feb 15, 2024
1 parent 5af1679 commit 277ab1f
Show file tree
Hide file tree
Showing 21 changed files with 180 additions and 130 deletions.
2 changes: 1 addition & 1 deletion e2e/bare-https/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


# Start Flower server
hist = fl.server.driver.start_driver(
hist = fl.server.start_driver(
server_address="127.0.0.1:9091",
config=fl.server.ServerConfig(num_rounds=3),
root_certificates=Path("certificates/ca.crt").read_bytes(),
Expand Down
2 changes: 1 addition & 1 deletion e2e/bare/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


# Start Flower server
hist = fl.server.driver.start_driver(
hist = fl.server.start_driver(
server_address="0.0.0.0:9091",
config=fl.server.ServerConfig(num_rounds=3),
)
Expand Down
2 changes: 1 addition & 1 deletion e2e/fastai/driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import flwr as fl

hist = fl.server.driver.start_driver(
hist = fl.server.start_driver(
server_address="0.0.0.0:9091",
config=fl.server.ServerConfig(num_rounds=3),
)
Expand Down
2 changes: 1 addition & 1 deletion e2e/jax/driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import flwr as fl

hist = fl.server.driver.start_driver(
hist = fl.server.start_driver(
server_address="0.0.0.0:9091",
config=fl.server.ServerConfig(num_rounds=3),
)
Expand Down
2 changes: 1 addition & 1 deletion e2e/opacus/driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import flwr as fl

hist = fl.server.driver.start_driver(
hist = fl.server.start_driver(
server_address="0.0.0.0:9091",
config=fl.server.ServerConfig(num_rounds=3),
)
Expand Down
2 changes: 1 addition & 1 deletion e2e/pandas/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from strategy import FedAnalytics

# Start Flower server
hist = fl.server.driver.start_driver(
hist = fl.server.start_driver(
server_address="0.0.0.0:9091",
config=fl.server.ServerConfig(num_rounds=1),
strategy=FedAnalytics(),
Expand Down
2 changes: 1 addition & 1 deletion e2e/pytorch-lightning/driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import flwr as fl

hist = fl.server.driver.start_driver(
hist = fl.server.start_driver(
server_address="0.0.0.0:9091",
config=fl.server.ServerConfig(num_rounds=3),
)
Expand Down
2 changes: 1 addition & 1 deletion e2e/pytorch/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average)

# Start Flower server
hist = fl.server.driver.start_driver(
hist = fl.server.start_driver(
server_address="0.0.0.0:9091",
config=fl.server.ServerConfig(num_rounds=3),
strategy=strategy,
Expand Down
2 changes: 1 addition & 1 deletion e2e/scikit-learn/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def evaluate(server_round, parameters: fl.common.NDArrays, config):
evaluate_fn=get_evaluate_fn(model),
on_fit_config_fn=fit_round,
)
hist = fl.server.driver.start_driver(
hist = fl.server.start_driver(
server_address="0.0.0.0:9091",
strategy=strategy,
config=fl.server.ServerConfig(num_rounds=3),
Expand Down
2 changes: 1 addition & 1 deletion e2e/tabnet/driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import flwr as fl

hist = fl.server.driver.start_driver(
hist = fl.server.start_driver(
server_address="0.0.0.0:9091",
config=fl.server.ServerConfig(num_rounds=3),
)
Expand Down
2 changes: 1 addition & 1 deletion e2e/tensorflow/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average)

# Start Flower server
hist = fl.server.driver.start_driver(
hist = fl.server.start_driver(
server_address="0.0.0.0:9091",
config=fl.server.ServerConfig(num_rounds=3),
strategy=strategy,
Expand Down
2 changes: 1 addition & 1 deletion examples/quickstart-cpp/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# Start Flower server for three rounds of federated learning
if __name__ == "__main__":
fl.server.driver.start_driver(
fl.server.start_driver(
server_address="0.0.0.0:9091",
config=fl.server.ServerConfig(num_rounds=3),
strategy=FedAvgCpp(),
Expand Down
4 changes: 3 additions & 1 deletion src/py/flwr/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
from . import driver, strategy
from .app import run_driver_api as run_driver_api
from .app import run_fleet_api as run_fleet_api
from .app import run_server_app as run_server_app
from .app import run_superlink as run_superlink
from .app import start_server as start_server
from .client_manager import ClientManager as ClientManager
from .client_manager import SimpleClientManager as SimpleClientManager
from .compat import start_driver as start_driver
from .history import History as History
from .run_serverapp import run_server_app as run_server_app
from .server import Server as Server
from .server_config import ServerConfig as ServerConfig
from .serverapp import ServerApp as ServerApp
Expand All @@ -40,6 +41,7 @@
"ServerApp",
"ServerConfig",
"SimpleClientManager",
"start_driver",
"start_server",
"strategy",
]
108 changes: 1 addition & 107 deletions src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import importlib.util
import sys
import threading
from logging import DEBUG, ERROR, INFO, WARN
from logging import ERROR, INFO, WARN
from os.path import isfile
from pathlib import Path
from signal import SIGINT, SIGTERM, signal
Expand Down Expand Up @@ -47,7 +47,6 @@
from .history import History
from .server import Server
from .server_config import ServerConfig
from .serverapp import ServerApp, load_server_app
from .strategy import FedAvg, Strategy
from .superlink.driver.driver_servicer import DriverServicer
from .superlink.fleet.grpc_bidi.grpc_server import (
Expand All @@ -65,72 +64,6 @@
DATABASE = ":flwr-in-memory-state:"


def run_server_app() -> None:
"""Run Flower server app."""
event(EventType.RUN_SERVER_APP_ENTER)

args = _parse_args_run_server_app().parse_args()

# Obtain certificates
if args.insecure:
if args.root_certificates is not None:
sys.exit(
"Conflicting options: The '--insecure' flag disables HTTPS, "
"but '--root-certificates' was also specified. Please remove "
"the '--root-certificates' option when running in insecure mode, "
"or omit '--insecure' to use HTTPS."
)
log(
WARN,
"Option `--insecure` was set. "
"Starting insecure HTTP client connected to %s.",
args.server,
)
root_certificates = None
else:
# Load the certificates if provided, or load the system certificates
cert_path = args.root_certificates
if cert_path is None:
root_certificates = None
else:
root_certificates = Path(cert_path).read_bytes()
log(
DEBUG,
"Starting secure HTTPS client connected to %s "
"with the following certificates: %s.",
args.server,
cert_path,
)

log(
DEBUG,
"Flower will load ServerApp `%s`",
getattr(args, "server-app"),
)

log(
DEBUG,
"root_certificates: `%s`",
root_certificates,
)

log(WARN, "Not implemented: run_server_app")

server_app_dir = args.dir
if server_app_dir is not None:
sys.path.insert(0, server_app_dir)

def _load() -> ServerApp:
server_app: ServerApp = load_server_app(getattr(args, "server-app"))
return server_app

server_app = _load()

log(DEBUG, "server_app: `%s`", server_app)

event(EventType.RUN_SERVER_APP_LEAVE)


def start_server( # pylint: disable=too-many-arguments,too-many-locals
*,
server_address: str = ADDRESS_FLEET_API_GRPC_BIDI,
Expand Down Expand Up @@ -816,42 +749,3 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None:
type=int,
default=1,
)


def _parse_args_run_server_app() -> argparse.ArgumentParser:
"""Parse flower-server-app command line arguments."""
parser = argparse.ArgumentParser(
description="Start a Flower server app",
)

parser.add_argument(
"server-app",
help="For example: `server:app` or `project.package.module:wrapper.app`",
)
parser.add_argument(
"--insecure",
action="store_true",
help="Run the server app without HTTPS. By default, the app runs with "
"HTTPS enabled. Use this flag only if you understand the risks.",
)
parser.add_argument(
"--root-certificates",
metavar="ROOT_CERT",
type=str,
help="Specifies the path to the PEM-encoded root certificate file for "
"establishing secure HTTPS connections.",
)
parser.add_argument(
"--server",
default="0.0.0.0:9092",
help="Server address",
)
parser.add_argument(
"--dir",
default="",
help="Add specified directory to the PYTHONPATH and load Flower "
"app from there."
" Default: current working directory.",
)

return parser
22 changes: 22 additions & 0 deletions src/py/flwr/server/compat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower ServerApp compatibility package."""


from .app import start_driver as start_driver

__all__ = [
"start_driver",
]
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
from flwr.server.server_config import ServerConfig
from flwr.server.strategy import Strategy

from ..driver.grpc_driver import GrpcDriver
from .driver_client_proxy import DriverClientProxy
from .grpc_driver import GrpcDriver

DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"

Expand Down Expand Up @@ -111,10 +111,11 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
# Create the Driver
if isinstance(root_certificates, str):
root_certificates = Path(root_certificates).read_bytes()
driver = GrpcDriver(
grpc_driver = GrpcDriver(
driver_service_address=address, root_certificates=root_certificates
)
driver.connect()

grpc_driver.connect()
lock = threading.Lock()

# Initialize the Driver API server and config
Expand All @@ -134,7 +135,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
thread = threading.Thread(
target=update_client_manager,
args=(
driver,
grpc_driver,
initialized_server.client_manager(),
lock,
),
Expand All @@ -149,7 +150,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals

# Stop the Driver API server and the thread
with lock:
driver.disconnect()
grpc_driver.disconnect()
thread.join()

event(EventType.START_SERVER_LEAVE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
)
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.server.client_manager import SimpleClientManager
from flwr.server.driver.app import update_client_manager

from .app import update_client_manager


class TestClientManagerWithDriver(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
from flwr.server.client_proxy import ClientProxy

from .grpc_driver import GrpcDriver
from ..driver.grpc_driver import GrpcDriver

SLEEP_TIME = 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
Status,
)
from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
from flwr.server.driver.driver_client_proxy import DriverClientProxy

from .driver_client_proxy import DriverClientProxy

MESSAGE_PARAMETERS = Parameters(tensors=[b"abc"], tensor_type="np")

Expand Down
2 changes: 0 additions & 2 deletions src/py/flwr/server/driver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@
"""Flower driver SDK."""


from .app import start_driver
from .driver import Driver
from .grpc_driver import GrpcDriver

__all__ = [
"Driver",
"GrpcDriver",
"start_driver",
]
Loading

0 comments on commit 277ab1f

Please sign in to comment.