Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(framework) Make DriverClientProxy use driver's send_and_receive #4289

Merged
merged 2 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 15 additions & 29 deletions src/py/flwr/server/compat/driver_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Flower ClientProxy implementation for Driver API."""


import time
from typing import Optional

from flwr import common
Expand All @@ -25,8 +24,6 @@

from ..driver.driver import Driver

SLEEP_TIME = 1


class DriverClientProxy(ClientProxy):
"""Flower client proxy which delegates work using the Driver API."""
Expand Down Expand Up @@ -122,29 +119,18 @@ def _send_receive_recordset(
ttl=timeout,
)

# Push message
message_ids = list(self.driver.push_messages(messages=[message]))
if len(message_ids) != 1:
raise ValueError("Unexpected number of message_ids")

message_id = message_ids[0]
if message_id == "":
raise ValueError(f"Failed to send message to node {self.node_id}")

if timeout:
start_time = time.time()

while True:
messages = list(self.driver.pull_messages(message_ids))
if len(messages) == 1:
msg: Message = messages[0]
if msg.has_error():
raise ValueError(
f"Message contains an Error (reason: {msg.error.reason}). "
"It originated during client-side execution of a message."
)
return msg.content

if timeout is not None and time.time() > start_time + timeout:
raise RuntimeError("Timeout reached")
time.sleep(SLEEP_TIME)
# Send message and wait for reply
messages = list(self.driver.send_and_receive(messages=[message]))

# A single reply is expected
if len(messages) != 1:
raise ValueError(f"Expected one Message but got: {len(messages)}")

# Only messages without errors can be handled beyond these point
msg: Message = messages[0]
if msg.has_error():
raise ValueError(
f"Message contains an Error (reason: {msg.error.reason}). "
"It originated during client-side execution of a message."
)
return msg.content
53 changes: 20 additions & 33 deletions src/py/flwr/server/compat/driver_client_proxy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@

RUN_ID = 61016
NODE_ID = 1
INSTRUCTION_MESSAGE_ID = "mock instruction message id"
REPLY_MESSAGE_ID = "mock reply message id"


class DriverClientProxyTestCase(unittest.TestCase):
Expand All @@ -77,7 +75,7 @@ def test_get_properties(self) -> None:
"""Test positive case."""
# Prepare
res = GetPropertiesRes(status=CLIENT_STATUS, properties=CLIENT_PROPERTIES)
self.driver.push_messages.side_effect = self._get_push_messages(res)
self.driver.send_and_receive.side_effect = self._exec_send_and_receive(res)
request_properties: Config = {"tensor_type": "str"}
ins = GetPropertiesIns(config=request_properties)

Expand All @@ -95,7 +93,7 @@ def test_get_parameters(self) -> None:
status=CLIENT_STATUS,
parameters=MESSAGE_PARAMETERS,
)
self.driver.push_messages.side_effect = self._get_push_messages(res)
self.driver.send_and_receive.side_effect = self._exec_send_and_receive(res)
ins = GetParametersIns(config={})

# Execute
Expand All @@ -114,7 +112,7 @@ def test_fit(self) -> None:
num_examples=10,
metrics={},
)
self.driver.push_messages.side_effect = self._get_push_messages(res)
self.driver.send_and_receive.side_effect = self._exec_send_and_receive(res)
parameters = flwr.common.ndarrays_to_parameters([np.ones((2, 2))])
ins = FitIns(parameters, {})

Expand All @@ -134,7 +132,7 @@ def test_evaluate(self) -> None:
num_examples=0,
metrics={},
)
self.driver.push_messages.side_effect = self._get_push_messages(res)
self.driver.send_and_receive.side_effect = self._exec_send_and_receive(res)
parameters = Parameters(tensors=[b"random params%^&*F"], tensor_type="np")
ins = EvaluateIns(parameters, {})

Expand All @@ -148,7 +146,7 @@ def test_evaluate(self) -> None:
def test_get_properties_and_fail(self) -> None:
"""Test negative case."""
# Prepare
self.driver.push_messages.side_effect = self._get_push_messages(
self.driver.send_and_receive.side_effect = self._exec_send_and_receive(
None, error_reply=True
)
request_properties: Config = {"tensor_type": "str"}
Expand All @@ -163,7 +161,7 @@ def test_get_properties_and_fail(self) -> None:
def test_get_parameters_and_fail(self) -> None:
"""Test negative case."""
# Prepare
self.driver.push_messages.side_effect = self._get_push_messages(
self.driver.send_and_receive.side_effect = self._exec_send_and_receive(
None, error_reply=True
)
ins = GetParametersIns(config={})
Expand All @@ -177,7 +175,7 @@ def test_get_parameters_and_fail(self) -> None:
def test_fit_and_fail(self) -> None:
"""Test negative case."""
# Prepare
self.driver.push_messages.side_effect = self._get_push_messages(
self.driver.send_and_receive.side_effect = self._exec_send_and_receive(
None, error_reply=True
)
parameters = flwr.common.ndarrays_to_parameters([np.ones((2, 2))])
Expand All @@ -190,7 +188,7 @@ def test_fit_and_fail(self) -> None:
def test_evaluate_and_fail(self) -> None:
"""Test negative case."""
# Prepare
self.driver.push_messages.side_effect = self._get_push_messages(
self.driver.send_and_receive.side_effect = self._exec_send_and_receive(
None, error_reply=True
)
parameters = Parameters(tensors=[b"random params%^&*F"], tensor_type="np")
Expand Down Expand Up @@ -229,15 +227,15 @@ def _create_message_dummy( # pylint: disable=R0913
self.created_msg = Message(metadata=metadata, content=content)
return self.created_msg

def _get_push_messages(
def _exec_send_and_receive(
self,
res: Union[GetParametersRes, GetPropertiesRes, FitRes, EvaluateRes, None],
error_reply: bool = False,
) -> Callable[[Iterable[Message]], Iterable[str]]:
"""Get the push_messages function that sets the return value of pull_messages
when called."""
) -> Callable[[Iterable[Message]], Iterable[Message]]:
"""Get the generate_replies function that sets the return value of driver's
send_and_receive when called."""

def push_messages(messages: Iterable[Message]) -> Iterable[str]:
def generate_replies(messages: Iterable[Message]) -> Iterable[Message]:
msg = list(messages)[0]
if error_reply:
recordset = None
Expand All @@ -254,13 +252,11 @@ def push_messages(messages: Iterable[Message]) -> Iterable[str]:
raise ValueError(f"Unsupported type: {type(res)}")
if recordset is not None:
ret = msg.create_reply(recordset)
ret.metadata.__dict__["_message_id"] = REPLY_MESSAGE_ID

# Set the return value of `pull_messages`
self.driver.pull_messages.return_value = [ret]
return [INSTRUCTION_MESSAGE_ID]
# Reply messages given the push message
return [ret]

return push_messages
return generate_replies

def _common_assertions(self, original_ins: Any) -> None:
"""Check common assertions."""
Expand All @@ -275,18 +271,9 @@ def _common_assertions(self, original_ins: Any) -> None:
self.assertEqual(self.called_times, 1)
self.assertEqual(actual_ins, original_ins)

# Check if push_messages is called once with expected args/kwargs.
self.driver.push_messages.assert_called_once()
# Check if send_and_receive is called once with expected args/kwargs.
self.driver.send_and_receive.assert_called_once()
try:
self.driver.push_messages.assert_any_call([self.created_msg])
self.driver.send_and_receive.assert_any_call([self.created_msg])
except AssertionError:
self.driver.push_messages.assert_any_call(messages=[self.created_msg])

# Check if pull_messages is called once with expected args/kwargs.
self.driver.pull_messages.assert_called_once()
try:
self.driver.pull_messages.assert_called_with([INSTRUCTION_MESSAGE_ID])
except AssertionError:
self.driver.pull_messages.assert_called_with(
message_ids=[INSTRUCTION_MESSAGE_ID]
)
self.driver.send_and_receive.assert_any_call(messages=[self.created_msg])