Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix consistency in the output of scalar tensors #622

Merged
merged 15 commits into from
Nov 24, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Scalar metrics will now consistently have additional dimensions squeezed ([#622](https://github.com/PyTorchLightning/metrics/pull/622))


### Deprecated

Expand Down
4 changes: 2 additions & 2 deletions torchmetrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore

def compute(self) -> Tensor:
"""Compute the aggregated value."""
return self.value.squeeze() if isinstance(self.value, Tensor) else self.value
return self.value


class MaxMetric(BaseAggregator):
Expand Down Expand Up @@ -398,7 +398,7 @@ class MeanMetric(BaseAggregator):
>>> metric.update(1)
>>> metric.update(torch.tensor([2, 3]))
>>> metric.compute()
tensor([2.])
tensor(2.)
"""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/classification/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class PrecisionRecallCurve(Metric):
>>> recall
[tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
>>> thresholds
[tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]
[tensor(0.7500), tensor(0.7500), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor(0.0500)]
"""

is_differentiable = False
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/image/lpip_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class LPIPS(Metric):
>>> img1 = torch.rand(10, 3, 100, 100)
>>> img2 = torch.rand(10, 3, 100, 100)
>>> lpips(img1, img2)
tensor([0.3566], grad_fn=<DivBackward0>)
tensor(0.3566, grad_fn=<SqueezeBackward0>)
"""

is_differentiable = True
Expand Down
13 changes: 11 additions & 2 deletions torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,15 @@
from torch.nn import Module

from torchmetrics.utilities import apply_to_collection, rank_zero_warn
from torchmetrics.utilities.data import _flatten, dim_zero_cat, dim_zero_max, dim_zero_mean, dim_zero_min, dim_zero_sum
from torchmetrics.utilities.data import (
_flatten,
_squeeze_if_scalar,
dim_zero_cat,
dim_zero_max,
dim_zero_mean,
dim_zero_min,
dim_zero_sum,
)
from torchmetrics.utilities.distributed import gather_all_tensors
from torchmetrics.utilities.exceptions import TorchMetricsUserError
from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _compare_version
Expand Down Expand Up @@ -369,7 +377,8 @@ def wrapped_func(*args: Any, **kwargs: Any) -> Any:
with self.sync_context(
dist_sync_fn=self.dist_sync_fn, should_sync=self._to_sync, should_unsync=self._should_unsync
):
self._computed = compute(*args, **kwargs)
value = compute(*args, **kwargs)
self._computed = _squeeze_if_scalar(value)

return self._computed

Expand Down
8 changes: 8 additions & 0 deletions torchmetrics/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,11 @@ def get_group_indexes(indexes: Tensor) -> List[Tensor]:
res[_id] = [i]

return [tensor(x, dtype=torch.long) for x in res.values()]


def _squeeze_scalar_element_tensor(x: Tensor) -> Tensor:
return x.squeeze() if x.numel() == 1 else x


def _squeeze_if_scalar(data: Any) -> Any:
return apply_to_collection(data, Tensor, _squeeze_scalar_element_tensor)
carmocca marked this conversation as resolved.
Show resolved Hide resolved