Skip to content

Commit

Permalink
Multilabel support in ROC (#114)
Browse files Browse the repository at this point in the history
* multilabel_roc_supp

* formatting

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
SkafteNicki and Borda authored Mar 23, 2021
1 parent fc6c8ef commit 2af13fb
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 32 deletions.
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added prefix arg to metric collection ([#70](https://github.com/PyTorchLightning/metrics/pull/70))
- Added `prefix` argument to `MetricCollection` ([#70](https://github.com/PyTorchLightning/metrics/pull/70))


- Added `CohenKappa` metric ([#69](https://github.com/PyTorchLightning/metrics/pull/69))
Expand All @@ -18,12 +18,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `RetrievalMAP` metric for Information Retrieval ([#5032](https://github.com/PyTorchLightning/pytorch-lightning/pull/5032))


- Added `average='micro'` as an option in auroc for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110))
- Added `average='micro'` as an option in AUROC for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110))


- Added `MatthewsCorrcoef` metric ([#98](https://github.com/PyTorchLightning/metrics/pull/98))


- Added multilabel support to `ROC` metric ([#114](https://github.com/PyTorchLightning/metrics/pull/114))

### Changed

- Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68))
Expand Down
40 changes: 37 additions & 3 deletions tests/classification/test_roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,27 @@
from tests.classification.inputs import _input_binary_prob
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.classification.inputs import _input_multilabel_multidim_prob, _input_multilabel_prob
from tests.helpers.testers import NUM_CLASSES, MetricTester
from torchmetrics.classification.roc import ROC
from torchmetrics.functional import roc

torch.manual_seed(42)


def _sk_roc_curve(y_true, probas_pred, num_classes=1):
def _sk_roc_curve(y_true, probas_pred, num_classes: int = 1, multilabel: bool = False):
""" Adjusted comparison function that can also handles multiclass """
if num_classes == 1:
return sk_roc_curve(y_true, probas_pred, drop_intermediate=False)

fpr, tpr, thresholds = [], [], []
for i in range(num_classes):
y_true_temp = np.zeros_like(y_true)
y_true_temp[y_true == i] = 1
if multilabel:
y_true_temp = y_true[:, i]
else:
y_true_temp = np.zeros_like(y_true)
y_true_temp[y_true == i] = 1

res = sk_roc_curve(y_true_temp, probas_pred[:, i], drop_intermediate=False)
fpr.append(res[0])
tpr.append(res[1])
Expand Down Expand Up @@ -65,11 +70,40 @@ def _sk_roc_multidim_multiclass_prob(preds, target, num_classes=1):
return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)


def _sk_roc_multilabel_prob(preds, target, num_classes=1):
sk_preds = preds.numpy()
sk_target = target.numpy()
return _sk_roc_curve(
y_true=sk_target,
probas_pred=sk_preds,
num_classes=num_classes,
multilabel=True
)


def _sk_roc_multilabel_multidim_prob(preds, target, num_classes=1):
sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
sk_target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
return _sk_roc_curve(
y_true=sk_target,
probas_pred=sk_preds,
num_classes=num_classes,
multilabel=True
)


@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes", [
(_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES),
(_input_multilabel_prob.preds, _input_multilabel_prob.target, _sk_roc_multilabel_prob, NUM_CLASSES),
(
_input_multilabel_multidim_prob.preds,
_input_multilabel_multidim_prob.target,
_sk_roc_multilabel_multidim_prob,
NUM_CLASSES
)
]
)
class TestROC(MetricTester):
Expand Down
45 changes: 37 additions & 8 deletions torchmetrics/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union

import torch
from torch import Tensor
Expand All @@ -24,13 +24,13 @@
class ROC(Metric):
"""
Computes the Receiver Operating Characteristic (ROC). Works for both
binary and multiclass problems. In the case of multiclass, the values will
binary, multiclass and multilabel problems. In the case of multiclass, the values will
be calculated based on a one-vs-the-rest approach.
Forward accepts
- ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor
with probabilities, where C is the number of classes.
- ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass/multilabel) tensor
with probabilities, where C is the number of classes/labels.
- ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels
Expand All @@ -48,9 +48,12 @@ class ROC(Metric):
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather
Example (binary case):
Example:
>>> # binary case
>>> from torchmetrics import ROC
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 1])
Expand All @@ -63,7 +66,9 @@ class ROC(Metric):
>>> thresholds
tensor([4, 3, 2, 1, 0])
>>> # multiclass case
Example (multiclass case):
>>> from torchmetrics import ROC
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05],
... [0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.75, 0.05],
Expand All @@ -81,20 +86,44 @@ class ROC(Metric):
tensor([1.7500, 0.7500, 0.0500]),
tensor([1.7500, 0.7500, 0.0500])]
"""
Example (multilabel case):
>>> from torchmetrics import ROC
>>> pred = torch.tensor([[0.8191, 0.3680, 0.1138],
... [0.3584, 0.7576, 0.1183],
... [0.2286, 0.3468, 0.1338],
... [0.8603, 0.0745, 0.1837]])
>>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [0, 0, 0], [0, 1, 1]])
>>> roc = ROC(num_classes=3, pos_label=1)
>>> fpr, tpr, thresholds = roc(pred, target)
>>> fpr # doctest: +NORMALIZE_WHITESPACE
[tensor([0.0000, 0.3333, 0.3333, 0.6667, 1.0000]),
tensor([0., 0., 0., 1., 1.]),
tensor([0.0000, 0.0000, 0.3333, 0.6667, 1.0000])]
>>> tpr # doctest: +NORMALIZE_WHITESPACE
[tensor([0., 0., 1., 1., 1.]),
tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]),
tensor([0., 1., 1., 1., 1.])]
>>> thresholds # doctest: +NORMALIZE_WHITESPACE
[tensor([1.8603, 0.8603, 0.8191, 0.3584, 0.2286]),
tensor([1.7576, 0.7576, 0.3680, 0.3468, 0.0745]),
tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])]
"""
def __init__(
self,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)

self.num_classes = num_classes
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/functional/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _auroc_compute(
# calculate fpr, tpr
if mode == 'multi-label':
if average == AverageMethod.MICRO:
fpr, tpr, _ = roc(preds.flatten(), target.flatten(), num_classes, pos_label, sample_weights)
fpr, tpr, _ = roc(preds.flatten(), target.flatten(), 1, pos_label, sample_weights)
else:
# for multilabel we iteratively evaluate roc in a binary fashion
output = [
Expand Down
24 changes: 18 additions & 6 deletions torchmetrics/functional/classification/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,28 @@ def _precision_recall_curve_update(
) -> Tuple[Tensor, Tensor, int, int]:
if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1):
raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds")
# single class evaluation

if len(preds.shape) == len(target.shape):
num_classes = 1
if pos_label is None:
rank_zero_warn('`pos_label` automatically set 1.')
pos_label = 1
preds = preds.flatten()
target = target.flatten()

# multi class evaluation
if num_classes is not None and num_classes != 1:
# multilabel problem
if num_classes != preds.shape[1]:
raise ValueError(
f'Argument `num_classes` was set to {num_classes} in'
f' metric `precision_recall_curve` but detected {preds.shape[1]}'
' number of classes from predictions'
)
preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1)
target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1)
else:
# binary problem
preds = preds.flatten()
target = target.flatten()
num_classes = 1

# multi class problem
if len(preds.shape) == len(target.shape) + 1:
if pos_label is not None:
rank_zero_warn(
Expand Down
57 changes: 45 additions & 12 deletions torchmetrics/functional/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ def _roc_update(
target: Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
) -> Tuple[Tensor, Tensor, int, int]:
return _precision_recall_curve_update(preds, target, num_classes, pos_label)
) -> Tuple[Tensor, Tensor, int, int, str]:
preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label)
return preds, target, num_classes, pos_label


def _roc_compute(
Expand All @@ -39,7 +40,7 @@ def _roc_compute(
sample_weights: Optional[Sequence] = None,
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:

if num_classes == 1:
if num_classes == 1 and preds.ndim == 1: # binary
fps, tps, thresholds = _binary_clf_curve(
preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label
)
Expand All @@ -62,12 +63,19 @@ def _roc_compute(
# Recursively call per class
fpr, tpr, thresholds = [], [], []
for c in range(num_classes):
preds_c = preds[:, c]
if preds.shape == target.shape:
preds_c = preds[:, c]
target_c = target[:, c]
pos_label = 1
else:
preds_c = preds[:, c]
target_c = target
pos_label = c
res = roc(
preds=preds_c,
target=target,
target=target_c,
num_classes=1,
pos_label=c,
pos_label=pos_label,
sample_weights=sample_weights,
)
fpr.append(res[0])
Expand All @@ -86,6 +94,7 @@ def roc(
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""
Computes the Receiver Operating Characteristic (ROC).
Works with both binary, multiclass and multilabel input.
Args:
preds: predictions from model (logits or probabilities)
Expand All @@ -103,15 +112,16 @@ def roc(
fpr:
tensor with false positive rates.
If multiclass, this is a list of such tensors, one for each class.
If multiclass or multilabel, this is a list of such tensors, one for each class/label.
tpr:
tensor with true positive rates.
If multiclass, this is a list of such tensors, one for each class.
If multiclass or multilabel, this is a list of such tensors, one for each class/label.
thresholds:
thresholds used for computing false- and true postive rates
tensor with thresholds used for computing false- and true postive rates
If multiclass or multilabel, this is a list of such tensors, one for each class/label.
Example (binary case):
Example:
>>> # binary case
>>> from torchmetrics.functional import roc
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 1])
Expand All @@ -123,7 +133,9 @@ def roc(
>>> thresholds
tensor([4, 3, 2, 1, 0])
>>> # multiclass case
Example (multiclass case):
>>> from torchmetrics.functional import roc
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05],
... [0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.75, 0.05],
Expand All @@ -139,6 +151,27 @@ def roc(
tensor([1.7500, 0.7500, 0.0500]),
tensor([1.7500, 0.7500, 0.0500]),
tensor([1.7500, 0.7500, 0.0500])]
Example (multilabel case):
>>> from torchmetrics.functional import roc
>>> pred = torch.tensor([[0.8191, 0.3680, 0.1138],
... [0.3584, 0.7576, 0.1183],
... [0.2286, 0.3468, 0.1338],
... [0.8603, 0.0745, 0.1837]])
>>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [0, 0, 0], [0, 1, 1]])
>>> fpr, tpr, thresholds = roc(pred, target, num_classes=3, pos_label=1)
>>> fpr # doctest: +NORMALIZE_WHITESPACE
[tensor([0.0000, 0.3333, 0.3333, 0.6667, 1.0000]),
tensor([0., 0., 0., 1., 1.]),
tensor([0.0000, 0.0000, 0.3333, 0.6667, 1.0000])]
>>> tpr
[tensor([0., 0., 1., 1., 1.]), tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]), tensor([0., 1., 1., 1., 1.])]
>>> thresholds # doctest: +NORMALIZE_WHITESPACE
[tensor([1.8603, 0.8603, 0.8191, 0.3584, 0.2286]),
tensor([1.7576, 0.7576, 0.3680, 0.3468, 0.0745]),
tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])]
"""
preds, target, num_classes, pos_label = _roc_update(preds, target, num_classes, pos_label)
return _roc_compute(preds, target, num_classes, pos_label, sample_weights)

0 comments on commit 2af13fb

Please sign in to comment.