Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metrics] AUC/AUROC class interface #5479

Merged
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
f1af411
base files
SkafteNicki Jan 4, 2021
70678eb
auc done
SkafteNicki Jan 4, 2021
0b8efe9
init files
SkafteNicki Jan 4, 2021
91bf148
auc class interface
SkafteNicki Jan 4, 2021
5af053e
fixing auc
SkafteNicki Jan 5, 2021
a695aea
more fixes
SkafteNicki Jan 6, 2021
2e6380f
working auroc
SkafteNicki Jan 11, 2021
99b49dc
update auc
SkafteNicki Jan 11, 2021
e325f4f
add docs
SkafteNicki Jan 12, 2021
6a1f610
remove leftovers from merge
SkafteNicki Jan 12, 2021
076facf
suggestions
SkafteNicki Jan 12, 2021
5622adf
fix f-string
SkafteNicki Jan 12, 2021
563b30e
Apply suggestions from code review
SkafteNicki Jan 13, 2021
0b93fc4
add deprecated tests
SkafteNicki Jan 13, 2021
d7e999a
make logic clearer
SkafteNicki Jan 13, 2021
8116b6e
Update pytorch_lightning/metrics/classification/auroc.py
SkafteNicki Jan 13, 2021
763536a
Merge branch 'release/1.2-dev' into metrics/auc_auroc
SkafteNicki Jan 14, 2021
bcfdf84
fix
SkafteNicki Jan 14, 2021
15512d9
fix
SkafteNicki Jan 14, 2021
a3ceeee
fix docs
SkafteNicki Jan 14, 2021
52b1bbc
fix isort
SkafteNicki Jan 14, 2021
712f7f1
fix deprecated test
SkafteNicki Jan 14, 2021
a4d94c1
fix tests
SkafteNicki Jan 14, 2021
20559fc
Merge branch 'release/1.2-dev' into metrics/auc_auroc
SkafteNicki Jan 19, 2021
d1a9c2f
fix tests
SkafteNicki Jan 20, 2021
1eefc6f
fix isort
SkafteNicki Jan 20, 2021
58b1dc7
Apply suggestions from code review
Borda Jan 24, 2021
e03d5c2
Merge branch 'release/1.2-dev' into metrics/auc_auroc
mergify[bot] Jan 24, 2021
6ae1add
Merge branch 'release/1.2-dev' into metrics/auc_auroc
mergify[bot] Jan 24, 2021
7fe97d5
Merge branch 'release/1.2-dev' into metrics/auc_auroc
mergify[bot] Jan 24, 2021
ef1f954
Merge branch 'release/1.2-dev' into metrics/auc_auroc
mergify[bot] Jan 24, 2021
54920c1
Merge branch 'release/1.2-dev' into metrics/auc_auroc
mergify[bot] Jan 24, 2021
f384c1d
add enum
SkafteNicki Jan 25, 2021
c3b1de8
deprecate old impl
SkafteNicki Jan 25, 2021
67dfd09
merge
SkafteNicki Jan 25, 2021
f747923
update from suggestions
SkafteNicki Jan 25, 2021
156a947
chlog
Borda Jan 26, 2021
35a2f8d
Merge branch 'release/1.2-dev' into metrics/auc_auroc
Borda Jan 26, 2021
b37393f
Merge branch 'release/1.2-dev' into metrics/auc_auroc
SkafteNicki Jan 26, 2021
d69bd68
Merge branch 'release/1.2-dev' into metrics/auc_auroc
mergify[bot] Jan 26, 2021
ec0e65a
merge
SkafteNicki Jan 27, 2021
6d5f889
Merge branch 'release/1.2-dev' into metrics/auc_auroc
mergify[bot] Jan 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,18 @@ AveragePrecision
.. autoclass:: pytorch_lightning.metrics.classification.AveragePrecision
:noindex:

AUC
~~~

.. autoclass:: pytorch_lightning.metrics.classification.AUC
:noindex:

AUROC
~~~~~

.. autoclass:: pytorch_lightning.metrics.classification.AUROC
:noindex:

ConfusionMatrix
~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -524,24 +536,18 @@ accuracy [func]
.. autofunction:: pytorch_lightning.metrics.functional.accuracy
:noindex:


auc [func]
~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.classification.auc
.. autofunction:: pytorch_lightning.metrics.functional.auc
:noindex:


auroc [func]
~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.classification.auroc
:noindex:


multiclass_auroc [func]
~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.classification.multiclass_auroc
.. autofunction:: pytorch_lightning.metrics.functional.auroc
:noindex:


Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
ROC,
FBeta,
F1,
StatScores
StatScores,
AUC,
AUROC
)

from pytorch_lightning.metrics.regression import ( # noqa: F401
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/metrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.metrics.classification.accuracy import Accuracy # noqa: F401
from pytorch_lightning.metrics.classification.auc import AUC # noqa: F401
from pytorch_lightning.metrics.classification.auroc import AUROC # noqa: F401
from pytorch_lightning.metrics.classification.average_precision import AveragePrecision # noqa: F401
from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix # noqa: F401
from pytorch_lightning.metrics.classification.f_beta import FBeta, F1 # noqa: F401
Expand Down
89 changes: 89 additions & 0 deletions pytorch_lightning/metrics/classification/auc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional

import torch

from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.metrics.functional.auc import _auc_update, _auc_compute
from pytorch_lightning.utilities import rank_zero_warn


class AUC(Metric):
r"""
Computes Area Under the Curve (AUC) using the trapezoidal rule

Forward accepts two input tensors that should be 1D and have the same number
of elements

Args:
reorder: AUC expects its first input to be sorted. If this is not the case,
setting this argument to ``True`` will use a stable sorting algorithm to
sort the input in decending order
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather
"""
def __init__(
self,
reorder: bool = False,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)

self.reorder = reorder

self.add_state("x", default=[], dist_reduce_fx=None)
self.add_state("y", default=[], dist_reduce_fx=None)

rank_zero_warn(
'Metric `AUC` will save all targets and predictions in buffer.'
' For large datasets this may lead to large memory footprint.'
)

def update(self, x: torch.Tensor, y: torch.Tensor):
"""
Update state with predictions and targets.

Args:
x: Predictions from model (probabilities, or labels)
y: Ground truth labels
"""
x, y = _auc_update(x, y)

self.x.append(x)
self.y.append(y)

def compute(self) -> torch.Tensor:
"""
Computes AUC based on inputs passed in to ``update`` previously.
"""
x = torch.cat(self.x, dim=0)
y = torch.cat(self.y, dim=0)
return _auc_compute(x, y, reorder=self.reorder)
163 changes: 163 additions & 0 deletions pytorch_lightning/metrics/classification/auroc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional

import torch
from distutils.version import LooseVersion

from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.metrics.functional.auroc import _auroc_compute, _auroc_update
from pytorch_lightning.utilities import rank_zero_warn


class AUROC(Metric):
r"""Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC)
<https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Further_interpretations>`_.
Works for both binary, multilabel and multiclass problems. In the case of
multiclass, the values will be calculated based on a one-vs-the-rest approach.

Forward accepts

- ``preds`` (float tensor): ``(N, )`` (binary) or ``(N, C, ...)`` (multilabel/multiclass)
where C is the number of classes

- ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)``

For non-binary input, if the ``preds`` and ``target`` tensor have the same
size the input will be interpretated as multilabel and if ``preds`` have one
dimension more than the ``target`` tensor the input will be interpretated as
multiclass.

Args:
num_classes: integer with number of classes. Not nessesary to provide
for binary problems.
pos_label: integer determining the positive class. Default is ``None``
which for binary problem is translate to 1. For multiclass problems
this argument should not be set as we iteratively change it in the
range [0,num_classes-1]
average:
- ``'macro'`` computes metric for each class and uniformly averages them
- ``'weighted'`` computes metric for each class and does a weighted-average,
where each class is weighted by their support (accounts for class imbalance)
- ``None`` computes and returns the metric per class
max_fpr:
If not ``None``, calculates standardized partial AUC over the
range [0, max_fpr]. Should be a float between 0 and 1.
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather

Example (binary case):

>>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34])
>>> target = torch.tensor([0, 0, 1, 1, 1])
>>> auroc = AUROC(pos_label=1)
>>> auroc(preds, target)
tensor(0.5000)

Example (multiclass case):

>>> preds = torch.tensor([[0.90, 0.05, 0.05],
... [0.05, 0.90, 0.05],
... [0.05, 0.05, 0.90],
... [0.85, 0.05, 0.10],
... [0.10, 0.10, 0.80]])
>>> target = torch.tensor([0, 1, 1, 2, 2])
>>> auroc = AUROC(num_classes=3)
>>> auroc(preds, target)
tensor(0.7778)

"""
def __init__(
self,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
average: Optional[str] = 'macro',
max_fpr: Optional[float] = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)

self.num_classes = num_classes
self.pos_label = pos_label
self.average = average
self.max_fpr = max_fpr

allowed_average = (None, 'macro', 'weighted')
if self.average not in allowed_average:
raise ValueError('Argument `average` expected to be one of the following:'
f' {allowed_average} but got {average}')

if self.max_fpr is not None:
if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1):
raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}")

if LooseVersion(torch.__version__) < LooseVersion('1.6.0'):
raise RuntimeError(
'`max_fpr` argument requires `torch.bucketize` which is not available below PyTorch version 1.6'
)

self.mode = None
self.add_state("preds", default=[], dist_reduce_fx=None)
self.add_state("target", default=[], dist_reduce_fx=None)

rank_zero_warn(
'Metric `AUROC` will save all targets and predictions in buffer.'
' For large datasets this may lead to large memory footprint.'
)

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.

Args:
preds: Predictions from model (probabilities, or labels)
target: Ground truth labels
"""
preds, target, mode = _auroc_update(preds, target)

self.preds.append(preds)
self.target.append(target)
self.mode = mode
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

def compute(self) -> torch.Tensor:
"""
Computes AUROC based on inputs passed in to ``update`` previously.
"""
preds = torch.cat(self.preds, dim=0)
target = torch.cat(self.target, dim=0)
return _auroc_compute(
preds,
target,
self.mode,
self.num_classes,
self.pos_label,
self.average,
self.max_fpr
)
6 changes: 3 additions & 3 deletions pytorch_lightning/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.metrics.functional.average_precision import average_precision # noqa: F401
from pytorch_lightning.metrics.functional.classification import ( # noqa: F401
auc,
auroc,
dice_score,
get_num_classes,
multiclass_auroc,
Expand All @@ -27,6 +24,9 @@
)
# TODO: unify metrics between class and functional, add below
from pytorch_lightning.metrics.functional.accuracy import accuracy # noqa: F401
from pytorch_lightning.metrics.functional.auc import auc # noqa: F401
from pytorch_lightning.metrics.functional.auroc import auroc # noqa: F401
from pytorch_lightning.metrics.functional.average_precision import average_precision # noqa: F401
from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # noqa: F401
from pytorch_lightning.metrics.functional.explained_variance import explained_variance # noqa: F401
from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 # noqa: F401
Expand Down
Loading