Skip to content

Commit

Permalink
feat(framework) Add grpc-adapter transport (#3540)
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Jun 19, 2024
1 parent 31afc93 commit 88b08f4
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 11 deletions.
4 changes: 4 additions & 0 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from flwr.common.address import parse_address
from flwr.common.constant import (
MISSING_EXTRA_REST,
TRANSPORT_TYPE_GRPC_ADAPTER,
TRANSPORT_TYPE_GRPC_BIDI,
TRANSPORT_TYPE_GRPC_RERE,
TRANSPORT_TYPE_REST,
Expand All @@ -41,6 +42,7 @@
from flwr.common.message import Error
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential

from .grpc_adapter_client.connection import grpc_adapter
from .grpc_client.connection import grpc_connection
from .grpc_rere_client.connection import grpc_request_response
from .message_handler.message_handler import handle_control_message
Expand Down Expand Up @@ -600,6 +602,8 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
connection, error_type = http_request_response, RequestsConnectionError
elif transport == TRANSPORT_TYPE_GRPC_RERE:
connection, error_type = grpc_request_response, RpcError
elif transport == TRANSPORT_TYPE_GRPC_ADAPTER:
connection, error_type = grpc_adapter, RpcError
elif transport == TRANSPORT_TYPE_GRPC_BIDI:
connection, error_type = grpc_connection, RpcError
else:
Expand Down
31 changes: 27 additions & 4 deletions src/py/flwr/client/supernode/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@
from flwr.client.client_app import ClientApp, LoadClientAppError
from flwr.common import EventType, event
from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir
from flwr.common.constant import (
TRANSPORT_TYPE_GRPC_ADAPTER,
TRANSPORT_TYPE_GRPC_RERE,
TRANSPORT_TYPE_REST,
)
from flwr.common.exit_handlers import register_exit_handlers
from flwr.common.logger import log, warn_deprecated_feature
from flwr.common.object_ref import load_app, validate
Expand All @@ -56,7 +61,7 @@ def run_supernode() -> None:
_start_client_internal(
server_address=args.superlink,
load_client_app_fn=load_fn,
transport="rest" if args.rest else "grpc-rere",
transport=args.transport,
root_certificates=root_certificates,
insecure=args.insecure,
authentication_keys=authentication_keys,
Expand Down Expand Up @@ -87,7 +92,7 @@ def run_client_app() -> None:
_start_client_internal(
server_address=args.superlink,
load_client_app_fn=load_fn,
transport="rest" if args.rest else "grpc-rere",
transport=args.transport,
root_certificates=root_certificates,
insecure=args.insecure,
authentication_keys=authentication_keys,
Expand Down Expand Up @@ -295,9 +300,27 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
help="Run the client without HTTPS. By default, the client runs with "
"HTTPS enabled. Use this flag only if you understand the risks.",
)
parser.add_argument(
ex_group = parser.add_mutually_exclusive_group()
ex_group.add_argument(
"--grpc-rere",
action="store_const",
dest="transport",
const=TRANSPORT_TYPE_GRPC_RERE,
default=TRANSPORT_TYPE_GRPC_RERE,
help="Use grpc-rere as a transport layer for the client.",
)
ex_group.add_argument(
"--grpc-adapter",
action="store_const",
dest="transport",
const=TRANSPORT_TYPE_GRPC_ADAPTER,
help="Use grpc-adapter as a transport layer for the client.",
)
ex_group.add_argument(
"--rest",
action="store_true",
action="store_const",
dest="transport",
const=TRANSPORT_TYPE_REST,
help="Use REST as a transport layer for the client.",
)
parser.add_argument(
Expand Down
4 changes: 4 additions & 0 deletions src/py/flwr/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

TRANSPORT_TYPE_GRPC_BIDI = "grpc-bidi"
TRANSPORT_TYPE_GRPC_RERE = "grpc-rere"
TRANSPORT_TYPE_GRPC_ADAPTER = "grpc-adapter"
TRANSPORT_TYPE_REST = "rest"
TRANSPORT_TYPE_VCE = "vce"
TRANSPORT_TYPES = [
Expand All @@ -45,6 +46,9 @@
PING_RANDOM_RANGE = (-0.1, 0.1)
PING_MAX_INTERVAL = 1e300

GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version"
GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit"

# Constants for FAB
APP_DIR = "apps"
FAB_CONFIG_FILE = "pyproject.toml"
Expand Down
54 changes: 47 additions & 7 deletions src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from flwr.common.address import parse_address
from flwr.common.constant import (
MISSING_EXTRA_REST,
TRANSPORT_TYPE_GRPC_ADAPTER,
TRANSPORT_TYPE_GRPC_RERE,
TRANSPORT_TYPE_REST,
)
Expand All @@ -48,13 +49,15 @@
from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
add_FleetServicer_to_server,
)
from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server

from .client_manager import ClientManager
from .history import History
from .server import Server, init_defaults, run_fl
from .server_config import ServerConfig
from .strategy import Strategy
from .superlink.driver.driver_grpc import run_driver_api_grpc
from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
from .superlink.fleet.grpc_bidi.grpc_server import (
generic_create_grpc_server,
start_grpc_server,
Expand Down Expand Up @@ -218,11 +221,13 @@ def run_superlink() -> None:
grpc_servers = [driver_server]
bckg_threads = []
if not args.fleet_api_address:
args.fleet_api_address = (
ADDRESS_FLEET_API_GRPC_RERE
if args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE
else ADDRESS_FLEET_API_REST
)
if args.fleet_api_type in [
TRANSPORT_TYPE_GRPC_RERE,
TRANSPORT_TYPE_GRPC_ADAPTER,
]:
args.fleet_api_address = ADDRESS_FLEET_API_GRPC_RERE
elif args.fleet_api_type == TRANSPORT_TYPE_REST:
args.fleet_api_address = ADDRESS_FLEET_API_REST

fleet_address, host, port = _format_address(args.fleet_api_address)

Expand Down Expand Up @@ -293,6 +298,13 @@ def run_superlink() -> None:
interceptors=interceptors,
)
grpc_servers.append(fleet_server)
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_ADAPTER:
fleet_server = _run_fleet_api_grpc_adapter(
address=fleet_address,
state_factory=state_factory,
certificates=certificates,
)
grpc_servers.append(fleet_server)
else:
raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")

Expand Down Expand Up @@ -419,7 +431,7 @@ def _try_obtain_certificates(
log(WARN, "Option `--insecure` was set. Starting insecure HTTP server.")
return None
# Check if certificates are provided
if args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
if args.fleet_api_type in [TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_GRPC_ADAPTER]:
if args.ssl_certfile and args.ssl_keyfile and args.ssl_ca_certfile:
if not isfile(args.ssl_ca_certfile):
sys.exit("Path argument `--ssl-ca-certfile` does not point to a file.")
Expand Down Expand Up @@ -491,6 +503,30 @@ def _run_fleet_api_grpc_rere(
return fleet_grpc_server


def _run_fleet_api_grpc_adapter(
address: str,
state_factory: StateFactory,
certificates: Optional[Tuple[bytes, bytes, bytes]],
) -> grpc.Server:
"""Run Fleet API (GrpcAdapter)."""
# Create Fleet API gRPC server
fleet_servicer = GrpcAdapterServicer(
state_factory=state_factory,
)
fleet_add_servicer_to_server_fn = add_GrpcAdapterServicer_to_server
fleet_grpc_server = generic_create_grpc_server(
servicer_and_add_fn=(fleet_servicer, fleet_add_servicer_to_server_fn),
server_address=address,
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
certificates=certificates,
)

log(INFO, "Flower ECE: Starting Fleet API (GrpcAdapter) on %s", address)
fleet_grpc_server.start()

return fleet_grpc_server


# pylint: disable=import-outside-toplevel,too-many-arguments
def _run_fleet_api_rest(
host: str,
Expand Down Expand Up @@ -606,7 +642,11 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None:
"--fleet-api-type",
default=TRANSPORT_TYPE_GRPC_RERE,
type=str,
choices=[TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_REST],
choices=[
TRANSPORT_TYPE_GRPC_RERE,
TRANSPORT_TYPE_GRPC_ADAPTER,
TRANSPORT_TYPE_REST,
],
help="Start a gRPC-rere or REST (experimental) Fleet API server.",
)
parser.add_argument(
Expand Down
4 changes: 4 additions & 0 deletions src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
)
from flwr.server.client_manager import ClientManager
from flwr.server.superlink.driver.driver_servicer import DriverServicer
from flwr.server.superlink.fleet.grpc_adapter.grpc_adapter_servicer import (
GrpcAdapterServicer,
)
from flwr.server.superlink.fleet.grpc_bidi.flower_service_servicer import (
FlowerServiceServicer,
)
Expand Down Expand Up @@ -154,6 +157,7 @@ def start_grpc_server( # pylint: disable=too-many-arguments
def generic_create_grpc_server( # pylint: disable=too-many-arguments
servicer_and_add_fn: Union[
Tuple[FleetServicer, AddServicerToServerFn],
Tuple[GrpcAdapterServicer, AddServicerToServerFn],
Tuple[FlowerServiceServicer, AddServicerToServerFn],
Tuple[DriverServicer, AddServicerToServerFn],
],
Expand Down

0 comments on commit 88b08f4

Please sign in to comment.