Skip to content

Commit

Permalink
Add new fields to Metadata class (#2961)
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Feb 18, 2024
1 parent 1884222 commit da682a1
Show file tree
Hide file tree
Showing 16 changed files with 398 additions and 271 deletions.
6 changes: 4 additions & 2 deletions src/py/flwr/client/grpc_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,11 @@ def receive() -> Message:
metadata=Metadata(
run_id=0,
message_id=str(uuid.uuid4()),
src_node_id=0,
dst_node_id=0,
reply_to_message="",
group_id="",
ttl="",
node_id=0,
message_type=message_type,
),
content=recordset,
Expand Down Expand Up @@ -205,7 +207,7 @@ def send(message: Message) -> None:
disconnect_res=ClientMessage.DisconnectRes(reason=reason)
)
else:
raise ValueError(f"Invalid task type: {message_type}")
raise ValueError(f"Invalid message type: {message_type}")

# Send ClientMessage proto
return queue.put(msg_proto, block=False)
Expand Down
8 changes: 6 additions & 2 deletions src/py/flwr/client/grpc_client/connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@
metadata=Metadata(
run_id=0,
message_id="",
src_node_id=0,
dst_node_id=0,
reply_to_message="",
group_id="",
node_id=0,
ttl="",
message_type=MESSAGE_TYPE_GET_PROPERTIES,
),
Expand All @@ -59,8 +61,10 @@
metadata=Metadata(
run_id=0,
message_id="",
src_node_id=0,
dst_node_id=0,
reply_to_message="",
group_id="",
node_id=0,
ttl="",
message_type="reconnect",
),
Expand Down
56 changes: 28 additions & 28 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,17 @@


from contextlib import contextmanager
from copy import copy
from logging import DEBUG, ERROR
from pathlib import Path
from typing import Callable, Dict, Iterator, Optional, Tuple, Union, cast

from flwr.client.message_handler.task_handler import (
configure_task_res,
get_task_ins,
validate_task_ins,
validate_task_res,
)
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Message
from flwr.client.message_handler.message_handler import validate_out_message
from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
from flwr.common.grpc import create_channel
from flwr.common.logger import log, warn_experimental_feature
from flwr.common.message import Message, Metadata
from flwr.common.serde import message_from_taskins, message_to_taskres
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
Expand All @@ -41,7 +39,7 @@
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611

KEY_NODE = "node"
KEY_TASK_INS = "current_task_ins"
KEY_METADATA = "in_message_metadata"


def on_channel_state_change(channel_connectivity: str) -> None:
Expand Down Expand Up @@ -102,8 +100,8 @@ def grpc_request_response(
channel.subscribe(on_channel_state_change)
stub = FleetStub(channel)

# Necessary state to link TaskRes to TaskIns
state: Dict[str, Optional[TaskIns]] = {KEY_TASK_INS: None}
# Necessary state to validate messages to be sent
state: Dict[str, Optional[Metadata]] = {KEY_METADATA: None}

# Enable create_node and delete_node to store node
node_store: Dict[str, Optional[Node]] = {KEY_NODE: None}
Expand Down Expand Up @@ -149,45 +147,47 @@ def receive() -> Optional[Message]:
task_ins: Optional[TaskIns] = get_task_ins(response)

# Discard the current TaskIns if not valid
if task_ins is not None and not validate_task_ins(task_ins):
if task_ins is not None and not (
task_ins.task.consumer.node_id == node.node_id
and validate_task_ins(task_ins)
):
task_ins = None

# Remember `task_ins` until `task_res` is available
state[KEY_TASK_INS] = task_ins
# Construct the Message
in_message = message_from_taskins(task_ins) if task_ins else None

# Remember `metadata` of the in message
state[KEY_METADATA] = copy(in_message.metadata) if in_message else None

# Return the message if available
return message_from_taskins(task_ins) if task_ins is not None else None
return in_message

def send(message: Message) -> None:
"""Send task result back to server."""
# Get Node
if node_store[KEY_NODE] is None:
log(ERROR, "Node instance missing")
return
node: Node = cast(Node, node_store[KEY_NODE])

# Get incoming TaskIns
if state[KEY_TASK_INS] is None:
log(ERROR, "No current TaskIns")
# Get incoming message
in_metadata = state[KEY_METADATA]
if in_metadata is None:
log(ERROR, "No current message")
return

# Validate out message
if not validate_out_message(message, in_metadata):
log(ERROR, "Invalid out message")
return
task_ins: TaskIns = cast(TaskIns, state[KEY_TASK_INS])

# Construct TaskRes
task_res = message_to_taskres(message)

# Check if fields to be set are not initialized
if not validate_task_res(task_res):
state[KEY_TASK_INS] = None
log(ERROR, "TaskRes has been initialized accidentally")

# Configure TaskRes
task_res = configure_task_res(task_res, task_ins, node)

# Serialize ProtoBuf to bytes
request = PushTaskResRequest(task_res_list=[task_res])
_ = stub.PushTaskRes(request)

state[KEY_TASK_INS] = None
state[KEY_METADATA] = None

try:
# Yield methods
Expand Down
46 changes: 21 additions & 25 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,7 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]:
reason = cast(int, disconnect_msg.disconnect_res.reason)
recordset = RecordSet()
recordset.set_configs("config", ConfigsRecord({"reason": reason}))
out_message = Message(
metadata=Metadata(
run_id=0,
message_id="",
group_id="",
node_id=0,
ttl="",
message_type="reconnect",
),
content=recordset,
)
out_message = message.create_reply(recordset, ttl="")
# Return TaskRes and sleep duration
return out_message, sleep_duration

Expand All @@ -107,7 +97,7 @@ def handle_legacy_message_from_msgtype(
client_fn: ClientFn, message: Message, context: Context
) -> Message:
"""Handle legacy message in the inner most mod."""
client = client_fn(str(message.metadata.node_id))
client = client_fn(str(message.metadata.dst_node_id))

client.set_context(context)

Expand Down Expand Up @@ -144,21 +134,10 @@ def handle_legacy_message_from_msgtype(
)
out_recordset = evaluateres_to_recordset(evaluate_res)
else:
raise ValueError(f"Invalid task type: {message_type}")
raise ValueError(f"Invalid message type: {message_type}")

# Return Message
out_message = Message(
metadata=Metadata(
run_id=0,
message_id="",
group_id="",
node_id=0,
ttl="",
message_type=message_type,
),
content=out_recordset,
)
return out_message
return message.create_reply(out_recordset, ttl="")


def _reconnect(
Expand All @@ -173,3 +152,20 @@ def _reconnect(
# Build DisconnectRes message
disconnect_res = ClientMessage.DisconnectRes(reason=reason)
return ClientMessage(disconnect_res=disconnect_res), sleep_duration


def validate_out_message(out_message: Message, in_message_metadata: Metadata) -> bool:
"""Validate the out message."""
out_meta = out_message.metadata
in_meta = in_message_metadata
if ( # pylint: disable-next=too-many-boolean-expressions
out_meta.run_id == in_meta.run_id
and out_meta.message_id == "" # This will be generated by the server
and out_meta.src_node_id == in_meta.dst_node_id
and out_meta.dst_node_id == in_meta.src_node_id
and out_meta.reply_to_message == in_meta.message_id
and out_meta.group_id == in_meta.group_id
and out_meta.message_type == in_meta.message_type
):
return True
return False
117 changes: 106 additions & 11 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
"""Client-side message handler tests."""


import unittest
import uuid
from copy import copy
from typing import List

from flwr.client import Client
from flwr.client.typing import ClientFn
Expand All @@ -40,7 +43,7 @@
from flwr.common import typing
from flwr.common.constant import MESSAGE_TYPE_GET_PROPERTIES

from .message_handler import handle_legacy_message_from_msgtype
from .message_handler import handle_legacy_message_from_msgtype, validate_out_message


class ClientWithoutProps(Client):
Expand Down Expand Up @@ -122,10 +125,12 @@ def test_client_without_get_properties() -> None:
recordset = compat.getpropertiesins_to_recordset(GetPropertiesIns({}))
message = Message(
metadata=Metadata(
run_id=0,
run_id=123,
message_id=str(uuid.uuid4()),
group_id="",
node_id=0,
group_id="some group ID",
src_node_id=0,
dst_node_id=1123,
reply_to_message="",
ttl="",
message_type=MESSAGE_TYPE_GET_PROPERTIES,
),
Expand All @@ -148,10 +153,22 @@ def test_client_without_get_properties() -> None:
properties={},
)
expected_rs = compat.getpropertiesres_to_recordset(expected_get_properties_res)
expected_msg = Message(message.metadata, expected_rs)
expected_msg = Message(
metadata=Metadata(
run_id=123,
message_id="",
group_id="some group ID",
src_node_id=1123,
dst_node_id=0,
reply_to_message=message.metadata.message_id,
ttl="",
message_type=MESSAGE_TYPE_GET_PROPERTIES,
),
content=expected_rs,
)

assert actual_msg.content == expected_msg.content
assert actual_msg.metadata.message_type == expected_msg.metadata.message_type
assert actual_msg.metadata == expected_msg.metadata


def test_client_with_get_properties() -> None:
Expand All @@ -161,10 +178,12 @@ def test_client_with_get_properties() -> None:
recordset = compat.getpropertiesins_to_recordset(GetPropertiesIns({}))
message = Message(
metadata=Metadata(
run_id=0,
run_id=123,
message_id=str(uuid.uuid4()),
group_id="",
node_id=0,
group_id="some group ID",
src_node_id=0,
dst_node_id=1123,
reply_to_message="",
ttl="",
message_type=MESSAGE_TYPE_GET_PROPERTIES,
),
Expand All @@ -187,7 +206,83 @@ def test_client_with_get_properties() -> None:
properties={"str_prop": "val", "int_prop": 1},
)
expected_rs = compat.getpropertiesres_to_recordset(expected_get_properties_res)
expected_msg = Message(message.metadata, expected_rs)
expected_msg = Message(
metadata=Metadata(
run_id=123,
message_id="",
group_id="some group ID",
src_node_id=1123,
dst_node_id=0,
reply_to_message=message.metadata.message_id,
ttl="",
message_type=MESSAGE_TYPE_GET_PROPERTIES,
),
content=expected_rs,
)

assert actual_msg.content == expected_msg.content
assert actual_msg.metadata.message_type == expected_msg.metadata.message_type
assert actual_msg.metadata == expected_msg.metadata


class TestMessageValidation(unittest.TestCase):
"""Test message validation."""

def setUp(self) -> None:
"""Set up the message validation."""
# Common setup for tests
self.in_metadata = Metadata(
run_id=123,
message_id="qwerty",
src_node_id=10,
dst_node_id=20,
reply_to_message="",
group_id="group1",
ttl="60",
message_type="mock",
)
self.valid_out_metadata = Metadata(
run_id=123,
message_id="",
src_node_id=20,
dst_node_id=10,
reply_to_message="qwerty",
group_id="group1",
ttl="60",
message_type="mock",
)
self.common_content = RecordSet()

def test_valid_message(self) -> None:
"""Test a valid message."""
# Prepare
valid_message = Message(metadata=self.valid_out_metadata, content=RecordSet())

# Assert
self.assertTrue(validate_out_message(valid_message, self.in_metadata))

def test_invalid_message_run_id(self) -> None:
"""Test invalid messages."""
# Prepare
msg = Message(metadata=self.valid_out_metadata, content=RecordSet())

# Execute
invalid_metadata_list: List[Metadata] = []
attrs = list(vars(self.valid_out_metadata).keys())
for attr in attrs:
if attr == "_ttl": # Skip configurable ttl
continue
# Make an invalid metadata
invalid_metadata = copy(self.valid_out_metadata)
value = getattr(invalid_metadata, attr)
if isinstance(value, int):
value = 999
elif isinstance(value, str):
value = "999"
setattr(invalid_metadata, attr, value)
# Add to list
invalid_metadata_list.append(invalid_metadata)

# Assert
for invalid_metadata in invalid_metadata_list:
msg._metadata = invalid_metadata # pylint: disable=protected-access
self.assertFalse(validate_out_message(msg, self.in_metadata))
Loading

0 comments on commit da682a1

Please sign in to comment.