Skip to content

Commit

Permalink
Fix AUROC to take advantage of known order ahead of time (#230)
Browse files Browse the repository at this point in the history
* Fix AUROC to take advantage of known order ahead of time

* changelog

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
  • Loading branch information
maximsch2 and SkafteNicki authored May 6, 2021
1 parent 33864db commit 05c45a5
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed metric concatenation for list states for zero-dim input ([#229](https://github.com/PyTorchLightning/metrics/pull/229))


- Fixed numerical instability in `AUROC` metric for large input ([#230](https://github.com/PyTorchLightning/metrics/pull/230))


## [0.3.1] - 2021-04-21

- Cleaning remaining inconsistency and fix PL develop integration (
Expand Down
7 changes: 6 additions & 1 deletion torchmetrics/functional/classification/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def _auc_update(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:
return x, y


def _auc_compute_without_check(x: Tensor, y: Tensor, direction: float) -> Tensor:
with torch.no_grad():
return direction * torch.trapz(y, x)


def _auc_compute(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor:
with torch.no_grad():
if reorder:
Expand All @@ -49,7 +54,7 @@ def _auc_compute(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor:
)
else:
direction = 1.
return direction * torch.trapz(y, x)
return _auc_compute_without_check(x, y, direction)


def auc(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor:
Expand Down
8 changes: 4 additions & 4 deletions torchmetrics/functional/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
from torch import Tensor, tensor

from torchmetrics.functional.classification.auc import auc
from torchmetrics.functional.classification.auc import _auc_compute_without_check
from torchmetrics.functional.classification.roc import roc
from torchmetrics.utilities.checks import _input_format_classification
from torchmetrics.utilities.enums import AverageMethod, DataType
Expand Down Expand Up @@ -93,7 +93,7 @@ def _auroc_compute(
pass
elif num_classes != 1:
# calculate auc scores per class
auc_scores = [auc(x, y) for x, y in zip(fpr, tpr)]
auc_scores = [_auc_compute_without_check(x, y, 1.0) for x, y in zip(fpr, tpr)]

# calculate average
if average == AverageMethod.NONE:
Expand All @@ -113,7 +113,7 @@ def _auroc_compute(
f" {allowed_average} but got {average}"
)

return auc(fpr, tpr)
return _auc_compute_without_check(fpr, tpr, 1.0)

max_fpr = tensor(max_fpr, device=fpr.device)
# Add a single point at max_fpr and interpolate its tpr value
Expand All @@ -124,7 +124,7 @@ def _auroc_compute(
fpr = torch.cat([fpr[:stop], max_fpr.view(1)])

# Compute partial AUC
partial_auc = auc(fpr, tpr)
partial_auc = _auc_compute_without_check(fpr, tpr, 1.0)

# McClish correction: standardize result to be 0.5 if non-discriminant
# and 1 if maximal
Expand Down

0 comments on commit 05c45a5

Please sign in to comment.