diff --git a/nvflare/app_common/abstract/fl_model.py b/nvflare/app_common/abstract/fl_model.py index 5b115df150..934326ea56 100644 --- a/nvflare/app_common/abstract/fl_model.py +++ b/nvflare/app_common/abstract/fl_model.py @@ -90,10 +90,36 @@ def __init__( else: meta = {} self.meta = meta + self._summary: dict = {} - def __str__(self): - return ( - f"FLModel(params:{self.params}, params_type: {self.params_type}," - f" optimizer_params: {self.optimizer_params}, metrics: {self.metrics}," - f" current_round: {self.current_round}, meta: {self.meta})" + def _add_to_summary(self, kvs: Dict): + for key, value in kvs.items(): + if value: + if isinstance(value, dict): + self._summary[key] = len(value) + elif isinstance(value, ParamsType): + self._summary[key] = value + elif isinstance(value, int): + self._summary[key] = value + else: + self._summary[key] = type(value) + + def summary(self): + kvs = dict( + params=self.params, + optimizer_params=self.optimizer_params, + metrics=self.metrics, + meta=self.meta, + params_type=self.params_type, + start_round=self.start_round, + current_round=self.current_round, + total_rounds=self.total_rounds, ) + self._add_to_summary(kvs) + return self._summary + + def __repr__(self): + return str(self.summary()) + + def __str__(self): + return str(self.summary()) diff --git a/tests/unit_test/app_common/abstract/__init__.py b/tests/unit_test/app_common/abstract/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/tests/unit_test/app_common/abstract/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit_test/app_common/abstract/fl_model_test.py b/tests/unit_test/app_common/abstract/fl_model_test.py new file mode 100644 index 0000000000..cec2ffa312 --- /dev/null +++ b/tests/unit_test/app_common/abstract/fl_model_test.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.app_common.abstract.fl_model import FLModel, ParamsType + + +class TestFLModel: + model = FLModel( + params_type=ParamsType.FULL, + params={"a": 100, "b": 200, "c": {"c1": 100}}, + optimizer_params={}, + metrics={"loss": 100, "accuracy": 0.9}, + start_round=1, + current_round=100, + total_rounds=12000, + ) + summary = model.summary() + assert summary["params"] == 3 + assert summary["metrics"] == 2 + assert summary["params_type"] == ParamsType.FULL + assert summary["start_round"] == 1 + assert summary["current_round"] == 100 + assert summary["total_rounds"] == 12000