Skip to content

Commit

Permalink
Bugfix for using metric collection and aggregation metric (#1896)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Jul 12, 2023
1 parent e1fd252 commit a448ad3
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 16 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixes corner case when using `MetricCollection` together with aggregation metrics ([#1896](https://github.com/Lightning-AI/torchmetrics/pull/1896))


- Fixed the use of `max_fpr` in `AUROC` metric when only one class is present ([#1895](https://github.com/Lightning-AI/torchmetrics/pull/1895))


Expand Down
30 changes: 22 additions & 8 deletions src/torchmetrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ class BaseAggregator(Metric):
- ``'ignore'``: all `nan` values are silently removed
- a float: if a float is provided will impude any `nan` values with this value
state_name: name of the metric state
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
"""

value: Tensor
is_differentiable = None
higher_is_better = None
full_state_update: bool = False
Expand All @@ -56,6 +56,7 @@ def __init__(
fn: Union[Callable, str],
default_value: Union[Tensor, List],
nan_strategy: Union[str, float] = "error",
state_name: str = "value",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -67,7 +68,8 @@ def __init__(
)

self.nan_strategy = nan_strategy
self.add_state("value", default=default_value, dist_reduce_fx=fn)
self.add_state(state_name, default=default_value, dist_reduce_fx=fn)
self.state_name = state_name

def _cast_and_nan_check_input(
self, x: Union[float, Tensor], weight: Optional[Union[float, Tensor]] = None
Expand Down Expand Up @@ -105,7 +107,7 @@ def update(self, value: Union[float, Tensor]) -> None:

def compute(self) -> Tensor:
"""Compute the aggregated value."""
return self.value
return getattr(self, self.state_name)


class MaxMetric(BaseAggregator):
Expand Down Expand Up @@ -144,6 +146,7 @@ class MaxMetric(BaseAggregator):
"""

full_state_update: bool = True
max_value: Tensor

def __init__(
self,
Expand All @@ -154,6 +157,7 @@ def __init__(
"max",
-torch.tensor(float("inf")),
nan_strategy,
state_name="max_value",
**kwargs,
)

Expand All @@ -166,7 +170,7 @@ def update(self, value: Union[float, Tensor]) -> None:
"""
value, _ = self._cast_and_nan_check_input(value)
if value.numel(): # make sure tensor not empty
self.value = torch.max(self.value, torch.max(value))
self.max_value = torch.max(self.max_value, torch.max(value))

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down Expand Up @@ -244,6 +248,7 @@ class MinMetric(BaseAggregator):
"""

full_state_update: bool = True
min_value: Tensor

def __init__(
self,
Expand All @@ -254,6 +259,7 @@ def __init__(
"min",
torch.tensor(float("inf")),
nan_strategy,
state_name="min_value",
**kwargs,
)

Expand All @@ -266,7 +272,7 @@ def update(self, value: Union[float, Tensor]) -> None:
"""
value, _ = self._cast_and_nan_check_input(value)
if value.numel(): # make sure tensor not empty
self.value = torch.min(self.value, torch.min(value))
self.min_value = torch.min(self.min_value, torch.min(value))

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down Expand Up @@ -343,6 +349,8 @@ class SumMetric(BaseAggregator):
tensor(6.)
"""

sum_value: Tensor

def __init__(
self,
nan_strategy: Union[str, float] = "warn",
Expand All @@ -352,6 +360,7 @@ def __init__(
"sum",
torch.tensor(0.0),
nan_strategy,
state_name="sum_value",
**kwargs,
)

Expand All @@ -364,7 +373,7 @@ def update(self, value: Union[float, Tensor]) -> None:
"""
value, _ = self._cast_and_nan_check_input(value)
if value.numel():
self.value += value.sum()
self.sum_value += value.sum()

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down Expand Up @@ -442,6 +451,8 @@ class CatMetric(BaseAggregator):
tensor([1., 2., 3.])
"""

value: Tensor

def __init__(
self,
nan_strategy: Union[str, float] = "warn",
Expand Down Expand Up @@ -503,6 +514,8 @@ class MeanMetric(BaseAggregator):
tensor(2.)
"""

mean_value: Tensor

def __init__(
self,
nan_strategy: Union[str, float] = "warn",
Expand All @@ -512,6 +525,7 @@ def __init__(
"sum",
torch.tensor(0.0),
nan_strategy,
state_name="mean_value",
**kwargs,
)
self.add_state("weight", default=torch.tensor(0.0), dist_reduce_fx="sum")
Expand All @@ -537,12 +551,12 @@ def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0

if value.numel() == 0:
return
self.value += (value * weight).sum()
self.mean_value += (value * weight).sum()
self.weight += weight.sum()

def compute(self) -> Tensor:
"""Compute the aggregated value."""
return self.value / self.weight
return self.mean_value / self.weight

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
16 changes: 8 additions & 8 deletions tests/integrations/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,22 +419,22 @@ def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)

model = BoringModel()
assert model.metric.value.dtype == torch.float32
assert model.metric.sum_value.dtype == torch.float32
model = model.half()
assert model.metric.value.dtype == torch.float32
assert model.metric.sum_value.dtype == torch.float32

model = BoringModel()
assert model.metric.value.dtype == torch.float32
assert model.metric.sum_value.dtype == torch.float32
model = model.double()
assert model.metric.value.dtype == torch.float32
assert model.metric.sum_value.dtype == torch.float32

model = BoringModel(metric_dtype=torch.float16)
assert model.metric.value.dtype == torch.float16
assert model.metric.sum_value.dtype == torch.float16
model = model.float()
assert model.metric.value.dtype == torch.float16
assert model.metric.sum_value.dtype == torch.float16

model = BoringModel()
assert model.metric.value.dtype == torch.float32
assert model.metric.sum_value.dtype == torch.float32

model = model.type(torch.half)
assert model.metric.value.dtype == torch.float32
assert model.metric.sum_value.dtype == torch.float32
17 changes: 17 additions & 0 deletions tests/unittests/bases/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest
import torch
from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric
from torchmetrics.collections import MetricCollection

from unittests import BATCH_SIZE, NUM_BATCHES
from unittests.helpers.testers import MetricTester
Expand Down Expand Up @@ -166,6 +167,22 @@ def test_mean_metric_broadcasting(weights, expected):
assert avg(values, weights) == expected


def test_aggregation_in_collection_with_compute_groups():
"""Check that aggregation metrics work in MetricCollection with compute_groups=True."""
m = MetricCollection(MinMetric(), MaxMetric(), SumMetric(), MeanMetric(), compute_groups=True)
assert len(m.compute_groups) == 4, "Expected 4 compute groups"
m.update(1)
assert len(m.compute_groups) == 4, "Expected 4 compute groups"
m.update(2)
assert len(m.compute_groups) == 4, "Expected 4 compute groups"

res = m.compute()
assert res["MinMetric"] == 1
assert res["MaxMetric"] == 2
assert res["SumMetric"] == 3
assert res["MeanMetric"] == 1.5


@pytest.mark.skipif(not hasattr(torch, "broadcast_to"), reason="PyTorch <1.8 does not have broadcast_to")
@pytest.mark.parametrize("nan_strategy", ["ignore", "warn"])
def test_mean_metric_broadcast(nan_strategy):
Expand Down

0 comments on commit a448ad3

Please sign in to comment.