diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index 8ee727ebcba75..3e5852ac5fbe0 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -285,6 +285,7 @@ def _stable_1d_sort(x: torch, N: int = 2049): n = x.numel() if N - n > 0: x_max = x.max() - x_pad = torch.cat([x, (x_max + 1) * torch.ones(2049 - n, dtype=x.dtype, device=x.device)], 0) - x_sort = x_pad.sort() - return x_sort.values[:n], x_sort.indices[:n] + x = torch.cat([x, (x_max + 1) * torch.ones(N - n, dtype=x.dtype, device=x.device)], 0) + x_sort = x.sort() + i = min(N, n) + return x_sort.values[:i], x_sort.indices[:i] diff --git a/tests/metrics/classification/test_auc.py b/tests/metrics/classification/test_auc.py index 70d61b696711f..e902151ecffce 100644 --- a/tests/metrics/classification/test_auc.py +++ b/tests/metrics/classification/test_auc.py @@ -61,4 +61,4 @@ def test_auc_functional(self, x, y): ]) def test_auc(x, y, expected): # Test Area Under Curve (AUC) computation - assert auc(torch.tensor(x), torch.tensor(y)) == expected + assert auc(torch.tensor(x), torch.tensor(y), reorder=True) == expected