Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement pre and postfix for Classwise Wrapper #1866

Merged
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added `prefix` and `postfix` arguments to `ClasswiseWrapper` ([#1866](https://github.com/Lightning-AI/torchmetrics/pull/1866))


- Added speech-to-reverberation modulation energy ratio (SRMR) metric ([#1792](https://github.com/Lightning-AI/torchmetrics/pull/1792), [#1872](https://github.com/Lightning-AI/torchmetrics/pull/1872))


Expand Down
67 changes: 59 additions & 8 deletions src/torchmetrics/wrappers/classwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@ class ClasswiseWrapper(Metric):
metric: base metric that should be wrapped. It is assumed that the metric outputs a single
tensor that is split along the first dimension.
labels: list of strings indicating the different classes.
prefix: string that is prepended to the metric names.
postfix: string that is appended to the metric names.

Example::
Basic example where the ouput of a metric is unwrapped into a dictionary with the class index as keys:

Example:
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.wrappers import ClasswiseWrapper
Expand All @@ -47,7 +51,29 @@ class ClasswiseWrapper(Metric):
'multiclassaccuracy_1': tensor(0.7500),
'multiclassaccuracy_2': tensor(0.)}

Example (labels as list of strings):
Example::
Using custom name via prefix and postfix:

>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.wrappers import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric_pre = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="acc-")
>>> metric_post = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), postfix="-acc")
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric_pre(preds, target) # doctest: +NORMALIZE_WHITESPACE
{'acc-0': tensor(0.5000),
'acc-1': tensor(0.7500),
'acc-2': tensor(0.)}
>>> metric_post(preds, target) # doctest: +NORMALIZE_WHITESPACE
{'0-acc': tensor(0.5000),
'1-acc': tensor(0.7500),
'2-acc': tensor(0.)}

Example::
Providing labels as a list of strings:

>>> from torchmetrics.wrappers import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric = ClasswiseWrapper(
Expand All @@ -61,7 +87,10 @@ class ClasswiseWrapper(Metric):
'multiclassaccuracy_fish': tensor(0.6667),
'multiclassaccuracy_dog': tensor(0.)}

Example (in metric collection):
Example::
Classwise can also be used in combination with :class:`~torchmetrics.MetricCollection`. In this case, everything
will be flattened into a single dictionary:

>>> from torchmetrics import MetricCollection
>>> from torchmetrics.wrappers import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy, MulticlassRecall
Expand All @@ -81,21 +110,43 @@ class ClasswiseWrapper(Metric):
'multiclassrecall_dog': tensor(0.4000)}
"""

def __init__(self, metric: Metric, labels: Optional[List[str]] = None) -> None:
def __init__(
self,
metric: Metric,
labels: Optional[List[str]] = None,
prefix: Optional[str] = None,
postfix: Optional[str] = None,
) -> None:
super().__init__()
if not isinstance(metric, Metric):
raise ValueError(f"Expected argument `metric` to be an instance of `torchmetrics.Metric` but got {metric}")
self.metric = metric

if labels is not None and not (isinstance(labels, list) and all(isinstance(lab, str) for lab in labels)):
raise ValueError(f"Expected argument `labels` to either be `None` or a list of strings but got {labels}")
self.metric = metric
self.labels = labels

if prefix is not None and not isinstance(prefix, str):
raise ValueError(f"Expected argument `prefix` to either be `None` or a string but got {prefix}")
self.prefix = prefix

if postfix is not None and not isinstance(postfix, str):
raise ValueError(f"Expected argument `postfix` to either be `None` or a string but got {postfix}")
self.postfix = postfix

self._update_count = 1

def _convert(self, x: Tensor) -> Dict[str, Any]:
name = self.metric.__class__.__name__.lower()
# Will set the class name as prefix if neither prefix nor postfix is given
if not self.prefix and not self.postfix:
prefix = f"{self.metric.__class__.__name__.lower()}_"
postfix = ""
else:
prefix = self.prefix or ""
postfix = self.postfix or ""
if self.labels is None:
return {f"{name}_{i}": val for i, val in enumerate(x)}
return {f"{name}_{lab}": val for lab, val in zip(self.labels, x)}
return {f"{prefix}{i}{postfix}": val for i, val in enumerate(x)}
return {f"{prefix}{lab}{postfix}": val for lab, val in zip(self.labels, x)}

def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Calculate on batch and accumulate to global state."""
Expand Down
38 changes: 38 additions & 0 deletions tests/unittests/wrappers/test_classwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ def test_raises_error_on_wrong_input():
with pytest.raises(ValueError, match="Expected argument `labels` to either be `None` or a list of strings.*"):
ClasswiseWrapper(MulticlassAccuracy(num_classes=3), "hest")

with pytest.raises(ValueError, match="Expected argument `prefix` to either be `None` or a string.*"):
ClasswiseWrapper(MulticlassAccuracy(num_classes=3), prefix=1)

with pytest.raises(ValueError, match="Expected argument `postfix` to either be `None` or a string.*"):
ClasswiseWrapper(MulticlassAccuracy(num_classes=3), postfix=1)


def test_output_no_labels():
"""Test that wrapper works with no label input."""
Expand Down Expand Up @@ -54,6 +60,38 @@ def test_output_with_labels():
assert val[f"multiclassaccuracy_{lab}"] == val_base[i]


def test_output_with_prefix():
"""Test that wrapper works with prefix."""
base = MulticlassAccuracy(num_classes=3, average=None)
metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="pre_")
for _ in range(2):
preds = torch.randn(20, 3).softmax(dim=-1)
target = torch.randint(3, (20,))
val = metric(preds, target)
val_base = base(preds, target)
assert isinstance(val, dict)
assert len(val) == 3
for i in range(3):
assert f"pre_{i}" in val
assert val[f"pre_{i}"] == val_base[i]


def test_output_with_postfix():
"""Test that wrapper works with postfix."""
base = MulticlassAccuracy(num_classes=3, average=None)
metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), postfix="_post")
for _ in range(2):
preds = torch.randn(20, 3).softmax(dim=-1)
target = torch.randint(3, (20,))
val = metric(preds, target)
val_base = base(preds, target)
assert isinstance(val, dict)
assert len(val) == 3
for i in range(3):
assert f"{i}_post" in val
assert val[f"{i}_post"] == val_base[i]


@pytest.mark.parametrize("prefix", [None, "pre_"])
@pytest.mark.parametrize("postfix", [None, "_post"])
def test_using_metriccollection(prefix, postfix):
Expand Down
Loading