From a749a0ef7bd99ec41c4165b3f90435b9dacb5ac1 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Mon, 15 Jan 2024 22:36:18 +0000 Subject: [PATCH 01/21] v0 --- src/py/flwr/common/recordset.py | 40 ++++++++++-- src/py/flwr/common/recordset_test.py | 89 +++++++++++++++++++++++++++ src/py/flwr/common/recordset_utils.py | 56 +++++++++++++++++ 3 files changed, 180 insertions(+), 5 deletions(-) create mode 100644 src/py/flwr/common/recordset_test.py create mode 100644 src/py/flwr/common/recordset_utils.py diff --git a/src/py/flwr/common/recordset.py b/src/py/flwr/common/recordset.py index 0088b7397a6d..e39865f02b1b 100644 --- a/src/py/flwr/common/recordset.py +++ b/src/py/flwr/common/recordset.py @@ -14,14 +14,44 @@ # ============================================================================== """RecordSet.""" -from dataclasses import dataclass -from typing import Dict +from dataclasses import dataclass, field +from typing import Dict, List + + +@dataclass +class Tensor: + """Tensor type.""" + + data: bytes + dtype: str + shape: List[int] + ref: str = "" # future functionality @dataclass class ParametersRecord: """Parameters record.""" + data: Dict[str, Tensor] = field(default_factory=dict) + + def add_parameters(self, tensor_dict: Dict[str, Tensor]) -> None: + """Add parameters to record. + + This not implemented as a constructor so we can cleanly create and empyt + ParametersRecord object. + """ + if any(not isinstance(k, str) for k in tensor_dict.keys()): + raise TypeError(f"Not all keys are of valide type. Expected {str}") + if any(not isinstance(v, Tensor) for v in tensor_dict.values()): + raise TypeError(f"Not all values are of valide type. Expected {Tensor}") + + # Add entries to dataclass without duplicating memor footprint + for key in list(tensor_dict.keys()): + self.data[key] = tensor_dict[key] + del tensor_dict[key] + + self.data = tensor_dict + @dataclass class MetricsRecord: @@ -37,9 +67,9 @@ class ConfigsRecord: class RecordSet: """Definition of RecordSet.""" - parameters: Dict[str, ParametersRecord] = {} - metrics: Dict[str, MetricsRecord] = {} - configs: Dict[str, ConfigsRecord] = {} + parameters: Dict[str, ParametersRecord] = field(default_factory=dict) + metrics: Dict[str, MetricsRecord] = field(default_factory=dict) + configs: Dict[str, ConfigsRecord] = field(default_factory=dict) def set_parameters(self, name: str, record: ParametersRecord) -> None: """Add a ParametersRecord.""" diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py new file mode 100644 index 000000000000..9f459d7f937d --- /dev/null +++ b/src/py/flwr/common/recordset_test.py @@ -0,0 +1,89 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RecordSet tests.""" +import secrets + +import numpy as np + +from .parameter import ndarrays_to_parameters, parameters_to_ndarrays +from .recordset import Tensor +from .recordset_utils import ( + parameters_to_parametersrecord, + parametersrecord_to_parameters, +) +from .typing import NDArrays, Parameters + + +def get_ndarrays() -> NDArrays: + """Return list of NumPy arrays.""" + arr1 = np.array([[1.0, 2.0], [3.0, 4], [5.0, 6.0]]) + arr2 = np.eye(2, 7, 3) + + return [arr1, arr2] + + +def test_ndarray_to_tensor() -> None: + """Test creation of Tensor object from NumPy array.""" + shape = (2, 7, 9) + arr = np.eye(*shape) + + tensor = Tensor( + arr.tobytes(), + dtype=str(arr.dtype), + shape=list(arr.shape), + ref=secrets.token_hex(16), + ) + + arr_ = np.frombuffer(buffer=tensor.data, dtype=tensor.dtype).reshape(tensor.shape) + + assert np.allclose(arr, arr_) + + +def test_parameters_to_tensor_and_back() -> None: + """Test conversion between legacy Parameters and Tensor.""" + ndarrays = get_ndarrays() + + # Tensor represents a single array, unlike Paramters, which represent a + # list of arrays + ndarray = ndarrays[0] + + parameters = ndarrays_to_parameters([ndarray]) + + tensor = Tensor(data=parameters.tensors[0], dtype=parameters.tensor_type, shape=[]) + + parameters = Parameters(tensors=[tensor.data], tensor_type=tensor.dtype) + + ndarray_ = parameters_to_ndarrays(parameters=parameters)[0] + + assert np.allclose(ndarray, ndarray_) + + +def test_parameters_to_parametersrecord_and_back() -> None: + """Test utility function to convert between legacy Parameters. + + and ParametersRecords. + """ + ndarrays = get_ndarrays() + + parameters = ndarrays_to_parameters(ndarrays) + + params_record = parameters_to_parametersrecord(parameters=parameters) + + parameters_ = parametersrecord_to_parameters(params_record) + + ndarrays_ = parameters_to_ndarrays(parameters=parameters_) + + for arr, arr_ in zip(ndarrays, ndarrays_): + assert np.allclose(arr, arr_) diff --git a/src/py/flwr/common/recordset_utils.py b/src/py/flwr/common/recordset_utils.py new file mode 100644 index 000000000000..af8f9505e80b --- /dev/null +++ b/src/py/flwr/common/recordset_utils.py @@ -0,0 +1,56 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RecordSet utilities.""" + +from secrets import token_hex + +from .recordset import ParametersRecord, Tensor +from .typing import Parameters + + +def parametersrecord_to_parameters(record: ParametersRecord) -> Parameters: + """Convert ParameterRecord to legacy Parameters. + + The data in ParameterRecord will be freed. Because legacy Parameters do not keep + names of tensors, this information will be discarded. + """ + parameters = Parameters(tensors=[], tensor_type="") + + for key in list(record.data.keys()): + parameters.tensors.append(record.data[key].data) + + del record.data[key] + + return parameters + + +def parameters_to_parametersrecord(parameters: Parameters) -> ParametersRecord: + """Convert legacy Parameters into a single ParametersRecord. + + The memory ocupied by inputed parameters will be freed. Because there is no concept + of names in the legacy Paramters, arbitrary keys will be used when constructing the + ParametersRecord. Similarly, the shape won't be recorded in the Tensor objects. + """ + tensor_type = parameters.tensor_type + + p_record = ParametersRecord() + + for _ in range(len(parameters.tensors)): + tensor = parameters.tensors.pop(0) + p_record.add_parameters( + {token_hex(8): Tensor(data=tensor, dtype=tensor_type, shape=[])} + ) + + return p_record From 29d1a411eec6cb900e8e8cf7948906647da03431 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Mon, 15 Jan 2024 23:02:52 +0000 Subject: [PATCH 02/21] w/ previous --- src/py/flwr/common/recordset.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/py/flwr/common/recordset.py b/src/py/flwr/common/recordset.py index e39865f02b1b..4640365abf4d 100644 --- a/src/py/flwr/common/recordset.py +++ b/src/py/flwr/common/recordset.py @@ -45,13 +45,11 @@ def add_parameters(self, tensor_dict: Dict[str, Tensor]) -> None: if any(not isinstance(v, Tensor) for v in tensor_dict.values()): raise TypeError(f"Not all values are of valide type. Expected {Tensor}") - # Add entries to dataclass without duplicating memor footprint + # Add entries to dataclass without duplicating memory for key in list(tensor_dict.keys()): self.data[key] = tensor_dict[key] del tensor_dict[key] - self.data = tensor_dict - @dataclass class MetricsRecord: From b399c3806de6bb71704f547f33e151b791cdc70d Mon Sep 17 00:00:00 2001 From: jafermarq Date: Tue, 16 Jan 2024 11:34:06 +0000 Subject: [PATCH 03/21] new `parametersrecord.py`; ranamed `Tensor`->`Array`; more --- src/py/flwr/common/parametersrecord.py | 52 ++++++++++++++++++++++++++ src/py/flwr/common/recordset.py | 35 +---------------- src/py/flwr/common/recordset_test.py | 21 ++++++----- src/py/flwr/common/recordset_utils.py | 10 ++--- 4 files changed, 70 insertions(+), 48 deletions(-) create mode 100644 src/py/flwr/common/parametersrecord.py diff --git a/src/py/flwr/common/parametersrecord.py b/src/py/flwr/common/parametersrecord.py new file mode 100644 index 000000000000..783aed0ea466 --- /dev/null +++ b/src/py/flwr/common/parametersrecord.py @@ -0,0 +1,52 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ParametersRecord and Tensor.""" + +from dataclasses import dataclass, field +from typing import Dict, List + + +@dataclass +class Array: + """Array type.""" + + data: bytes + dtype: str + stype: str + shape: List[int] + ref: str = "" # future functionality + + +@dataclass +class ParametersRecord: + """Parameters record.""" + + data: Dict[str, Array] = field(default_factory=dict) + + def add_parameters(self, tensor_dict: Dict[str, Array]) -> None: + """Add parameters to record. + + This not implemented as a constructor so we can cleanly create and empyt + ParametersRecord object. + """ + if any(not isinstance(k, str) for k in tensor_dict.keys()): + raise TypeError(f"Not all keys are of valide type. Expected {str}") + if any(not isinstance(v, Array) for v in tensor_dict.values()): + raise TypeError(f"Not all values are of valide type. Expected {Array}") + + # Add entries to dataclass without duplicating memory + for key in list(tensor_dict.keys()): + self.data[key] = tensor_dict[key] + del tensor_dict[key] diff --git a/src/py/flwr/common/recordset.py b/src/py/flwr/common/recordset.py index 4640365abf4d..dc723a2cea86 100644 --- a/src/py/flwr/common/recordset.py +++ b/src/py/flwr/common/recordset.py @@ -15,40 +15,9 @@ """RecordSet.""" from dataclasses import dataclass, field -from typing import Dict, List +from typing import Dict - -@dataclass -class Tensor: - """Tensor type.""" - - data: bytes - dtype: str - shape: List[int] - ref: str = "" # future functionality - - -@dataclass -class ParametersRecord: - """Parameters record.""" - - data: Dict[str, Tensor] = field(default_factory=dict) - - def add_parameters(self, tensor_dict: Dict[str, Tensor]) -> None: - """Add parameters to record. - - This not implemented as a constructor so we can cleanly create and empyt - ParametersRecord object. - """ - if any(not isinstance(k, str) for k in tensor_dict.keys()): - raise TypeError(f"Not all keys are of valide type. Expected {str}") - if any(not isinstance(v, Tensor) for v in tensor_dict.values()): - raise TypeError(f"Not all values are of valide type. Expected {Tensor}") - - # Add entries to dataclass without duplicating memory - for key in list(tensor_dict.keys()): - self.data[key] = tensor_dict[key] - del tensor_dict[key] +from .parametersrecord import ParametersRecord @dataclass diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py index 9f459d7f937d..a87149255085 100644 --- a/src/py/flwr/common/recordset_test.py +++ b/src/py/flwr/common/recordset_test.py @@ -18,7 +18,7 @@ import numpy as np from .parameter import ndarrays_to_parameters, parameters_to_ndarrays -from .recordset import Tensor +from .parametersrecord import Array from .recordset_utils import ( parameters_to_parametersrecord, parametersrecord_to_parameters, @@ -34,14 +34,15 @@ def get_ndarrays() -> NDArrays: return [arr1, arr2] -def test_ndarray_to_tensor() -> None: - """Test creation of Tensor object from NumPy array.""" +def test_ndarray_to_array() -> None: + """Test creation of Array object from NumPy array.""" shape = (2, 7, 9) arr = np.eye(*shape) - tensor = Tensor( + tensor = Array( arr.tobytes(), dtype=str(arr.dtype), + stype="np.tobytes", shape=list(arr.shape), ref=secrets.token_hex(16), ) @@ -51,19 +52,21 @@ def test_ndarray_to_tensor() -> None: assert np.allclose(arr, arr_) -def test_parameters_to_tensor_and_back() -> None: - """Test conversion between legacy Parameters and Tensor.""" +def test_parameters_to_array_and_back() -> None: + """Test conversion between legacy Parameters and Array.""" ndarrays = get_ndarrays() - # Tensor represents a single array, unlike Paramters, which represent a + # Array represents a single array, unlike Paramters, which represent a # list of arrays ndarray = ndarrays[0] parameters = ndarrays_to_parameters([ndarray]) - tensor = Tensor(data=parameters.tensors[0], dtype=parameters.tensor_type, shape=[]) + array = Array( + data=parameters.tensors[0], dtype=parameters.tensor_type, stype="", shape=[] + ) - parameters = Parameters(tensors=[tensor.data], tensor_type=tensor.dtype) + parameters = Parameters(tensors=[array.data], tensor_type=array.dtype) ndarray_ = parameters_to_ndarrays(parameters=parameters)[0] diff --git a/src/py/flwr/common/recordset_utils.py b/src/py/flwr/common/recordset_utils.py index af8f9505e80b..13ef2a91fe94 100644 --- a/src/py/flwr/common/recordset_utils.py +++ b/src/py/flwr/common/recordset_utils.py @@ -14,9 +14,7 @@ # ============================================================================== """RecordSet utilities.""" -from secrets import token_hex - -from .recordset import ParametersRecord, Tensor +from .parametersrecord import Array, ParametersRecord from .typing import Parameters @@ -41,16 +39,16 @@ def parameters_to_parametersrecord(parameters: Parameters) -> ParametersRecord: The memory ocupied by inputed parameters will be freed. Because there is no concept of names in the legacy Paramters, arbitrary keys will be used when constructing the - ParametersRecord. Similarly, the shape won't be recorded in the Tensor objects. + ParametersRecord. Similarly, the shape won't be recorded in the Array objects. """ tensor_type = parameters.tensor_type p_record = ParametersRecord() - for _ in range(len(parameters.tensors)): + for idx in range(len(parameters.tensors)): tensor = parameters.tensors.pop(0) p_record.add_parameters( - {token_hex(8): Tensor(data=tensor, dtype=tensor_type, shape=[])} + {str(idx): Array(data=tensor, dtype=tensor_type, stype="", shape=[])} ) return p_record From 44de0e752d3d5bc97e3e48086b96f92f999ec878 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Tue, 16 Jan 2024 14:35:57 +0000 Subject: [PATCH 04/21] updates --- src/py/flwr/common/parametersrecord.py | 3 +- src/py/flwr/common/recordset_test.py | 39 ++++++++++++++++++++------ 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/src/py/flwr/common/parametersrecord.py b/src/py/flwr/common/parametersrecord.py index 783aed0ea466..076f909fa1b2 100644 --- a/src/py/flwr/common/parametersrecord.py +++ b/src/py/flwr/common/parametersrecord.py @@ -14,6 +14,7 @@ # ============================================================================== """ParametersRecord and Tensor.""" +from collections import OrderedDict from dataclasses import dataclass, field from typing import Dict, List @@ -33,7 +34,7 @@ class Array: class ParametersRecord: """Parameters record.""" - data: Dict[str, Array] = field(default_factory=dict) + data: OrderedDict[str, Array] = field(default_factory=OrderedDict[str, Array]) def add_parameters(self, tensor_dict: Dict[str, Array]) -> None: """Add parameters to record. diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py index a87149255085..ba3bb3f5cefd 100644 --- a/src/py/flwr/common/recordset_test.py +++ b/src/py/flwr/common/recordset_test.py @@ -23,7 +23,7 @@ parameters_to_parametersrecord, parametersrecord_to_parameters, ) -from .typing import NDArrays, Parameters +from .typing import NDArray, NDArrays, Parameters def get_ndarrays() -> NDArrays: @@ -34,20 +34,25 @@ def get_ndarrays() -> NDArrays: return [arr1, arr2] +def nparray_to_array(np_array: NDArray) -> Array: + """Represent NumPy array as Array.""" + return Array( + np_array.tobytes(), + dtype=str(np_array.dtype), + stype="np.tobytes", + shape=list(np_array.shape), + ref=secrets.token_hex(16), + ) + + def test_ndarray_to_array() -> None: """Test creation of Array object from NumPy array.""" shape = (2, 7, 9) arr = np.eye(*shape) - tensor = Array( - arr.tobytes(), - dtype=str(arr.dtype), - stype="np.tobytes", - shape=list(arr.shape), - ref=secrets.token_hex(16), - ) + array = nparray_to_array(arr) - arr_ = np.frombuffer(buffer=tensor.data, dtype=tensor.dtype).reshape(tensor.shape) + arr_ = np.frombuffer(buffer=array.data, dtype=array.dtype).reshape(array.shape) assert np.allclose(arr, arr_) @@ -90,3 +95,19 @@ def test_parameters_to_parametersrecord_and_back() -> None: for arr, arr_ in zip(ndarrays, ndarrays_): assert np.allclose(arr, arr_) + + +# def test_torch_statedict_to_parametersrecord() -> None: +# """.""" +# import torch +# from .parametersrecord import ParametersRecord + +# layer = torch.nn.Conv2d(3, 5, 16) +# layer_sd = layer.state_dict() + +# p_c = ParametersRecord() + +# for k in layer_sd.keys(): +# layer_sd[k] = nparray_to_array(layer_sd[k].numpy()) + +# p_c.add_parameters(layer_sd) From d7b23dbda692092bb009025e5eb4b412a07180ba Mon Sep 17 00:00:00 2001 From: jafermarq Date: Tue, 16 Jan 2024 16:06:31 +0000 Subject: [PATCH 05/21] `MetricsRecord` init with tests --- src/py/flwr/common/metricsrecord.py | 63 ++++++++++++++++++++++++++++ src/py/flwr/common/recordset_test.py | 18 +++++++- src/py/flwr/common/typing.py | 7 ++++ 3 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 src/py/flwr/common/metricsrecord.py diff --git a/src/py/flwr/common/metricsrecord.py b/src/py/flwr/common/metricsrecord.py new file mode 100644 index 000000000000..76ab41d1ad30 --- /dev/null +++ b/src/py/flwr/common/metricsrecord.py @@ -0,0 +1,63 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MetricsRecord.""" + +from dataclasses import dataclass, field +from typing import Dict, Union, get_args + +from .typing import Scalar, ScalarList + + +@dataclass +class MetricsRecord: + """Parameters record.""" + + data: Dict[str, Union[Scalar, ScalarList]] = field(default_factory=dict) + + def add_metrics(self, metrics_dict: Dict[str, Union[Scalar, ScalarList]]) -> None: + """Add metrics to record. + + This not implemented as a constructor so we can cleanly create and empyt + MetricsRecord object. + """ + if any(not (k, str) for k in metrics_dict.keys()): + raise TypeError(f"Not all keys are of valide type. Expected {str}") + + def is_valid(value: Scalar) -> None: + """Check if value is of expected type.""" + if not isinstance(value, get_args(Scalar)): + raise TypeError( + "Not all values are of valide type." + f" Expected {Union[Scalar, ScalarList]}" + ) + + # Check types of values + # Split between those values that are list and those that aren't + # then process in the same way + for value in metrics_dict.values(): + if isinstance(value, list): + # If your lists are large (e.g. 1M+ elements) this will be slow + # 1s to check 10M element list on a M2 Pro + # In such settings, you'd be better of treating such metric as + # an array and pass it to a ParametersRecord. + for list_value in value: + is_valid(list_value) + else: + is_valid(value) + + # Add entries to dataclass without duplicating memory + for key in list(metrics_dict.keys()): + self.data[key] = metrics_dict[key] + del metrics_dict[key] diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py index ba3bb3f5cefd..b1814351bba2 100644 --- a/src/py/flwr/common/recordset_test.py +++ b/src/py/flwr/common/recordset_test.py @@ -14,16 +14,18 @@ # ============================================================================== """RecordSet tests.""" import secrets +from typing import Dict, Union import numpy as np +from .metricsrecord import MetricsRecord from .parameter import ndarrays_to_parameters, parameters_to_ndarrays from .parametersrecord import Array from .recordset_utils import ( parameters_to_parametersrecord, parametersrecord_to_parameters, ) -from .typing import NDArray, NDArrays, Parameters +from .typing import NDArray, NDArrays, Parameters, Scalar, ScalarList def get_ndarrays() -> NDArrays: @@ -97,6 +99,20 @@ def test_parameters_to_parametersrecord_and_back() -> None: assert np.allclose(arr, arr_) +def test_add_metrics_to_metricsrecord() -> None: + """Test adding metrics of various types to a MetricsRecord.""" + m_record = MetricsRecord() + + my_metrics: Dict[str, Union[Scalar, ScalarList]] = { + "loss": 0.12445, + "converged": True, + "my_int": 2, + "embeddings": np.random.randn(10).tolist(), + } + + m_record.add_metrics(my_metrics) + + # def test_torch_statedict_to_parametersrecord() -> None: # """.""" # import torch diff --git a/src/py/flwr/common/typing.py b/src/py/flwr/common/typing.py index 6c0266f5eec8..ffa7be88e40c 100644 --- a/src/py/flwr/common/typing.py +++ b/src/py/flwr/common/typing.py @@ -32,6 +32,13 @@ # not conform to other definitions of what a scalar is. Source: # https://developers.google.com/protocol-buffers/docs/overview#scalar Scalar = Union[bool, bytes, float, int, str] +ScalarList = Union[ + List[bool], + List[bytes], + List[float], + List[int], + List[str], +] Value = Union[ bool, bytes, From 7ed2b8248614c9240463ac4558018bb3068b5c8a Mon Sep 17 00:00:00 2001 From: jafermarq Date: Tue, 16 Jan 2024 16:20:11 +0000 Subject: [PATCH 06/21] w/ previous --- src/py/flwr/common/recordset.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/py/flwr/common/recordset.py b/src/py/flwr/common/recordset.py index dc723a2cea86..a5af909911fe 100644 --- a/src/py/flwr/common/recordset.py +++ b/src/py/flwr/common/recordset.py @@ -17,14 +17,10 @@ from dataclasses import dataclass, field from typing import Dict +from .metricsrecord import MetricsRecord from .parametersrecord import ParametersRecord -@dataclass -class MetricsRecord: - """Metrics record.""" - - @dataclass class ConfigsRecord: """Configs record.""" From 41e43fe2f9efd09915e3b4cc22d0f10b4e7cbffb Mon Sep 17 00:00:00 2001 From: jafermarq Date: Tue, 16 Jan 2024 16:53:52 +0000 Subject: [PATCH 07/21] v0 `ConfigsRecord` --- src/py/flwr/common/configsrecord.py | 67 ++++++++++++++++++++++++++++ src/py/flwr/common/recordset_test.py | 19 ++++++++ 2 files changed, 86 insertions(+) create mode 100644 src/py/flwr/common/configsrecord.py diff --git a/src/py/flwr/common/configsrecord.py b/src/py/flwr/common/configsrecord.py new file mode 100644 index 000000000000..1c5d5cbec255 --- /dev/null +++ b/src/py/flwr/common/configsrecord.py @@ -0,0 +1,67 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ConfigsRecord.""" + +from dataclasses import dataclass, field +from typing import Dict, Union, get_args + +from .typing import Scalar, ScalarList + +ConfigKeys = Union[str, int] + + +@dataclass +class ConfigsRecord: + """Configs record.""" + + data: Dict[ConfigKeys, Union[Scalar, ScalarList]] = field(default_factory=dict) + + def add_configs( + self, configs_dict: Dict[ConfigKeys, Union[Scalar, ScalarList]] + ) -> None: + """Add configs to record. + + This not implemented as a constructor so we can cleanly create and empyt + ConfigsRecord object. + """ + if any(not isinstance(k, get_args(ConfigKeys)) for k in configs_dict.keys()): + raise TypeError(f"Not all keys are of valide type. Expected {ConfigKeys}") + + def is_valid(value: Scalar) -> None: + """Check if value is of expected type.""" + if not isinstance(value, get_args(Scalar)): + raise TypeError( + "Not all values are of valide type." + f" Expected {Union[Scalar, ScalarList]}" + ) + + # Check types of values + # Split between those values that are list and those that aren't + # then process in the same way + for value in configs_dict.values(): + if isinstance(value, list): + # If your lists are large (e.g. 1M+ elements) this will be slow + # 1s to check 10M element list on a M2 Pro + # In such settings, you'd be better of treating such metric as + # an array and pass it to a ParametersRecord. + for list_value in value: + is_valid(list_value) + else: + is_valid(value) + + # Add entries to dataclass without duplicating memory + for key in list(configs_dict.keys()): + self.data[key] = configs_dict[key] + del configs_dict[key] diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py index b1814351bba2..d0bc73e57a1a 100644 --- a/src/py/flwr/common/recordset_test.py +++ b/src/py/flwr/common/recordset_test.py @@ -18,6 +18,7 @@ import numpy as np +from .configsrecord import ConfigsRecord from .metricsrecord import MetricsRecord from .parameter import ndarrays_to_parameters, parameters_to_ndarrays from .parametersrecord import Array @@ -113,6 +114,24 @@ def test_add_metrics_to_metricsrecord() -> None: m_record.add_metrics(my_metrics) +def test_add_config_to_configsrecord() -> None: + """Test adding configs of various types to a ConfigsRecord.""" + m_record = ConfigsRecord() + + some_stage_id = 12345 + some_bytes = np.random.randn(256).tobytes() + + my_metrics: Dict[Union[int, str], Union[Scalar, ScalarList]] = { + "loss": 0.12445, + "converged": True, + "my_int": 2, + "embeddings": np.random.randn(10).tolist(), + some_stage_id: some_bytes, + } + + m_record.add_configs(my_metrics) + + # def test_torch_statedict_to_parametersrecord() -> None: # """.""" # import torch From 46e1f3a64d805c3ca662a943b0677c11c854e975 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Tue, 16 Jan 2024 16:56:49 +0000 Subject: [PATCH 08/21] fix --- src/py/flwr/common/metricsrecord.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/common/metricsrecord.py b/src/py/flwr/common/metricsrecord.py index 76ab41d1ad30..aa683ff7a2ca 100644 --- a/src/py/flwr/common/metricsrecord.py +++ b/src/py/flwr/common/metricsrecord.py @@ -32,7 +32,7 @@ def add_metrics(self, metrics_dict: Dict[str, Union[Scalar, ScalarList]]) -> Non This not implemented as a constructor so we can cleanly create and empyt MetricsRecord object. """ - if any(not (k, str) for k in metrics_dict.keys()): + if any(not isinstance(k, str) for k in metrics_dict.keys()): raise TypeError(f"Not all keys are of valide type. Expected {str}") def is_valid(value: Scalar) -> None: From 3fe3728b62239fc3815b9586ee8637a2ed7d8ceb Mon Sep 17 00:00:00 2001 From: jafermarq Date: Tue, 16 Jan 2024 17:20:44 +0000 Subject: [PATCH 09/21] w/ previous --- src/py/flwr/common/recordset.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/py/flwr/common/recordset.py b/src/py/flwr/common/recordset.py index a5af909911fe..2f3f08dddf6b 100644 --- a/src/py/flwr/common/recordset.py +++ b/src/py/flwr/common/recordset.py @@ -17,15 +17,11 @@ from dataclasses import dataclass, field from typing import Dict +from .configsrecord import ConfigsRecord from .metricsrecord import MetricsRecord from .parametersrecord import ParametersRecord -@dataclass -class ConfigsRecord: - """Configs record.""" - - @dataclass class RecordSet: """Definition of RecordSet.""" From 32b1155a017fac4c5f2fbd73c12e0d3a209b347a Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 17 Jan 2024 21:10:39 +0000 Subject: [PATCH 10/21] better tests; definitions in typing --- src/py/flwr/common/metricsrecord.py | 45 ++++++++++------ src/py/flwr/common/recordset_test.py | 76 ++++++++++++++++++++++++---- src/py/flwr/common/typing.py | 11 ++-- 3 files changed, 98 insertions(+), 34 deletions(-) diff --git a/src/py/flwr/common/metricsrecord.py b/src/py/flwr/common/metricsrecord.py index aa683ff7a2ca..3b9d1dc35b2e 100644 --- a/src/py/flwr/common/metricsrecord.py +++ b/src/py/flwr/common/metricsrecord.py @@ -15,32 +15,50 @@ """MetricsRecord.""" from dataclasses import dataclass, field -from typing import Dict, Union, get_args +from typing import Dict, Optional, Union, get_args -from .typing import Scalar, ScalarList +from .typing import MetricsScalar, MetricsScalarList + +MetricsRecordValues = Union[MetricsScalar, MetricsScalarList] @dataclass class MetricsRecord: """Parameters record.""" - data: Dict[str, Union[Scalar, ScalarList]] = field(default_factory=dict) + data: Dict[str, MetricsRecordValues] = field(default_factory=dict) + + def __init__(self, metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None): + """Construct a MetricsRecord object. - def add_metrics(self, metrics_dict: Dict[str, Union[Scalar, ScalarList]]) -> None: + Parameters + ---------- + array_dict : Optional[Dict[str, MetricsRecordValues]] + A dictionary that stores basic types (i.e. `str`, `int`, `float` as defined + in `MetricsScalar`) and list of such types (see `MetricsScalarList`). + """ + self.data = {} + if metrics_dict: + self.set_metrics(metrics_dict) + + def set_metrics(self, metrics_dict: Dict[str, MetricsRecordValues]) -> None: """Add metrics to record. - This not implemented as a constructor so we can cleanly create and empyt - MetricsRecord object. + Parameters + ---------- + array_dict : Optional[Dict[str, MetricsRecordValues]] + A dictionary that stores basic types (i.e. `str`, `int`, `float` as defined + in `MetricsScalar`) and list of such types (see `MetricsScalarList`). """ if any(not isinstance(k, str) for k in metrics_dict.keys()): - raise TypeError(f"Not all keys are of valide type. Expected {str}") + raise TypeError(f"Not all keys are of valid type. Expected {str}") - def is_valid(value: Scalar) -> None: + def is_valid(value: MetricsScalar) -> None: """Check if value is of expected type.""" - if not isinstance(value, get_args(Scalar)): + if not isinstance(value, get_args(MetricsScalar)): raise TypeError( - "Not all values are of valide type." - f" Expected {Union[Scalar, ScalarList]}" + "Not all values are of valid type." + f" Expected {MetricsRecordValues}" ) # Check types of values @@ -56,8 +74,3 @@ def is_valid(value: Scalar) -> None: is_valid(list_value) else: is_valid(value) - - # Add entries to dataclass without duplicating memory - for key in list(metrics_dict.keys()): - self.data[key] = metrics_dict[key] - del metrics_dict[key] diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py index 17a501ed50cd..a8ae6f0778a6 100644 --- a/src/py/flwr/common/recordset_test.py +++ b/src/py/flwr/common/recordset_test.py @@ -15,7 +15,7 @@ """RecordSet tests.""" -from typing import Callable, List, OrderedDict, Type, Union +from typing import Callable, Dict, List, OrderedDict, Type, Union import numpy as np import pytest @@ -27,7 +27,7 @@ parameters_to_parametersrecord, parametersrecord_to_parameters, ) -from .typing import NDArray, NDArrays, Parameters, Scalar, ScalarList +from .typing import NDArray, NDArrays, Parameters def get_ndarrays() -> NDArrays: @@ -148,17 +148,71 @@ def test_set_parameters_with_incorrect_types( p_record.set_parameters(array_dict) # type: ignore -def test_add_metrics_to_metricsrecord() -> None: +@pytest.mark.parametrize( + "key_type, value_fn", + [ + (str, lambda x: str(x.flatten()[0])), # str: str + (str, lambda x: int(x.flatten()[0])), # str: int + (str, lambda x: float(x.flatten()[0])), # str: float + (str, lambda x: x.flatten().astype("str").tolist()), # str: List[str] + (str, lambda x: x.flatten().astype("int").tolist()), # str: List[int] + (str, lambda x: x.flatten().astype("float").tolist()), # str: List[float] + ], +) +def test_set_metrics_to_metricsrecord_with_correct_types( + key_type: Type[str], + value_fn: Callable[ + [NDArray], Union[str, int, float, List[str], List[int], List[float]] + ], +) -> None: """Test adding metrics of various types to a MetricsRecord.""" m_record = MetricsRecord() - my_metrics: OrderedDict[str, Union[Scalar, ScalarList]] = OrderedDict( - { - "loss": 0.12445, - "converged": True, - "my_int": 2, - "embeddings": np.random.randn(10).tolist(), - } + labels = [1, 2.0] + arrays = get_ndarrays() + + my_metrics = OrderedDict( + {key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)} + ) + + m_record.set_metrics(my_metrics) + + +@pytest.mark.parametrize( + "key_type, value_fn", + [ + (str, lambda x: x), # str: NDArray (supported: unsupported) + ( + str, + lambda x: {str(v): v for v in x.flatten()}, + ), # str: dict[str: float] (supported: unsupported) + ( + str, + lambda x: [{str(v): v for v in x.flatten()}], + ), # str: List[dict[str: float]] (supported: unsupported) + ( + int, + lambda x: x.flatten().tolist(), + ), # int: List[str] (unsupported: supported) + ( + float, + lambda x: x.flatten().tolist(), + ), # float: List[int] (unsupported: supported) + ], +) +def test_set_metrics_to_metricsrecord_with_incorrect_types( + key_type: Type[Union[str, int, float]], + value_fn: Callable[[NDArray], Union[NDArray, Dict[str, NDArray], List[float]]], +) -> None: + """Test adding metrics of various unsupported types to a MetricsRecord.""" + m_record = MetricsRecord() + + labels = [1, 2.0] + arrays = get_ndarrays() + + my_metrics = OrderedDict( + {key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)} ) - m_record.add_metrics(my_metrics) + with pytest.raises(TypeError): + m_record.set_metrics(my_metrics) # type: ignore diff --git a/src/py/flwr/common/typing.py b/src/py/flwr/common/typing.py index ffa7be88e40c..6ec7979835fe 100644 --- a/src/py/flwr/common/typing.py +++ b/src/py/flwr/common/typing.py @@ -32,13 +32,6 @@ # not conform to other definitions of what a scalar is. Source: # https://developers.google.com/protocol-buffers/docs/overview#scalar Scalar = Union[bool, bytes, float, int, str] -ScalarList = Union[ - List[bool], - List[bytes], - List[float], - List[int], - List[str], -] Value = Union[ bool, bytes, @@ -52,6 +45,10 @@ List[str], ] +# Value types for common.MetricsRecord +MetricsScalar = Union[str, int, float] +MetricsScalarList = Union[List[str], List[int], List[float]] + Metrics = Dict[str, Scalar] MetricsAggregationFn = Callable[[List[Tuple[int, Metrics]]], Metrics] From d3316f81b6aa875693058ad2bcb39088c22dce43 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 17 Jan 2024 21:12:37 +0000 Subject: [PATCH 11/21] double space top of file --- src/py/flwr/common/metricsrecord.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/py/flwr/common/metricsrecord.py b/src/py/flwr/common/metricsrecord.py index 3b9d1dc35b2e..f8f99a62101a 100644 --- a/src/py/flwr/common/metricsrecord.py +++ b/src/py/flwr/common/metricsrecord.py @@ -14,6 +14,7 @@ # ============================================================================== """MetricsRecord.""" + from dataclasses import dataclass, field from typing import Dict, Optional, Union, get_args From a79aab8542ff748ca57b58ee3f65fd203fa06df1 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 17 Jan 2024 21:13:26 +0000 Subject: [PATCH 12/21] w/ previous --- src/py/flwr/common/metricsrecord.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/common/metricsrecord.py b/src/py/flwr/common/metricsrecord.py index f8f99a62101a..29bdaaa2adac 100644 --- a/src/py/flwr/common/metricsrecord.py +++ b/src/py/flwr/common/metricsrecord.py @@ -25,7 +25,7 @@ @dataclass class MetricsRecord: - """Parameters record.""" + """Metrics record.""" data: Dict[str, MetricsRecordValues] = field(default_factory=dict) @@ -43,7 +43,7 @@ def __init__(self, metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None self.set_metrics(metrics_dict) def set_metrics(self, metrics_dict: Dict[str, MetricsRecordValues]) -> None: - """Add metrics to record. + """Add metrics to the record. Parameters ---------- From c08bdefc3d237d71e8d23826296d5214812c754c Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 17 Jan 2024 21:22:00 +0000 Subject: [PATCH 13/21] w/ previous --- src/py/flwr/common/configsrecord.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/py/flwr/common/configsrecord.py b/src/py/flwr/common/configsrecord.py index abd0cd5dd410..03e6ee5f7e20 100644 --- a/src/py/flwr/common/configsrecord.py +++ b/src/py/flwr/common/configsrecord.py @@ -26,9 +26,5 @@ class ConfigsRecord(MetricsRecord): """Configs record.""" def set_configs(self, configs_dict: Dict[str, MetricsRecordValues]) -> None: - """Add configs to record. - - This not implemented as a constructor so we can cleanly create and empyt - ConfigsRecord object. - """ + """Add configs to record.""" super().set_metrics(configs_dict) From 22d48f4fefeaf5421d48b6048f87a70ebae5fb0f Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 18 Jan 2024 08:28:46 +0000 Subject: [PATCH 14/21] no `str` in `MetricsRecords` values --- src/py/flwr/common/metricsrecord.py | 4 ++-- src/py/flwr/common/recordset_test.py | 11 ++++++----- src/py/flwr/common/typing.py | 4 ++-- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/py/flwr/common/metricsrecord.py b/src/py/flwr/common/metricsrecord.py index 29bdaaa2adac..f1cc107c949b 100644 --- a/src/py/flwr/common/metricsrecord.py +++ b/src/py/flwr/common/metricsrecord.py @@ -35,7 +35,7 @@ def __init__(self, metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None Parameters ---------- array_dict : Optional[Dict[str, MetricsRecordValues]] - A dictionary that stores basic types (i.e. `str`, `int`, `float` as defined + A dictionary that stores basic types (i.e. `int`, `float` as defined in `MetricsScalar`) and list of such types (see `MetricsScalarList`). """ self.data = {} @@ -48,7 +48,7 @@ def set_metrics(self, metrics_dict: Dict[str, MetricsRecordValues]) -> None: Parameters ---------- array_dict : Optional[Dict[str, MetricsRecordValues]] - A dictionary that stores basic types (i.e. `str`, `int`, `float` as defined + A dictionary that stores basic types (i.e. `int`, `float` as defined in `MetricsScalar`) and list of such types (see `MetricsScalarList`). """ if any(not isinstance(k, str) for k in metrics_dict.keys()): diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py index a8ae6f0778a6..350fd1f5781c 100644 --- a/src/py/flwr/common/recordset_test.py +++ b/src/py/flwr/common/recordset_test.py @@ -151,19 +151,15 @@ def test_set_parameters_with_incorrect_types( @pytest.mark.parametrize( "key_type, value_fn", [ - (str, lambda x: str(x.flatten()[0])), # str: str (str, lambda x: int(x.flatten()[0])), # str: int (str, lambda x: float(x.flatten()[0])), # str: float - (str, lambda x: x.flatten().astype("str").tolist()), # str: List[str] (str, lambda x: x.flatten().astype("int").tolist()), # str: List[int] (str, lambda x: x.flatten().astype("float").tolist()), # str: List[float] ], ) def test_set_metrics_to_metricsrecord_with_correct_types( key_type: Type[str], - value_fn: Callable[ - [NDArray], Union[str, int, float, List[str], List[int], List[float]] - ], + value_fn: Callable[[NDArray], Union[int, float, List[int], List[float]]], ) -> None: """Test adding metrics of various types to a MetricsRecord.""" m_record = MetricsRecord() @@ -181,6 +177,11 @@ def test_set_metrics_to_metricsrecord_with_correct_types( @pytest.mark.parametrize( "key_type, value_fn", [ + (str, lambda x: str(x.flatten()[0])), # str: str (supported: unsupported) + ( + str, + lambda x: x.flatten().astype("str").tolist(), + ), # str: List[str] (supported: unsupported) (str, lambda x: x), # str: NDArray (supported: unsupported) ( str, diff --git a/src/py/flwr/common/typing.py b/src/py/flwr/common/typing.py index 6ec7979835fe..31a9cca6379a 100644 --- a/src/py/flwr/common/typing.py +++ b/src/py/flwr/common/typing.py @@ -46,8 +46,8 @@ ] # Value types for common.MetricsRecord -MetricsScalar = Union[str, int, float] -MetricsScalarList = Union[List[str], List[int], List[float]] +MetricsScalar = Union[int, float] +MetricsScalarList = Union[List[int], List[float]] Metrics = Dict[str, Scalar] MetricsAggregationFn = Callable[[List[Tuple[int, Metrics]]], Metrics] From 3823ed248392aedb19de24cb470ba33bdb5765a0 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 18 Jan 2024 08:45:04 +0000 Subject: [PATCH 15/21] update --- src/py/flwr/common/configsrecord.py | 61 ++++++++++++++++++++++--- src/py/flwr/common/recordset_test.py | 68 ++++++++++++++++++++++++++-- src/py/flwr/common/typing.py | 2 + 3 files changed, 120 insertions(+), 11 deletions(-) diff --git a/src/py/flwr/common/configsrecord.py b/src/py/flwr/common/configsrecord.py index 03e6ee5f7e20..27757f3b66f2 100644 --- a/src/py/flwr/common/configsrecord.py +++ b/src/py/flwr/common/configsrecord.py @@ -15,16 +15,63 @@ """ConfigsRecord.""" -from dataclasses import dataclass -from typing import Dict +from dataclasses import dataclass, field +from typing import Dict, Optional, Union, get_args -from .metricsrecord import MetricsRecord, MetricsRecordValues +from .typing import ConfigsScalar, ConfigsScalarList + +ConfigsRecordValues = Union[ConfigsScalar, ConfigsScalarList] @dataclass -class ConfigsRecord(MetricsRecord): +class ConfigsRecord: """Configs record.""" - def set_configs(self, configs_dict: Dict[str, MetricsRecordValues]) -> None: - """Add configs to record.""" - super().set_metrics(configs_dict) + data: Dict[str, ConfigsRecordValues] = field(default_factory=dict) + + def __init__(self, configs_dict: Optional[Dict[str, ConfigsRecordValues]] = None): + """Construct a ConfigsRecord object. + + Parameters + ---------- + configs_dict : Optional[Dict[str, ConfigsRecordValues]] + A dictionary that stores basic types (i.e. `str`, `int`, `float` as defined + in `MetricsScalar`) and list of such types (see `ConfigsScalarList`). + """ + self.data = {} + if configs_dict: + self.set_configs(configs_dict) + + def set_configs(self, configs_dict: Dict[str, ConfigsRecordValues]) -> None: + """Add configs to the record. + + Parameters + ---------- + configs_dict : Optional[Dict[str, ConfigsRecordValues]] + A dictionary that stores basic types (i.e. `str`,`int`, `float` as defined + in `ConfigsRecordValues`) and list of such types (see `ConfigsScalarList`). + """ + if any(not isinstance(k, str) for k in configs_dict.keys()): + raise TypeError(f"Not all keys are of valid type. Expected {str}") + + def is_valid(value: ConfigsScalar) -> None: + """Check if value is of expected type.""" + if not isinstance(value, get_args(ConfigsScalar)): + raise TypeError( + "Not all values are of valid type." + f" Expected {ConfigsRecordValues}" + ) + + # Check types of values + # Split between those values that are list and those that aren't + # then process in the same way + for value in configs_dict.values(): + if isinstance(value, list): + # If your lists are large (e.g. 1M+ elements) this will be slow + # 1s to check 10M element list on a M2 Pro + # In such settings, you'd be better of treating such metric as + # an array and pass it to a ParametersRecord. + for list_value in value: + is_valid(list_value) + else: + is_valid(value) diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py index cd036f85fbbf..ea77c3208f4a 100644 --- a/src/py/flwr/common/recordset_test.py +++ b/src/py/flwr/common/recordset_test.py @@ -20,6 +20,7 @@ import numpy as np import pytest +from .configsrecord import ConfigsRecord from .metricsrecord import MetricsRecord from .parameter import ndarrays_to_parameters, parameters_to_ndarrays from .parametersrecord import Array, ParametersRecord @@ -27,7 +28,7 @@ parameters_to_parametersrecord, parametersrecord_to_parameters, ) -from .typing import NDArray, NDArrays, Parameters +from .typing import ConfigsScalar, ConfigsScalarList, NDArray, NDArrays, Parameters def get_ndarrays() -> NDArrays: @@ -219,7 +220,66 @@ def test_set_metrics_to_metricsrecord_with_incorrect_types( m_record.set_metrics(my_metrics) # type: ignore -# def test_add_config_to_configsrecord() -> None: -# """Test adding configs of various types to a ConfigsRecord.""" +@pytest.mark.parametrize( + "key_type, value_fn", + [ + (str, lambda x: str(x.flatten()[0])), # str: str + (str, lambda x: int(x.flatten()[0])), # str: int + (str, lambda x: float(x.flatten()[0])), # str: float + (str, lambda x: x.flatten().astype("int").tolist()), # str: List[int] + (str, lambda x: x.flatten().astype("float").tolist()), # str: List[float] + ], +) +def test_set_configs_to_configsrecord_with_correct_types( + key_type: Type[str], + value_fn: Callable[[NDArray], Union[ConfigsScalar, ConfigsScalarList]], +) -> None: + """Test adding configs of various types to a ConfigsRecord.""" + labels = [1, 2.0] + arrays = get_ndarrays() + + my_configs = OrderedDict( + {key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)} + ) + + _ = ConfigsRecord(my_configs) + + +@pytest.mark.parametrize( + "key_type, value_fn", + [ + (str, lambda x: x), # str: NDArray (supported: unsupported) + ( + str, + lambda x: {str(v): v for v in x.flatten()}, + ), # str: dict[str: float] (supported: unsupported) + ( + str, + lambda x: [{str(v): v for v in x.flatten()}], + ), # str: List[dict[str: float]] (supported: unsupported) + ( + int, + lambda x: x.flatten().tolist(), + ), # int: List[str] (unsupported: supported) + ( + float, + lambda x: x.flatten().tolist(), + ), # float: List[int] (unsupported: supported) + ], +) +def test_set_configs_to_configsrecord_with_incorrect_types( + key_type: Type[Union[str, int, float]], + value_fn: Callable[[NDArray], Union[NDArray, Dict[str, NDArray], List[float]]], +) -> None: + """Test adding configs of various unsupported types to a ConfigsRecord.""" + m_record = ConfigsRecord() -# # TODO + labels = [1, 2.0] + arrays = get_ndarrays() + + my_metrics = OrderedDict( + {key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)} + ) + + with pytest.raises(TypeError): + m_record.set_configs(my_metrics) # type: ignore diff --git a/src/py/flwr/common/typing.py b/src/py/flwr/common/typing.py index 31a9cca6379a..13eef29ab26d 100644 --- a/src/py/flwr/common/typing.py +++ b/src/py/flwr/common/typing.py @@ -48,6 +48,8 @@ # Value types for common.MetricsRecord MetricsScalar = Union[int, float] MetricsScalarList = Union[List[int], List[float]] +ConfigsScalar = Union[MetricsScalar, str] +ConfigsScalarList = Union[MetricsScalarList, List[float]] Metrics = Dict[str, Scalar] MetricsAggregationFn = Callable[[List[Tuple[int, Metrics]]], Metrics] From 376ffb8a56623eb6ec1e8bf0fed7e6cc5b947a33 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 18 Jan 2024 08:46:17 +0000 Subject: [PATCH 16/21] fix docstrings --- src/py/flwr/common/metricsrecord.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/common/metricsrecord.py b/src/py/flwr/common/metricsrecord.py index f1cc107c949b..b73e32030854 100644 --- a/src/py/flwr/common/metricsrecord.py +++ b/src/py/flwr/common/metricsrecord.py @@ -34,7 +34,7 @@ def __init__(self, metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None Parameters ---------- - array_dict : Optional[Dict[str, MetricsRecordValues]] + metrics_dict : Optional[Dict[str, MetricsRecordValues]] A dictionary that stores basic types (i.e. `int`, `float` as defined in `MetricsScalar`) and list of such types (see `MetricsScalarList`). """ @@ -47,7 +47,7 @@ def set_metrics(self, metrics_dict: Dict[str, MetricsRecordValues]) -> None: Parameters ---------- - array_dict : Optional[Dict[str, MetricsRecordValues]] + metrics_dict : Optional[Dict[str, MetricsRecordValues]] A dictionary that stores basic types (i.e. `int`, `float` as defined in `MetricsScalar`) and list of such types (see `MetricsScalarList`). """ From b18b05d1c15510ffe501083d63447628c6eed8e2 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 18 Jan 2024 08:58:19 +0000 Subject: [PATCH 17/21] more info in TypeError messsage --- src/py/flwr/common/metricsrecord.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/common/metricsrecord.py b/src/py/flwr/common/metricsrecord.py index b73e32030854..5dbdf00286b9 100644 --- a/src/py/flwr/common/metricsrecord.py +++ b/src/py/flwr/common/metricsrecord.py @@ -52,14 +52,14 @@ def set_metrics(self, metrics_dict: Dict[str, MetricsRecordValues]) -> None: in `MetricsScalar`) and list of such types (see `MetricsScalarList`). """ if any(not isinstance(k, str) for k in metrics_dict.keys()): - raise TypeError(f"Not all keys are of valid type. Expected {str}") + raise TypeError(f"Not all keys are of valid type. Expected {str}.") def is_valid(value: MetricsScalar) -> None: """Check if value is of expected type.""" if not isinstance(value, get_args(MetricsScalar)): raise TypeError( "Not all values are of valid type." - f" Expected {MetricsRecordValues}" + f" Expected {MetricsRecordValues} but you passed {type(value)}." ) # Check types of values From 6389def943b95daa30bf910a68245fbe804ef2ba Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 18 Jan 2024 09:01:18 +0000 Subject: [PATCH 18/21] more info in TypeError messsage --- src/py/flwr/common/configsrecord.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/common/configsrecord.py b/src/py/flwr/common/configsrecord.py index 27757f3b66f2..a3a20243583a 100644 --- a/src/py/flwr/common/configsrecord.py +++ b/src/py/flwr/common/configsrecord.py @@ -59,7 +59,7 @@ def is_valid(value: ConfigsScalar) -> None: if not isinstance(value, get_args(ConfigsScalar)): raise TypeError( "Not all values are of valid type." - f" Expected {ConfigsRecordValues}" + f" Expected {ConfigsRecordValues} but you passed {type(value)}." ) # Check types of values From 12ed920481b36b7a233deed502c044f4676c5ef5 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 18 Jan 2024 10:30:10 +0000 Subject: [PATCH 19/21] updates; more tests --- src/py/flwr/common/metricsrecord.py | 31 ++++++++++++++++---- src/py/flwr/common/recordset_test.py | 42 ++++++++++++++++++++++++++-- src/py/flwr/common/typing.py | 1 + 3 files changed, 66 insertions(+), 8 deletions(-) diff --git a/src/py/flwr/common/metricsrecord.py b/src/py/flwr/common/metricsrecord.py index 5dbdf00286b9..68eca732efa2 100644 --- a/src/py/flwr/common/metricsrecord.py +++ b/src/py/flwr/common/metricsrecord.py @@ -16,20 +16,23 @@ from dataclasses import dataclass, field -from typing import Dict, Optional, Union, get_args +from typing import Dict, Optional, get_args -from .typing import MetricsScalar, MetricsScalarList - -MetricsRecordValues = Union[MetricsScalar, MetricsScalarList] +from .typing import MetricsRecordValues, MetricsScalar @dataclass class MetricsRecord: """Metrics record.""" + keep_input: bool data: Dict[str, MetricsRecordValues] = field(default_factory=dict) - def __init__(self, metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None): + def __init__( + self, + metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None, + keep_input: bool = True, + ): """Construct a MetricsRecord object. Parameters @@ -37,7 +40,13 @@ def __init__(self, metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None metrics_dict : Optional[Dict[str, MetricsRecordValues]] A dictionary that stores basic types (i.e. `int`, `float` as defined in `MetricsScalar`) and list of such types (see `MetricsScalarList`). + keep_input : bool (default: True) + A boolean indicating whether metrics should be deleted from the input + dictionary immediately after adding them to the record. When set + to True, the data is duplicated in memory. If memory is a concern, set + it to False. """ + self.keep_input = keep_input self.data = {} if metrics_dict: self.set_metrics(metrics_dict) @@ -47,7 +56,7 @@ def set_metrics(self, metrics_dict: Dict[str, MetricsRecordValues]) -> None: Parameters ---------- - metrics_dict : Optional[Dict[str, MetricsRecordValues]] + metrics_dict : Dict[str, MetricsRecordValues] A dictionary that stores basic types (i.e. `int`, `float` as defined in `MetricsScalar`) and list of such types (see `MetricsScalarList`). """ @@ -75,3 +84,13 @@ def is_valid(value: MetricsScalar) -> None: is_valid(list_value) else: is_valid(value) + + # Add metrics to record + if self.keep_input: + # Copy + self.data = metrics_dict.copy() + else: + # Add entries to dataclass without duplicating memory + for key in list(metrics_dict.keys()): + self.data[key] = metrics_dict[key] + del metrics_dict[key] diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py index 350fd1f5781c..13e357552a20 100644 --- a/src/py/flwr/common/recordset_test.py +++ b/src/py/flwr/common/recordset_test.py @@ -27,7 +27,7 @@ parameters_to_parametersrecord, parametersrecord_to_parameters, ) -from .typing import NDArray, NDArrays, Parameters +from .typing import MetricsRecordValues, NDArray, NDArrays, Parameters def get_ndarrays() -> NDArrays: @@ -159,7 +159,7 @@ def test_set_parameters_with_incorrect_types( ) def test_set_metrics_to_metricsrecord_with_correct_types( key_type: Type[str], - value_fn: Callable[[NDArray], Union[int, float, List[int], List[float]]], + value_fn: Callable[[NDArray], MetricsRecordValues], ) -> None: """Test adding metrics of various types to a MetricsRecord.""" m_record = MetricsRecord() @@ -171,8 +171,12 @@ def test_set_metrics_to_metricsrecord_with_correct_types( {key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)} ) + # Add metric m_record.set_metrics(my_metrics) + # Check metrics are actually added + assert list(my_metrics.keys()) == list(m_record.data.keys()) + @pytest.mark.parametrize( "key_type, value_fn", @@ -217,3 +221,37 @@ def test_set_metrics_to_metricsrecord_with_incorrect_types( with pytest.raises(TypeError): m_record.set_metrics(my_metrics) # type: ignore + + +@pytest.mark.parametrize( + "keep_input", + [ + (True), + (False), + ], +) +def test_set_metrics_to_metricsrecord_with_and_without_keeping_input( + keep_input: bool, +) -> None: + """Test keep_input functionality for MetricsRecord.""" + m_record = MetricsRecord(keep_input=keep_input) + + # constructing a valid input + labels = [1, 2.0] + arrays = get_ndarrays() + my_metrics = OrderedDict( + {str(label): arr.flatten().tolist() for label, arr in zip(labels, arrays)} + ) + + my_metrics_copy = my_metrics.copy() + + # Add metric + m_record.set_metrics(my_metrics) + + # Check metrics are actually added + # Check that input dict has been emptied when enabled such behaviour + if keep_input: + assert my_metrics == m_record.data + else: + assert my_metrics_copy == m_record.data + assert len(my_metrics) == 0 diff --git a/src/py/flwr/common/typing.py b/src/py/flwr/common/typing.py index 31a9cca6379a..a8196126ecfc 100644 --- a/src/py/flwr/common/typing.py +++ b/src/py/flwr/common/typing.py @@ -48,6 +48,7 @@ # Value types for common.MetricsRecord MetricsScalar = Union[int, float] MetricsScalarList = Union[List[int], List[float]] +MetricsRecordValues = Union[MetricsScalar, MetricsScalarList] Metrics = Dict[str, Scalar] MetricsAggregationFn = Callable[[List[Tuple[int, Metrics]]], Metrics] From e7e8cf6c0393f9aa95879a305604493ec8bbaece Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Thu, 18 Jan 2024 14:14:41 +0100 Subject: [PATCH 20/21] Update src/py/flwr/common/configsrecord.py --- src/py/flwr/common/configsrecord.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/common/configsrecord.py b/src/py/flwr/common/configsrecord.py index 03ba0add1592..494cb88586ac 100644 --- a/src/py/flwr/common/configsrecord.py +++ b/src/py/flwr/common/configsrecord.py @@ -39,7 +39,7 @@ def __init__( ---------- configs_dict : Optional[Dict[str, ConfigsRecordValues]] A dictionary that stores basic types (i.e. `str`, `int`, `float`, `bytes` as - defined in `ConfigsScalar`) and list of such types (see + defined in `ConfigsScalar`) and lists of such types (see `ConfigsScalarList`). keep_input : bool (default: True) A boolean indicating whether config passed should be deleted from the input From f058c5003a7381306822d66a41b4a4e2e8de2fa8 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 18 Jan 2024 13:28:26 +0000 Subject: [PATCH 21/21] extra type test --- src/py/flwr/common/recordset_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py index f64244902ae9..3f0917d75cf5 100644 --- a/src/py/flwr/common/recordset_test.py +++ b/src/py/flwr/common/recordset_test.py @@ -271,6 +271,7 @@ def test_set_metrics_to_metricsrecord_with_and_without_keeping_input( (str, lambda x: int(x.flatten()[0])), # str: int (str, lambda x: float(x.flatten()[0])), # str: float (str, lambda x: x.flatten().tobytes()), # str: bytes + (str, lambda x: x.flatten().astype("str").tolist()), # str: List[str] (str, lambda x: x.flatten().astype("int").tolist()), # str: List[int] (str, lambda x: x.flatten().astype("float").tolist()), # str: List[float] (str, lambda x: [x.flatten().tobytes()]), # str: List[bytes]