From a448ad3ff4329682a83fe1036ef21f35a2a8418a Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 12 Jul 2023 19:25:09 +0200 Subject: [PATCH] Bugfix for using metric collection and aggregation metric (#1896) --- CHANGELOG.md | 3 +++ src/torchmetrics/aggregation.py | 30 +++++++++++++++++------ tests/integrations/test_lightning.py | 16 ++++++------ tests/unittests/bases/test_aggregation.py | 17 +++++++++++++ 4 files changed, 50 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a74ca378c04..af23038b3ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixes corner case when using `MetricCollection` together with aggregation metrics ([#1896](https://github.com/Lightning-AI/torchmetrics/pull/1896)) + + - Fixed the use of `max_fpr` in `AUROC` metric when only one class is present ([#1895](https://github.com/Lightning-AI/torchmetrics/pull/1895)) diff --git a/src/torchmetrics/aggregation.py b/src/torchmetrics/aggregation.py index 3b1f4c1f5cc..d15325fedbe 100644 --- a/src/torchmetrics/aggregation.py +++ b/src/torchmetrics/aggregation.py @@ -39,6 +39,7 @@ class BaseAggregator(Metric): - ``'ignore'``: all `nan` values are silently removed - a float: if a float is provided will impude any `nan` values with this value + state_name: name of the metric state kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: @@ -46,7 +47,6 @@ class BaseAggregator(Metric): If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float """ - value: Tensor is_differentiable = None higher_is_better = None full_state_update: bool = False @@ -56,6 +56,7 @@ def __init__( fn: Union[Callable, str], default_value: Union[Tensor, List], nan_strategy: Union[str, float] = "error", + state_name: str = "value", **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -67,7 +68,8 @@ def __init__( ) self.nan_strategy = nan_strategy - self.add_state("value", default=default_value, dist_reduce_fx=fn) + self.add_state(state_name, default=default_value, dist_reduce_fx=fn) + self.state_name = state_name def _cast_and_nan_check_input( self, x: Union[float, Tensor], weight: Optional[Union[float, Tensor]] = None @@ -105,7 +107,7 @@ def update(self, value: Union[float, Tensor]) -> None: def compute(self) -> Tensor: """Compute the aggregated value.""" - return self.value + return getattr(self, self.state_name) class MaxMetric(BaseAggregator): @@ -144,6 +146,7 @@ class MaxMetric(BaseAggregator): """ full_state_update: bool = True + max_value: Tensor def __init__( self, @@ -154,6 +157,7 @@ def __init__( "max", -torch.tensor(float("inf")), nan_strategy, + state_name="max_value", **kwargs, ) @@ -166,7 +170,7 @@ def update(self, value: Union[float, Tensor]) -> None: """ value, _ = self._cast_and_nan_check_input(value) if value.numel(): # make sure tensor not empty - self.value = torch.max(self.value, torch.max(value)) + self.max_value = torch.max(self.max_value, torch.max(value)) def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None @@ -244,6 +248,7 @@ class MinMetric(BaseAggregator): """ full_state_update: bool = True + min_value: Tensor def __init__( self, @@ -254,6 +259,7 @@ def __init__( "min", torch.tensor(float("inf")), nan_strategy, + state_name="min_value", **kwargs, ) @@ -266,7 +272,7 @@ def update(self, value: Union[float, Tensor]) -> None: """ value, _ = self._cast_and_nan_check_input(value) if value.numel(): # make sure tensor not empty - self.value = torch.min(self.value, torch.min(value)) + self.min_value = torch.min(self.min_value, torch.min(value)) def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None @@ -343,6 +349,8 @@ class SumMetric(BaseAggregator): tensor(6.) """ + sum_value: Tensor + def __init__( self, nan_strategy: Union[str, float] = "warn", @@ -352,6 +360,7 @@ def __init__( "sum", torch.tensor(0.0), nan_strategy, + state_name="sum_value", **kwargs, ) @@ -364,7 +373,7 @@ def update(self, value: Union[float, Tensor]) -> None: """ value, _ = self._cast_and_nan_check_input(value) if value.numel(): - self.value += value.sum() + self.sum_value += value.sum() def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None @@ -442,6 +451,8 @@ class CatMetric(BaseAggregator): tensor([1., 2., 3.]) """ + value: Tensor + def __init__( self, nan_strategy: Union[str, float] = "warn", @@ -503,6 +514,8 @@ class MeanMetric(BaseAggregator): tensor(2.) """ + mean_value: Tensor + def __init__( self, nan_strategy: Union[str, float] = "warn", @@ -512,6 +525,7 @@ def __init__( "sum", torch.tensor(0.0), nan_strategy, + state_name="mean_value", **kwargs, ) self.add_state("weight", default=torch.tensor(0.0), dist_reduce_fx="sum") @@ -537,12 +551,12 @@ def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0 if value.numel() == 0: return - self.value += (value * weight).sum() + self.mean_value += (value * weight).sum() self.weight += weight.sum() def compute(self) -> Tensor: """Compute the aggregated value.""" - return self.value / self.weight + return self.mean_value / self.weight def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None diff --git a/tests/integrations/test_lightning.py b/tests/integrations/test_lightning.py index dc2aaf5f3e6..daa1742e259 100644 --- a/tests/integrations/test_lightning.py +++ b/tests/integrations/test_lightning.py @@ -419,22 +419,22 @@ def configure_optimizers(self): return torch.optim.SGD(self.layer.parameters(), lr=0.1) model = BoringModel() - assert model.metric.value.dtype == torch.float32 + assert model.metric.sum_value.dtype == torch.float32 model = model.half() - assert model.metric.value.dtype == torch.float32 + assert model.metric.sum_value.dtype == torch.float32 model = BoringModel() - assert model.metric.value.dtype == torch.float32 + assert model.metric.sum_value.dtype == torch.float32 model = model.double() - assert model.metric.value.dtype == torch.float32 + assert model.metric.sum_value.dtype == torch.float32 model = BoringModel(metric_dtype=torch.float16) - assert model.metric.value.dtype == torch.float16 + assert model.metric.sum_value.dtype == torch.float16 model = model.float() - assert model.metric.value.dtype == torch.float16 + assert model.metric.sum_value.dtype == torch.float16 model = BoringModel() - assert model.metric.value.dtype == torch.float32 + assert model.metric.sum_value.dtype == torch.float32 model = model.type(torch.half) - assert model.metric.value.dtype == torch.float32 + assert model.metric.sum_value.dtype == torch.float32 diff --git a/tests/unittests/bases/test_aggregation.py b/tests/unittests/bases/test_aggregation.py index ab107fa3d4a..2702d323564 100644 --- a/tests/unittests/bases/test_aggregation.py +++ b/tests/unittests/bases/test_aggregation.py @@ -2,6 +2,7 @@ import pytest import torch from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric +from torchmetrics.collections import MetricCollection from unittests import BATCH_SIZE, NUM_BATCHES from unittests.helpers.testers import MetricTester @@ -166,6 +167,22 @@ def test_mean_metric_broadcasting(weights, expected): assert avg(values, weights) == expected +def test_aggregation_in_collection_with_compute_groups(): + """Check that aggregation metrics work in MetricCollection with compute_groups=True.""" + m = MetricCollection(MinMetric(), MaxMetric(), SumMetric(), MeanMetric(), compute_groups=True) + assert len(m.compute_groups) == 4, "Expected 4 compute groups" + m.update(1) + assert len(m.compute_groups) == 4, "Expected 4 compute groups" + m.update(2) + assert len(m.compute_groups) == 4, "Expected 4 compute groups" + + res = m.compute() + assert res["MinMetric"] == 1 + assert res["MaxMetric"] == 2 + assert res["SumMetric"] == 3 + assert res["MeanMetric"] == 1.5 + + @pytest.mark.skipif(not hasattr(torch, "broadcast_to"), reason="PyTorch <1.8 does not have broadcast_to") @pytest.mark.parametrize("nan_strategy", ["ignore", "warn"]) def test_mean_metric_broadcast(nan_strategy):