Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dim_zero_cat reduction #2226

Merged
merged 11 commits into from
Nov 25, 2023
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed device and dtype for `LearnedPerceptualImagePatchSimilarity` functional metric ([#2234](https://github.com/Lightning-AI/torchmetrics/pull/2234))


- Fixed bug in `Metric._reduce_states(...)` when using `dist_sync_fn="cat"` ([#2226](https://github.com/Lightning-AI/torchmetrics/pull/2226))


## [1.2.0] - 2023-09-22

### Added
Expand Down
5 changes: 4 additions & 1 deletion src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,10 @@ def _reduce_states(self, incoming_state: Dict[str, Any]) -> None:
elif reduce_fn == dim_zero_min:
reduced = torch.min(global_state, local_state)
elif reduce_fn == dim_zero_cat:
reduced = global_state + local_state
if isinstance(global_state, Tensor):
reduced = torch.cat([global_state, local_state])
else:
reduced = global_state + local_state
elif reduce_fn is None and isinstance(global_state, Tensor):
reduced = torch.stack([global_state, local_state])
elif reduce_fn is None and isinstance(global_state, list):
Expand Down
Loading