diff --git a/src/py/flwr/server/compat/app.py b/src/py/flwr/server/compat/app.py index ff1d99b5366e..81da3f57e86a 100644 --- a/src/py/flwr/server/compat/app.py +++ b/src/py/flwr/server/compat/app.py @@ -29,7 +29,7 @@ from flwr.server.server_config import ServerConfig from flwr.server.strategy import Strategy -from ..driver import Driver +from ..driver import Driver, GrpcDriver from .app_utils import start_update_client_manager_thread DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091" @@ -114,7 +114,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals # Create the Driver if isinstance(root_certificates, str): root_certificates = Path(root_certificates).read_bytes() - driver = Driver( + driver = GrpcDriver( driver_service_address=address, root_certificates=root_certificates ) diff --git a/src/py/flwr/server/compat/driver_client_proxy.py b/src/py/flwr/server/compat/driver_client_proxy.py index d3f17c244eee..c8dae8a432e6 100644 --- a/src/py/flwr/server/compat/driver_client_proxy.py +++ b/src/py/flwr/server/compat/driver_client_proxy.py @@ -25,7 +25,7 @@ from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 from flwr.server.client_proxy import ClientProxy -from ..driver.driver import GrpcDriverHelper +from ..driver.grpc_driver import GrpcDriverHelper SLEEP_TIME = 1 diff --git a/src/py/flwr/server/driver/__init__.py b/src/py/flwr/server/driver/__init__.py index 53d381e179de..b24a4fd92cd4 100644 --- a/src/py/flwr/server/driver/__init__.py +++ b/src/py/flwr/server/driver/__init__.py @@ -16,7 +16,10 @@ from .driver import Driver +from .grpc_driver import GrpcDriver, GrpcDriverHelper __all__ = [ "Driver", + "GrpcDriver", + "GrpcDriverHelper", ] diff --git a/src/py/flwr/server/driver/abc_driver.py b/src/py/flwr/server/driver/abc_driver.py deleted file mode 100644 index b95cec95ab47..000000000000 --- a/src/py/flwr/server/driver/abc_driver.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Driver (abstract base class).""" - - -from abc import ABC, abstractmethod -from typing import Iterable, List, Optional - -from flwr.common import Message, RecordSet - - -class Driver(ABC): - """Abstract base Driver class for the Driver API.""" - - @abstractmethod - def create_message( # pylint: disable=too-many-arguments - self, - content: RecordSet, - message_type: str, - dst_node_id: int, - group_id: str, - ttl: Optional[float] = None, - ) -> Message: - """Create a new message with specified parameters. - - This method constructs a new `Message` with given content and metadata. - The `run_id` and `src_node_id` will be set automatically. - - Parameters - ---------- - content : RecordSet - The content for the new message. This holds records that are to be sent - to the destination node. - message_type : str - The type of the message, defining the action to be executed on - the receiving end. - dst_node_id : int - The ID of the destination node to which the message is being sent. - group_id : str - The ID of the group to which this message is associated. In some settings, - this is used as the FL round. - ttl : Optional[float] (default: None) - Time-to-live for the round trip of this message, i.e., the time from sending - this message to receiving a reply. It specifies in seconds the duration for - which the message and its potential reply are considered valid. If unset, - the default TTL (i.e., `common.DEFAULT_TTL`) will be used. - - Returns - ------- - message : Message - A new `Message` instance with the specified content and metadata. - """ - - @abstractmethod - def get_node_ids(self) -> List[int]: - """Get node IDs.""" - - @abstractmethod - def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: - """Push messages to specified node IDs. - - This method takes an iterable of messages and sends each message - to the node specified in `dst_node_id`. - - Parameters - ---------- - messages : Iterable[Message] - An iterable of messages to be sent. - - Returns - ------- - message_ids : Iterable[str] - An iterable of IDs for the messages that were sent, which can be used - to pull replies. - """ - - @abstractmethod - def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]: - """Pull messages based on message IDs. - - This method is used to collect messages from the SuperLink - that correspond to a set of given message IDs. - - Parameters - ---------- - message_ids : Iterable[str] - An iterable of message IDs for which reply messages are to be retrieved. - - Returns - ------- - messages : Iterable[Message] - An iterable of messages received. - """ - - @abstractmethod - def send_and_receive( - self, - messages: Iterable[Message], - *, - timeout: Optional[float] = None, - ) -> Iterable[Message]: - """Push messages to specified node IDs and pull the reply messages. - - This method sends a list of messages to their destination node IDs and then - waits for the replies. It continues to pull replies until either all - replies are received or the specified timeout duration is exceeded. - - Parameters - ---------- - messages : Iterable[Message] - An iterable of messages to be sent. - timeout : Optional[float] (default: None) - The timeout duration in seconds. If specified, the method will wait for - replies for this duration. If `None`, there is no time limit and the method - will wait until replies for all messages are received. - - Returns - ------- - replies : Iterable[Message] - An iterable of reply messages received from the SuperLink. - - Notes - ----- - This method uses `push_messages` to send the messages and `pull_messages` - to collect the replies. If `timeout` is set, the method may not return - replies for all sent messages. A message remains valid until its TTL, - which is not affected by `timeout`. - """ diff --git a/src/py/flwr/server/driver/driver.py b/src/py/flwr/server/driver/driver.py index 2b700b4dd443..b95cec95ab47 100644 --- a/src/py/flwr/server/driver/driver.py +++ b/src/py/flwr/server/driver/driver.py @@ -1,4 +1,4 @@ -# Copyright 2022 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,180 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Flower driver service client.""" +"""Driver (abstract base class).""" -import time -import warnings -from logging import DEBUG, ERROR, WARNING -from typing import Iterable, List, Optional, Tuple -import grpc +from abc import ABC, abstractmethod +from typing import Iterable, List, Optional -from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event -from flwr.common.grpc import create_channel -from flwr.common.logger import log -from flwr.common.serde import message_from_taskres, message_to_taskins -from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 - CreateRunRequest, - CreateRunResponse, - GetNodesRequest, - GetNodesResponse, - PullTaskResRequest, - PullTaskResResponse, - PushTaskInsRequest, - PushTaskInsResponse, -) -from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611 -from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 +from flwr.common import Message, RecordSet -DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091" -ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """ -[Driver] Error: Not connected. - -Call `connect()` on the `GrpcDriverHelper` instance before calling any of the other -`GrpcDriverHelper` methods. -""" - - -class GrpcDriverHelper: - """`GrpcDriverHelper` provides access to the gRPC Driver API/service.""" - - def __init__( - self, - driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, - root_certificates: Optional[bytes] = None, - ) -> None: - self.driver_service_address = driver_service_address - self.root_certificates = root_certificates - self.channel: Optional[grpc.Channel] = None - self.stub: Optional[DriverStub] = None - - def connect(self) -> None: - """Connect to the Driver API.""" - event(EventType.DRIVER_CONNECT) - if self.channel is not None or self.stub is not None: - log(WARNING, "Already connected") - return - self.channel = create_channel( - server_address=self.driver_service_address, - insecure=(self.root_certificates is None), - root_certificates=self.root_certificates, - ) - self.stub = DriverStub(self.channel) - log(DEBUG, "[Driver] Connected to %s", self.driver_service_address) - - def disconnect(self) -> None: - """Disconnect from the Driver API.""" - event(EventType.DRIVER_DISCONNECT) - if self.channel is None or self.stub is None: - log(DEBUG, "Already disconnected") - return - channel = self.channel - self.channel = None - self.stub = None - channel.close() - log(DEBUG, "[Driver] Disconnected") - - def create_run(self, req: CreateRunRequest) -> CreateRunResponse: - """Request for run ID.""" - # Check if channel is open - if self.stub is None: - log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverHelper` instance not connected") - - # Call Driver API - res: CreateRunResponse = self.stub.CreateRun(request=req) - return res - - def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse: - """Get client IDs.""" - # Check if channel is open - if self.stub is None: - log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverHelper` instance not connected") - - # Call gRPC Driver API - res: GetNodesResponse = self.stub.GetNodes(request=req) - return res - - def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse: - """Schedule tasks.""" - # Check if channel is open - if self.stub is None: - log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverHelper` instance not connected") - - # Call gRPC Driver API - res: PushTaskInsResponse = self.stub.PushTaskIns(request=req) - return res - - def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse: - """Get task results.""" - # Check if channel is open - if self.stub is None: - log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverHelper` instance not connected") - - # Call Driver API - res: PullTaskResResponse = self.stub.PullTaskRes(request=req) - return res - - -class Driver: - """`Driver` class provides an interface to the Driver API. - - Parameters - ---------- - driver_service_address : Optional[str] - The IPv4 or IPv6 address of the Driver API server. - Defaults to `"[::]:9091"`. - certificates : bytes (default: None) - Tuple containing root certificate, server certificate, and private key - to start a secure SSL-enabled server. The tuple is expected to have - three bytes elements in the following order: - - * CA certificate. - * server certificate. - * server private key. - """ - - def __init__( - self, - driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, - root_certificates: Optional[bytes] = None, - ) -> None: - self.addr = driver_service_address - self.root_certificates = root_certificates - self.grpc_driver_helper: Optional[GrpcDriverHelper] = None - self.run_id: Optional[int] = None - self.node = Node(node_id=0, anonymous=True) - - def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]: - # Check if the GrpcDriverHelper is initialized - if self.grpc_driver_helper is None or self.run_id is None: - # Connect and create run - self.grpc_driver_helper = GrpcDriverHelper( - driver_service_address=self.addr, - root_certificates=self.root_certificates, - ) - self.grpc_driver_helper.connect() - res = self.grpc_driver_helper.create_run(CreateRunRequest()) - self.run_id = res.run_id - return self.grpc_driver_helper, self.run_id - - def _check_message(self, message: Message) -> None: - # Check if the message is valid - if not ( - message.metadata.run_id == self.run_id - and message.metadata.src_node_id == self.node.node_id - and message.metadata.message_id == "" - and message.metadata.reply_to_message == "" - and message.metadata.ttl > 0 - ): - raise ValueError(f"Invalid message: {message}") +class Driver(ABC): + """Abstract base Driver class for the Driver API.""" + @abstractmethod def create_message( # pylint: disable=too-many-arguments self, content: RecordSet, @@ -223,35 +62,12 @@ def create_message( # pylint: disable=too-many-arguments message : Message A new `Message` instance with the specified content and metadata. """ - _, run_id = self._get_grpc_driver_helper_and_run_id() - if ttl: - warnings.warn( - "A custom TTL was set, but note that the SuperLink does not enforce " - "the TTL yet. The SuperLink will start enforcing the TTL in a future " - "version of Flower.", - stacklevel=2, - ) - - ttl_ = DEFAULT_TTL if ttl is None else ttl - metadata = Metadata( - run_id=run_id, - message_id="", # Will be set by the server - src_node_id=self.node.node_id, - dst_node_id=dst_node_id, - reply_to_message="", - group_id=group_id, - ttl=ttl_, - message_type=message_type, - ) - return Message(metadata=metadata, content=content) + @abstractmethod def get_node_ids(self) -> List[int]: """Get node IDs.""" - grpc_driver_helper, run_id = self._get_grpc_driver_helper_and_run_id() - # Call GrpcDriverHelper method - res = grpc_driver_helper.get_nodes(GetNodesRequest(run_id=run_id)) - return [node.node_id for node in res.nodes] + @abstractmethod def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: """Push messages to specified node IDs. @@ -269,22 +85,8 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: An iterable of IDs for the messages that were sent, which can be used to pull replies. """ - grpc_driver_helper, _ = self._get_grpc_driver_helper_and_run_id() - # Construct TaskIns - task_ins_list: List[TaskIns] = [] - for msg in messages: - # Check message - self._check_message(msg) - # Convert Message to TaskIns - taskins = message_to_taskins(msg) - # Add to list - task_ins_list.append(taskins) - # Call GrpcDriverHelper method - res = grpc_driver_helper.push_task_ins( - PushTaskInsRequest(task_ins_list=task_ins_list) - ) - return list(res.task_ids) + @abstractmethod def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]: """Pull messages based on message IDs. @@ -301,15 +103,8 @@ def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]: messages : Iterable[Message] An iterable of messages received. """ - grpc_driver, _ = self._get_grpc_driver_helper_and_run_id() - # Pull TaskRes - res = grpc_driver.pull_task_res( - PullTaskResRequest(node=self.node, task_ids=message_ids) - ) - # Convert TaskRes to Message - msgs = [message_from_taskres(taskres) for taskres in res.task_res_list] - return msgs + @abstractmethod def send_and_receive( self, messages: Iterable[Message], @@ -343,28 +138,3 @@ def send_and_receive( replies for all sent messages. A message remains valid until its TTL, which is not affected by `timeout`. """ - # Push messages - msg_ids = set(self.push_messages(messages)) - - # Pull messages - end_time = time.time() + (timeout if timeout is not None else 0.0) - ret: List[Message] = [] - while timeout is None or time.time() < end_time: - res_msgs = self.pull_messages(msg_ids) - ret.extend(res_msgs) - msg_ids.difference_update( - {msg.metadata.reply_to_message for msg in res_msgs} - ) - if len(msg_ids) == 0: - break - # Sleep - time.sleep(3) - return ret - - def close(self) -> None: - """Disconnect from the SuperLink if connected.""" - # Check if GrpcDriverHelper is initialized - if self.grpc_driver_helper is None: - return - # Disconnect - self.grpc_driver_helper.disconnect() diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py new file mode 100644 index 000000000000..84e4788ba6e1 --- /dev/null +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -0,0 +1,306 @@ +# Copyright 2022 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower gRPC Driver.""" + +import time +import warnings +from logging import DEBUG, ERROR, WARNING +from typing import Iterable, List, Optional, Tuple + +import grpc + +from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event +from flwr.common.grpc import create_channel +from flwr.common.logger import log +from flwr.common.serde import message_from_taskres, message_to_taskins +from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 + CreateRunRequest, + CreateRunResponse, + GetNodesRequest, + GetNodesResponse, + PullTaskResRequest, + PullTaskResResponse, + PushTaskInsRequest, + PushTaskInsResponse, +) +from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611 +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 + +from .driver import Driver + +DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091" + +ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """ +[Driver] Error: Not connected. + +Call `connect()` on the `GrpcDriverHelper` instance before calling any of the other +`GrpcDriverHelper` methods. +""" + + +class GrpcDriverHelper: + """`GrpcDriverHelper` provides access to the gRPC Driver API/service.""" + + def __init__( + self, + driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, + root_certificates: Optional[bytes] = None, + ) -> None: + self.driver_service_address = driver_service_address + self.root_certificates = root_certificates + self.channel: Optional[grpc.Channel] = None + self.stub: Optional[DriverStub] = None + + def connect(self) -> None: + """Connect to the Driver API.""" + event(EventType.DRIVER_CONNECT) + if self.channel is not None or self.stub is not None: + log(WARNING, "Already connected") + return + self.channel = create_channel( + server_address=self.driver_service_address, + insecure=(self.root_certificates is None), + root_certificates=self.root_certificates, + ) + self.stub = DriverStub(self.channel) + log(DEBUG, "[Driver] Connected to %s", self.driver_service_address) + + def disconnect(self) -> None: + """Disconnect from the Driver API.""" + event(EventType.DRIVER_DISCONNECT) + if self.channel is None or self.stub is None: + log(DEBUG, "Already disconnected") + return + channel = self.channel + self.channel = None + self.stub = None + channel.close() + log(DEBUG, "[Driver] Disconnected") + + def create_run(self, req: CreateRunRequest) -> CreateRunResponse: + """Request for run ID.""" + # Check if channel is open + if self.stub is None: + log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) + raise ConnectionError("`GrpcDriverHelper` instance not connected") + + # Call Driver API + res: CreateRunResponse = self.stub.CreateRun(request=req) + return res + + def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse: + """Get client IDs.""" + # Check if channel is open + if self.stub is None: + log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) + raise ConnectionError("`GrpcDriverHelper` instance not connected") + + # Call gRPC Driver API + res: GetNodesResponse = self.stub.GetNodes(request=req) + return res + + def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse: + """Schedule tasks.""" + # Check if channel is open + if self.stub is None: + log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) + raise ConnectionError("`GrpcDriverHelper` instance not connected") + + # Call gRPC Driver API + res: PushTaskInsResponse = self.stub.PushTaskIns(request=req) + return res + + def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse: + """Get task results.""" + # Check if channel is open + if self.stub is None: + log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) + raise ConnectionError("`GrpcDriverHelper` instance not connected") + + # Call Driver API + res: PullTaskResResponse = self.stub.PullTaskRes(request=req) + return res + + +class GrpcDriver(Driver): + """`Driver` class provides an interface to the Driver API. + + Parameters + ---------- + driver_service_address : Optional[str] + The IPv4 or IPv6 address of the Driver API server. + Defaults to `"[::]:9091"`. + certificates : bytes (default: None) + Tuple containing root certificate, server certificate, and private key + to start a secure SSL-enabled server. The tuple is expected to have + three bytes elements in the following order: + + * CA certificate. + * server certificate. + * server private key. + """ + + def __init__( + self, + driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, + root_certificates: Optional[bytes] = None, + ) -> None: + self.addr = driver_service_address + self.root_certificates = root_certificates + self.grpc_driver_helper: Optional[GrpcDriverHelper] = None + self.run_id: Optional[int] = None + self.node = Node(node_id=0, anonymous=True) + + def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]: + # Check if the GrpcDriverHelper is initialized + if self.grpc_driver_helper is None or self.run_id is None: + # Connect and create run + self.grpc_driver_helper = GrpcDriverHelper( + driver_service_address=self.addr, + root_certificates=self.root_certificates, + ) + self.grpc_driver_helper.connect() + res = self.grpc_driver_helper.create_run(CreateRunRequest()) + self.run_id = res.run_id + return self.grpc_driver_helper, self.run_id + + def _check_message(self, message: Message) -> None: + # Check if the message is valid + if not ( + message.metadata.run_id == self.run_id + and message.metadata.src_node_id == self.node.node_id + and message.metadata.message_id == "" + and message.metadata.reply_to_message == "" + and message.metadata.ttl > 0 + ): + raise ValueError(f"Invalid message: {message}") + + def create_message( # pylint: disable=too-many-arguments + self, + content: RecordSet, + message_type: str, + dst_node_id: int, + group_id: str, + ttl: Optional[float] = None, + ) -> Message: + """Create a new message with specified parameters. + + This method constructs a new `Message` with given content and metadata. + The `run_id` and `src_node_id` will be set automatically. + """ + _, run_id = self._get_grpc_driver_helper_and_run_id() + if ttl: + warnings.warn( + "A custom TTL was set, but note that the SuperLink does not enforce " + "the TTL yet. The SuperLink will start enforcing the TTL in a future " + "version of Flower.", + stacklevel=2, + ) + + ttl_ = DEFAULT_TTL if ttl is None else ttl + metadata = Metadata( + run_id=run_id, + message_id="", # Will be set by the server + src_node_id=self.node.node_id, + dst_node_id=dst_node_id, + reply_to_message="", + group_id=group_id, + ttl=ttl_, + message_type=message_type, + ) + return Message(metadata=metadata, content=content) + + def get_node_ids(self) -> List[int]: + """Get node IDs.""" + grpc_driver_helper, run_id = self._get_grpc_driver_helper_and_run_id() + # Call GrpcDriverHelper method + res = grpc_driver_helper.get_nodes(GetNodesRequest(run_id=run_id)) + return [node.node_id for node in res.nodes] + + def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: + """Push messages to specified node IDs. + + This method takes an iterable of messages and sends each message + to the node specified in `dst_node_id`. + """ + grpc_driver_helper, _ = self._get_grpc_driver_helper_and_run_id() + # Construct TaskIns + task_ins_list: List[TaskIns] = [] + for msg in messages: + # Check message + self._check_message(msg) + # Convert Message to TaskIns + taskins = message_to_taskins(msg) + # Add to list + task_ins_list.append(taskins) + # Call GrpcDriverHelper method + res = grpc_driver_helper.push_task_ins( + PushTaskInsRequest(task_ins_list=task_ins_list) + ) + return list(res.task_ids) + + def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]: + """Pull messages based on message IDs. + + This method is used to collect messages from the SuperLink that correspond to a + set of given message IDs. + """ + grpc_driver, _ = self._get_grpc_driver_helper_and_run_id() + # Pull TaskRes + res = grpc_driver.pull_task_res( + PullTaskResRequest(node=self.node, task_ids=message_ids) + ) + # Convert TaskRes to Message + msgs = [message_from_taskres(taskres) for taskres in res.task_res_list] + return msgs + + def send_and_receive( + self, + messages: Iterable[Message], + *, + timeout: Optional[float] = None, + ) -> Iterable[Message]: + """Push messages to specified node IDs and pull the reply messages. + + This method sends a list of messages to their destination node IDs and then + waits for the replies. It continues to pull replies until either all replies are + received or the specified timeout duration is exceeded. + """ + # Push messages + msg_ids = set(self.push_messages(messages)) + + # Pull messages + end_time = time.time() + (timeout if timeout is not None else 0.0) + ret: List[Message] = [] + while timeout is None or time.time() < end_time: + res_msgs = self.pull_messages(msg_ids) + ret.extend(res_msgs) + msg_ids.difference_update( + {msg.metadata.reply_to_message for msg in res_msgs} + ) + if len(msg_ids) == 0: + break + # Sleep + time.sleep(3) + return ret + + def close(self) -> None: + """Disconnect from the SuperLink if connected.""" + # Check if GrpcDriverHelper is initialized + if self.grpc_driver_helper is None: + return + # Disconnect + self.grpc_driver_helper.disconnect() diff --git a/src/py/flwr/server/driver/driver_test.py b/src/py/flwr/server/driver/grpc_driver_test.py similarity index 97% rename from src/py/flwr/server/driver/driver_test.py rename to src/py/flwr/server/driver/grpc_driver_test.py index a5d887cb5736..3bead1ea2473 100644 --- a/src/py/flwr/server/driver/driver_test.py +++ b/src/py/flwr/server/driver/grpc_driver_test.py @@ -29,11 +29,11 @@ ) from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 -from .driver import Driver +from .grpc_driver import GrpcDriver -class TestDriver(unittest.TestCase): - """Tests for `Driver` class.""" +class TestGrpcDriver(unittest.TestCase): + """Tests for `GrpcDriver` class.""" def setUp(self) -> None: """Initialize mock GrpcDriverHelper and Driver instance before each test.""" @@ -42,11 +42,11 @@ def setUp(self) -> None: self.mock_grpc_driver = Mock() self.mock_grpc_driver.create_run.return_value = mock_response self.patcher = patch( - "flwr.server.driver.driver.GrpcDriverHelper", + "flwr.server.driver.grpc_driver.GrpcDriverHelper", return_value=self.mock_grpc_driver, ) self.patcher.start() - self.driver = Driver() + self.driver = GrpcDriver() def tearDown(self) -> None: """Cleanup after each test.""" diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index 2f0f1185847e..c5a126f473a5 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -25,7 +25,7 @@ from flwr.common.logger import log, update_console_handler from flwr.common.object_ref import load_app -from .driver.driver import Driver +from .driver import Driver, GrpcDriver from .server_app import LoadServerAppError, ServerApp @@ -128,13 +128,13 @@ def run_server_app() -> None: server_app_dir = args.dir server_app_attr = getattr(args, "server-app") - # Initialize Driver - driver = Driver( + # Initialize GrpcDriver + driver = GrpcDriver( driver_service_address=args.server, root_certificates=root_certificates, ) - # Run the Server App with the Driver + # Run the ServerApp with the Driver run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr) # Clean up diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 56fce363726a..d5f1a655adc3 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -29,7 +29,7 @@ from flwr.client import ClientApp from flwr.common import EventType, event, log from flwr.common.typing import ConfigsRecordValues -from flwr.server.driver.driver import Driver +from flwr.server.driver import Driver, GrpcDriver from flwr.server.run_serverapp import run from flwr.server.server_app import ServerApp from flwr.server.superlink.driver.driver_grpc import run_driver_api_grpc @@ -204,7 +204,7 @@ def _main_loop( serverapp_th = None try: # Initialize Driver - driver = Driver( + driver = GrpcDriver( driver_service_address=driver_api_address, root_certificates=None, )