Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve AUC numeric stability #224

Closed
wants to merge 9 commits into from
Closed
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed `PSNR` not working with `DDP` ([#214](https://github.com/PyTorchLightning/metrics/pull/214))


- Fixed `AUC` sometimes raises errors even for sorted imput due to numerical instability ([#224](https://github.com/PyTorchLightning/metrics/pull/224))


- Fixed metric calculation with unequal batch sizes ([#220](https://github.com/PyTorchLightning/metrics/pull/220))


Expand Down
6 changes: 3 additions & 3 deletions torchmetrics/functional/classification/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@ def _auc_update(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:
return x, y


def _auc_compute(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor:
def _auc_compute(x: Tensor, y: Tensor, reorder: bool = False, tol: float = 1e-6) -> Tensor:
with torch.no_grad():
if reorder:
# TODO: include stable=True arg when pytorch v1.9 is released
x, x_idx = torch.sort(x)
y = y[x_idx]

dx = x[1:] - x[:-1]
if (dx < 0).any():
if (dx <= 0).all():
if (dx + tol < 0).any():
if (dx <= tol).all():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

am I reading it correctly that if I have

dx = [-1e-7, 5e-7, 5e-7, ... 10000 times ...., 5e-7] then we'll deduce incorrect direction here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it seems that direction/sign is changed, so shall ve preserve it...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe I am misunderstanding the question...
the tensor
dx = [1e-7, 5e-7, 5e-7, ... 10000 times ...., 5e-7]
implies that direction=1 right? so if we say that numerical instability leads to a change of sign in the first element
dx = [-1e-7, 5e-7, 5e-7, ... 10000 times ...., 5e-7]
then it will still be direction=1 with the change.

Copy link
Contributor

@maximsch2 maximsch2 May 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, let's maybe use this instead:
dx = [-1.0, 5e-7, 5e-7, ... 10000 times ...., 5e-7]
Now both (dx+tol<0).any() and (dx <= tol).all() are true and we'll discover direction as -1, wheras the whole thing is incorrectly sorted.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I see the problem now...
Do you have anyway to solve this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Numerical issues are usually tricky. I'm suggesting we remove checks for the internal callers of this function, let me quickly show an example.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SkafteNicki , check out #230

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maximsch2 looks good, closing this in favour of yours

direction = -1.
else:
raise ValueError(
Expand Down