Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better support for classwise logging #832

Merged
merged 10 commits into from
Feb 10, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added smart update of `MetricCollection` ([#709](https://github.com/PyTorchLightning/metrics/pull/709))


- Added `ClasswiseWrapper` for better logging of classification metrics with multiple output values ([#832](https://github.com/PyTorchLightning/metrics/pull/832))


### Changed


Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,12 @@ BootStrapper
.. autoclass:: torchmetrics.BootStrapper
:noindex:

ClasswiseWrapper
~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.ClasswiseWrapper
:noindex:

MetricTracker
~~~~~~~~~~~~~

Expand Down
16 changes: 15 additions & 1 deletion tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch import tensor

from torchmetrics.utilities import rank_zero_debug, rank_zero_info, rank_zero_warn
from torchmetrics.utilities.data import get_num_classes, to_categorical, to_onehot
from torchmetrics.utilities.data import _flatten, _flatten_dict, get_num_classes, to_categorical, to_onehot
from torchmetrics.utilities.distributed import class_reduce, reduce


Expand Down Expand Up @@ -102,3 +102,17 @@ def test_to_categorical():
)
def test_get_num_classes(preds, target, num_classes, expected_num_classes):
assert get_num_classes(preds, target, num_classes) == expected_num_classes


def test_flatten_list():
"""Check that _flatten utility function works as expected."""
inp = [[1, 2, 3], [4, 5], [6]]
out = _flatten(inp)
assert out == [1, 2, 3, 4, 5, 6]


def test_flatten_dict():
"""Check that _flatten_dict utility function works as expected."""
inp = {"a": {"b": 1, "c": 2}, "d": 3}
out = _flatten_dict(inp)
assert out == {"b": 1, "c": 2, "d": 3}
57 changes: 57 additions & 0 deletions tests/wrappers/test_classwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pytest
import torch

from torchmetrics import Accuracy, ClasswiseWrapper, MetricCollection, Recall


def test_raises_error_on_wrong_input():
"""Test that errors are raised on wrong input."""
with pytest.raises(ValueError, match="Expected argument `metric` to be an instance of `torchmetrics.Metric` but.*"):
ClasswiseWrapper([])

with pytest.raises(ValueError, match="Expected argument `labels` to either be `None` or a list of strings.*"):
ClasswiseWrapper(Accuracy(), "hest")


def test_output_no_labels():
"""Test that wrapper works with no label input."""
metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None))
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
val = metric(preds, target)
assert isinstance(val, dict)
assert len(val) == 3
for i in range(3):
assert f"accuracy_{i}" in val


def test_output_with_labels():
"""Test that wrapper works with label input."""
labels = ["horse", "fish", "cat"]
metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels=labels)
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
val = metric(preds, target)
assert isinstance(val, dict)
assert len(val) == 3
for lab in labels:
assert f"accuracy_{lab}" in val


def test_using_metriccollection():
"""Test wrapper in combination with metric collection."""
labels = ["horse", "fish", "cat"]
metric = MetricCollection(
{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but here it won't share states, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. That is common problem for all the wrappers because we assume that states are not nested when trying to figure out the compute groups.
I can try to come up with a general solution for this problem

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we maybe have a property that kind of forwards the states of the wrapped metrics?

"accuracy": ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels=labels),
"recall": ClasswiseWrapper(Recall(num_classes=3, average=None), labels=labels),
}
)
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
val = metric(preds, target)
assert isinstance(val, dict)
assert len(val) == 6
for lab in labels:
assert f"accuracy_{lab}" in val
assert f"recall_{lab}" in val
9 changes: 8 additions & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,13 @@
WordInfoLost,
WordInfoPreserved,
)
from torchmetrics.wrappers import BootStrapper, MetricTracker, MinMaxMetric, MultioutputWrapper # noqa: E402
from torchmetrics.wrappers import ( # noqa: E402
BootStrapper,
ClasswiseWrapper,
MetricTracker,
MinMaxMetric,
MultioutputWrapper,
)

__all__ = [
"functional",
Expand All @@ -104,6 +110,7 @@
"BootStrapper",
"CalibrationError",
"CatMetric",
"ClasswiseWrapper",
"CHRFScore",
"CohenKappa",
"ConfusionMatrix",
Expand Down
7 changes: 4 additions & 3 deletions torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import _flatten_dict

# this is just a bypass for this module name collision with build-in one
from torchmetrics.utilities.imports import OrderedDict
Expand Down Expand Up @@ -130,7 +131,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
return {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}
return _flatten_dict({k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()})

def update(self, *args: Any, **kwargs: Any) -> None:
"""Iteratively call update for each metric.
Expand Down Expand Up @@ -218,8 +219,8 @@ def compute(self) -> Dict[str, Any]:
mi = getattr(self, cg[i])
for state in m0._defaults:
setattr(mi, state, getattr(m0, state))

return {k: m.compute() for k, m in self.items()}
res = {k: m.compute() for k, m in self.items()}
return _flatten_dict(res)

def reset(self) -> None:
"""Iteratively call reset for each metric."""
Expand Down
15 changes: 14 additions & 1 deletion torchmetrics/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Mapping, Optional, Sequence, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union

import torch
from torch import Tensor, tensor
Expand Down Expand Up @@ -51,9 +51,22 @@ def dim_zero_min(x: Tensor) -> Tensor:


def _flatten(x: Sequence) -> list:
"""flatten list of list into single list."""
return [item for sublist in x for item in sublist]


def _flatten_dict(x: Dict) -> Dict:
"""flatten dict of dicts into single dict."""
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
new_dict = {}
for key, value in x.items():
if isinstance(value, dict):
for k, v in value.items():
new_dict[k] = v
else:
new_dict[key] = value
return new_dict


def to_onehot(
label_tensor: Tensor,
num_classes: Optional[int] = None,
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401
from torchmetrics.wrappers.classwise import ClasswiseWrapper # noqa: F401
from torchmetrics.wrappers.minmax import MinMaxMetric # noqa: F401
from torchmetrics.wrappers.multioutput import MultioutputWrapper # noqa: F401
from torchmetrics.wrappers.tracker import MetricTracker # noqa: F401
74 changes: 74 additions & 0 deletions torchmetrics/wrappers/classwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Dict, List, Optional, Union

from torch import Tensor

from torchmetrics import Metric


class ClasswiseWrapper(Metric):
"""Wrapper class for altering the output of classification metrics that returns multiple values to include
label information.

Args:
metric: base metric that should be wrapped. It is assumed that the metric outputs a single
tensor that is split along the first dimension.

class_labels: list of strings indicating the different classes.

Example:
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics import Accuracy, ClasswiseWrapper
>>> metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None))
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric(preds, target)
{'accuracy_0': tensor(0.5000), 'accuracy_1': tensor(0.7500), 'accuracy_2': tensor(0.)}

Example (labels as list of strings):
>>> import torch
>>> from torchmetrics import Accuracy, ClasswiseWrapper
>>> metric = ClasswiseWrapper(
... Accuracy(num_classes=3, average=None),
... labels=["horse", "fish", "dog"]
... )
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric(preds, target)
{'accuracy_horse': tensor(0.3333), 'accuracy_fish': tensor(0.6667), 'accuracy_dog': tensor(0.)}

Example (in metric collection):
>>> import torch
>>> from torchmetrics import Accuracy, ClasswiseWrapper, MetricCollection, Recall
>>> labels = ["horse", "fish", "dog"]
>>> metric = MetricCollection(
... {'accuracy': ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels),
... 'recall': ClasswiseWrapper(Recall(num_classes=3, average=None), labels)}
... )
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE
{'accuracy_horse': tensor(0.), 'accuracy_fish': tensor(0.3333), 'accuracy_dog': tensor(0.4000),
'recall_horse': tensor(0.), 'recall_fish': tensor(0.3333), 'recall_dog': tensor(0.4000)}
"""

def __init__(self, metric: Metric, labels: Optional[List[str]] = None) -> None:
super().__init__()
if not isinstance(metric, Metric):
raise ValueError(f"Expected argument `metric` to be an instance of `torchmetrics.Metric` but got {metric}")
if labels is not None and not (isinstance(labels, list) and all(isinstance(lab, str) for lab in labels)):
raise ValueError(f"Expected argument `labels` to either be `None` or a list of strings but got {labels}")
self.metric = metric
self.labels = labels

def _convert(self, x: Tensor) -> Dict[Union[str, int], float]:
name = self.metric.__class__.__name__.lower()
if self.labels is None:
return {f"{name}_{i}": val for i, val in enumerate(x)}
return {f"{name}_{lab}": val for lab, val in zip(self.labels, x)}

def update(self, *args, **kwargs) -> None:
self.metric.update(*args, **kwargs)

def compute(self) -> Dict[str, Tensor]:
return self._convert(self.metric.compute())