Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add __eq__ and __repr__ to classes #3375

Merged
merged 8 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions src/py/flwr/common/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@

import time
import warnings
from dataclasses import dataclass
from typing import Optional, cast

from .record import RecordSet

DEFAULT_TTL = 3600


@dataclass
class Metadata: # pylint: disable=too-many-instance-attributes
"""A dataclass holding metadata associated with the current message.

Expand Down Expand Up @@ -161,8 +159,18 @@ def partition_id(self, value: int) -> None:
"""Set partition_id."""
self.__dict__["_partition_id"] = value

def __repr__(self) -> str:
"""Return a string representation of this instance."""
view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
return f"{self.__class__.__qualname__}({view})"

def __eq__(self, other: object) -> bool:
"""Compare two instances of the class."""
if not isinstance(other, self.__class__):
raise NotImplementedError
return self.__dict__ == other.__dict__


@dataclass
class Error:
"""A dataclass that stores information about an error that occurred.

Expand Down Expand Up @@ -191,8 +199,18 @@ def reason(self) -> str | None:
"""Reason reported about the error."""
return cast(Optional[str], self.__dict__["_reason"])

def __repr__(self) -> str:
"""Return a string representation of this instance."""
view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
return f"{self.__class__.__qualname__}({view})"

def __eq__(self, other: object) -> bool:
"""Compare two instances of the class."""
if not isinstance(other, self.__class__):
raise NotImplementedError
return self.__dict__ == other.__dict__


@dataclass
class Message:
"""State of your application from the viewpoint of the entity using it.

Expand Down Expand Up @@ -357,6 +375,17 @@ def create_reply(self, content: RecordSet, ttl: float | None = None) -> Message:

return message

def __repr__(self) -> str:
"""Return a string representation of this instance."""
view = ", ".join(
[
f"{k.lstrip('_')}={v!r}"
for k, v in self.__dict__.items()
if v is not None
]
)
return f"{self.__class__.__qualname__}({view})"


def _create_reply_metadata(msg: Message, ttl: float) -> Metadata:
"""Construct metadata for a reply message."""
Expand Down
50 changes: 48 additions & 2 deletions src/py/flwr/common/message_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
"""Message tests."""

import time
from collections import namedtuple
from contextlib import ExitStack
from typing import Any, Callable, Optional
from typing import Any, Callable, Dict, Optional

import pytest

# pylint: enable=E0611
from . import RecordSet
from .message import Error, Message
from .message import Error, Message, Metadata
from .serde_test import RecordMaker


Expand Down Expand Up @@ -157,3 +158,48 @@ def test_create_reply(
assert message.metadata.src_node_id == reply_message.metadata.dst_node_id
assert message.metadata.dst_node_id == reply_message.metadata.src_node_id
assert reply_message.metadata.reply_to_message == message.metadata.message_id


@pytest.mark.parametrize(
"cls, kwargs",
[
(
Metadata,
{
"run_id": 123,
"message_id": "msg_456",
"src_node_id": 1,
"dst_node_id": 2,
"reply_to_message": "reply_789",
"group_id": "group_xyz",
"ttl": 10.0,
"message_type": "request",
"partition_id": None,
},
),
(Error, {"code": 1, "reason": "reason_098"}),
(
Message,
{
"metadata": RecordMaker(1).metadata(),
"content": RecordMaker(1).recordset(1, 1, 1),
},
),
(
Message,
{
"metadata": RecordMaker(2).metadata(),
"error": Error(0, "some reason"),
},
),
],
)
def test_repr(cls: type, kwargs: Dict[str, Any]) -> None:
"""Test string representations of Metadata/Message/Error."""
# Prepare
anon_cls = namedtuple(cls.__qualname__, kwargs.keys()) # type: ignore
expected = anon_cls(**kwargs)
actual = cls(**kwargs)

# Assert
assert str(actual) == str(expected)
1 change: 0 additions & 1 deletion src/py/flwr/common/record/parametersrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def _check_value(value: Array) -> None:
)


@dataclass
class ParametersRecord(TypedDict[str, Array]):
"""Parameters record.

Expand Down
13 changes: 12 additions & 1 deletion src/py/flwr/common/record/recordset.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def _check_fn_configs(self, record: ConfigsRecord) -> None:
)


@dataclass
class RecordSet:
"""RecordSet stores groups of parameters, metrics and configs."""

Expand Down Expand Up @@ -117,3 +116,15 @@ def configs_records(self) -> TypedDict[str, ConfigsRecord]:
"""Dictionary holding ConfigsRecord instances."""
data = cast(RecordSetData, self.__dict__["_data"])
return data.configs_records

def __repr__(self) -> str:
"""Return a string representation of this instance."""
flds = ("parameters_records", "metrics_records", "configs_records")
view = ", ".join([f"{fld}={getattr(self, fld)!r}" for fld in flds])
return f"{self.__class__.__qualname__}({view})"

def __eq__(self, other: object) -> bool:
"""Compare two instances of the class."""
if not isinstance(other, self.__class__):
raise NotImplementedError
return self.__dict__ == other.__dict__
16 changes: 16 additions & 0 deletions src/py/flwr/common/record/recordset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""RecordSet tests."""

import pickle
from collections import namedtuple
from copy import deepcopy
from typing import Callable, Dict, List, OrderedDict, Type, Union

Expand Down Expand Up @@ -414,3 +415,18 @@ def test_record_is_picklable() -> None:

# Execute
pickle.dumps((p_record, m_record, c_record, rs))


def test_recordset_repr() -> None:
"""Test the string representation of RecordSet."""
# Prepare
kwargs = {
"parameters_records": {"params": ParametersRecord()},
"metrics_records": {"metrics": MetricsRecord({"aa": 123})},
"configs_records": {"configs": ConfigsRecord({"cc": bytes(9)})},
}
rs = RecordSet(**kwargs) # type: ignore
expected = namedtuple("RecordSet", kwargs.keys())(**kwargs)

# Assert
assert str(rs) == str(expected)