Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ConfigsRecord #2803

Merged
merged 30 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a749a0e
v0
jafermarq Jan 15, 2024
29d1a41
w/ previous
jafermarq Jan 15, 2024
b399c38
new `parametersrecord.py`; ranamed `Tensor`->`Array`; more
jafermarq Jan 16, 2024
44de0e7
updates
jafermarq Jan 16, 2024
d7b23db
`MetricsRecord` init with tests
jafermarq Jan 16, 2024
7ed2b82
w/ previous
jafermarq Jan 16, 2024
41e43fe
v0 `ConfigsRecord`
jafermarq Jan 16, 2024
46e1f3a
fix
jafermarq Jan 16, 2024
b6bd15c
Merge branch 'add-metricsrecord' into add-configsrecord
jafermarq Jan 16, 2024
3fe3728
w/ previous
jafermarq Jan 16, 2024
092a74b
merge w/ main; tweaks
jafermarq Jan 17, 2024
32b1155
better tests; definitions in typing
jafermarq Jan 17, 2024
36d3f3b
Merge branch 'main' into add-metricsrecord
jafermarq Jan 17, 2024
d3316f8
double space top of file
jafermarq Jan 17, 2024
a79aab8
w/ previous
jafermarq Jan 17, 2024
2552957
merge; `ConfigsRecord` as a `MetricsRecord`
jafermarq Jan 17, 2024
c08bdef
w/ previous
jafermarq Jan 17, 2024
22d48f4
no `str` in `MetricsRecords` values
jafermarq Jan 18, 2024
c54f438
Merge branch 'add-metricsrecord' into add-configsrecord
jafermarq Jan 18, 2024
3823ed2
update
jafermarq Jan 18, 2024
376ffb8
fix docstrings
jafermarq Jan 18, 2024
3a62bb5
Merge branch 'main' into add-metricsrecord
jafermarq Jan 18, 2024
b18b05d
more info in TypeError messsage
jafermarq Jan 18, 2024
8888c1e
Merge branch 'add-metricsrecord' into add-configsrecord
jafermarq Jan 18, 2024
6389def
more info in TypeError messsage
jafermarq Jan 18, 2024
12ed920
updates; more tests
jafermarq Jan 18, 2024
0bc9f09
fixes; support for bytes, List[bytes]
jafermarq Jan 18, 2024
0278ed5
merged w/ main
jafermarq Jan 18, 2024
e7e8cf6
Update src/py/flwr/common/configsrecord.py
danieljanes Jan 18, 2024
f058c50
extra type test
jafermarq Jan 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions src/py/flwr/common/configsrecord.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ConfigsRecord."""


from dataclasses import dataclass, field
from typing import Dict, Optional, get_args

from .typing import ConfigsRecordValues, ConfigsScalar


@dataclass
class ConfigsRecord:
"""Configs record."""

keep_input: bool
data: Dict[str, ConfigsRecordValues] = field(default_factory=dict)

def __init__(
self,
configs_dict: Optional[Dict[str, ConfigsRecordValues]] = None,
keep_input: bool = True,
):
"""Construct a ConfigsRecord object.

Parameters
----------
configs_dict : Optional[Dict[str, ConfigsRecordValues]]
A dictionary that stores basic types (i.e. `str`, `int`, `float`, `bytes` as
defined in `ConfigsScalar`) and lists of such types (see
`ConfigsScalarList`).
keep_input : bool (default: True)
A boolean indicating whether config passed should be deleted from the input
dictionary immediately after adding them to the record. When set
to True, the data is duplicated in memory. If memory is a concern, set
it to False.
"""
self.keep_input = keep_input
self.data = {}
if configs_dict:
self.set_configs(configs_dict)

def set_configs(self, configs_dict: Dict[str, ConfigsRecordValues]) -> None:
"""Add configs to the record.

Parameters
----------
configs_dict : Dict[str, ConfigsRecordValues]
A dictionary that stores basic types (i.e. `str`,`int`, `float`, `bytes` as
defined in `ConfigsRecordValues`) and list of such types (see
`ConfigsScalarList`).
"""
if any(not isinstance(k, str) for k in configs_dict.keys()):
raise TypeError(f"Not all keys are of valid type. Expected {str}")

def is_valid(value: ConfigsScalar) -> None:
"""Check if value is of expected type."""
if not isinstance(value, get_args(ConfigsScalar)):
raise TypeError(
"Not all values are of valid type."
f" Expected {ConfigsRecordValues} but you passed {type(value)}."
)

# Check types of values
# Split between those values that are list and those that aren't
# then process in the same way
for value in configs_dict.values():
if isinstance(value, list):
# If your lists are large (e.g. 1M+ elements) this will be slow
# 1s to check 10M element list on a M2 Pro
# In such settings, you'd be better of treating such config as
# an array and pass it to a ParametersRecord.
for list_value in value:
is_valid(list_value)
else:
is_valid(value)

# Add configs to record
if self.keep_input:
# Copy
self.data = configs_dict.copy()
else:
# Add entries to dataclass without duplicating memory
for key in list(configs_dict.keys()):
self.data[key] = configs_dict[key]
del configs_dict[key]
7 changes: 2 additions & 5 deletions src/py/flwr/common/recordset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,15 @@
# ==============================================================================
"""RecordSet."""


from dataclasses import dataclass, field
from typing import Dict

from .configsrecord import ConfigsRecord
from .metricsrecord import MetricsRecord
from .parametersrecord import ParametersRecord


@dataclass
class ConfigsRecord:
"""Configs record."""


@dataclass
class RecordSet:
"""Definition of RecordSet."""
Expand Down
80 changes: 79 additions & 1 deletion src/py/flwr/common/recordset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,21 @@
import numpy as np
import pytest

from .configsrecord import ConfigsRecord
from .metricsrecord import MetricsRecord
from .parameter import ndarrays_to_parameters, parameters_to_ndarrays
from .parametersrecord import Array, ParametersRecord
from .recordset_utils import (
parameters_to_parametersrecord,
parametersrecord_to_parameters,
)
from .typing import MetricsRecordValues, NDArray, NDArrays, Parameters
from .typing import (
ConfigsRecordValues,
MetricsRecordValues,
NDArray,
NDArrays,
Parameters,
)


def get_ndarrays() -> NDArrays:
Expand Down Expand Up @@ -255,3 +262,74 @@ def test_set_metrics_to_metricsrecord_with_and_without_keeping_input(
else:
assert my_metrics_copy == m_record.data
assert len(my_metrics) == 0


@pytest.mark.parametrize(
"key_type, value_fn",
[
(str, lambda x: str(x.flatten()[0])), # str: str
(str, lambda x: int(x.flatten()[0])), # str: int
(str, lambda x: float(x.flatten()[0])), # str: float
(str, lambda x: x.flatten().tobytes()), # str: bytes
(str, lambda x: x.flatten().astype("str").tolist()), # str: List[str]
(str, lambda x: x.flatten().astype("int").tolist()), # str: List[int]
(str, lambda x: x.flatten().astype("float").tolist()), # str: List[float]
(str, lambda x: [x.flatten().tobytes()]), # str: List[bytes]
jafermarq marked this conversation as resolved.
Show resolved Hide resolved
],
)
def test_set_configs_to_configsrecord_with_correct_types(
key_type: Type[str],
value_fn: Callable[[NDArray], ConfigsRecordValues],
) -> None:
"""Test adding configs of various types to a ConfigsRecord."""
labels = [1, 2.0]
arrays = get_ndarrays()

my_configs = OrderedDict(
{key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)}
)

c_record = ConfigsRecord(my_configs)

# check values are actually there
assert c_record.data == my_configs


@pytest.mark.parametrize(
"key_type, value_fn",
[
(str, lambda x: x), # str: NDArray (supported: unsupported)
(
str,
lambda x: {str(v): v for v in x.flatten()},
), # str: dict[str: float] (supported: unsupported)
(
str,
lambda x: [{str(v): v for v in x.flatten()}],
), # str: List[dict[str: float]] (supported: unsupported)
(
int,
lambda x: x.flatten().tolist(),
), # int: List[str] (unsupported: supported)
(
float,
lambda x: x.flatten().tolist(),
), # float: List[int] (unsupported: supported)
],
)
def test_set_configs_to_configsrecord_with_incorrect_types(
key_type: Type[Union[str, int, float]],
value_fn: Callable[[NDArray], Union[NDArray, Dict[str, NDArray], List[float]]],
) -> None:
"""Test adding configs of various unsupported types to a ConfigsRecord."""
m_record = ConfigsRecord()

labels = [1, 2.0]
arrays = get_ndarrays()

my_metrics = OrderedDict(
{key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)}
)

with pytest.raises(TypeError):
m_record.set_configs(my_metrics) # type: ignore
4 changes: 4 additions & 0 deletions src/py/flwr/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@
MetricsScalar = Union[int, float]
MetricsScalarList = Union[List[int], List[float]]
MetricsRecordValues = Union[MetricsScalar, MetricsScalarList]
# Value types for common.ConfigsRecord
ConfigsScalar = Union[MetricsScalar, str, bytes]
ConfigsScalarList = Union[MetricsScalarList, List[str], List[bytes]]
ConfigsRecordValues = Union[ConfigsScalar, ConfigsScalarList]

Metrics = Dict[str, Scalar]
MetricsAggregationFn = Callable[[List[Tuple[int, Metrics]]], Metrics]
Expand Down