Skip to content

Commit

Permalink
Fix precision issue in calibration error (#1919)
Browse files Browse the repository at this point in the history
* fix implementation

* add tests

* changelog

* skip on older versions

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* skip testing on older

---------

Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>

(cherry picked from commit 879595d)
  • Loading branch information
SkafteNicki authored and Borda committed Aug 3, 2023
1 parent a8819d4 commit 0c732f3
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed bug in `CalibrationError` where calculations for double precision input was performed in float precision ([#1919](https://github.com/Lightning-AI/torchmetrics/pull/1919))


- Fixed bug related to the `prefix/postfix` arguments in `MetricCollection` and `ClasswiseWrapper` being duplicated ([#1918](https://github.com/Lightning-AI/torchmetrics/pull/1918))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _ce_compute(
Tensor: Calibration error scalar.
"""
if isinstance(bin_boundaries, int):
bin_boundaries = torch.linspace(0, 1, bin_boundaries + 1, dtype=torch.float, device=confidences.device)
bin_boundaries = torch.linspace(0, 1, bin_boundaries + 1, dtype=confidences.dtype, device=confidences.device)

if norm not in {"l1", "l2", "max"}:
raise ValueError(f"Argument `norm` is expected to be one of 'l1', 'l2', 'max' but got {norm}")
Expand Down
27 changes: 25 additions & 2 deletions tests/unittests/classification/test_calibration_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
binary_calibration_error,
multiclass_calibration_error,
)
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_9
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_9, _TORCH_GREATER_EQUAL_1_13

from unittests import NUM_CLASSES
from unittests.classification.inputs import _binary_cases, _multiclass_cases
Expand Down Expand Up @@ -108,7 +108,8 @@ def test_binary_calibration_error_differentiability(self, inputs):
def test_binary_calibration_error_dtype_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs

if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_13:
pytest.xfail(reason="torch.linspace in metric not supported before pytorch v1.13 for cpu + half")
if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
self.run_precision_test_cpu(
Expand All @@ -123,6 +124,8 @@ def test_binary_calibration_error_dtype_cpu(self, inputs, dtype):
@pytest.mark.parametrize("dtype", [torch.half, torch.double])
def test_binary_calibration_error_dtype_gpu(self, inputs, dtype):
"""Test dtype support of the metric on GPU."""
if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_13:
pytest.xfail(reason="torch.searchsorted in metric not supported before pytorch v1.13 for gpu + half")
preds, target = inputs
self.run_precision_test_gpu(
preds=preds,
Expand Down Expand Up @@ -246,3 +249,23 @@ def test_multiclass_calibration_error_dtype_gpu(self, inputs, dtype):
metric_args={"num_classes": NUM_CLASSES},
dtype=dtype,
)


def test_corner_case_due_to_dtype():
"""Test that metric works with edge case where the precision is really important for the right result.
See issue: https://github.com/Lightning-AI/torchmetrics/issues/1907
"""
preds = torch.tensor(
[0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.8000, 0.8000, 0.0100, 0.3300, 0.3400, 0.9900, 0.6100],
dtype=torch.float64,
)
target = torch.tensor([1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0])

assert np.allclose(
ECE(99).measure(preds.numpy(), target.numpy()), binary_calibration_error(preds, target, n_bins=99)
), "The metric should be close to the netcal implementation"
assert np.allclose(
ECE(100).measure(preds.numpy(), target.numpy()), binary_calibration_error(preds, target, n_bins=100)
), "The metric should be close to the netcal implementation"

0 comments on commit 0c732f3

Please sign in to comment.