diff --git a/CHANGELOG.md b/CHANGELOG.md index 51c802d05eb..5fec0c9aa1e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed `BootstrapWrapper` not being reset correctly ([#2574](https://github.com/Lightning-AI/torchmetrics/pull/2574)) +- Fixed integration between `ClasswiseWrapper` and `MetricCollection` with custom `_filter_kwargs` method ([#2575](https://github.com/Lightning-AI/torchmetrics/pull/2575)) + + ## [1.4.0] - 2024-05-03 ### Added diff --git a/src/torchmetrics/wrappers/classwise.py b/src/torchmetrics/wrappers/classwise.py index 0920118c919..78cb27ae46c 100644 --- a/src/torchmetrics/wrappers/classwise.py +++ b/src/torchmetrics/wrappers/classwise.py @@ -142,7 +142,12 @@ def __init__( self._update_count = 1 - def _convert(self, x: Tensor) -> Dict[str, Any]: + def _filter_kwargs(self, **kwargs: Any) -> Dict[str, Any]: + """Filter kwargs for the metric.""" + return self.metric._filter_kwargs(**kwargs) + + def _convert_output(self, x: Tensor) -> Dict[str, Any]: + """Convert output to dictionary with labels as keys.""" # Will set the class name as prefix if neither prefix nor postfix is given if not self._prefix and not self._postfix: prefix = f"{self.metric.__class__.__name__.lower()}_" @@ -156,7 +161,7 @@ def _convert(self, x: Tensor) -> Dict[str, Any]: def forward(self, *args: Any, **kwargs: Any) -> Any: """Calculate on batch and accumulate to global state.""" - return self._convert(self.metric(*args, **kwargs)) + return self._convert_output(self.metric(*args, **kwargs)) def update(self, *args: Any, **kwargs: Any) -> None: """Update state.""" @@ -164,7 +169,7 @@ def update(self, *args: Any, **kwargs: Any) -> None: def compute(self) -> Dict[str, Tensor]: """Compute metric.""" - return self._convert(self.metric.compute()) + return self._convert_output(self.metric.compute()) def reset(self) -> None: """Reset metric.""" diff --git a/tests/unittests/wrappers/test_classwise.py b/tests/unittests/wrappers/test_classwise.py index 5c72362b545..e6491903145 100644 --- a/tests/unittests/wrappers/test_classwise.py +++ b/tests/unittests/wrappers/test_classwise.py @@ -2,6 +2,7 @@ import torch from torchmetrics import MetricCollection from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score, MulticlassRecall +from torchmetrics.clustering import CalinskiHarabaszScore from torchmetrics.wrappers import ClasswiseWrapper @@ -150,3 +151,27 @@ def test_double_use_of_prefix_with_metriccollection(): assert "val/accuracy" in res assert "val/f_score_Tree" in res assert "val/f_score_Bush" in res + + +def test_filter_kwargs_and_metriccollection(): + """Test that kwargs are correctly filtered when using metric collection.""" + metric = MetricCollection( + { + "accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)), + "cluster": CalinskiHarabaszScore(), + }, + ) + preds = torch.randn(10, 3).softmax(dim=-1) + target = torch.randint(3, (10,)) + data = torch.randn(10, 3) + + metric.update(preds=preds, target=target, data=data, labels=target) + metric(preds=preds, target=target, data=data, labels=target) + val = metric.compute() + + assert isinstance(val, dict) + assert len(val) == 4 + assert "multiclassaccuracy_0" in val + assert "multiclassaccuracy_1" in val + assert "multiclassaccuracy_2" in val + assert "cluster" in val