From 960c95145843a9c232add836306c89a5697a0d1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20K=C3=B6nig?= Date: Mon, 20 Nov 2023 10:47:55 +0100 Subject: [PATCH 1/3] Fix "dim_zero_cat" reduction #2225 Fixes #2225 by changing "reduced = global_state + local_state" to "reduced = torch.cat([global_state, local_state] --- src/torchmetrics/metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index f574320dc44..c968208faff 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -409,7 +409,7 @@ def _reduce_states(self, incoming_state: Dict[str, Any]) -> None: elif reduce_fn == dim_zero_min: reduced = torch.min(global_state, local_state) elif reduce_fn == dim_zero_cat: - reduced = global_state + local_state + reduced = torch.cat([global_state, local_state]) elif reduce_fn is None and isinstance(global_state, Tensor): reduced = torch.stack([global_state, local_state]) elif reduce_fn is None and isinstance(global_state, list): From c9f3d87a12f6b81a1d9a3ab39d1b1c97efe985c5 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Thu, 23 Nov 2023 10:15:42 +0100 Subject: [PATCH 2/3] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 72549469d59..5855ab8f727 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed numerical stability issue in `UniversalImageQualityIndex` metric ([#2222](https://github.com/Lightning-AI/torchmetrics/pull/2222)) +- Fixed bug in `Metric._reduce_states(...)` when using `dist_sync_fn="cat"` ([#2226](https://github.com/Lightning-AI/torchmetrics/pull/2226)) + + ## [1.2.0] - 2023-09-22 ### Added From 0cfcb6e584152bc959ba7e3b23094e7f43979524 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 25 Nov 2023 15:18:14 +0100 Subject: [PATCH 3/3] fix --- src/torchmetrics/metric.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index c968208faff..8e8b4dbe337 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -409,7 +409,10 @@ def _reduce_states(self, incoming_state: Dict[str, Any]) -> None: elif reduce_fn == dim_zero_min: reduced = torch.min(global_state, local_state) elif reduce_fn == dim_zero_cat: - reduced = torch.cat([global_state, local_state]) + if isinstance(global_state, Tensor): + reduced = torch.cat([global_state, local_state]) + else: + reduced = global_state + local_state elif reduce_fn is None and isinstance(global_state, Tensor): reduced = torch.stack([global_state, local_state]) elif reduce_fn is None and isinstance(global_state, list):