Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Implement DriverAPI GetRun #3580

Merged
merged 39 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
82a399a
add run proto
panh99 Jun 11, 2024
019d490
implement getrun in driver_servicer and grpc_driver
panh99 Jun 11, 2024
54b185a
Merge branch 'main' into impl-get-run-driver
panh99 Jun 12, 2024
e6cbede
Merge branch 'main' into impl-get-run-driver
panh99 Jun 13, 2024
bf3cd3b
Merge branch 'main' into impl-get-run-driver
panh99 Jun 14, 2024
5bc1c9c
Merge branch 'main' into impl-get-run-driver
danieljanes Jun 18, 2024
7c82082
amend driver class and in mem driver
panh99 Jun 18, 2024
f9f8a10
Merge branch 'main' into impl-get-run-driver
danieljanes Jun 18, 2024
1150474
Merge branch 'main' into impl-get-run-driver
danieljanes Jun 18, 2024
6ef8ceb
update with main
panh99 Jun 19, 2024
bf10f81
fix the test for in mem driver
panh99 Jun 19, 2024
b57eeb7
Merge branch 'main' into impl-get-run-driver
panh99 Jun 19, 2024
0a09223
Merge branch 'main' into impl-get-run-driver
panh99 Jun 19, 2024
5461e5d
make run_id mandatory
panh99 Jun 19, 2024
77f8927
Merge branch 'main' into impl-get-run-driver
panh99 Jun 19, 2024
dcc6cfd
fix a bug in _init_run_id in simulation
panh99 Jun 19, 2024
dc03386
format
panh99 Jun 19, 2024
4d73b9c
Merge branch 'main' into impl-get-run-driver
danieljanes Jun 19, 2024
5aa7ced
update doc string
panh99 Jun 19, 2024
30aa173
Merge remote-tracking branch 'refs/remotes/origin/impl-get-run-driver…
panh99 Jun 19, 2024
084e22c
update doc string
panh99 Jun 19, 2024
db16408
update GrpcDriverStub
panh99 Jun 19, 2024
6c9657e
use _run & _run_id
panh99 Jun 20, 2024
fe31189
Merge branch 'main' into impl-get-run-driver
panh99 Jun 20, 2024
5110873
update sim
panh99 Jun 20, 2024
d60d295
fix in run_simulation() (#3654)
jafermarq Jun 20, 2024
e65af5f
fix doc string
panh99 Jun 20, 2024
f914b2c
Update src/py/flwr/server/driver/grpc_driver.py
danieljanes Jun 20, 2024
c62f496
fix naming conflicts
panh99 Jun 20, 2024
729d418
Merge remote-tracking branch 'refs/remotes/origin/impl-get-run-driver…
panh99 Jun 20, 2024
b8515e3
fix a bug that driver stub not connected
panh99 Jun 20, 2024
fe5a8db
quick fix
panh99 Jun 20, 2024
aeebcfb
Merge branch 'main' into impl-get-run-driver
panh99 Jun 20, 2024
25e6354
update naming
panh99 Jun 20, 2024
62c3c81
Merge remote-tracking branch 'refs/remotes/origin/impl-get-run-driver…
panh99 Jun 20, 2024
05945f0
fix unit tests
panh99 Jun 20, 2024
1efc53d
fix get_run
panh99 Jun 20, 2024
2cdaf83
update in mem driver
panh99 Jun 20, 2024
bba27ea
Merge branch 'main' into impl-get-run-driver
panh99 Jun 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/py/flwr/server/compat/app_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _update_client_manager(
node_id=node_id,
driver=driver,
anonymous=False,
run_id=driver.run_id, # type: ignore
run_id=driver.run.run_id,
)
if client_manager.register(client_proxy):
registered_nodes[node_id] = client_proxy
Expand Down
6 changes: 6 additions & 0 deletions src/py/flwr/server/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,17 @@
from typing import Iterable, List, Optional

from flwr.common import Message, RecordSet
from flwr.common.typing import Run


class Driver(ABC):
"""Abstract base Driver class for the Driver API."""

@property
@abstractmethod
def run(self) -> Run:
"""Run information."""

@abstractmethod
def create_message( # pylint: disable=too-many-arguments
self,
Expand Down
146 changes: 84 additions & 62 deletions src/py/flwr/server/driver/grpc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
import time
import warnings
from logging import DEBUG, ERROR, WARNING
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple, cast

import grpc

from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event
from flwr.common.grpc import create_channel
from flwr.common.logger import log
from flwr.common.serde import message_from_taskres, message_to_taskins
from flwr.common.typing import Run
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
CreateRunRequest,
CreateRunResponse,
Expand All @@ -37,6 +38,7 @@
)
from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611

from .driver import Driver
Expand All @@ -46,13 +48,24 @@
ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
[Driver] Error: Not connected.

Call `connect()` on the `GrpcDriverHelper` instance before calling any of the other
`GrpcDriverHelper` methods.
Call `connect()` on the `GrpcDriverStub` instance before calling any of the other
`GrpcDriverStub` methods.
"""


class GrpcDriverHelper:
"""`GrpcDriverHelper` provides access to the gRPC Driver API/service."""
class GrpcDriverStub:
"""`GrpcDriverStub` provides access to the gRPC Driver API/service.

Parameters
----------
driver_service_address : Optional[str]
The IPv4 or IPv6 address of the Driver API server.
Defaults to `"[::]:9091"`.
root_certificates : Optional[bytes] (default: None)
The PEM-encoded root certificates as a byte string.
If provided, a secure connection using the certificates will be
established to an SSL-enabled Flower server.
"""

def __init__(
self,
Expand All @@ -64,6 +77,10 @@ def __init__(
self.channel: Optional[grpc.Channel] = None
self.stub: Optional[DriverStub] = None

def is_connected(self) -> bool:
"""Return True if connected to the Driver API server, otherwise False."""
return self.channel is not None

def connect(self) -> None:
"""Connect to the Driver API."""
event(EventType.DRIVER_CONNECT)
Expand Down Expand Up @@ -95,18 +112,29 @@ def create_run(self, req: CreateRunRequest) -> CreateRunResponse:
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise ConnectionError("`GrpcDriverHelper` instance not connected")
raise ConnectionError("`GrpcDriverStub` instance not connected")

# Call Driver API
res: CreateRunResponse = self.stub.CreateRun(request=req)
return res

def get_run(self, req: GetRunRequest) -> GetRunResponse:
"""Get run information."""
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise ConnectionError("`GrpcDriverStub` instance not connected")

# Call gRPC Driver API
res: GetRunResponse = self.stub.GetRun(request=req)
return res

def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
"""Get client IDs."""
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise ConnectionError("`GrpcDriverHelper` instance not connected")
raise ConnectionError("`GrpcDriverStub` instance not connected")

# Call gRPC Driver API
res: GetNodesResponse = self.stub.GetNodes(request=req)
Expand All @@ -117,7 +145,7 @@ def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse:
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise ConnectionError("`GrpcDriverHelper` instance not connected")
raise ConnectionError("`GrpcDriverStub` instance not connected")

# Call gRPC Driver API
res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
Expand All @@ -128,7 +156,7 @@ def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse:
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise ConnectionError("`GrpcDriverHelper` instance not connected")
raise ConnectionError("`GrpcDriverStub` instance not connected")

# Call Driver API
res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
Expand All @@ -140,56 +168,52 @@ class GrpcDriver(Driver):

Parameters
----------
driver_service_address : Optional[str]
The IPv4 or IPv6 address of the Driver API server.
Defaults to `"[::]:9091"`.
certificates : bytes (default: None)
Tuple containing root certificate, server certificate, and private key
to start a secure SSL-enabled server. The tuple is expected to have
three bytes elements in the following order:

* CA certificate.
* server certificate.
* server private key.
fab_id : str (default: None)
The identifier of the FAB used in the run.
fab_version : str (default: None)
The version of the FAB used in the run.
run_id : int
The identifier of the run.
stub : Optional[GrpcDriverStub] (default: None)
The ``GrpcDriverStub`` instance used to communicate with the SuperLink.
If None, an instance connected to "[::]:9091" will be created.
"""

def __init__(
def __init__( # pylint: disable=too-many-arguments
self,
driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
root_certificates: Optional[bytes] = None,
fab_id: Optional[str] = None,
fab_version: Optional[str] = None,
run_id: int,
stub: Optional[GrpcDriverStub] = None,
panh99 marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
self.addr = driver_service_address
self.root_certificates = root_certificates
self.driver_helper: Optional[GrpcDriverHelper] = None
self.run_id: Optional[int] = None
self.fab_id = fab_id if fab_id is not None else ""
self.fab_version = fab_version if fab_version is not None else ""
self._run_id = run_id
self._run: Optional[Run] = None
self.stub = stub if stub is not None else GrpcDriverStub()
self.node = Node(node_id=0, anonymous=True)

def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]:
# Check if the GrpcDriverHelper is initialized
if self.driver_helper is None or self.run_id is None:
# Connect and create run
self.driver_helper = GrpcDriverHelper(
driver_service_address=self.addr,
root_certificates=self.root_certificates,
@property
def run(self) -> Run:
"""Run information."""
self._get_stub_and_run_id()
return Run(**vars(cast(Run, self._run)))

def _get_stub_and_run_id(self) -> Tuple[GrpcDriverStub, int]:
# Check if is initialized
if self._run is None:
# Connect
if not self.stub.is_connected():
self.stub.connect()
# Get the run info
req = GetRunRequest(run_id=self._run_id)
res = self.stub.get_run(req)
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
if not res.HasField("run"):
raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
self._run = Run(
run_id=res.run.run_id,
fab_id=res.run.fab_id,
fab_version=res.run.fab_version,
)
self.driver_helper.connect()
req = CreateRunRequest(fab_id=self.fab_id, fab_version=self.fab_version)
res = self.driver_helper.create_run(req)
self.run_id = res.run_id
return self.driver_helper, self.run_id

return self.stub, self._run.run_id

def _check_message(self, message: Message) -> None:
# Check if the message is valid
if not (
message.metadata.run_id == self.run_id
message.metadata.run_id == cast(Run, self._run).run_id
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
and message.metadata.src_node_id == self.node.node_id
and message.metadata.message_id == ""
and message.metadata.reply_to_message == ""
Expand All @@ -210,7 +234,7 @@ def create_message( # pylint: disable=too-many-arguments
This method constructs a new `Message` with given content and metadata.
The `run_id` and `src_node_id` will be set automatically.
"""
_, run_id = self._get_grpc_driver_helper_and_run_id()
_, run_id = self._get_stub_and_run_id()
if ttl:
warnings.warn(
"A custom TTL was set, but note that the SuperLink does not enforce "
Expand All @@ -234,9 +258,9 @@ def create_message( # pylint: disable=too-many-arguments

def get_node_ids(self) -> List[int]:
"""Get node IDs."""
grpc_driver_helper, run_id = self._get_grpc_driver_helper_and_run_id()
# Call GrpcDriverHelper method
res = grpc_driver_helper.get_nodes(GetNodesRequest(run_id=run_id))
stub, run_id = self._get_stub_and_run_id()
# Call GrpcDriverStub method
res = stub.get_nodes(GetNodesRequest(run_id=run_id))
return [node.node_id for node in res.nodes]

def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
Expand All @@ -245,7 +269,7 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
This method takes an iterable of messages and sends each message
to the node specified in `dst_node_id`.
"""
grpc_driver_helper, _ = self._get_grpc_driver_helper_and_run_id()
stub, _ = self._get_stub_and_run_id()
# Construct TaskIns
task_ins_list: List[TaskIns] = []
for msg in messages:
Expand All @@ -255,10 +279,8 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
taskins = message_to_taskins(msg)
# Add to list
task_ins_list.append(taskins)
# Call GrpcDriverHelper method
res = grpc_driver_helper.push_task_ins(
PushTaskInsRequest(task_ins_list=task_ins_list)
)
# Call GrpcDriverStub method
res = stub.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list))
return list(res.task_ids)

def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
Expand All @@ -267,9 +289,9 @@ def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
This method is used to collect messages from the SuperLink that correspond to a
set of given message IDs.
"""
grpc_driver, _ = self._get_grpc_driver_helper_and_run_id()
stub, _ = self._get_stub_and_run_id()
# Pull TaskRes
res = grpc_driver.pull_task_res(
res = stub.pull_task_res(
PullTaskResRequest(node=self.node, task_ids=message_ids)
)
# Convert TaskRes to Message
Expand Down Expand Up @@ -308,8 +330,8 @@ def send_and_receive(

def close(self) -> None:
"""Disconnect from the SuperLink if connected."""
# Check if GrpcDriverHelper is initialized
if self.driver_helper is None:
# Check if `connect` was called before
if not self.stub.is_connected():
return
# Disconnect
self.driver_helper.disconnect()
self.stub.disconnect()
Loading