diff --git a/CHANGELOG.md b/CHANGELOG.md index 543116f989d..6a673688743 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added smart update of `MetricCollection` ([#709](https://github.com/PyTorchLightning/metrics/pull/709)) +- Added `ClasswiseWrapper` for better logging of classification metrics with multiple output values ([#832](https://github.com/PyTorchLightning/metrics/pull/832)) + + ### Changed diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 460a1a54134..7c8b5f181e3 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -710,6 +710,12 @@ BootStrapper .. autoclass:: torchmetrics.BootStrapper :noindex: +ClasswiseWrapper +~~~~~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.ClasswiseWrapper + :noindex: + MetricTracker ~~~~~~~~~~~~~ diff --git a/tests/test_utilities.py b/tests/test_utilities.py index e1b59884793..9b989929331 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -16,7 +16,7 @@ from torch import tensor from torchmetrics.utilities import rank_zero_debug, rank_zero_info, rank_zero_warn -from torchmetrics.utilities.data import get_num_classes, to_categorical, to_onehot +from torchmetrics.utilities.data import _flatten, _flatten_dict, get_num_classes, to_categorical, to_onehot from torchmetrics.utilities.distributed import class_reduce, reduce @@ -102,3 +102,17 @@ def test_to_categorical(): ) def test_get_num_classes(preds, target, num_classes, expected_num_classes): assert get_num_classes(preds, target, num_classes) == expected_num_classes + + +def test_flatten_list(): + """Check that _flatten utility function works as expected.""" + inp = [[1, 2, 3], [4, 5], [6]] + out = _flatten(inp) + assert out == [1, 2, 3, 4, 5, 6] + + +def test_flatten_dict(): + """Check that _flatten_dict utility function works as expected.""" + inp = {"a": {"b": 1, "c": 2}, "d": 3} + out = _flatten_dict(inp) + assert out == {"b": 1, "c": 2, "d": 3} diff --git a/tests/wrappers/test_classwise.py b/tests/wrappers/test_classwise.py new file mode 100644 index 00000000000..a30dead233b --- /dev/null +++ b/tests/wrappers/test_classwise.py @@ -0,0 +1,57 @@ +import pytest +import torch + +from torchmetrics import Accuracy, ClasswiseWrapper, MetricCollection, Recall + + +def test_raises_error_on_wrong_input(): + """Test that errors are raised on wrong input.""" + with pytest.raises(ValueError, match="Expected argument `metric` to be an instance of `torchmetrics.Metric` but.*"): + ClasswiseWrapper([]) + + with pytest.raises(ValueError, match="Expected argument `labels` to either be `None` or a list of strings.*"): + ClasswiseWrapper(Accuracy(), "hest") + + +def test_output_no_labels(): + """Test that wrapper works with no label input.""" + metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None)) + preds = torch.randn(10, 3).softmax(dim=-1) + target = torch.randint(3, (10,)) + val = metric(preds, target) + assert isinstance(val, dict) + assert len(val) == 3 + for i in range(3): + assert f"accuracy_{i}" in val + + +def test_output_with_labels(): + """Test that wrapper works with label input.""" + labels = ["horse", "fish", "cat"] + metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels=labels) + preds = torch.randn(10, 3).softmax(dim=-1) + target = torch.randint(3, (10,)) + val = metric(preds, target) + assert isinstance(val, dict) + assert len(val) == 3 + for lab in labels: + assert f"accuracy_{lab}" in val + + +def test_using_metriccollection(): + """Test wrapper in combination with metric collection.""" + labels = ["horse", "fish", "cat"] + metric = MetricCollection( + { + "accuracy": ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels=labels), + "recall": ClasswiseWrapper(Recall(num_classes=3, average=None), labels=labels), + } + ) + preds = torch.randn(10, 3).softmax(dim=-1) + target = torch.randint(3, (10,)) + val = metric(preds, target) + assert isinstance(val, dict) + assert len(val) == 6 + for lab in labels: + assert f"accuracy_{lab}" in val + assert f"recall_{lab}" in val diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 1f1568d29c9..dce6adf3302 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -89,7 +89,13 @@ WordInfoLost, WordInfoPreserved, ) -from torchmetrics.wrappers import BootStrapper, MetricTracker, MinMaxMetric, MultioutputWrapper # noqa: E402 +from torchmetrics.wrappers import ( # noqa: E402 + BootStrapper, + ClasswiseWrapper, + MetricTracker, + MinMaxMetric, + MultioutputWrapper, +) __all__ = [ "functional", @@ -104,6 +110,7 @@ "BootStrapper", "CalibrationError", "CatMetric", + "ClasswiseWrapper", "CHRFScore", "CohenKappa", "ConfusionMatrix", diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index b1009902b6b..6004a9281cd 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -19,6 +19,7 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn +from torchmetrics.utilities.data import _flatten_dict # this is just a bypass for this module name collision with build-in one from torchmetrics.utilities.imports import OrderedDict @@ -130,7 +131,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs) will be filtered based on the signature of the individual metric. """ - return {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()} + return _flatten_dict({k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}) def update(self, *args: Any, **kwargs: Any) -> None: """Iteratively call update for each metric. @@ -218,8 +219,8 @@ def compute(self) -> Dict[str, Any]: mi = getattr(self, cg[i]) for state in m0._defaults: setattr(mi, state, getattr(m0, state)) - - return {k: m.compute() for k, m in self.items()} + res = {k: m.compute() for k, m in self.items()} + return _flatten_dict(res) def reset(self) -> None: """Iteratively call reset for each metric.""" diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 8603888b262..8db95511c90 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Mapping, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union import torch from torch import Tensor, tensor @@ -51,9 +51,22 @@ def dim_zero_min(x: Tensor) -> Tensor: def _flatten(x: Sequence) -> list: + """flatten list of list into single list.""" return [item for sublist in x for item in sublist] +def _flatten_dict(x: Dict) -> Dict: + """flatten dict of dicts into single dict.""" + new_dict = {} + for key, value in x.items(): + if isinstance(value, dict): + for k, v in value.items(): + new_dict[k] = v + else: + new_dict[key] = value + return new_dict + + def to_onehot( label_tensor: Tensor, num_classes: Optional[int] = None, diff --git a/torchmetrics/wrappers/__init__.py b/torchmetrics/wrappers/__init__.py index 1f9dab6da45..98dbb137b19 100644 --- a/torchmetrics/wrappers/__init__.py +++ b/torchmetrics/wrappers/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401 +from torchmetrics.wrappers.classwise import ClasswiseWrapper # noqa: F401 from torchmetrics.wrappers.minmax import MinMaxMetric # noqa: F401 from torchmetrics.wrappers.multioutput import MultioutputWrapper # noqa: F401 from torchmetrics.wrappers.tracker import MetricTracker # noqa: F401 diff --git a/torchmetrics/wrappers/classwise.py b/torchmetrics/wrappers/classwise.py new file mode 100644 index 00000000000..3137bd05fd7 --- /dev/null +++ b/torchmetrics/wrappers/classwise.py @@ -0,0 +1,74 @@ +from typing import Dict, List, Optional, Union + +from torch import Tensor + +from torchmetrics import Metric + + +class ClasswiseWrapper(Metric): + """Wrapper class for altering the output of classification metrics that returns multiple values to include + label information. + + Args: + metric: base metric that should be wrapped. It is assumed that the metric outputs a single + tensor that is split along the first dimension. + + class_labels: list of strings indicating the different classes. + + Example: + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics import Accuracy, ClasswiseWrapper + >>> metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None)) + >>> preds = torch.randn(10, 3).softmax(dim=-1) + >>> target = torch.randint(3, (10,)) + >>> metric(preds, target) + {'accuracy_0': tensor(0.5000), 'accuracy_1': tensor(0.7500), 'accuracy_2': tensor(0.)} + + Example (labels as list of strings): + >>> import torch + >>> from torchmetrics import Accuracy, ClasswiseWrapper + >>> metric = ClasswiseWrapper( + ... Accuracy(num_classes=3, average=None), + ... labels=["horse", "fish", "dog"] + ... ) + >>> preds = torch.randn(10, 3).softmax(dim=-1) + >>> target = torch.randint(3, (10,)) + >>> metric(preds, target) + {'accuracy_horse': tensor(0.3333), 'accuracy_fish': tensor(0.6667), 'accuracy_dog': tensor(0.)} + + Example (in metric collection): + >>> import torch + >>> from torchmetrics import Accuracy, ClasswiseWrapper, MetricCollection, Recall + >>> labels = ["horse", "fish", "dog"] + >>> metric = MetricCollection( + ... {'accuracy': ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels), + ... 'recall': ClasswiseWrapper(Recall(num_classes=3, average=None), labels)} + ... ) + >>> preds = torch.randn(10, 3).softmax(dim=-1) + >>> target = torch.randint(3, (10,)) + >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE + {'accuracy_horse': tensor(0.), 'accuracy_fish': tensor(0.3333), 'accuracy_dog': tensor(0.4000), + 'recall_horse': tensor(0.), 'recall_fish': tensor(0.3333), 'recall_dog': tensor(0.4000)} + """ + + def __init__(self, metric: Metric, labels: Optional[List[str]] = None) -> None: + super().__init__() + if not isinstance(metric, Metric): + raise ValueError(f"Expected argument `metric` to be an instance of `torchmetrics.Metric` but got {metric}") + if labels is not None and not (isinstance(labels, list) and all(isinstance(lab, str) for lab in labels)): + raise ValueError(f"Expected argument `labels` to either be `None` or a list of strings but got {labels}") + self.metric = metric + self.labels = labels + + def _convert(self, x: Tensor) -> Dict[Union[str, int], float]: + name = self.metric.__class__.__name__.lower() + if self.labels is None: + return {f"{name}_{i}": val for i, val in enumerate(x)} + return {f"{name}_{lab}": val for lab, val in zip(self.labels, x)} + + def update(self, *args, **kwargs) -> None: + self.metric.update(*args, **kwargs) + + def compute(self) -> Dict[str, Tensor]: + return self._convert(self.metric.compute())