Skip to content

Commit

Permalink
Fix when _stable_1d_sort to work when n >= N (#6177)
Browse files Browse the repository at this point in the history
* Fix when _stable_1d_sort to work when n >= N

* Apply suggestions

Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
  • Loading branch information
2 people authored and tchaton committed Mar 9, 2021
1 parent de0efa9 commit 5f9bb9d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
7 changes: 4 additions & 3 deletions pytorch_lightning/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def _stable_1d_sort(x: torch, N: int = 2049):
n = x.numel()
if N - n > 0:
x_max = x.max()
x_pad = torch.cat([x, (x_max + 1) * torch.ones(2049 - n, dtype=x.dtype, device=x.device)], 0)
x_sort = x_pad.sort()
return x_sort.values[:n], x_sort.indices[:n]
x = torch.cat([x, (x_max + 1) * torch.ones(N - n, dtype=x.dtype, device=x.device)], 0)
x_sort = x.sort()
i = min(N, n)
return x_sort.values[:i], x_sort.indices[:i]
2 changes: 1 addition & 1 deletion tests/metrics/classification/test_auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ def test_auc_functional(self, x, y):
])
def test_auc(x, y, expected):
# Test Area Under Curve (AUC) computation
assert auc(torch.tensor(x), torch.tensor(y)) == expected
assert auc(torch.tensor(x), torch.tensor(y), reorder=True) == expected

0 comments on commit 5f9bb9d

Please sign in to comment.