diff --git a/CHANGELOG.md b/CHANGELOG.md index e640fe410e1..ff083f29174 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed device and dtype for `LearnedPerceptualImagePatchSimilarity` functional metric ([#2234](https://github.com/Lightning-AI/torchmetrics/pull/2234)) +- Fixed bug in `Metric._reduce_states(...)` when using `dist_sync_fn="cat"` ([#2226](https://github.com/Lightning-AI/torchmetrics/pull/2226)) + + ## [1.2.0] - 2023-09-22 ### Added diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index f574320dc44..8e8b4dbe337 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -409,7 +409,10 @@ def _reduce_states(self, incoming_state: Dict[str, Any]) -> None: elif reduce_fn == dim_zero_min: reduced = torch.min(global_state, local_state) elif reduce_fn == dim_zero_cat: - reduced = global_state + local_state + if isinstance(global_state, Tensor): + reduced = torch.cat([global_state, local_state]) + else: + reduced = global_state + local_state elif reduce_fn is None and isinstance(global_state, Tensor): reduced = torch.stack([global_state, local_state]) elif reduce_fn is None and isinstance(global_state, list):