diff --git a/CHANGELOG.md b/CHANGELOG.md index 60f4b4d1a51..6692332fbcf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- Added prefix arg to metric collection ([#70](https://github.com/PyTorchLightning/metrics/pull/70)) +- Added `prefix` argument to `MetricCollection` ([#70](https://github.com/PyTorchLightning/metrics/pull/70)) - Added `CohenKappa` metric ([#69](https://github.com/PyTorchLightning/metrics/pull/69)) @@ -18,12 +18,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `RetrievalMAP` metric for Information Retrieval ([#5032](https://github.com/PyTorchLightning/pytorch-lightning/pull/5032)) -- Added `average='micro'` as an option in auroc for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110)) +- Added `average='micro'` as an option in AUROC for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110)) - Added `MatthewsCorrcoef` metric ([#98](https://github.com/PyTorchLightning/metrics/pull/98)) +- Added multilabel support to `ROC` metric ([#114](https://github.com/PyTorchLightning/metrics/pull/114)) + ### Changed - Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68)) diff --git a/tests/classification/test_roc.py b/tests/classification/test_roc.py index e4156bdb7f6..99895e6c855 100644 --- a/tests/classification/test_roc.py +++ b/tests/classification/test_roc.py @@ -22,6 +22,7 @@ from tests.classification.inputs import _input_binary_prob from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob +from tests.classification.inputs import _input_multilabel_multidim_prob, _input_multilabel_prob from tests.helpers.testers import NUM_CLASSES, MetricTester from torchmetrics.classification.roc import ROC from torchmetrics.functional import roc @@ -29,15 +30,19 @@ torch.manual_seed(42) -def _sk_roc_curve(y_true, probas_pred, num_classes=1): +def _sk_roc_curve(y_true, probas_pred, num_classes: int = 1, multilabel: bool = False): """ Adjusted comparison function that can also handles multiclass """ if num_classes == 1: return sk_roc_curve(y_true, probas_pred, drop_intermediate=False) fpr, tpr, thresholds = [], [], [] for i in range(num_classes): - y_true_temp = np.zeros_like(y_true) - y_true_temp[y_true == i] = 1 + if multilabel: + y_true_temp = y_true[:, i] + else: + y_true_temp = np.zeros_like(y_true) + y_true_temp[y_true == i] = 1 + res = sk_roc_curve(y_true_temp, probas_pred[:, i], drop_intermediate=False) fpr.append(res[0]) tpr.append(res[1]) @@ -65,11 +70,40 @@ def _sk_roc_multidim_multiclass_prob(preds, target, num_classes=1): return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) +def _sk_roc_multilabel_prob(preds, target, num_classes=1): + sk_preds = preds.numpy() + sk_target = target.numpy() + return _sk_roc_curve( + y_true=sk_target, + probas_pred=sk_preds, + num_classes=num_classes, + multilabel=True + ) + + +def _sk_roc_multilabel_multidim_prob(preds, target, num_classes=1): + sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() + sk_target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() + return _sk_roc_curve( + y_true=sk_target, + probas_pred=sk_preds, + num_classes=num_classes, + multilabel=True + ) + + @pytest.mark.parametrize( "preds, target, sk_metric, num_classes", [ (_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1), (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES), (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES), + (_input_multilabel_prob.preds, _input_multilabel_prob.target, _sk_roc_multilabel_prob, NUM_CLASSES), + ( + _input_multilabel_multidim_prob.preds, + _input_multilabel_multidim_prob.target, + _sk_roc_multilabel_multidim_prob, + NUM_CLASSES + ) ] ) class TestROC(MetricTester): diff --git a/torchmetrics/classification/roc.py b/torchmetrics/classification/roc.py index d5c9fda74e1..6cab71d7534 100644 --- a/torchmetrics/classification/roc.py +++ b/torchmetrics/classification/roc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch from torch import Tensor @@ -24,13 +24,13 @@ class ROC(Metric): """ Computes the Receiver Operating Characteristic (ROC). Works for both - binary and multiclass problems. In the case of multiclass, the values will + binary, multiclass and multilabel problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach. Forward accepts - - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor - with probabilities, where C is the number of classes. + - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass/multilabel) tensor + with probabilities, where C is the number of classes/labels. - ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels @@ -48,9 +48,12 @@ class ROC(Metric): before returning the value at the step. default: False process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather + + Example (binary case): - Example: - >>> # binary case >>> from torchmetrics import ROC >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 1, 1]) @@ -63,7 +66,9 @@ class ROC(Metric): >>> thresholds tensor([4, 3, 2, 1, 0]) - >>> # multiclass case + Example (multiclass case): + + >>> from torchmetrics import ROC >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05], @@ -81,8 +86,30 @@ class ROC(Metric): tensor([1.7500, 0.7500, 0.0500]), tensor([1.7500, 0.7500, 0.0500])] - """ + Example (multilabel case): + >>> from torchmetrics import ROC + >>> pred = torch.tensor([[0.8191, 0.3680, 0.1138], + ... [0.3584, 0.7576, 0.1183], + ... [0.2286, 0.3468, 0.1338], + ... [0.8603, 0.0745, 0.1837]]) + >>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [0, 0, 0], [0, 1, 1]]) + >>> roc = ROC(num_classes=3, pos_label=1) + >>> fpr, tpr, thresholds = roc(pred, target) + >>> fpr # doctest: +NORMALIZE_WHITESPACE + [tensor([0.0000, 0.3333, 0.3333, 0.6667, 1.0000]), + tensor([0., 0., 0., 1., 1.]), + tensor([0.0000, 0.0000, 0.3333, 0.6667, 1.0000])] + >>> tpr # doctest: +NORMALIZE_WHITESPACE + [tensor([0., 0., 1., 1., 1.]), + tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]), + tensor([0., 1., 1., 1., 1.])] + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + [tensor([1.8603, 0.8603, 0.8191, 0.3584, 0.2286]), + tensor([1.7576, 0.7576, 0.3680, 0.3468, 0.0745]), + tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])] + + """ def __init__( self, num_classes: Optional[int] = None, @@ -90,11 +117,13 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + dist_sync_fn=dist_sync_fn, ) self.num_classes = num_classes diff --git a/torchmetrics/functional/classification/auroc.py b/torchmetrics/functional/classification/auroc.py index 22093a02470..d8c8ddd7be0 100644 --- a/torchmetrics/functional/classification/auroc.py +++ b/torchmetrics/functional/classification/auroc.py @@ -75,7 +75,7 @@ def _auroc_compute( # calculate fpr, tpr if mode == 'multi-label': if average == AverageMethod.MICRO: - fpr, tpr, _ = roc(preds.flatten(), target.flatten(), num_classes, pos_label, sample_weights) + fpr, tpr, _ = roc(preds.flatten(), target.flatten(), 1, pos_label, sample_weights) else: # for multilabel we iteratively evaluate roc in a binary fashion output = [ diff --git a/torchmetrics/functional/classification/precision_recall_curve.py b/torchmetrics/functional/classification/precision_recall_curve.py index 0c2d32dc1d6..7d6881b23f3 100644 --- a/torchmetrics/functional/classification/precision_recall_curve.py +++ b/torchmetrics/functional/classification/precision_recall_curve.py @@ -71,16 +71,28 @@ def _precision_recall_curve_update( ) -> Tuple[Tensor, Tensor, int, int]: if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") - # single class evaluation + if len(preds.shape) == len(target.shape): - num_classes = 1 if pos_label is None: rank_zero_warn('`pos_label` automatically set 1.') pos_label = 1 - preds = preds.flatten() - target = target.flatten() - - # multi class evaluation + if num_classes is not None and num_classes != 1: + # multilabel problem + if num_classes != preds.shape[1]: + raise ValueError( + f'Argument `num_classes` was set to {num_classes} in' + f' metric `precision_recall_curve` but detected {preds.shape[1]}' + ' number of classes from predictions' + ) + preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1) + target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1) + else: + # binary problem + preds = preds.flatten() + target = target.flatten() + num_classes = 1 + + # multi class problem if len(preds.shape) == len(target.shape) + 1: if pos_label is not None: rank_zero_warn( diff --git a/torchmetrics/functional/classification/roc.py b/torchmetrics/functional/classification/roc.py index cee1873a2a0..b7fee4ba620 100644 --- a/torchmetrics/functional/classification/roc.py +++ b/torchmetrics/functional/classification/roc.py @@ -27,8 +27,9 @@ def _roc_update( target: Tensor, num_classes: Optional[int] = None, pos_label: Optional[int] = None, -) -> Tuple[Tensor, Tensor, int, int]: - return _precision_recall_curve_update(preds, target, num_classes, pos_label) +) -> Tuple[Tensor, Tensor, int, int, str]: + preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label) + return preds, target, num_classes, pos_label def _roc_compute( @@ -39,7 +40,7 @@ def _roc_compute( sample_weights: Optional[Sequence] = None, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: - if num_classes == 1: + if num_classes == 1 and preds.ndim == 1: # binary fps, tps, thresholds = _binary_clf_curve( preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label ) @@ -62,12 +63,19 @@ def _roc_compute( # Recursively call per class fpr, tpr, thresholds = [], [], [] for c in range(num_classes): - preds_c = preds[:, c] + if preds.shape == target.shape: + preds_c = preds[:, c] + target_c = target[:, c] + pos_label = 1 + else: + preds_c = preds[:, c] + target_c = target + pos_label = c res = roc( preds=preds_c, - target=target, + target=target_c, num_classes=1, - pos_label=c, + pos_label=pos_label, sample_weights=sample_weights, ) fpr.append(res[0]) @@ -86,6 +94,7 @@ def roc( ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: """ Computes the Receiver Operating Characteristic (ROC). + Works with both binary, multiclass and multilabel input. Args: preds: predictions from model (logits or probabilities) @@ -103,15 +112,16 @@ def roc( fpr: tensor with false positive rates. - If multiclass, this is a list of such tensors, one for each class. + If multiclass or multilabel, this is a list of such tensors, one for each class/label. tpr: tensor with true positive rates. - If multiclass, this is a list of such tensors, one for each class. + If multiclass or multilabel, this is a list of such tensors, one for each class/label. thresholds: - thresholds used for computing false- and true postive rates + tensor with thresholds used for computing false- and true postive rates + If multiclass or multilabel, this is a list of such tensors, one for each class/label. + + Example (binary case): - Example: - >>> # binary case >>> from torchmetrics.functional import roc >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 1, 1]) @@ -123,7 +133,9 @@ def roc( >>> thresholds tensor([4, 3, 2, 1, 0]) - >>> # multiclass case + Example (multiclass case): + + >>> from torchmetrics.functional import roc >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05], @@ -139,6 +151,27 @@ def roc( tensor([1.7500, 0.7500, 0.0500]), tensor([1.7500, 0.7500, 0.0500]), tensor([1.7500, 0.7500, 0.0500])] + + Example (multilabel case): + + >>> from torchmetrics.functional import roc + >>> pred = torch.tensor([[0.8191, 0.3680, 0.1138], + ... [0.3584, 0.7576, 0.1183], + ... [0.2286, 0.3468, 0.1338], + ... [0.8603, 0.0745, 0.1837]]) + >>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [0, 0, 0], [0, 1, 1]]) + >>> fpr, tpr, thresholds = roc(pred, target, num_classes=3, pos_label=1) + >>> fpr # doctest: +NORMALIZE_WHITESPACE + [tensor([0.0000, 0.3333, 0.3333, 0.6667, 1.0000]), + tensor([0., 0., 0., 1., 1.]), + tensor([0.0000, 0.0000, 0.3333, 0.6667, 1.0000])] + >>> tpr + [tensor([0., 0., 1., 1., 1.]), tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]), tensor([0., 1., 1., 1., 1.])] + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + [tensor([1.8603, 0.8603, 0.8191, 0.3584, 0.2286]), + tensor([1.7576, 0.7576, 0.3680, 0.3468, 0.0745]), + tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])] + """ preds, target, num_classes, pos_label = _roc_update(preds, target, num_classes, pos_label) return _roc_compute(preds, target, num_classes, pos_label, sample_weights)