Skip to content

Commit

Permalink
Add _filter_kwargs method to ClasswiseWrapper for better integrat…
Browse files Browse the repository at this point in the history
…ion with `MetricCollection` (#2575)

(cherry picked from commit bf0d6e2)
  • Loading branch information
SkafteNicki authored and Borda committed Aug 2, 2024
1 parent 4b18aea commit d30ace2
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed `BootstrapWrapper` not being reset correctly ([#2574](https://github.com/Lightning-AI/torchmetrics/pull/2574))


- Fixed integration between `ClasswiseWrapper` and `MetricCollection` with custom `_filter_kwargs` method ([#2575](https://github.com/Lightning-AI/torchmetrics/pull/2575))


## [1.4.0] - 2024-05-03

### Added
Expand Down
11 changes: 8 additions & 3 deletions src/torchmetrics/wrappers/classwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,12 @@ def __init__(

self._update_count = 1

def _convert(self, x: Tensor) -> Dict[str, Any]:
def _filter_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
"""Filter kwargs for the metric."""
return self.metric._filter_kwargs(**kwargs)

def _convert_output(self, x: Tensor) -> Dict[str, Any]:
"""Convert output to dictionary with labels as keys."""
# Will set the class name as prefix if neither prefix nor postfix is given
if not self._prefix and not self._postfix:
prefix = f"{self.metric.__class__.__name__.lower()}_"
Expand All @@ -156,15 +161,15 @@ def _convert(self, x: Tensor) -> Dict[str, Any]:

def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Calculate on batch and accumulate to global state."""
return self._convert(self.metric(*args, **kwargs))
return self._convert_output(self.metric(*args, **kwargs))

def update(self, *args: Any, **kwargs: Any) -> None:
"""Update state."""
self.metric.update(*args, **kwargs)

def compute(self) -> Dict[str, Tensor]:
"""Compute metric."""
return self._convert(self.metric.compute())
return self._convert_output(self.metric.compute())

def reset(self) -> None:
"""Reset metric."""
Expand Down
25 changes: 25 additions & 0 deletions tests/unittests/wrappers/test_classwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score, MulticlassRecall
from torchmetrics.clustering import CalinskiHarabaszScore
from torchmetrics.wrappers import ClasswiseWrapper


Expand Down Expand Up @@ -150,3 +151,27 @@ def test_double_use_of_prefix_with_metriccollection():
assert "val/accuracy" in res
assert "val/f_score_Tree" in res
assert "val/f_score_Bush" in res


def test_filter_kwargs_and_metriccollection():
"""Test that kwargs are correctly filtered when using metric collection."""
metric = MetricCollection(
{
"accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)),
"cluster": CalinskiHarabaszScore(),
},
)
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
data = torch.randn(10, 3)

metric.update(preds=preds, target=target, data=data, labels=target)
metric(preds=preds, target=target, data=data, labels=target)
val = metric.compute()

assert isinstance(val, dict)
assert len(val) == 4
assert "multiclassaccuracy_0" in val
assert "multiclassaccuracy_1" in val
assert "multiclassaccuracy_2" in val
assert "cluster" in val

0 comments on commit d30ace2

Please sign in to comment.