From e097ea438660d1e84842e46955e2ae230d844b7d Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 30 May 2024 09:43:31 +0100 Subject: [PATCH 01/25] add proto --- src/proto/flwr/proto/grpcadapter.proto | 28 +++++++++ src/py/flwr/proto/grpcadapter_pb2.py | 32 +++++++++++ src/py/flwr/proto/grpcadapter_pb2.pyi | 43 ++++++++++++++ src/py/flwr/proto/grpcadapter_pb2_grpc.py | 66 ++++++++++++++++++++++ src/py/flwr/proto/grpcadapter_pb2_grpc.pyi | 24 ++++++++ 5 files changed, 193 insertions(+) create mode 100644 src/proto/flwr/proto/grpcadapter.proto create mode 100644 src/py/flwr/proto/grpcadapter_pb2.py create mode 100644 src/py/flwr/proto/grpcadapter_pb2.pyi create mode 100644 src/py/flwr/proto/grpcadapter_pb2_grpc.py create mode 100644 src/py/flwr/proto/grpcadapter_pb2_grpc.pyi diff --git a/src/proto/flwr/proto/grpcadapter.proto b/src/proto/flwr/proto/grpcadapter.proto new file mode 100644 index 000000000000..53fd716787ab --- /dev/null +++ b/src/proto/flwr/proto/grpcadapter.proto @@ -0,0 +1,28 @@ +// Copyright 2024 Flower Labs GmbH. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================== + +syntax = "proto3"; + +package flwr.proto; + +service GrpcAdapter { + rpc SendReceive(MessageContainer) returns (MessageContainer) {} +} + +message MessageContainer { + map metadata = 1; + string grpc_message_name = 2; + bytes grpc_message_content = 3; +} \ No newline at end of file diff --git a/src/py/flwr/proto/grpcadapter_pb2.py b/src/py/flwr/proto/grpcadapter_pb2.py new file mode 100644 index 000000000000..7c0374736850 --- /dev/null +++ b/src/py/flwr/proto/grpcadapter_pb2.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: flwr/proto/grpcadapter.proto +# Protobuf Python Version: 4.25.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/grpcadapter.proto\x12\nflwr.proto\"\xba\x01\n\x10MessageContainer\x12<\n\x08metadata\x18\x01 \x03(\x0b\x32*.flwr.proto.MessageContainer.MetadataEntry\x12\x19\n\x11grpc_message_name\x18\x02 \x01(\t\x12\x1c\n\x14grpc_message_content\x18\x03 \x01(\x0c\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x0c:\x02\x38\x01\x32Z\n\x0bGrpcAdapter\x12K\n\x0bSendReceive\x12\x1c.flwr.proto.MessageContainer\x1a\x1c.flwr.proto.MessageContainer\"\x00\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.grpcadapter_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_MESSAGECONTAINER_METADATAENTRY']._options = None + _globals['_MESSAGECONTAINER_METADATAENTRY']._serialized_options = b'8\001' + _globals['_MESSAGECONTAINER']._serialized_start=45 + _globals['_MESSAGECONTAINER']._serialized_end=231 + _globals['_MESSAGECONTAINER_METADATAENTRY']._serialized_start=184 + _globals['_MESSAGECONTAINER_METADATAENTRY']._serialized_end=231 + _globals['_GRPCADAPTER']._serialized_start=233 + _globals['_GRPCADAPTER']._serialized_end=323 +# @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/grpcadapter_pb2.pyi b/src/py/flwr/proto/grpcadapter_pb2.pyi new file mode 100644 index 000000000000..d5f89ac27c4a --- /dev/null +++ b/src/py/flwr/proto/grpcadapter_pb2.pyi @@ -0,0 +1,43 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import builtins +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.message +import typing +import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +class MessageContainer(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + class MetadataEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + value: builtins.bytes + def __init__(self, + *, + key: typing.Text = ..., + value: builtins.bytes = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + METADATA_FIELD_NUMBER: builtins.int + GRPC_MESSAGE_NAME_FIELD_NUMBER: builtins.int + GRPC_MESSAGE_CONTENT_FIELD_NUMBER: builtins.int + @property + def metadata(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, builtins.bytes]: ... + grpc_message_name: typing.Text + grpc_message_content: builtins.bytes + def __init__(self, + *, + metadata: typing.Optional[typing.Mapping[typing.Text, builtins.bytes]] = ..., + grpc_message_name: typing.Text = ..., + grpc_message_content: builtins.bytes = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["grpc_message_content",b"grpc_message_content","grpc_message_name",b"grpc_message_name","metadata",b"metadata"]) -> None: ... +global___MessageContainer = MessageContainer diff --git a/src/py/flwr/proto/grpcadapter_pb2_grpc.py b/src/py/flwr/proto/grpcadapter_pb2_grpc.py new file mode 100644 index 000000000000..831f99d7b237 --- /dev/null +++ b/src/py/flwr/proto/grpcadapter_pb2_grpc.py @@ -0,0 +1,66 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from flwr.proto import grpcadapter_pb2 as flwr_dot_proto_dot_grpcadapter__pb2 + + +class GrpcAdapterStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.SendReceive = channel.unary_unary( + '/flwr.proto.GrpcAdapter/SendReceive', + request_serializer=flwr_dot_proto_dot_grpcadapter__pb2.MessageContainer.SerializeToString, + response_deserializer=flwr_dot_proto_dot_grpcadapter__pb2.MessageContainer.FromString, + ) + + +class GrpcAdapterServicer(object): + """Missing associated documentation comment in .proto file.""" + + def SendReceive(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_GrpcAdapterServicer_to_server(servicer, server): + rpc_method_handlers = { + 'SendReceive': grpc.unary_unary_rpc_method_handler( + servicer.SendReceive, + request_deserializer=flwr_dot_proto_dot_grpcadapter__pb2.MessageContainer.FromString, + response_serializer=flwr_dot_proto_dot_grpcadapter__pb2.MessageContainer.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'flwr.proto.GrpcAdapter', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class GrpcAdapter(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def SendReceive(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/flwr.proto.GrpcAdapter/SendReceive', + flwr_dot_proto_dot_grpcadapter__pb2.MessageContainer.SerializeToString, + flwr_dot_proto_dot_grpcadapter__pb2.MessageContainer.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/src/py/flwr/proto/grpcadapter_pb2_grpc.pyi b/src/py/flwr/proto/grpcadapter_pb2_grpc.pyi new file mode 100644 index 000000000000..640f983e6e04 --- /dev/null +++ b/src/py/flwr/proto/grpcadapter_pb2_grpc.pyi @@ -0,0 +1,24 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import abc +import flwr.proto.grpcadapter_pb2 +import grpc + +class GrpcAdapterStub: + def __init__(self, channel: grpc.Channel) -> None: ... + SendReceive: grpc.UnaryUnaryMultiCallable[ + flwr.proto.grpcadapter_pb2.MessageContainer, + flwr.proto.grpcadapter_pb2.MessageContainer] + + +class GrpcAdapterServicer(metaclass=abc.ABCMeta): + @abc.abstractmethod + def SendReceive(self, + request: flwr.proto.grpcadapter_pb2.MessageContainer, + context: grpc.ServicerContext, + ) -> flwr.proto.grpcadapter_pb2.MessageContainer: ... + + +def add_GrpcAdapterServicer_to_server(servicer: GrpcAdapterServicer, server: grpc.Server) -> None: ... From 2b3f44c4f2deae67f47e549902771ccafef98bff Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 30 May 2024 09:53:20 +0100 Subject: [PATCH 02/25] update test --- src/py/flwr_tool/protoc_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr_tool/protoc_test.py b/src/py/flwr_tool/protoc_test.py index 2d48582eb441..6aec4251c384 100644 --- a/src/py/flwr_tool/protoc_test.py +++ b/src/py/flwr_tool/protoc_test.py @@ -28,4 +28,4 @@ def test_directories() -> None: def test_proto_file_count() -> None: """Test if the correct number of proto files were captured by the glob.""" - assert len(PROTO_FILES) == 7 + assert len(PROTO_FILES) == 9 From a69ee45e02995a8efb510793f5485037c4761f36 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 30 May 2024 10:41:46 +0100 Subject: [PATCH 03/25] fix test error --- src/py/flwr_tool/protoc_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr_tool/protoc_test.py b/src/py/flwr_tool/protoc_test.py index 6aec4251c384..8dcf4c6474d6 100644 --- a/src/py/flwr_tool/protoc_test.py +++ b/src/py/flwr_tool/protoc_test.py @@ -28,4 +28,4 @@ def test_directories() -> None: def test_proto_file_count() -> None: """Test if the correct number of proto files were captured by the glob.""" - assert len(PROTO_FILES) == 9 + assert len(PROTO_FILES) == 8 From 58d823e31d51b819773a2070db3217a0194ee583 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 30 May 2024 16:44:57 +0100 Subject: [PATCH 04/25] add grpc adapter --- .../client/grpc_rere_client/grpc_adapter.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 src/py/flwr/client/grpc_rere_client/grpc_adapter.py diff --git a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py new file mode 100644 index 000000000000..979f3c05cf2a --- /dev/null +++ b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py @@ -0,0 +1,53 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Grpc Adapter.""" + + +from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub +from flwr.proto.grpcadapter_pb2 import MessageContainer +from typing import Any, TypeVar, Type, cast +from google.protobuf.message import Message as GrpcMessage +from flwr.proto.fleet_pb2 import CreateNodeRequest, CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, PingRequest, PingResponse, PullTaskInsResponse, PullTaskInsRequest, PushTaskResRequest, PushTaskResResponse, GetRunRequest, GetRunResponse +import flwr + + +KEY_FLOWER_VERSION = "flower-version" +T = TypeVar("T", bound=GrpcMessage) + + +class GrpcAdapter: + def __init__(self, channel: Any) -> None: + self.stub = GrpcAdapterStub(channel) + + def _send_and_receive(self, request: GrpcMessage, response_type: Type[T]) -> T: + container_req = MessageContainer( + metadata={KEY_FLOWER_VERSION: flwr.__version__}, + grpc_message_name=request.__class__.__qualname__, + grpc_message_content=request.SerializeToString() + ) + container_res = cast(MessageContainer, self.stub.SendReceive(container_req)) + if container_res.grpc_message_name != response_type.__qualname__: + raise ValueError( + f"Invalid grpc_message_name. Expected {response_type.__qualname__}" + f", but got {container_res.grpc_message_name}." + ) + response = response_type() + response.ParseFromString(container_res) + return response + + def CreateNode(): + ... + + From 88e69ceeb9387cd7f4e24237e6b190461f228578 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 30 May 2024 17:02:02 +0100 Subject: [PATCH 05/25] complete grpc adapter class --- .../client/grpc_rere_client/grpc_adapter.py | 55 +++++++++++++++---- 1 file changed, 45 insertions(+), 10 deletions(-) diff --git a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py index 979f3c05cf2a..71ae332e2c68 100644 --- a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py +++ b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py @@ -15,19 +15,35 @@ """Grpc Adapter.""" -from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub -from flwr.proto.grpcadapter_pb2 import MessageContainer -from typing import Any, TypeVar, Type, cast +from typing import Any, Type, TypeVar, cast + from google.protobuf.message import Message as GrpcMessage -from flwr.proto.fleet_pb2 import CreateNodeRequest, CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, PingRequest, PingResponse, PullTaskInsResponse, PullTaskInsRequest, PushTaskResRequest, PushTaskResResponse, GetRunRequest, GetRunResponse -import flwr +import flwr +from flwr.proto.fleet_pb2 import ( + CreateNodeRequest, + CreateNodeResponse, + DeleteNodeRequest, + DeleteNodeResponse, + GetRunRequest, + GetRunResponse, + PingRequest, + PingResponse, + PullTaskInsRequest, + PullTaskInsResponse, + PushTaskResRequest, + PushTaskResResponse, +) +from flwr.proto.grpcadapter_pb2 import MessageContainer +from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub KEY_FLOWER_VERSION = "flower-version" T = TypeVar("T", bound=GrpcMessage) class GrpcAdapter: + """Grpc Adapter.""" + def __init__(self, channel: Any) -> None: self.stub = GrpcAdapterStub(channel) @@ -35,7 +51,7 @@ def _send_and_receive(self, request: GrpcMessage, response_type: Type[T]) -> T: container_req = MessageContainer( metadata={KEY_FLOWER_VERSION: flwr.__version__}, grpc_message_name=request.__class__.__qualname__, - grpc_message_content=request.SerializeToString() + grpc_message_content=request.SerializeToString(), ) container_res = cast(MessageContainer, self.stub.SendReceive(container_req)) if container_res.grpc_message_name != response_type.__qualname__: @@ -44,10 +60,29 @@ def _send_and_receive(self, request: GrpcMessage, response_type: Type[T]) -> T: f", but got {container_res.grpc_message_name}." ) response = response_type() - response.ParseFromString(container_res) + response.ParseFromString(container_res.grpc_message_content) return response - - def CreateNode(): - ... + def CreateNode(self, request: CreateNodeRequest) -> CreateNodeResponse: + """.""" + return self._send_and_receive(request, CreateNodeResponse) + + def DeleteNode(self, request: DeleteNodeRequest) -> DeleteNodeResponse: + """.""" + return self._send_and_receive(request, DeleteNodeResponse) + + def Ping(self, request: PingRequest) -> PingResponse: + """.""" + return self._send_and_receive(request, PingResponse) + + def PullTaskIns(self, request: PullTaskInsRequest) -> PullTaskInsResponse: + """.""" + return self._send_and_receive(request, PullTaskInsResponse) + + def PushTaskRes(self, request: PushTaskResRequest) -> PushTaskResResponse: + """.""" + return self._send_and_receive(request, PushTaskResResponse) + def GetRun(self, request: GetRunRequest) -> GetRunResponse: + """.""" + return self._send_and_receive(request, GetRunResponse) From 6783995eadc7075efa65db443d4ab01eed930f91 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 30 May 2024 17:03:23 +0100 Subject: [PATCH 06/25] fix type --- src/proto/flwr/proto/grpcadapter.proto | 2 +- src/py/flwr/proto/grpcadapter_pb2.py | 2 +- src/py/flwr/proto/grpcadapter_pb2.pyi | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/proto/flwr/proto/grpcadapter.proto b/src/proto/flwr/proto/grpcadapter.proto index 826efec4315e..acf9a9d3d94f 100644 --- a/src/proto/flwr/proto/grpcadapter.proto +++ b/src/proto/flwr/proto/grpcadapter.proto @@ -22,7 +22,7 @@ service GrpcAdapter { } message MessageContainer { - map metadata = 1; + map metadata = 1; string grpc_message_name = 2; bytes grpc_message_content = 3; } diff --git a/src/py/flwr/proto/grpcadapter_pb2.py b/src/py/flwr/proto/grpcadapter_pb2.py index 7c0374736850..2eff4bb78e47 100644 --- a/src/py/flwr/proto/grpcadapter_pb2.py +++ b/src/py/flwr/proto/grpcadapter_pb2.py @@ -14,7 +14,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/grpcadapter.proto\x12\nflwr.proto\"\xba\x01\n\x10MessageContainer\x12<\n\x08metadata\x18\x01 \x03(\x0b\x32*.flwr.proto.MessageContainer.MetadataEntry\x12\x19\n\x11grpc_message_name\x18\x02 \x01(\t\x12\x1c\n\x14grpc_message_content\x18\x03 \x01(\x0c\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x0c:\x02\x38\x01\x32Z\n\x0bGrpcAdapter\x12K\n\x0bSendReceive\x12\x1c.flwr.proto.MessageContainer\x1a\x1c.flwr.proto.MessageContainer\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/grpcadapter.proto\x12\nflwr.proto\"\xba\x01\n\x10MessageContainer\x12<\n\x08metadata\x18\x01 \x03(\x0b\x32*.flwr.proto.MessageContainer.MetadataEntry\x12\x19\n\x11grpc_message_name\x18\x02 \x01(\t\x12\x1c\n\x14grpc_message_content\x18\x03 \x01(\x0c\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x32Z\n\x0bGrpcAdapter\x12K\n\x0bSendReceive\x12\x1c.flwr.proto.MessageContainer\x1a\x1c.flwr.proto.MessageContainer\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) diff --git a/src/py/flwr/proto/grpcadapter_pb2.pyi b/src/py/flwr/proto/grpcadapter_pb2.pyi index d5f89ac27c4a..35889b30d2b6 100644 --- a/src/py/flwr/proto/grpcadapter_pb2.pyi +++ b/src/py/flwr/proto/grpcadapter_pb2.pyi @@ -18,11 +18,11 @@ class MessageContainer(google.protobuf.message.Message): KEY_FIELD_NUMBER: builtins.int VALUE_FIELD_NUMBER: builtins.int key: typing.Text - value: builtins.bytes + value: typing.Text def __init__(self, *, key: typing.Text = ..., - value: builtins.bytes = ..., + value: typing.Text = ..., ) -> None: ... def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... @@ -30,12 +30,12 @@ class MessageContainer(google.protobuf.message.Message): GRPC_MESSAGE_NAME_FIELD_NUMBER: builtins.int GRPC_MESSAGE_CONTENT_FIELD_NUMBER: builtins.int @property - def metadata(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, builtins.bytes]: ... + def metadata(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, typing.Text]: ... grpc_message_name: typing.Text grpc_message_content: builtins.bytes def __init__(self, *, - metadata: typing.Optional[typing.Mapping[typing.Text, builtins.bytes]] = ..., + metadata: typing.Optional[typing.Mapping[typing.Text, typing.Text]] = ..., grpc_message_name: typing.Text = ..., grpc_message_content: builtins.bytes = ..., ) -> None: ... From 6af6680b4682e1aa2a2bf581c9e745c5fb3b655b Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 30 May 2024 17:17:14 +0100 Subject: [PATCH 07/25] fix check errors --- src/py/flwr/client/grpc_rere_client/grpc_adapter.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py index 71ae332e2c68..7635e9c0c734 100644 --- a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py +++ b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py @@ -20,7 +20,7 @@ from google.protobuf.message import Message as GrpcMessage import flwr -from flwr.proto.fleet_pb2 import ( +from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, CreateNodeResponse, DeleteNodeRequest, @@ -34,7 +34,7 @@ PushTaskResRequest, PushTaskResResponse, ) -from flwr.proto.grpcadapter_pb2 import MessageContainer +from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611 from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub KEY_FLOWER_VERSION = "flower-version" @@ -63,26 +63,32 @@ def _send_and_receive(self, request: GrpcMessage, response_type: Type[T]) -> T: response.ParseFromString(container_res.grpc_message_content) return response + # pylint: disable-next=C0103 def CreateNode(self, request: CreateNodeRequest) -> CreateNodeResponse: """.""" return self._send_and_receive(request, CreateNodeResponse) + # pylint: disable-next=C0103 def DeleteNode(self, request: DeleteNodeRequest) -> DeleteNodeResponse: """.""" return self._send_and_receive(request, DeleteNodeResponse) + # pylint: disable-next=C0103 def Ping(self, request: PingRequest) -> PingResponse: """.""" return self._send_and_receive(request, PingResponse) + # pylint: disable-next=C0103 def PullTaskIns(self, request: PullTaskInsRequest) -> PullTaskInsResponse: """.""" return self._send_and_receive(request, PullTaskInsResponse) + # pylint: disable-next=C0103 def PushTaskRes(self, request: PushTaskResRequest) -> PushTaskResResponse: """.""" return self._send_and_receive(request, PushTaskResResponse) + # pylint: disable-next=C0103 def GetRun(self, request: GetRunRequest) -> GetRunResponse: """.""" return self._send_and_receive(request, GetRunResponse) From f61b550f3567c52e831640f1f6e224ecea3f0d03 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 30 May 2024 20:09:20 +0100 Subject: [PATCH 08/25] update doc string --- src/py/flwr/client/grpc_rere_client/grpc_adapter.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py index 7635e9c0c734..f773d788ca67 100644 --- a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py +++ b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py @@ -42,7 +42,11 @@ class GrpcAdapter: - """Grpc Adapter.""" + """The Adapter class to send and receive gRPC messages via the GrpcAdapterStub. + + This class utilizes the GrpcAdapterStub to send and receive gRPC messages which are + defined and used by the Fleet API, as defined in `fleet.proto`. + """ def __init__(self, channel: Any) -> None: self.stub = GrpcAdapterStub(channel) @@ -63,32 +67,27 @@ def _send_and_receive(self, request: GrpcMessage, response_type: Type[T]) -> T: response.ParseFromString(container_res.grpc_message_content) return response - # pylint: disable-next=C0103 + # pylint: disable=C0103 def CreateNode(self, request: CreateNodeRequest) -> CreateNodeResponse: """.""" return self._send_and_receive(request, CreateNodeResponse) - # pylint: disable-next=C0103 def DeleteNode(self, request: DeleteNodeRequest) -> DeleteNodeResponse: """.""" return self._send_and_receive(request, DeleteNodeResponse) - # pylint: disable-next=C0103 def Ping(self, request: PingRequest) -> PingResponse: """.""" return self._send_and_receive(request, PingResponse) - # pylint: disable-next=C0103 def PullTaskIns(self, request: PullTaskInsRequest) -> PullTaskInsResponse: """.""" return self._send_and_receive(request, PullTaskInsResponse) - # pylint: disable-next=C0103 def PushTaskRes(self, request: PushTaskResRequest) -> PushTaskResResponse: """.""" return self._send_and_receive(request, PushTaskResResponse) - # pylint: disable-next=C0103 def GetRun(self, request: GetRunRequest) -> GetRunResponse: """.""" return self._send_and_receive(request, GetRunResponse) From 98cfda422ad516e506b1fedbafe028e06b928f35 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Fri, 31 May 2024 13:06:54 +0100 Subject: [PATCH 09/25] add servicer --- .../superlink/fleet/grpc_adapter/__init__.py | 15 ++ .../grpc_adapter/grpc_adapter_servicer.py | 132 ++++++++++++++++++ 2 files changed, 147 insertions(+) create mode 100644 src/py/flwr/server/superlink/fleet/grpc_adapter/__init__.py create mode 100644 src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py diff --git a/src/py/flwr/server/superlink/fleet/grpc_adapter/__init__.py b/src/py/flwr/server/superlink/fleet/grpc_adapter/__init__.py new file mode 100644 index 000000000000..43a63baf921a --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/grpc_adapter/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Server-side part of the gRPC transport layer using GrpcAdapter.""" diff --git a/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py new file mode 100644 index 000000000000..c7fb4030f0d9 --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py @@ -0,0 +1,132 @@ +# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Fleet API gRPC adapter servicer.""" + + +from logging import DEBUG, INFO +from typing import Callable, Type, TypeVar + +import grpc +from google.protobuf.message import Message as GrpcMessage + +from flwr.common.logger import log +from flwr.proto import grpcadapter_pb2_grpc # pylint: disable=E0611 +from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 + CreateNodeRequest, + CreateNodeResponse, + DeleteNodeRequest, + DeleteNodeResponse, + GetRunRequest, + GetRunResponse, + PingRequest, + PingResponse, + PullTaskInsRequest, + PullTaskInsResponse, + PushTaskResRequest, + PushTaskResResponse, +) +from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611 +from flwr.server.superlink.fleet.message_handler import message_handler +from flwr.server.superlink.state import StateFactory + +T = TypeVar("T", bound=GrpcMessage) + + +def _handle( + msg_container: MessageContainer, + request_type: Type[T], + handler: Callable[[T], GrpcMessage], +) -> MessageContainer: + req = request_type.FromString(msg_container.grpc_message_content) + res = handler(req) + return MessageContainer( + metadata={}, + grpc_message_name=res.__class__.__qualname__, + grpc_message_content=res.SerializeToString(), + ) + + +class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer): + """Fleet API via GrpcAdapter servicer.""" + + def __init__(self, state_factory: StateFactory) -> None: + self.state_factory = state_factory + + def SendReceive( + self, request: MessageContainer, context: grpc.ServicerContext + ) -> MessageContainer: + """.""" + log(DEBUG, "GrpcAdapterServicer.SendReceive") + if request.grpc_message_name == CreateNodeRequest.__qualname__: + return _handle(request, CreateNodeRequest, self._create_node) + if request.grpc_message_name == DeleteNodeRequest.__qualname__: + return _handle(request, DeleteNodeRequest, self._delete_node) + if request.grpc_message_name == PingRequest.__qualname__: + return _handle(request, PingRequest, self._ping) + if request.grpc_message_name == PullTaskInsRequest.__qualname__: + return _handle(request, PullTaskInsRequest, self._pull_task_ins) + if request.grpc_message_name == PushTaskResRequest.__qualname__: + return _handle(request, PushTaskResRequest, self._push_task_res) + if request.grpc_message_name == GetRunRequest.__qualname__: + return _handle(request, GetRunRequest, self._get_run) + raise ValueError(f"Invalid grpc_message_name: {request.grpc_message_name}") + + def _create_node(self, request: CreateNodeRequest) -> CreateNodeResponse: + """.""" + log(INFO, "GrpcAdapter.CreateNode") + return message_handler.create_node( + request=request, + state=self.state_factory.state(), + ) + + def _delete_node(self, request: DeleteNodeRequest) -> DeleteNodeResponse: + """.""" + log(INFO, "GrpcAdapter.DeleteNode") + return message_handler.delete_node( + request=request, + state=self.state_factory.state(), + ) + + def _ping(self, request: PingRequest) -> PingResponse: + """.""" + log(DEBUG, "GrpcAdapter.Ping") + return message_handler.ping( + request=request, + state=self.state_factory.state(), + ) + + def _pull_task_ins(self, request: PullTaskInsRequest) -> PullTaskInsResponse: + """Pull TaskIns.""" + log(INFO, "GrpcAdapter.PullTaskIns") + return message_handler.pull_task_ins( + request=request, + state=self.state_factory.state(), + ) + + def _push_task_res(self, request: PushTaskResRequest) -> PushTaskResResponse: + """Push TaskRes.""" + log(INFO, "GrpcAdapter.PushTaskRes") + return message_handler.push_task_res( + request=request, + state=self.state_factory.state(), + ) + + def _get_run(self, request: GetRunRequest) -> GetRunResponse: + """Get run information.""" + log(INFO, "GrpcAdapter.GetRun") + return message_handler.get_run( + request=request, + state=self.state_factory.state(), + ) From 05af29c0dcf335359ffa85eea6f94c6fde29f5bf Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Fri, 31 May 2024 13:47:20 +0100 Subject: [PATCH 10/25] make directories --- .../client/grpc_adapter_client/__init__.py | 15 ++++ .../client/grpc_adapter_client/connection.py | 87 +++++++++++++++++++ .../client/grpc_rere_client/connection.py | 3 +- .../client/grpc_rere_client/grpc_adapter.py | 44 ++++++---- 4 files changed, 132 insertions(+), 17 deletions(-) create mode 100644 src/py/flwr/client/grpc_adapter_client/__init__.py create mode 100644 src/py/flwr/client/grpc_adapter_client/connection.py diff --git a/src/py/flwr/client/grpc_adapter_client/__init__.py b/src/py/flwr/client/grpc_adapter_client/__init__.py new file mode 100644 index 000000000000..5900e2dc2d06 --- /dev/null +++ b/src/py/flwr/client/grpc_adapter_client/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Client-side part of the GrpcAdapter transport layer.""" diff --git a/src/py/flwr/client/grpc_adapter_client/connection.py b/src/py/flwr/client/grpc_adapter_client/connection.py new file mode 100644 index 000000000000..c0566d1d406c --- /dev/null +++ b/src/py/flwr/client/grpc_adapter_client/connection.py @@ -0,0 +1,87 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Contextmanager for a GrpcAdapter channel to the Flower server.""" + + +from contextlib import contextmanager +from typing import Callable, Iterator, Optional, Tuple, Union + +from cryptography.hazmat.primitives.asymmetric import ec + +from flwr.client.grpc_rere_client.connection import grpc_request_response +from flwr.client.grpc_rere_client.grpc_adapter import GrpcAdapter +from flwr.common import GRPC_MAX_MESSAGE_LENGTH +from flwr.common.message import Message +from flwr.common.retry_invoker import RetryInvoker + + +@contextmanager +def grpc_adapter( # pylint: disable=R0913 + server_address: str, + insecure: bool, + retry_invoker: RetryInvoker, + max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613 + root_certificates: Optional[Union[bytes, str]] = None, + authentication_keys: Optional[ # pylint: disable=unused-argument + Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] + ] = None, +) -> Iterator[ + Tuple[ + Callable[[], Optional[Message]], + Callable[[Message], None], + Optional[Callable[[], None]], + Optional[Callable[[], None]], + Optional[Callable[[int], Tuple[str, str]]], + ] +]: + """Primitives for request/response-based interaction with a server via GrpcAdapter. + + Parameters + ---------- + server_address : str + The IPv6 address of the server with `http://` or `https://`. + If the Flower server runs on the same machine + on port 8080, then `server_address` would be `"http://[::]:8080"`. + insecure : bool + Starts an insecure gRPC connection when True. Enables HTTPS connection + when False, using system certificates if `root_certificates` is None. + retry_invoker: RetryInvoker + `RetryInvoker` object that will try to reconnect the client to the server + after gRPC errors. If None, the client will only try to + reconnect once after a failure. + max_message_length : int + Ignored, only present to preserve API-compatibility. + root_certificates : Optional[Union[bytes, str]] (default: None) + Path of the root certificate. If provided, a secure + connection using the certificates will be established to an SSL-enabled + Flower server. Bytes won't work for the REST API. + + Returns + ------- + receive : Callable + send : Callable + create_node : Optional[Callable] + delete_node : Optional[Callable] + """ + with grpc_request_response( + server_address=server_address, + insecure=insecure, + retry_invoker=retry_invoker, + max_message_length=max_message_length, + root_certificates=root_certificates, + authentication_keys=None, # Authentication is not supported + adapter_cls=GrpcAdapter, + ) as conn: + yield conn diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 8ef8e7ebf62a..19bf0c88749d 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -56,6 +56,7 @@ from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 from .client_interceptor import AuthenticateClientInterceptor +from .grpc_adapter import GrpcAdapter def on_channel_state_change(channel_connectivity: str) -> None: @@ -73,7 +74,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915 authentication_keys: Optional[ Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] ] = None, - adapter_cls: Optional[Type[FleetStub]] = None, + adapter_cls: Optional[Union[Type[FleetStub], Type[GrpcAdapter]]] = None, ) -> Iterator[ Tuple[ Callable[[], Optional[Message]], diff --git a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py index f773d788ca67..8ae10b21a58d 100644 --- a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py +++ b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py @@ -19,7 +19,7 @@ from google.protobuf.message import Message as GrpcMessage -import flwr +from flwr.common.version import package_version from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, CreateNodeResponse, @@ -51,13 +51,17 @@ class GrpcAdapter: def __init__(self, channel: Any) -> None: self.stub = GrpcAdapterStub(channel) - def _send_and_receive(self, request: GrpcMessage, response_type: Type[T]) -> T: + def _send_and_receive( + self, request: GrpcMessage, response_type: Type[T], **kwargs: Any + ) -> T: container_req = MessageContainer( - metadata={KEY_FLOWER_VERSION: flwr.__version__}, + metadata={KEY_FLOWER_VERSION: package_version}, grpc_message_name=request.__class__.__qualname__, grpc_message_content=request.SerializeToString(), ) - container_res = cast(MessageContainer, self.stub.SendReceive(container_req)) + container_res = cast( + MessageContainer, self.stub.SendReceive(container_req, **kwargs) + ) if container_res.grpc_message_name != response_type.__qualname__: raise ValueError( f"Invalid grpc_message_name. Expected {response_type.__qualname__}" @@ -68,26 +72,34 @@ def _send_and_receive(self, request: GrpcMessage, response_type: Type[T]) -> T: return response # pylint: disable=C0103 - def CreateNode(self, request: CreateNodeRequest) -> CreateNodeResponse: + def CreateNode( + self, request: CreateNodeRequest, **kwargs: Any + ) -> CreateNodeResponse: """.""" - return self._send_and_receive(request, CreateNodeResponse) + return self._send_and_receive(request, CreateNodeResponse, **kwargs) - def DeleteNode(self, request: DeleteNodeRequest) -> DeleteNodeResponse: + def DeleteNode( + self, request: DeleteNodeRequest, **kwargs: Any + ) -> DeleteNodeResponse: """.""" - return self._send_and_receive(request, DeleteNodeResponse) + return self._send_and_receive(request, DeleteNodeResponse, **kwargs) - def Ping(self, request: PingRequest) -> PingResponse: + def Ping(self, request: PingRequest, **kwargs: Any) -> PingResponse: """.""" - return self._send_and_receive(request, PingResponse) + return self._send_and_receive(request, PingResponse, **kwargs) - def PullTaskIns(self, request: PullTaskInsRequest) -> PullTaskInsResponse: + def PullTaskIns( + self, request: PullTaskInsRequest, **kwargs: Any + ) -> PullTaskInsResponse: """.""" - return self._send_and_receive(request, PullTaskInsResponse) + return self._send_and_receive(request, PullTaskInsResponse, **kwargs) - def PushTaskRes(self, request: PushTaskResRequest) -> PushTaskResResponse: + def PushTaskRes( + self, request: PushTaskResRequest, **kwargs: Any + ) -> PushTaskResResponse: """.""" - return self._send_and_receive(request, PushTaskResResponse) + return self._send_and_receive(request, PushTaskResResponse, **kwargs) - def GetRun(self, request: GetRunRequest) -> GetRunResponse: + def GetRun(self, request: GetRunRequest, **kwargs: Any) -> GetRunResponse: """.""" - return self._send_and_receive(request, GetRunResponse) + return self._send_and_receive(request, GetRunResponse, **kwargs) From 6c3ff4c020e9f8b0a884f4a9c31ac3c43382218b Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Fri, 31 May 2024 13:50:11 +0100 Subject: [PATCH 11/25] fix date, update doc string --- src/py/flwr/server/superlink/fleet/grpc_adapter/__init__.py | 4 ++-- .../superlink/fleet/grpc_adapter/grpc_adapter_servicer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/server/superlink/fleet/grpc_adapter/__init__.py b/src/py/flwr/server/superlink/fleet/grpc_adapter/__init__.py index 43a63baf921a..cf875a1b9666 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_adapter/__init__.py +++ b/src/py/flwr/server/superlink/fleet/grpc_adapter/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Server-side part of the gRPC transport layer using GrpcAdapter.""" +"""Server-side part of the GrpcAdapter transport layer.""" diff --git a/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py index c7fb4030f0d9..c0426985cb4f 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 4cd4a4f98b30336d3b06f6c9803e1fec03773471 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Fri, 31 May 2024 15:54:24 +0100 Subject: [PATCH 12/25] amend run_client_app and run_superlink --- src/py/flwr/client/app.py | 4 +++ src/py/flwr/client/supernode/app.py | 29 +++++++++++++++-- src/py/flwr/common/constant.py | 1 + src/py/flwr/server/app.py | 50 ++++++++++++++++++++++++++++- 4 files changed, 80 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index d7c05d8afbb2..8b7ab1367c74 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -29,6 +29,7 @@ from flwr.common.address import parse_address from flwr.common.constant import ( MISSING_EXTRA_REST, + TRANSPORT_TYPE_GRPC_ADAPTER, TRANSPORT_TYPE_GRPC_BIDI, TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_REST, @@ -39,6 +40,7 @@ from flwr.common.message import Error from flwr.common.retry_invoker import RetryInvoker, exponential +from .grpc_adapter_client.connection import grpc_adapter from .grpc_client.connection import grpc_connection from .grpc_rere_client.connection import grpc_request_response from .message_handler.message_handler import handle_control_message @@ -571,6 +573,8 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ connection, error_type = http_request_response, RequestsConnectionError elif transport == TRANSPORT_TYPE_GRPC_RERE: connection, error_type = grpc_request_response, RpcError + elif transport == TRANSPORT_TYPE_GRPC_ADAPTER: + connection, error_type = grpc_adapter, RpcError elif transport == TRANSPORT_TYPE_GRPC_BIDI: connection, error_type = grpc_connection, RpcError else: diff --git a/src/py/flwr/client/supernode/app.py b/src/py/flwr/client/supernode/app.py index ac58e9aa4a81..5d02d18001f6 100644 --- a/src/py/flwr/client/supernode/app.py +++ b/src/py/flwr/client/supernode/app.py @@ -29,6 +29,11 @@ from flwr.client.client_app import ClientApp, LoadClientAppError from flwr.common import EventType, event +from flwr.common.constant import ( + TRANSPORT_TYPE_GRPC_ADAPTER, + TRANSPORT_TYPE_GRPC_RERE, + TRANSPORT_TYPE_REST, +) from flwr.common.exit_handlers import register_exit_handlers from flwr.common.logger import log from flwr.common.object_ref import load_app, validate @@ -75,7 +80,7 @@ def run_client_app() -> None: _start_client_internal( server_address=args.server, load_client_app_fn=load_fn, - transport="rest" if args.rest else "grpc-rere", + transport=args.transport, root_certificates=root_certificates, insecure=args.insecure, authentication_keys=authentication_keys, @@ -199,9 +204,27 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None: help="Run the client without HTTPS. By default, the client runs with " "HTTPS enabled. Use this flag only if you understand the risks.", ) - parser.add_argument( + ex_group = parser.add_mutually_exclusive_group() + ex_group.add_argument( + "--grpc-rere", + action="store_const", + dest="transport", + const=TRANSPORT_TYPE_GRPC_RERE, + default=TRANSPORT_TYPE_GRPC_RERE, + help="Use grpc-rere as a transport layer for the client.", + ) + ex_group.add_argument( + "--grpc-adapter", + action="store_const", + dest="transport", + const=TRANSPORT_TYPE_GRPC_ADAPTER, + help="Use grpc-adapter as a transport layer for the client.", + ) + ex_group.add_argument( "--rest", - action="store_true", + action="store_const", + dest="transport", + const=TRANSPORT_TYPE_REST, help="Use REST as a transport layer for the client.", ) parser.add_argument( diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index b6d39b6e8932..c0f671f0a649 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -27,6 +27,7 @@ TRANSPORT_TYPE_GRPC_BIDI = "grpc-bidi" TRANSPORT_TYPE_GRPC_RERE = "grpc-rere" +TRANSPORT_TYPE_GRPC_ADAPTER = "grpc-adapter" TRANSPORT_TYPE_REST = "rest" TRANSPORT_TYPE_VCE = "vce" TRANSPORT_TYPES = [ diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 147ec5fb0f65..c96fdecb0fcb 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -36,6 +36,7 @@ from flwr.common.address import parse_address from flwr.common.constant import ( MISSING_EXTRA_REST, + TRANSPORT_TYPE_GRPC_ADAPTER, TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_REST, ) @@ -48,6 +49,7 @@ from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611 add_FleetServicer_to_server, ) +from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server from .client_manager import ClientManager from .history import History @@ -55,6 +57,7 @@ from .server_config import ServerConfig from .strategy import Strategy from .superlink.driver.driver_grpc import run_driver_api_grpc +from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer from .superlink.fleet.grpc_bidi.grpc_server import ( generic_create_grpc_server, start_grpc_server, @@ -401,6 +404,20 @@ def run_superlink() -> None: interceptors=interceptors, ) grpc_servers.append(fleet_server) + elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_ADAPTER: + address_arg = args.grpc_rere_fleet_api_address + parsed_address = parse_address(address_arg) + if not parsed_address: + sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.") + host, port, is_v6 = parsed_address + address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}" + + fleet_server = _run_fleet_api_grpc_adapter( + address=address, + state_factory=state_factory, + certificates=certificates, + ) + grpc_servers.append(fleet_server) else: raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}") @@ -517,7 +534,7 @@ def _try_obtain_certificates( log(WARN, "Option `--insecure` was set. Starting insecure HTTP server.") return None # Check if certificates are provided - if args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE: + if args.fleet_api_type in [TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_GRPC_ADAPTER]: if args.ssl_certfile and args.ssl_keyfile and args.ssl_ca_certfile: if not isfile(args.ssl_ca_certfile): sys.exit("Path argument `--ssl-ca-certfile` does not point to a file.") @@ -589,6 +606,30 @@ def _run_fleet_api_grpc_rere( return fleet_grpc_server +def _run_fleet_api_grpc_adapter( + address: str, + state_factory: StateFactory, + certificates: Optional[Tuple[bytes, bytes, bytes]], +) -> grpc.Server: + """Run Fleet API (GrpcAdapter).""" + # Create Fleet API gRPC server + fleet_servicer = GrpcAdapterServicer( + state_factory=state_factory, + ) + fleet_add_servicer_to_server_fn = add_GrpcAdapterServicer_to_server + fleet_grpc_server = generic_create_grpc_server( + servicer_and_add_fn=(fleet_servicer, fleet_add_servicer_to_server_fn), + server_address=address, + max_message_length=GRPC_MAX_MESSAGE_LENGTH, + certificates=certificates, + ) + + log(INFO, "Flower ECE: Starting Fleet API (GrpcAdapter) on %s", address) + fleet_grpc_server.start() + + return fleet_grpc_server + + # pylint: disable=import-outside-toplevel,too-many-arguments def _run_fleet_api_rest( host: str, @@ -748,6 +789,13 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None: default=TRANSPORT_TYPE_GRPC_RERE, help="Start a Fleet API server (gRPC-rere)", ) + ex_group.add_argument( + "--grpc-adapter", + action="store_const", + dest="fleet_api_type", + const=TRANSPORT_TYPE_GRPC_ADAPTER, + help="Start a Fleet API server (GrpcAdapter, experimental)", + ) ex_group.add_argument( "--rest", action="store_const", From e21aa1e0b7498959a78ed6d345c48cbd4265f82a Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Fri, 31 May 2024 16:11:25 +0100 Subject: [PATCH 13/25] fix typing --- src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py index 6aeaa7ef413f..ae685fda91a7 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py @@ -29,6 +29,9 @@ ) from flwr.server.client_manager import ClientManager from flwr.server.superlink.driver.driver_servicer import DriverServicer +from flwr.server.superlink.fleet.grpc_adapter.grpc_adapter_servicer import ( + GrpcAdapterServicer, +) from flwr.server.superlink.fleet.grpc_bidi.flower_service_servicer import ( FlowerServiceServicer, ) @@ -154,6 +157,7 @@ def start_grpc_server( # pylint: disable=too-many-arguments def generic_create_grpc_server( # pylint: disable=too-many-arguments servicer_and_add_fn: Union[ Tuple[FleetServicer, AddServicerToServerFn], + Tuple[GrpcAdapterServicer, AddServicerToServerFn], Tuple[FlowerServiceServicer, AddServicerToServerFn], Tuple[DriverServicer, AddServicerToServerFn], ], From ac0975283e8b4273d0da0f0ba15e2dfc9b667aa3 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 3 Jun 2024 21:14:04 +0100 Subject: [PATCH 14/25] add arg for flower-superlink to specify grpc-adapter fleet api server address --- src/py/flwr/server/app.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index c96fdecb0fcb..382270b22454 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -814,6 +814,16 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None: default=ADDRESS_FLEET_API_GRPC_RERE, ) + # Fleet API gRPC-adapter options + grpc_adapter_group = parser.add_argument_group( + "Fleet API (gRPC-adapter) server options", "" + ) + grpc_adapter_group.add_argument( + "--grpc-adapter-fleet-api-address", + help="Fleet API (gRPC-adapter) server address (IPv4, IPv6, or a domain name)", + default=ADDRESS_FLEET_API_GRPC_RERE, + ) + # Fleet API REST options rest_group = parser.add_argument_group("Fleet API (REST) server options", "") rest_group.add_argument( From 49aae70da2d67bb0829673a1920219de1f56055f Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 3 Jun 2024 21:31:56 +0100 Subject: [PATCH 15/25] parse arg --- src/py/flwr/server/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 382270b22454..c8a52db1aae3 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -405,7 +405,7 @@ def run_superlink() -> None: ) grpc_servers.append(fleet_server) elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_ADAPTER: - address_arg = args.grpc_rere_fleet_api_address + address_arg = args.grpc_adapter_fleet_api_address parsed_address = parse_address(address_arg) if not parsed_address: sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.") From c3d35bfc7ad013eee4070a1369fdea067bb23120 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 4 Jun 2024 09:16:00 +0100 Subject: [PATCH 16/25] mv flwr-version-key to constant.py --- src/py/flwr/client/grpc_rere_client/grpc_adapter.py | 4 ++-- src/py/flwr/common/constant.py | 2 ++ .../superlink/fleet/grpc_adapter/grpc_adapter_servicer.py | 4 +++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py index 8ae10b21a58d..823949e48278 100644 --- a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py +++ b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py @@ -19,6 +19,7 @@ from google.protobuf.message import Message as GrpcMessage +from flwr.common.constant import GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY from flwr.common.version import package_version from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, @@ -37,7 +38,6 @@ from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611 from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub -KEY_FLOWER_VERSION = "flower-version" T = TypeVar("T", bound=GrpcMessage) @@ -55,7 +55,7 @@ def _send_and_receive( self, request: GrpcMessage, response_type: Type[T], **kwargs: Any ) -> T: container_req = MessageContainer( - metadata={KEY_FLOWER_VERSION: package_version}, + metadata={GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY: package_version}, grpc_message_name=request.__class__.__qualname__, grpc_message_content=request.SerializeToString(), ) diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index c0f671f0a649..59ff0e99a74c 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -44,6 +44,8 @@ PING_RANDOM_RANGE = (-0.1, 0.1) PING_MAX_INTERVAL = 1e300 +GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version" + class MessageType: """Message type.""" diff --git a/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py index c0426985cb4f..c29d0008c441 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py @@ -21,7 +21,9 @@ import grpc from google.protobuf.message import Message as GrpcMessage +from flwr.common.constant import GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY from flwr.common.logger import log +from flwr.common.version import package_version from flwr.proto import grpcadapter_pb2_grpc # pylint: disable=E0611 from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, @@ -52,7 +54,7 @@ def _handle( req = request_type.FromString(msg_container.grpc_message_content) res = handler(req) return MessageContainer( - metadata={}, + metadata={GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY: package_version}, grpc_message_name=res.__class__.__qualname__, grpc_message_content=res.SerializeToString(), ) From f4e097821fb60e7d88bc3ee0d4296c846d47b15b Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 10 Jun 2024 09:59:52 +0100 Subject: [PATCH 17/25] mv constants to constant.py and handle control message --- .../client/grpc_rere_client/grpc_adapter.py | 25 +++++++++++++++++-- src/py/flwr/common/constant.py | 4 +++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py index 8ae10b21a58d..d17c5fe991ef 100644 --- a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py +++ b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py @@ -15,10 +15,17 @@ """Grpc Adapter.""" +import sys +from logging import DEBUG from typing import Any, Type, TypeVar, cast from google.protobuf.message import Message as GrpcMessage +from flwr.common import log +from flwr.common.constant import ( + GRPC_ADAPTER_METADATA_EXIT_FLAG_KEY, + GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY, +) from flwr.common.version import package_version from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, @@ -37,7 +44,6 @@ from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611 from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub -KEY_FLOWER_VERSION = "flower-version" T = TypeVar("T", bound=GrpcMessage) @@ -54,19 +60,34 @@ def __init__(self, channel: Any) -> None: def _send_and_receive( self, request: GrpcMessage, response_type: Type[T], **kwargs: Any ) -> T: + # Serialize request container_req = MessageContainer( - metadata={KEY_FLOWER_VERSION: package_version}, + metadata={GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY: package_version}, grpc_message_name=request.__class__.__qualname__, grpc_message_content=request.SerializeToString(), ) + + # Send via the stub container_res = cast( MessageContainer, self.stub.SendReceive(container_req, **kwargs) ) + + # Handle control message + exit_flag = container_res.metadata.get( + GRPC_ADAPTER_METADATA_EXIT_FLAG_KEY, False + ) + if exit_flag: + log(DEBUG, "Exit flag is set to True, exiting...") + sys.exit(0) + + # Check the grpc_message_name of the response if container_res.grpc_message_name != response_type.__qualname__: raise ValueError( f"Invalid grpc_message_name. Expected {response_type.__qualname__}" f", but got {container_res.grpc_message_name}." ) + + # Deserialize response response = response_type() response.ParseFromString(container_res.grpc_message_content) return response diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index b6d39b6e8932..711d3e95dabc 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -44,6 +44,10 @@ PING_MAX_INTERVAL = 1e300 +GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version" +GRPC_ADAPTER_METADATA_EXIT_FLAG_KEY = "exit-flag" + + class MessageType: """Message type.""" From 0967d483d8abb3c4c80c85f8e1222ef1c0121e85 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 10 Jun 2024 10:19:02 +0100 Subject: [PATCH 18/25] merge main --- src/py/flwr/server/app.py | 67 ++++++++------------------------------- 1 file changed, 13 insertions(+), 54 deletions(-) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 4dbbe954bbf6..f10e2db271d1 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -443,15 +443,8 @@ def run_superlink() -> None: ) grpc_servers.append(fleet_server) elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_ADAPTER: - address_arg = args.grpc_adapter_fleet_api_address - parsed_address = parse_address(address_arg) - if not parsed_address: - sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.") - host, port, is_v6 = parsed_address - address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}" - fleet_server = _run_fleet_api_grpc_adapter( - address=address, + address=fleet_address, state_factory=state_factory, certificates=certificates, ) @@ -816,54 +809,20 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--fleet-api-type", default=TRANSPORT_TYPE_GRPC_RERE, - help="Start a Fleet API server (gRPC-rere)", - ) - ex_group.add_argument( - "--grpc-adapter", - action="store_const", - dest="fleet_api_type", - const=TRANSPORT_TYPE_GRPC_ADAPTER, - help="Start a Fleet API server (GrpcAdapter, experimental)", - ) - ex_group.add_argument( - "--rest", - action="store_const", - dest="fleet_api_type", - const=TRANSPORT_TYPE_REST, - help="Start a Fleet API server (REST, experimental)", - ) - - # Fleet API gRPC-rere options - grpc_rere_group = parser.add_argument_group( - "Fleet API (gRPC-rere) server options", "" - ) - grpc_rere_group.add_argument( - "--grpc-rere-fleet-api-address", - help="Fleet API (gRPC-rere) server address (IPv4, IPv6, or a domain name)", - default=ADDRESS_FLEET_API_GRPC_RERE, - ) - - # Fleet API gRPC-adapter options - grpc_adapter_group = parser.add_argument_group( - "Fleet API (gRPC-adapter) server options", "" - ) - grpc_adapter_group.add_argument( - "--grpc-adapter-fleet-api-address", - help="Fleet API (gRPC-adapter) server address (IPv4, IPv6, or a domain name)", - default=ADDRESS_FLEET_API_GRPC_RERE, + type=str, + choices=[ + TRANSPORT_TYPE_GRPC_RERE, + TRANSPORT_TYPE_REST, + TRANSPORT_TYPE_GRPC_ADAPTER, + ], + help="Start a gRPC-rere or REST (experimental) Fleet API server.", ) - - # Fleet API REST options - rest_group = parser.add_argument_group("Fleet API (REST) server options", "") - rest_group.add_argument( - "--rest-fleet-api-address", - help="Fleet API (REST) server address (IPv4, IPv6, or a domain name)", - default=ADDRESS_FLEET_API_REST, + parser.add_argument( + "--fleet-api-address", + help="Fleet API server address (IPv4, IPv6, or a domain name).", ) - rest_group.add_argument( - "--rest-fleet-api-workers", - help="Set the number of concurrent workers for the Fleet API REST server.", - type=int, + parser.add_argument( + "--fleet-api-num-workers", default=1, type=int, help="Set the number of concurrent workers for the Fleet API server.", From 30c4a3e418b3a6704cbcd2ec3b2b0746493ea1fd Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 10 Jun 2024 12:35:29 +0100 Subject: [PATCH 19/25] update --- src/py/flwr/server/app.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index f10e2db271d1..569458b70e48 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -362,11 +362,13 @@ def run_superlink() -> None: grpc_servers = [driver_server] bckg_threads = [] if not args.fleet_api_address: - args.fleet_api_address = ( - ADDRESS_FLEET_API_GRPC_RERE - if args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE - else ADDRESS_FLEET_API_REST - ) + if args.fleet_api_type in [ + TRANSPORT_TYPE_GRPC_RERE, + TRANSPORT_TYPE_GRPC_ADAPTER, + ]: + args.fleet_api_address = ADDRESS_FLEET_API_GRPC_RERE + elif args.fleet_api_type == TRANSPORT_TYPE_REST: + args.fleet_api_address = ADDRESS_FLEET_API_REST parsed_fleet_address = parse_address(args.fleet_api_address) if not parsed_fleet_address: sys.exit(f"Fleet IP address ({args.fleet_api_address}) cannot be parsed.") From ce2e4a576bd1dbb0ef751a3e37859a5d540ab7c7 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 10 Jun 2024 13:05:51 +0100 Subject: [PATCH 20/25] update log message and key name --- src/py/flwr/client/grpc_rere_client/grpc_adapter.py | 11 ++++++----- src/py/flwr/common/constant.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py index d17c5fe991ef..ab0b09432c78 100644 --- a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py +++ b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py @@ -23,8 +23,8 @@ from flwr.common import log from flwr.common.constant import ( - GRPC_ADAPTER_METADATA_EXIT_FLAG_KEY, GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY, + GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY, ) from flwr.common.version import package_version from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 @@ -73,11 +73,12 @@ def _send_and_receive( ) # Handle control message - exit_flag = container_res.metadata.get( - GRPC_ADAPTER_METADATA_EXIT_FLAG_KEY, False + should_exit = ( + container_res.metadata.get(GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY, "false") + == "true" ) - if exit_flag: - log(DEBUG, "Exit flag is set to True, exiting...") + if should_exit: + log(DEBUG, "Received shutdown signal: exit flag is set to True. Exiting...") sys.exit(0) # Check the grpc_message_name of the response diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index 711d3e95dabc..dc23d54a5b29 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -45,7 +45,7 @@ GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version" -GRPC_ADAPTER_METADATA_EXIT_FLAG_KEY = "exit-flag" +GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit" class MessageType: From e6be0cc5cc0c44a410d1868fc8a946d0e61b6873 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 18 Jun 2024 08:28:40 +0100 Subject: [PATCH 21/25] update imports --- src/py/flwr/client/grpc_rere_client/grpc_adapter.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py index ab0b09432c78..3f1b9ceb6d54 100644 --- a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py +++ b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py @@ -32,8 +32,6 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, - GetRunRequest, - GetRunResponse, PingRequest, PingResponse, PullTaskInsRequest, @@ -43,6 +41,7 @@ ) from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611 from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub +from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 T = TypeVar("T", bound=GrpcMessage) From 0e1d7375f9632e14339d651fe30b32e7dc7e2b9f Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 18 Jun 2024 08:45:42 +0100 Subject: [PATCH 22/25] update imports --- .../superlink/fleet/grpc_adapter/grpc_adapter_servicer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py index c0426985cb4f..9325041061ac 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py @@ -28,8 +28,6 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, - GetRunRequest, - GetRunResponse, PingRequest, PingResponse, PullTaskInsRequest, @@ -38,6 +36,7 @@ PushTaskResResponse, ) from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611 +from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.server.superlink.fleet.message_handler import message_handler from flwr.server.superlink.state import StateFactory From 8b2be0d678fb75f992770bcbcccc7e65761c2aba Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 18 Jun 2024 09:00:23 +0100 Subject: [PATCH 23/25] update supernode/app.py --- src/py/flwr/client/supernode/app.py | 27 +++------------------------ 1 file changed, 3 insertions(+), 24 deletions(-) diff --git a/src/py/flwr/client/supernode/app.py b/src/py/flwr/client/supernode/app.py index 3d2a49465441..59c12339e9ce 100644 --- a/src/py/flwr/client/supernode/app.py +++ b/src/py/flwr/client/supernode/app.py @@ -29,12 +29,12 @@ from flwr.client.client_app import ClientApp, LoadClientAppError from flwr.common import EventType, event +from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir from flwr.common.constant import ( TRANSPORT_TYPE_GRPC_ADAPTER, TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_REST, ) -from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir from flwr.common.exit_handlers import register_exit_handlers from flwr.common.logger import log, warn_deprecated_feature from flwr.common.object_ref import load_app, validate @@ -61,7 +61,7 @@ def run_supernode() -> None: _start_client_internal( server_address=args.superlink, load_client_app_fn=load_fn, - transport="rest" if args.rest else "grpc-rere", + transport=args.transport, root_certificates=root_certificates, insecure=args.insecure, authentication_keys=authentication_keys, @@ -92,7 +92,7 @@ def run_client_app() -> None: _start_client_internal( server_address=args.superlink, load_client_app_fn=load_fn, - transport="rest" if args.rest else "grpc-rere", + transport=args.transport, root_certificates=root_certificates, insecure=args.insecure, authentication_keys=authentication_keys, @@ -121,27 +121,6 @@ def _warn_deprecated_server_arg(args: argparse.Namespace) -> None: else: args.superlink = args.server - root_certificates = _get_certificates(args) - log( - DEBUG, - "Flower will load ClientApp `%s`", - getattr(args, "client-app"), - ) - load_fn = _get_load_client_app_fn(args) - authentication_keys = _try_setup_client_authentication(args) - - _start_client_internal( - server_address=args.superlink, - load_client_app_fn=load_fn, - transport=args.transport, - root_certificates=root_certificates, - insecure=args.insecure, - authentication_keys=authentication_keys, - max_retries=args.max_retries, - max_wait_time=args.max_wait_time, - ) - register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE) - def _get_certificates(args: argparse.Namespace) -> Optional[bytes]: """Load certificates if specified in args.""" From 7ad13df3002d3fc5f40a1c55365a6dfe2c23d044 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 18 Jun 2024 09:39:36 +0100 Subject: [PATCH 24/25] update client/app.py and server/app.py --- src/py/flwr/server/app.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index d2a02286e732..822defdb5b13 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -221,11 +221,13 @@ def run_superlink() -> None: grpc_servers = [driver_server] bckg_threads = [] if not args.fleet_api_address: - args.fleet_api_address = ( - ADDRESS_FLEET_API_GRPC_RERE - if args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE - else ADDRESS_FLEET_API_REST - ) + if args.fleet_api_type in [ + TRANSPORT_TYPE_GRPC_RERE, + TRANSPORT_TYPE_GRPC_ADAPTER, + ]: + args.fleet_api_address = ADDRESS_FLEET_API_GRPC_RERE + elif args.fleet_api_type == TRANSPORT_TYPE_REST: + args.fleet_api_address = ADDRESS_FLEET_API_REST fleet_address, host, port = _format_address(args.fleet_api_address) @@ -642,8 +644,8 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None: type=str, choices=[ TRANSPORT_TYPE_GRPC_RERE, - TRANSPORT_TYPE_REST, TRANSPORT_TYPE_GRPC_ADAPTER, + TRANSPORT_TYPE_REST, ], help="Start a gRPC-rere or REST (experimental) Fleet API server.", ) From 5c61be438bcc6ce57b5826b7bf0b9f4ca63e36b1 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 19 Jun 2024 16:13:09 +0100 Subject: [PATCH 25/25] restore --- .../client/grpc_adapter_client/connection.py | 94 +++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/src/py/flwr/client/grpc_adapter_client/connection.py b/src/py/flwr/client/grpc_adapter_client/connection.py index e69de29bb2d1..e4e32b3accd0 100644 --- a/src/py/flwr/client/grpc_adapter_client/connection.py +++ b/src/py/flwr/client/grpc_adapter_client/connection.py @@ -0,0 +1,94 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Contextmanager for a GrpcAdapter channel to the Flower server.""" + + +from contextlib import contextmanager +from logging import ERROR +from typing import Callable, Iterator, Optional, Tuple, Union + +from cryptography.hazmat.primitives.asymmetric import ec + +from flwr.client.grpc_rere_client.connection import grpc_request_response +from flwr.client.grpc_rere_client.grpc_adapter import GrpcAdapter +from flwr.common import GRPC_MAX_MESSAGE_LENGTH +from flwr.common.logger import log +from flwr.common.message import Message +from flwr.common.retry_invoker import RetryInvoker + + +@contextmanager +def grpc_adapter( # pylint: disable=R0913 + server_address: str, + insecure: bool, + retry_invoker: RetryInvoker, + max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613 + root_certificates: Optional[Union[bytes, str]] = None, + authentication_keys: Optional[ # pylint: disable=unused-argument + Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] + ] = None, +) -> Iterator[ + Tuple[ + Callable[[], Optional[Message]], + Callable[[Message], None], + Optional[Callable[[], None]], + Optional[Callable[[], None]], + Optional[Callable[[int], Tuple[str, str]]], + ] +]: + """Primitives for request/response-based interaction with a server via GrpcAdapter. + + Parameters + ---------- + server_address : str + The IPv6 address of the server with `http://` or `https://`. + If the Flower server runs on the same machine + on port 8080, then `server_address` would be `"http://[::]:8080"`. + insecure : bool + Starts an insecure gRPC connection when True. Enables HTTPS connection + when False, using system certificates if `root_certificates` is None. + retry_invoker: RetryInvoker + `RetryInvoker` object that will try to reconnect the client to the server + after gRPC errors. If None, the client will only try to + reconnect once after a failure. + max_message_length : int + Ignored, only present to preserve API-compatibility. + root_certificates : Optional[Union[bytes, str]] (default: None) + Path of the root certificate. If provided, a secure + connection using the certificates will be established to an SSL-enabled + Flower server. Bytes won't work for the REST API. + authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None) + Client authentication is not supported for this transport type. + + Returns + ------- + receive : Callable + send : Callable + create_node : Optional[Callable] + delete_node : Optional[Callable] + get_run : Optional[Callable] + """ + if authentication_keys is not None: + log(ERROR, "Client authentication is not supported for this transport type.") + with grpc_request_response( + server_address=server_address, + insecure=insecure, + retry_invoker=retry_invoker, + max_message_length=max_message_length, + root_certificates=root_certificates, + authentication_keys=None, # Authentication is not supported + adapter_cls=GrpcAdapter, + ) as conn: + yield conn