From 529d5549feec945b345098a45c714dc3f6908c28 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 10 Feb 2022 13:25:00 +0100 Subject: [PATCH 1/8] implementation --- torchmetrics/wrappers/classwise.py | 73 ++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 torchmetrics/wrappers/classwise.py diff --git a/torchmetrics/wrappers/classwise.py b/torchmetrics/wrappers/classwise.py new file mode 100644 index 00000000000..15d7020250a --- /dev/null +++ b/torchmetrics/wrappers/classwise.py @@ -0,0 +1,73 @@ +from typing import Dict, List, Optional, Union + +from torch import Tensor + +from torchmetrics import Metric + + +class ClasswiseWrapper(Metric): + """Wrapper class for altering the output of classification metrics that returns multiple values. + + Args: + metric: base metric that should be wrapped + + class_labels: list of strings indicating the different classes. + + Example: + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics import Accuracy, ClasswiseWrapper + >>> metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None)) + >>> preds = torch.randn(10, 3).softmax(dim=-1) + >>> target = torch.randint(3, (10,)) + >>> metric(preds, target) + {'accuracy_0': tensor(0.5000), 'accuracy_1': tensor(0.7500), 'accuracy_2': tensor(0.)} + + Example (labels as list of strings): + >>> import torch + >>> from torchmetrics import Accuracy, ClasswiseWrapper + >>> metric = ClasswiseWrapper( + ... Accuracy(num_classes=3, average=None), + ... labels=["horse", "fish", "dog"] + ... ) + >>> preds = torch.randn(10, 3).softmax(dim=-1) + >>> target = torch.randint(3, (10,)) + >>> metric(preds, target) + {'accuracy_horse': tensor(0.3333), 'accuracy_fish': tensor(0.6667), 'accuracy_dog': tensor(0.)} + + Example (in metric collection): + >>> import torch + >>> from torchmetrics import Accuracy, ClasswiseWrapper, MetricCollection, Recall + >>> labels = ["horse", "fish", "dog"] + >>> metric = MetricCollection( + ... {'accuracy': ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels), + ... 'recall': ClasswiseWrapper(Recall(num_classes=3, average=None), labels)} + ... ) + >>> preds = torch.randn(10, 3).softmax(dim=-1) + >>> target = torch.randint(3, (10,)) + >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE + {'accuracy_horse': tensor(0.), 'accuracy_fish': tensor(0.3333), 'accuracy_dog': tensor(0.4000), + 'recall_horse': tensor(0.), 'recall_fish': tensor(0.3333), 'recall_dog': tensor(0.4000)} + """ + + def __init__(self, metric: Metric, labels: Optional[List[str]] = None) -> None: + super().__init__() + if not isinstance(metric, Metric): + raise ValueError(f"Expected argument `metric` to be an instance of `torchmetrics.Metric` but got {metric}") + if labels is not None and not (isinstance(labels, list) and all(isinstance(lab, str) for lab in labels)): + raise ValueError(f"Expected argument `labels` to either be `None` or a list of strings but got {labels}") + self.metric = metric + self.labels = labels + + def _convert(self, x: Tensor) -> Dict[Union[str, int], float]: + name = self.metric.__class__.__name__.lower() + if self.labels is None: + return {f"{name}_{i}": val for i, val in enumerate(x)} + else: + return {f"{name}_{lab}": val for lab, val in zip(self.labels, x)} + + def update(self, *args, **kwargs) -> None: + self.metric.update(*args, **kwargs) + + def compute(self) -> Dict[str, Tensor]: + return self._convert(self.metric.compute()) From 33262bac4b248fd95d493f37f83da232da5d77aa Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 10 Feb 2022 13:28:34 +0100 Subject: [PATCH 2/8] init --- torchmetrics/__init__.py | 9 ++++++++- torchmetrics/wrappers/__init__.py | 1 + 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 1f1568d29c9..dce6adf3302 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -89,7 +89,13 @@ WordInfoLost, WordInfoPreserved, ) -from torchmetrics.wrappers import BootStrapper, MetricTracker, MinMaxMetric, MultioutputWrapper # noqa: E402 +from torchmetrics.wrappers import ( # noqa: E402 + BootStrapper, + ClasswiseWrapper, + MetricTracker, + MinMaxMetric, + MultioutputWrapper, +) __all__ = [ "functional", @@ -104,6 +110,7 @@ "BootStrapper", "CalibrationError", "CatMetric", + "ClasswiseWrapper", "CHRFScore", "CohenKappa", "ConfusionMatrix", diff --git a/torchmetrics/wrappers/__init__.py b/torchmetrics/wrappers/__init__.py index 1f9dab6da45..98dbb137b19 100644 --- a/torchmetrics/wrappers/__init__.py +++ b/torchmetrics/wrappers/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401 +from torchmetrics.wrappers.classwise import ClasswiseWrapper # noqa: F401 from torchmetrics.wrappers.minmax import MinMaxMetric # noqa: F401 from torchmetrics.wrappers.multioutput import MultioutputWrapper # noqa: F401 from torchmetrics.wrappers.tracker import MetricTracker # noqa: F401 From 280cefceff01f82d5ae0d0766a35b0593093ea7a Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 10 Feb 2022 13:30:35 +0100 Subject: [PATCH 3/8] collection --- torchmetrics/collections.py | 7 ++++--- torchmetrics/utilities/data.py | 15 ++++++++++++++- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index b1009902b6b..6004a9281cd 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -19,6 +19,7 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn +from torchmetrics.utilities.data import _flatten_dict # this is just a bypass for this module name collision with build-in one from torchmetrics.utilities.imports import OrderedDict @@ -130,7 +131,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs) will be filtered based on the signature of the individual metric. """ - return {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()} + return _flatten_dict({k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}) def update(self, *args: Any, **kwargs: Any) -> None: """Iteratively call update for each metric. @@ -218,8 +219,8 @@ def compute(self) -> Dict[str, Any]: mi = getattr(self, cg[i]) for state in m0._defaults: setattr(mi, state, getattr(m0, state)) - - return {k: m.compute() for k, m in self.items()} + res = {k: m.compute() for k, m in self.items()} + return _flatten_dict(res) def reset(self) -> None: """Iteratively call reset for each metric.""" diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 8603888b262..8db95511c90 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Mapping, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union import torch from torch import Tensor, tensor @@ -51,9 +51,22 @@ def dim_zero_min(x: Tensor) -> Tensor: def _flatten(x: Sequence) -> list: + """flatten list of list into single list.""" return [item for sublist in x for item in sublist] +def _flatten_dict(x: Dict) -> Dict: + """flatten dict of dicts into single dict.""" + new_dict = {} + for key, value in x.items(): + if isinstance(value, dict): + for k, v in value.items(): + new_dict[k] = v + else: + new_dict[key] = value + return new_dict + + def to_onehot( label_tensor: Tensor, num_classes: Optional[int] = None, From c314b44dd9286add4153e04c2a3126af9f762eb5 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 10 Feb 2022 13:33:22 +0100 Subject: [PATCH 4/8] tests --- tests/wrappers/test_classwise.py | 57 ++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 tests/wrappers/test_classwise.py diff --git a/tests/wrappers/test_classwise.py b/tests/wrappers/test_classwise.py new file mode 100644 index 00000000000..a30dead233b --- /dev/null +++ b/tests/wrappers/test_classwise.py @@ -0,0 +1,57 @@ +import pytest +import torch + +from torchmetrics import Accuracy, ClasswiseWrapper, MetricCollection, Recall + + +def test_raises_error_on_wrong_input(): + """Test that errors are raised on wrong input.""" + with pytest.raises(ValueError, match="Expected argument `metric` to be an instance of `torchmetrics.Metric` but.*"): + ClasswiseWrapper([]) + + with pytest.raises(ValueError, match="Expected argument `labels` to either be `None` or a list of strings.*"): + ClasswiseWrapper(Accuracy(), "hest") + + +def test_output_no_labels(): + """Test that wrapper works with no label input.""" + metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None)) + preds = torch.randn(10, 3).softmax(dim=-1) + target = torch.randint(3, (10,)) + val = metric(preds, target) + assert isinstance(val, dict) + assert len(val) == 3 + for i in range(3): + assert f"accuracy_{i}" in val + + +def test_output_with_labels(): + """Test that wrapper works with label input.""" + labels = ["horse", "fish", "cat"] + metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels=labels) + preds = torch.randn(10, 3).softmax(dim=-1) + target = torch.randint(3, (10,)) + val = metric(preds, target) + assert isinstance(val, dict) + assert len(val) == 3 + for lab in labels: + assert f"accuracy_{lab}" in val + + +def test_using_metriccollection(): + """Test wrapper in combination with metric collection.""" + labels = ["horse", "fish", "cat"] + metric = MetricCollection( + { + "accuracy": ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels=labels), + "recall": ClasswiseWrapper(Recall(num_classes=3, average=None), labels=labels), + } + ) + preds = torch.randn(10, 3).softmax(dim=-1) + target = torch.randint(3, (10,)) + val = metric(preds, target) + assert isinstance(val, dict) + assert len(val) == 6 + for lab in labels: + assert f"accuracy_{lab}" in val + assert f"recall_{lab}" in val From 96c3ae09b1de07e353dea8f543459e42d4da7dc6 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 10 Feb 2022 13:37:55 +0100 Subject: [PATCH 5/8] docs --- docs/source/references/modules.rst | 6 ++++++ torchmetrics/wrappers/classwise.py | 6 ++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 460a1a54134..7c8b5f181e3 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -710,6 +710,12 @@ BootStrapper .. autoclass:: torchmetrics.BootStrapper :noindex: +ClasswiseWrapper +~~~~~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.ClasswiseWrapper + :noindex: + MetricTracker ~~~~~~~~~~~~~ diff --git a/torchmetrics/wrappers/classwise.py b/torchmetrics/wrappers/classwise.py index 15d7020250a..ece1301dbfa 100644 --- a/torchmetrics/wrappers/classwise.py +++ b/torchmetrics/wrappers/classwise.py @@ -6,10 +6,12 @@ class ClasswiseWrapper(Metric): - """Wrapper class for altering the output of classification metrics that returns multiple values. + """Wrapper class for altering the output of classification metrics that returns multiple values to include + label information. Args: - metric: base metric that should be wrapped + metric: base metric that should be wrapped. It is assumed that the metric outputs a single + tensor that is split along the first dimension. class_labels: list of strings indicating the different classes. From 748c83cbee32ee1946f71d8a3c4372a496829f87 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 10 Feb 2022 13:42:38 +0100 Subject: [PATCH 6/8] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 543116f989d..6a673688743 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added smart update of `MetricCollection` ([#709](https://github.com/PyTorchLightning/metrics/pull/709)) +- Added `ClasswiseWrapper` for better logging of classification metrics with multiple output values ([#832](https://github.com/PyTorchLightning/metrics/pull/832)) + + ### Changed From 0919136771ddc7f4a23864b07ff93a9ca0cf06e8 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Thu, 10 Feb 2022 13:56:45 +0100 Subject: [PATCH 7/8] Update torchmetrics/wrappers/classwise.py Co-authored-by: Jirka Borovec --- torchmetrics/wrappers/classwise.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchmetrics/wrappers/classwise.py b/torchmetrics/wrappers/classwise.py index ece1301dbfa..3137bd05fd7 100644 --- a/torchmetrics/wrappers/classwise.py +++ b/torchmetrics/wrappers/classwise.py @@ -65,8 +65,7 @@ def _convert(self, x: Tensor) -> Dict[Union[str, int], float]: name = self.metric.__class__.__name__.lower() if self.labels is None: return {f"{name}_{i}": val for i, val in enumerate(x)} - else: - return {f"{name}_{lab}": val for lab, val in zip(self.labels, x)} + return {f"{name}_{lab}": val for lab, val in zip(self.labels, x)} def update(self, *args, **kwargs) -> None: self.metric.update(*args, **kwargs) From 3f0419dfc1870111683a43f18e2d3114bb815ff6 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 10 Feb 2022 14:02:16 +0100 Subject: [PATCH 8/8] simple tests --- tests/test_utilities.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/test_utilities.py b/tests/test_utilities.py index e1b59884793..9b989929331 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -16,7 +16,7 @@ from torch import tensor from torchmetrics.utilities import rank_zero_debug, rank_zero_info, rank_zero_warn -from torchmetrics.utilities.data import get_num_classes, to_categorical, to_onehot +from torchmetrics.utilities.data import _flatten, _flatten_dict, get_num_classes, to_categorical, to_onehot from torchmetrics.utilities.distributed import class_reduce, reduce @@ -102,3 +102,17 @@ def test_to_categorical(): ) def test_get_num_classes(preds, target, num_classes, expected_num_classes): assert get_num_classes(preds, target, num_classes) == expected_num_classes + + +def test_flatten_list(): + """Check that _flatten utility function works as expected.""" + inp = [[1, 2, 3], [4, 5], [6]] + out = _flatten(inp) + assert out == [1, 2, 3, 4, 5, 6] + + +def test_flatten_dict(): + """Check that _flatten_dict utility function works as expected.""" + inp = {"a": {"b": 1, "c": 2}, "d": 3} + out = _flatten_dict(inp) + assert out == {"b": 1, "c": 2, "d": 3}