diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b6b3d6bea2..62891b270a6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `prefix` and `postfix` arguments to `ClasswiseWrapper` ([#1866](https://github.com/Lightning-AI/torchmetrics/pull/1866)) + + - Added speech-to-reverberation modulation energy ratio (SRMR) metric ([#1792](https://github.com/Lightning-AI/torchmetrics/pull/1792), [#1872](https://github.com/Lightning-AI/torchmetrics/pull/1872)) diff --git a/src/torchmetrics/wrappers/classwise.py b/src/torchmetrics/wrappers/classwise.py index 2cc0b6ed100..3e2d8eb7fc2 100644 --- a/src/torchmetrics/wrappers/classwise.py +++ b/src/torchmetrics/wrappers/classwise.py @@ -33,8 +33,12 @@ class ClasswiseWrapper(Metric): metric: base metric that should be wrapped. It is assumed that the metric outputs a single tensor that is split along the first dimension. labels: list of strings indicating the different classes. + prefix: string that is prepended to the metric names. + postfix: string that is appended to the metric names. + + Example:: + Basic example where the ouput of a metric is unwrapped into a dictionary with the class index as keys: - Example: >>> import torch >>> _ = torch.manual_seed(42) >>> from torchmetrics.wrappers import ClasswiseWrapper @@ -47,7 +51,29 @@ class ClasswiseWrapper(Metric): 'multiclassaccuracy_1': tensor(0.7500), 'multiclassaccuracy_2': tensor(0.)} - Example (labels as list of strings): + Example:: + Using custom name via prefix and postfix: + + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics.wrappers import ClasswiseWrapper + >>> from torchmetrics.classification import MulticlassAccuracy + >>> metric_pre = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="acc-") + >>> metric_post = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), postfix="-acc") + >>> preds = torch.randn(10, 3).softmax(dim=-1) + >>> target = torch.randint(3, (10,)) + >>> metric_pre(preds, target) # doctest: +NORMALIZE_WHITESPACE + {'acc-0': tensor(0.5000), + 'acc-1': tensor(0.7500), + 'acc-2': tensor(0.)} + >>> metric_post(preds, target) # doctest: +NORMALIZE_WHITESPACE + {'0-acc': tensor(0.5000), + '1-acc': tensor(0.7500), + '2-acc': tensor(0.)} + + Example:: + Providing labels as a list of strings: + >>> from torchmetrics.wrappers import ClasswiseWrapper >>> from torchmetrics.classification import MulticlassAccuracy >>> metric = ClasswiseWrapper( @@ -61,7 +87,10 @@ class ClasswiseWrapper(Metric): 'multiclassaccuracy_fish': tensor(0.6667), 'multiclassaccuracy_dog': tensor(0.)} - Example (in metric collection): + Example:: + Classwise can also be used in combination with :class:`~torchmetrics.MetricCollection`. In this case, everything + will be flattened into a single dictionary: + >>> from torchmetrics import MetricCollection >>> from torchmetrics.wrappers import ClasswiseWrapper >>> from torchmetrics.classification import MulticlassAccuracy, MulticlassRecall @@ -81,21 +110,43 @@ class ClasswiseWrapper(Metric): 'multiclassrecall_dog': tensor(0.4000)} """ - def __init__(self, metric: Metric, labels: Optional[List[str]] = None) -> None: + def __init__( + self, + metric: Metric, + labels: Optional[List[str]] = None, + prefix: Optional[str] = None, + postfix: Optional[str] = None, + ) -> None: super().__init__() if not isinstance(metric, Metric): raise ValueError(f"Expected argument `metric` to be an instance of `torchmetrics.Metric` but got {metric}") + self.metric = metric + if labels is not None and not (isinstance(labels, list) and all(isinstance(lab, str) for lab in labels)): raise ValueError(f"Expected argument `labels` to either be `None` or a list of strings but got {labels}") - self.metric = metric self.labels = labels + + if prefix is not None and not isinstance(prefix, str): + raise ValueError(f"Expected argument `prefix` to either be `None` or a string but got {prefix}") + self.prefix = prefix + + if postfix is not None and not isinstance(postfix, str): + raise ValueError(f"Expected argument `postfix` to either be `None` or a string but got {postfix}") + self.postfix = postfix + self._update_count = 1 def _convert(self, x: Tensor) -> Dict[str, Any]: - name = self.metric.__class__.__name__.lower() + # Will set the class name as prefix if neither prefix nor postfix is given + if not self.prefix and not self.postfix: + prefix = f"{self.metric.__class__.__name__.lower()}_" + postfix = "" + else: + prefix = self.prefix or "" + postfix = self.postfix or "" if self.labels is None: - return {f"{name}_{i}": val for i, val in enumerate(x)} - return {f"{name}_{lab}": val for lab, val in zip(self.labels, x)} + return {f"{prefix}{i}{postfix}": val for i, val in enumerate(x)} + return {f"{prefix}{lab}{postfix}": val for lab, val in zip(self.labels, x)} def forward(self, *args: Any, **kwargs: Any) -> Any: """Calculate on batch and accumulate to global state.""" diff --git a/tests/unittests/wrappers/test_classwise.py b/tests/unittests/wrappers/test_classwise.py index 561eb375bd2..88e71c411f7 100644 --- a/tests/unittests/wrappers/test_classwise.py +++ b/tests/unittests/wrappers/test_classwise.py @@ -13,6 +13,12 @@ def test_raises_error_on_wrong_input(): with pytest.raises(ValueError, match="Expected argument `labels` to either be `None` or a list of strings.*"): ClasswiseWrapper(MulticlassAccuracy(num_classes=3), "hest") + with pytest.raises(ValueError, match="Expected argument `prefix` to either be `None` or a string.*"): + ClasswiseWrapper(MulticlassAccuracy(num_classes=3), prefix=1) + + with pytest.raises(ValueError, match="Expected argument `postfix` to either be `None` or a string.*"): + ClasswiseWrapper(MulticlassAccuracy(num_classes=3), postfix=1) + def test_output_no_labels(): """Test that wrapper works with no label input.""" @@ -54,6 +60,38 @@ def test_output_with_labels(): assert val[f"multiclassaccuracy_{lab}"] == val_base[i] +def test_output_with_prefix(): + """Test that wrapper works with prefix.""" + base = MulticlassAccuracy(num_classes=3, average=None) + metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="pre_") + for _ in range(2): + preds = torch.randn(20, 3).softmax(dim=-1) + target = torch.randint(3, (20,)) + val = metric(preds, target) + val_base = base(preds, target) + assert isinstance(val, dict) + assert len(val) == 3 + for i in range(3): + assert f"pre_{i}" in val + assert val[f"pre_{i}"] == val_base[i] + + +def test_output_with_postfix(): + """Test that wrapper works with postfix.""" + base = MulticlassAccuracy(num_classes=3, average=None) + metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), postfix="_post") + for _ in range(2): + preds = torch.randn(20, 3).softmax(dim=-1) + target = torch.randint(3, (20,)) + val = metric(preds, target) + val_base = base(preds, target) + assert isinstance(val, dict) + assert len(val) == 3 + for i in range(3): + assert f"{i}_post" in val + assert val[f"{i}_post"] == val_base[i] + + @pytest.mark.parametrize("prefix", [None, "pre_"]) @pytest.mark.parametrize("postfix", [None, "_post"]) def test_using_metriccollection(prefix, postfix):