From 9f319600cc2027e89ce704e7f3eec26bc0e5e74b Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 17 Mar 2021 23:23:22 +0100 Subject: [PATCH 01/10] class: AUC AUROC --- .../metrics/classification/auc.py | 67 +------- .../metrics/classification/auroc.py | 158 +----------------- .../deprecated_api/test_remove_1-5_metrics.py | 36 ++-- 3 files changed, 39 insertions(+), 222 deletions(-) diff --git a/pytorch_lightning/metrics/classification/auc.py b/pytorch_lightning/metrics/classification/auc.py index 76c1959a8603a..ce28e1d4e7072 100644 --- a/pytorch_lightning/metrics/classification/auc.py +++ b/pytorch_lightning/metrics/classification/auc.py @@ -13,36 +13,14 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch -from torchmetrics import Metric +from torchmetrics import AUC as _AUC -from pytorch_lightning.metrics.functional.auc import _auc_compute, _auc_update -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.deprecation import deprecated -class AUC(Metric): - r""" - Computes Area Under the Curve (AUC) using the trapezoidal rule - - Forward accepts two input tensors that should be 1D and have the same number - of elements - - Args: - reorder: AUC expects its first input to be sorted. If this is not the case, - setting this argument to ``True`` will use a stable sorting algorithm to - sort the input in decending order - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When ``None``, DDP - will be used to perform the allgather - """ +class AUC(_AUC): + @deprecated(target=_AUC, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, reorder: bool = False, @@ -51,40 +29,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.reorder = reorder - - self.add_state("x", default=[], dist_reduce_fx=None) - self.add_state("y", default=[], dist_reduce_fx=None) - - rank_zero_warn( - 'Metric `AUC` will save all targets and predictions in buffer.' - ' For large datasets this may lead to large memory footprint.' - ) - - def update(self, x: torch.Tensor, y: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - x: Predictions from model (probabilities, or labels) - y: Ground truth labels """ - x, y = _auc_update(x, y) + This implementation refers to :class:`~torchmetrics.AUC`. - self.x.append(x) - self.y.append(y) - - def compute(self) -> torch.Tensor: - """ - Computes AUC based on inputs passed in to ``update`` previously. + .. deprecated:: + Use :class:`~torchmetrics.AUC`. Will be removed in v1.5.0. """ - x = torch.cat(self.x, dim=0) - y = torch.cat(self.y, dim=0) - return _auc_compute(x, y, reorder=self.reorder) diff --git a/pytorch_lightning/metrics/classification/auroc.py b/pytorch_lightning/metrics/classification/auroc.py index 7d8ba7368e45d..0866406ecea8f 100644 --- a/pytorch_lightning/metrics/classification/auroc.py +++ b/pytorch_lightning/metrics/classification/auroc.py @@ -11,95 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from distutils.version import LooseVersion from typing import Any, Callable, Optional -import torch -from torchmetrics import Metric +from torchmetrics import AUROC as _AUROC -from pytorch_lightning.metrics.functional.auroc import _auroc_compute, _auroc_update -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.deprecation import deprecated -class AUROC(Metric): - r"""Compute `Area Under the Receiver Operating Characteristic Curve (ROC AUC) - `_. - Works for both binary, multilabel and multiclass problems. In the case of - multiclass, the values will be calculated based on a one-vs-the-rest approach. - - Forward accepts - - - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor - with probabilities, where C is the number of classes. - - - ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels - - For non-binary input, if the ``preds`` and ``target`` tensor have the same - size the input will be interpretated as multilabel and if ``preds`` have one - dimension more than the ``target`` tensor the input will be interpretated as - multiclass. - - Args: - num_classes: integer with number of classes. Not nessesary to provide - for binary problems. - pos_label: integer determining the positive class. Default is ``None`` - which for binary problem is translate to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range [0,num_classes-1] - average: - - ``'macro'`` computes metric for each class and uniformly averages them - - ``'weighted'`` computes metric for each class and does a weighted-average, - where each class is weighted by their support (accounts for class imbalance) - - ``None`` computes and returns the metric per class - max_fpr: - If not ``None``, calculates standardized partial AUC over the - range [0, max_fpr]. Should be a float between 0 and 1. - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When ``None``, DDP - will be used to perform the allgather - - Raises: - ValueError: - If ``average`` is none of ``None``, ``"macro"`` or ``"weighted"``. - ValueError: - If ``max_fpr`` is not a ``float`` in the range ``(0, 1]``. - RuntimeError: - If ``PyTorch version`` is ``below 1.6`` since max_fpr requires ``torch.bucketize`` - which is not available below 1.6. - ValueError: - If the mode of data (binary, multi-label, multi-class) changes between batches. - - Example (binary case): - - >>> from pytorch_lightning.metrics import AUROC - >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) - >>> target = torch.tensor([0, 0, 1, 1, 1]) - >>> auroc = AUROC(pos_label=1) - >>> auroc(preds, target) - tensor(0.5000) - - Example (multiclass case): - - >>> from pytorch_lightning.metrics import AUROC - >>> preds = torch.tensor([[0.90, 0.05, 0.05], - ... [0.05, 0.90, 0.05], - ... [0.05, 0.05, 0.90], - ... [0.85, 0.05, 0.10], - ... [0.10, 0.10, 0.80]]) - >>> target = torch.tensor([0, 1, 1, 2, 2]) - >>> auroc = AUROC(num_classes=3) - >>> auroc(preds, target) - tensor(0.7778) - - """ +class AUROC(_AUROC): + @deprecated(target=_AUROC, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, num_classes: Optional[int] = None, @@ -111,74 +32,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.num_classes = num_classes - self.pos_label = pos_label - self.average = average - self.max_fpr = max_fpr - - allowed_average = (None, 'macro', 'weighted') - if self.average not in allowed_average: - raise ValueError( - f'Argument `average` expected to be one of the following: {allowed_average} but got {average}' - ) - - if self.max_fpr is not None: - if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1): - raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}") - - if LooseVersion(torch.__version__) < LooseVersion('1.6.0'): - raise RuntimeError( - '`max_fpr` argument requires `torch.bucketize` which is not available below PyTorch version 1.6' - ) - - self.mode = None - self.add_state("preds", default=[], dist_reduce_fx=None) - self.add_state("target", default=[], dist_reduce_fx=None) - - rank_zero_warn( - 'Metric `AUROC` will save all targets and predictions in buffer.' - ' For large datasets this may lead to large memory footprint.' - ) - - def update(self, preds: torch.Tensor, target: torch.Tensor): """ - Update state with predictions and targets. + This implementation refers to :class:`~torchmetrics.AUROC`. - Args: - preds: Predictions from model (probabilities, or labels) - target: Ground truth labels - """ - preds, target, mode = _auroc_update(preds, target) - - self.preds.append(preds) - self.target.append(target) - - if self.mode is not None and self.mode != mode: - raise ValueError( - 'The mode of data (binary, multi-label, multi-class) should be constant, but changed' - f' between batches from {self.mode} to {mode}' - ) - self.mode = mode - - def compute(self) -> torch.Tensor: - """ - Computes AUROC based on inputs passed in to ``update`` previously. + .. deprecated:: + Use :class:`~torchmetrics.AUROC`. Will be removed in v1.5.0. """ - preds = torch.cat(self.preds, dim=0) - target = torch.cat(self.target, dim=0) - return _auroc_compute( - preds, - target, - self.mode, - self.num_classes, - self.pos_label, - self.average, - self.max_fpr, - ) diff --git a/tests/deprecated_api/test_remove_1-5_metrics.py b/tests/deprecated_api/test_remove_1-5_metrics.py index 239241dfac2ed..96e97d0b6c52e 100644 --- a/tests/deprecated_api/test_remove_1-5_metrics.py +++ b/tests/deprecated_api/test_remove_1-5_metrics.py @@ -16,7 +16,7 @@ import pytest import torch -from pytorch_lightning.metrics import Accuracy, MetricCollection +from pytorch_lightning.metrics import Accuracy, MetricCollection, AUC, AUROC from pytorch_lightning.metrics.functional.accuracy import accuracy from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot @@ -38,16 +38,6 @@ def test_v1_5_metrics_utils(): assert torch.equal(to_categorical(x), torch.Tensor([1, 0]).to(int)) -def test_v1_5_metric_accuracy(): - accuracy.warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert accuracy(preds=torch.tensor([0, 1]), target=torch.tensor([0, 1])) == torch.tensor(1.) - - Accuracy.__init__.warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - Accuracy() - - def test_v1_5_metrics_collection(): target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) @@ -59,3 +49,27 @@ def test_v1_5_metrics_collection(): ): metrics = MetricCollection([Accuracy()]) assert metrics(preds, target) == {'Accuracy': torch.tensor(0.1250)} + + +def test_v1_5_metric_accuracy(): + accuracy.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert accuracy(preds=torch.tensor([0, 1]), target=torch.tensor([0, 1])) == torch.tensor(1.) + + Accuracy.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + Accuracy() + + +def test_v1_5_metric_auc(): + AUC.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + AUC() + + AUROC.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + AUROC() + + # auc.warned = False + # with pytest.deprecated_call(match='It will be removed in v1.5.0'): + # assert accuracy(preds=torch.tensor([0, 1]), target=torch.tensor([0, 1])) == torch.tensor(1.) From d4c549bdf9b6822bcaee72199d67e222b1aede8f Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 17 Mar 2021 23:31:45 +0100 Subject: [PATCH 02/10] func: auc auroc --- pytorch_lightning/metrics/functional/auc.py | 66 +------ pytorch_lightning/metrics/functional/auroc.py | 179 +----------------- .../deprecated_api/test_remove_1-5_metrics.py | 18 +- 3 files changed, 26 insertions(+), 237 deletions(-) diff --git a/pytorch_lightning/metrics/functional/auc.py b/pytorch_lightning/metrics/functional/auc.py index cc5c9cf889b7a..7cd3457789bf7 100644 --- a/pytorch_lightning/metrics/functional/auc.py +++ b/pytorch_lightning/metrics/functional/auc.py @@ -11,71 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple - import torch -from torchmetrics.utilities.data import _stable_1d_sort - - -def _auc_update(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - if x.ndim > 1 or y.ndim > 1: - raise ValueError( - f'Expected both `x` and `y` tensor to be 1d, but got' - f' tensors with dimention {x.ndim} and {y.ndim}' - ) - if x.numel() != y.numel(): - raise ValueError( - f'Expected the same number of elements in `x` and `y`' - f' tensor but received {x.numel()} and {y.numel()}' - ) - return x, y - - -def _auc_compute(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> torch.Tensor: - if reorder: - x, x_idx = _stable_1d_sort(x) - y = y[x_idx] +from torchmetrics.functional import auc as _auc - dx = x[1:] - x[:-1] - if (dx < 0).any(): - if (dx <= 0).all(): - direction = -1. - else: - raise ValueError( - "The `x` tensor is neither increasing or decreasing." - " Try setting the reorder argument to `True`." - ) - else: - direction = 1. - return direction * torch.trapz(y, x) +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_auc, ver_deprecate="1.3.0", ver_remove="1.5.0") def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> torch.Tensor: """ - Computes Area Under the Curve (AUC) using the trapezoidal rule - - Args: - x: x-coordinates - y: y-coordinates - reorder: if True, will reorder the arrays - - Return: - Tensor containing AUC score (float) - - Raises: - ValueError: - If both ``x`` and ``y`` tensors are not ``1d``. - ValueError: - If both ``x`` and ``y`` don't have the same numnber of elements. - ValueError: - If ``x`` tesnsor is neither increasing or decreasing. - - Example: - >>> from pytorch_lightning.metrics.functional import auc - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 2, 2]) - >>> auc(x, y) - tensor(4.) + .. deprecated:: + Use :func:`torchmetrics.functional.auc`. Will be removed in v1.5.0. """ - x, y = _auc_update(x, y) - return _auc_compute(x, y, reorder=reorder) diff --git a/pytorch_lightning/metrics/functional/auroc.py b/pytorch_lightning/metrics/functional/auroc.py index e772b5050260c..ea29146febcdb 100644 --- a/pytorch_lightning/metrics/functional/auroc.py +++ b/pytorch_lightning/metrics/functional/auroc.py @@ -11,130 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from distutils.version import LooseVersion from typing import Optional, Sequence, Tuple import torch -from torchmetrics.classification.checks import _input_format_classification -from torchmetrics.utilities.enums import DataType +from torchmetrics.functional import auroc as _auroc -from pytorch_lightning.metrics.functional.auc import auc -from pytorch_lightning.metrics.functional.roc import roc -from pytorch_lightning.utilities import LightningEnum - - -class AverageMethods(LightningEnum): - """ Type of averages """ - MACRO = 'macro' - WEIGHTED = 'weighted' - NONE = None - - -def _auroc_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, str]: - # use _input_format_classification for validating the input and get the mode of data - _, _, mode = _input_format_classification(preds, target) - - if mode == 'multi class multi dim': - n_classes = preds.shape[1] - preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1) - target = target.flatten() - if mode == 'multi-label' and preds.ndim > 2: - n_classes = preds.shape[1] - preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1) - target = target.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1) - - return preds, target, mode - - -def _auroc_compute( - preds: torch.Tensor, - target: torch.Tensor, - mode: str, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, - average: Optional[str] = 'macro', - max_fpr: Optional[float] = None, - sample_weights: Optional[Sequence] = None, -) -> torch.Tensor: - # binary mode override num_classes - if mode == 'binary': - num_classes = 1 - - # check max_fpr parameter - if max_fpr is not None: - if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1): - raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}") - - if LooseVersion(torch.__version__) < LooseVersion('1.6.0'): - raise RuntimeError( - "`max_fpr` argument requires `torch.bucketize` which" - " is not available below PyTorch version 1.6" - ) - - # max_fpr parameter is only support for binary - if mode != 'binary': - raise ValueError( - f"Partial AUC computation not available in" - f" multilabel/multiclass setting, 'max_fpr' must be" - f" set to `None`, received `{max_fpr}`." - ) - - # calculate fpr, tpr - if mode == 'multi-label': - # for multilabel we iteratively evaluate roc in a binary fashion - output = [ - roc(preds[:, i], target[:, i], num_classes=1, pos_label=1, sample_weights=sample_weights) - for i in range(num_classes) - ] - fpr = [o[0] for o in output] - tpr = [o[1] for o in output] - else: - fpr, tpr, _ = roc(preds, target, num_classes, pos_label, sample_weights) - - # calculate standard roc auc score - if max_fpr is None or max_fpr == 1: - if num_classes != 1: - # calculate auc scores per class - auc_scores = [auc(x, y) for x, y in zip(fpr, tpr)] - - # calculate average - if average == AverageMethods.NONE: - return auc_scores - elif average == AverageMethods.MACRO: - return torch.mean(torch.stack(auc_scores)) - elif average == AverageMethods.WEIGHTED: - if mode == DataType.MULTILABEL: - support = torch.sum(target, dim=0) - else: - support = torch.bincount(target.flatten(), minlength=num_classes) - return torch.sum(torch.stack(auc_scores) * support / support.sum()) - - allowed_average = [e.value for e in AverageMethods] - raise ValueError( - f"Argument `average` expected to be one of the following:" - f" {allowed_average} but got {average}" - ) - - return auc(fpr, tpr) - - max_fpr = torch.tensor(max_fpr, device=fpr.device) - # Add a single point at max_fpr and interpolate its tpr value - stop = torch.bucketize(max_fpr, fpr, out_int32=True, right=True) - weight = (max_fpr - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1]) - interp_tpr = torch.lerp(tpr[stop - 1], tpr[stop], weight) - tpr = torch.cat([tpr[:stop], interp_tpr.view(1)]) - fpr = torch.cat([fpr[:stop], max_fpr.view(1)]) - - # Compute partial AUC - partial_auc = auc(fpr, tpr) - - # McClish correction: standardize result to be 0.5 if non-discriminant - # and 1 if maximal - min_area = 0.5 * max_fpr**2 - max_area = max_fpr - return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area)) +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_auroc, ver_deprecate="1.3.0", ver_remove="1.5.0") def auroc( preds: torch.Tensor, target: torch.Tensor, @@ -144,59 +29,7 @@ def auroc( max_fpr: Optional[float] = None, sample_weights: Optional[Sequence] = None, ) -> torch.Tensor: - """ Compute `Area Under the Receiver Operating Characteristic Curve (ROC AUC) - `_ - - Args: - preds: predictions from model (logits or probabilities) - target: Ground truth labels - num_classes: integer with number of classes. Not nessesary to provide - for binary problems. - pos_label: integer determining the positive class. Default is ``None`` - which for binary problem is translate to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range [0,num_classes-1] - average: - - ``'macro'`` computes metric for each class and uniformly averages them - - ``'weighted'`` computes metric for each class and does a weighted-average, - where each class is weighted by their support (accounts for class imbalance) - - ``None`` computes and returns the metric per class - max_fpr: - If not ``None``, calculates standardized partial AUC over the - range [0, max_fpr]. Should be a float between 0 and 1. - sample_weight: sample weights for each data point - - Raises: - ValueError: - If ``max_fpr`` is not a ``float`` in the range ``(0, 1]``. - RuntimeError: - If ``PyTorch version`` is ``below 1.6`` since max_fpr requires `torch.bucketize` - which is not available below 1.6. - ValueError: - If ``max_fpr`` is not set to ``None`` and the mode is ``not binary`` - since partial AUC computation is not available in multilabel/multiclass. - ValueError: - If ``average`` is none of ``None``, ``"macro"`` or ``"weighted"``. - - Example (binary case): - - >>> from pytorch_lightning.metrics.functional import auroc - >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) - >>> target = torch.tensor([0, 0, 1, 1, 1]) - >>> auroc(preds, target, pos_label=1) - tensor(0.5000) - - Example (multiclass case): - - >>> from pytorch_lightning.metrics.functional import auroc - >>> preds = torch.tensor([[0.90, 0.05, 0.05], - ... [0.05, 0.90, 0.05], - ... [0.05, 0.05, 0.90], - ... [0.85, 0.05, 0.10], - ... [0.10, 0.10, 0.80]]) - >>> target = torch.tensor([0, 1, 1, 2, 2]) - >>> auroc(preds, target, num_classes=3) - tensor(0.7778) """ - preds, target, mode = _auroc_update(preds, target) - return _auroc_compute(preds, target, mode, num_classes, pos_label, average, max_fpr, sample_weights) + .. deprecated:: + Use :func:`torchmetrics.functional.auroc`. Will be removed in v1.5.0. + """ diff --git a/tests/deprecated_api/test_remove_1-5_metrics.py b/tests/deprecated_api/test_remove_1-5_metrics.py index 96e97d0b6c52e..a0d816253c4f7 100644 --- a/tests/deprecated_api/test_remove_1-5_metrics.py +++ b/tests/deprecated_api/test_remove_1-5_metrics.py @@ -17,6 +17,7 @@ import torch from pytorch_lightning.metrics import Accuracy, MetricCollection, AUC, AUROC +from pytorch_lightning.metrics.functional import auc, auroc from pytorch_lightning.metrics.functional.accuracy import accuracy from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot @@ -53,6 +54,9 @@ def test_v1_5_metrics_collection(): def test_v1_5_metric_accuracy(): accuracy.warned = False + + preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) + target = torch.tensor([0, 0, 1, 1, 1]) with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert accuracy(preds=torch.tensor([0, 1]), target=torch.tensor([0, 1])) == torch.tensor(1.) @@ -70,6 +74,14 @@ def test_v1_5_metric_auc(): with pytest.deprecated_call(match='It will be removed in v1.5.0'): AUROC() - # auc.warned = False - # with pytest.deprecated_call(match='It will be removed in v1.5.0'): - # assert accuracy(preds=torch.tensor([0, 1]), target=torch.tensor([0, 1])) == torch.tensor(1.) + x = torch.tensor([0, 1, 2, 3]) + y = torch.tensor([0, 1, 2, 2]) + auc.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert auc(x, y) == torch.tensor(4.) + + preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) + target = torch.tensor([0, 0, 1, 1, 1]) + auroc.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert auroc(preds, target, pos_label=1) == torch.tensor(0.5) From cd46ac111359d025062e8a813386ca6e69de0f46 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 17 Mar 2021 23:34:18 +0100 Subject: [PATCH 03/10] format --- pytorch_lightning/utilities/deprecation.py | 2 +- tests/deprecated_api/test_remove_1-5_metrics.py | 6 +++--- tests/utilities/test_deprecation.py | 8 ++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/utilities/deprecation.py b/pytorch_lightning/utilities/deprecation.py index 4460e5d070b10..f6591b4060b03 100644 --- a/pytorch_lightning/utilities/deprecation.py +++ b/pytorch_lightning/utilities/deprecation.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect from functools import wraps -from typing import Any, Callable, List, Tuple, Optional +from typing import Any, Callable, List, Optional, Tuple from pytorch_lightning.utilities import rank_zero_warn diff --git a/tests/deprecated_api/test_remove_1-5_metrics.py b/tests/deprecated_api/test_remove_1-5_metrics.py index a0d816253c4f7..cede1bc67c51a 100644 --- a/tests/deprecated_api/test_remove_1-5_metrics.py +++ b/tests/deprecated_api/test_remove_1-5_metrics.py @@ -16,7 +16,7 @@ import pytest import torch -from pytorch_lightning.metrics import Accuracy, MetricCollection, AUC, AUROC +from pytorch_lightning.metrics import Accuracy, AUC, AUROC, MetricCollection from pytorch_lightning.metrics.functional import auc, auroc from pytorch_lightning.metrics.functional.accuracy import accuracy from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot @@ -46,7 +46,7 @@ def test_v1_5_metrics_collection(): MetricCollection.__init__.warned = False with pytest.deprecated_call( match="`pytorch_lightning.metrics.metric.MetricCollection` was deprecated since v1.3.0 in favor" - " of `torchmetrics.collections.MetricCollection`. It will be removed in v1.5.0." + " of `torchmetrics.collections.MetricCollection`. It will be removed in v1.5.0." ): metrics = MetricCollection([Accuracy()]) assert metrics(preds, target) == {'Accuracy': torch.tensor(0.1250)} @@ -65,7 +65,7 @@ def test_v1_5_metric_accuracy(): Accuracy() -def test_v1_5_metric_auc(): +def test_v1_5_metric_auc_auroc(): AUC.__init__.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): AUC() diff --git a/tests/utilities/test_deprecation.py b/tests/utilities/test_deprecation.py index 2f54a77099701..42179f86b80ed 100644 --- a/tests/utilities/test_deprecation.py +++ b/tests/utilities/test_deprecation.py @@ -30,7 +30,7 @@ def dep3_sum(a, b=4): def test_deprecated_func(): with pytest.deprecated_call( match='`tests.utilities.test_deprecation.dep_sum` was deprecated since v0.1 in favor' - ' of `tests.utilities.test_deprecation.my_sum`. It will be removed in v0.5.' + ' of `tests.utilities.test_deprecation.my_sum`. It will be removed in v0.5.' ): assert dep_sum(2) == 7 @@ -41,7 +41,7 @@ def test_deprecated_func(): # and does not affect other functions with pytest.deprecated_call( match='`tests.utilities.test_deprecation.dep3_sum` was deprecated since v0.1 in favor' - ' of `tests.utilities.test_deprecation.my2_sum`. It will be removed in v0.5.' + ' of `tests.utilities.test_deprecation.my2_sum`. It will be removed in v0.5.' ): assert dep3_sum(2, 1) == 3 @@ -61,7 +61,7 @@ def test_deprecated_func_incomplete(): # does not affect other functions with pytest.deprecated_call( match='`tests.utilities.test_deprecation.dep2_sum` was deprecated since v0.1 in favor' - ' of `tests.utilities.test_deprecation.my2_sum`. It will be removed in v0.5.' + ' of `tests.utilities.test_deprecation.my2_sum`. It will be removed in v0.5.' ): assert dep2_sum(b=2, a=1) == 3 @@ -83,7 +83,7 @@ def __init__(self, c, d="efg"): def test_deprecated_class(): with pytest.deprecated_call( match='`tests.utilities.test_deprecation.PastCls` was deprecated since v0.2 in favor' - ' of `tests.utilities.test_deprecation.NewCls`. It will be removed in v0.4.' + ' of `tests.utilities.test_deprecation.NewCls`. It will be removed in v0.4.' ): past = PastCls(2) assert past.my_c == 2 From a357c093f3548d3f965c053165d92b229903d73c Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 17 Mar 2021 23:36:03 +0100 Subject: [PATCH 04/10] tests --- pytorch_lightning/metrics/functional/auroc.py | 2 +- tests/metrics/classification/test_auc.py | 64 -------- tests/metrics/classification/test_auroc.py | 142 ------------------ .../metrics/functional/test_classification.py | 48 ------ 4 files changed, 1 insertion(+), 255 deletions(-) delete mode 100644 tests/metrics/classification/test_auc.py delete mode 100644 tests/metrics/classification/test_auroc.py diff --git a/pytorch_lightning/metrics/functional/auroc.py b/pytorch_lightning/metrics/functional/auroc.py index ea29146febcdb..16058110175c5 100644 --- a/pytorch_lightning/metrics/functional/auroc.py +++ b/pytorch_lightning/metrics/functional/auroc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple +from typing import Optional, Sequence import torch from torchmetrics.functional import auroc as _auroc diff --git a/tests/metrics/classification/test_auc.py b/tests/metrics/classification/test_auc.py deleted file mode 100644 index e902151ecffce..0000000000000 --- a/tests/metrics/classification/test_auc.py +++ /dev/null @@ -1,64 +0,0 @@ -from collections import namedtuple - -import numpy as np -import pytest -import torch -from sklearn.metrics import auc as _sk_auc - -from pytorch_lightning.metrics.classification.auc import AUC -from pytorch_lightning.metrics.functional.auc import auc -from tests.metrics.utils import MetricTester, NUM_BATCHES - -torch.manual_seed(42) - - -def sk_auc(x, y): - x = x.flatten() - y = y.flatten() - return _sk_auc(x, y) - - -Input = namedtuple('Input', ["x", "y"]) - -_examples = [] -# generate already ordered samples, sorted in both directions -for i in range(4): - x = np.random.randint(0, 5, (NUM_BATCHES * 8)) - y = np.random.randint(0, 5, (NUM_BATCHES * 8)) - idx = np.argsort(x, kind='stable') - x = x[idx] if i % 2 == 0 else x[idx[::-1]] - y = y[idx] if i % 2 == 0 else x[idx[::-1]] - x = x.reshape(NUM_BATCHES, 8) - y = y.reshape(NUM_BATCHES, 8) - _examples.append(Input(x=torch.tensor(x), y=torch.tensor(y))) - - -@pytest.mark.parametrize("x, y", _examples) -class TestAUC(MetricTester): - - @pytest.mark.parametrize("ddp", [False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_auc(self, x, y, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp=ddp, - preds=x, - target=y, - metric_class=AUC, - sk_metric=sk_auc, - dist_sync_on_step=dist_sync_on_step, - ) - - def test_auc_functional(self, x, y): - self.run_functional_metric_test(x, y, metric_functional=auc, sk_metric=sk_auc, metric_args={"reorder": False}) - - -@pytest.mark.parametrize(['x', 'y', 'expected'], [ - pytest.param([0, 1], [0, 1], 0.5), - pytest.param([1, 0], [0, 1], 0.5), - pytest.param([1, 0, 0], [0, 1, 1], 0.5), - pytest.param([0, 1], [1, 1], 1), - pytest.param([0, 0.5, 1], [0, 0.5, 1], 0.5), -]) -def test_auc(x, y, expected): - # Test Area Under Curve (AUC) computation - assert auc(torch.tensor(x), torch.tensor(y), reorder=True) == expected diff --git a/tests/metrics/classification/test_auroc.py b/tests/metrics/classification/test_auroc.py deleted file mode 100644 index 0affcb1010225..0000000000000 --- a/tests/metrics/classification/test_auroc.py +++ /dev/null @@ -1,142 +0,0 @@ -from distutils.version import LooseVersion -from functools import partial - -import pytest -import torch -from sklearn.metrics import roc_auc_score as sk_roc_auc_score - -from pytorch_lightning.metrics.classification.auroc import AUROC -from pytorch_lightning.metrics.functional.auroc import auroc -from tests.metrics.classification.inputs import _input_binary_prob -from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.metrics.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob -from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob -from tests.metrics.utils import MetricTester, NUM_CLASSES - -torch.manual_seed(42) - - -def _sk_auroc_binary_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_roc_auc_score(y_true=sk_target, y_score=sk_preds, average=average, max_fpr=max_fpr) - - -def _sk_auroc_multiclass_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.view(-1).numpy() - return sk_roc_auc_score( - y_true=sk_target, - y_score=sk_preds, - average=average, - max_fpr=max_fpr, - multi_class=multi_class, - ) - - -def _sk_auroc_multidim_multiclass_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.view(-1).numpy() - return sk_roc_auc_score( - y_true=sk_target, - y_score=sk_preds, - average=average, - max_fpr=max_fpr, - multi_class=multi_class, - ) - - -def _sk_auroc_multilabel_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.reshape(-1, num_classes).numpy() - return sk_roc_auc_score( - y_true=sk_target, - y_score=sk_preds, - average=average, - max_fpr=max_fpr, - multi_class=multi_class, - ) - - -def _sk_auroc_multilabel_multidim_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - return sk_roc_auc_score( - y_true=sk_target, - y_score=sk_preds, - average=average, - max_fpr=max_fpr, - multi_class=multi_class, - ) - - -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", - [(_input_binary_prob.preds, _input_binary_prob.target, _sk_auroc_binary_prob, 1), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_auroc_multiclass_prob, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_auroc_multidim_multiclass_prob, NUM_CLASSES), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_auroc_multilabel_prob, NUM_CLASSES), - (_input_mlmd_prob.preds, _input_mlmd_prob.target, _sk_auroc_multilabel_multidim_prob, NUM_CLASSES)] -) -@pytest.mark.parametrize("average", ['macro', 'weighted']) -@pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5]) -class TestAUROC(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_auroc(self, preds, target, sk_metric, num_classes, average, max_fpr, ddp, dist_sync_on_step): - # max_fpr different from None is not support in multi class - if max_fpr is not None and num_classes != 1: - pytest.skip('max_fpr parameter not support for multi class or multi label') - - # max_fpr only supported for torch v1.6 or higher - if max_fpr is not None and LooseVersion(torch.__version__) < LooseVersion('1.6.0'): - pytest.skip('requires torch v1.6 or higher to test max_fpr argument') - - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=AUROC, - sk_metric=partial(sk_metric, num_classes=num_classes, average=average, max_fpr=max_fpr), - dist_sync_on_step=dist_sync_on_step, - metric_args={ - "num_classes": num_classes, - "average": average, - "max_fpr": max_fpr - }, - ) - - def test_auroc_functional(self, preds, target, sk_metric, num_classes, average, max_fpr): - # max_fpr different from None is not support in multi class - if max_fpr is not None and num_classes != 1: - pytest.skip('max_fpr parameter not support for multi class or multi label') - - # max_fpr only supported for torch v1.6 or higher - if max_fpr is not None and LooseVersion(torch.__version__) < LooseVersion('1.6.0'): - pytest.skip('requires torch v1.6 or higher to test max_fpr argument') - - self.run_functional_metric_test( - preds, - target, - metric_functional=auroc, - sk_metric=partial(sk_metric, num_classes=num_classes, average=average, max_fpr=max_fpr), - metric_args={ - "num_classes": num_classes, - "average": average, - "max_fpr": max_fpr - }, - ) - - -def test_error_on_different_mode(): - """ test that an error is raised if the user pass in data of - different modes (binary, multi-label, multi-class) - """ - metric = AUROC() - # pass in multi-class data - metric.update(torch.randn(10, 5).softmax(dim=-1), torch.randint(0, 5, (10, ))) - with pytest.raises(ValueError, match=r"The mode of data.* should be constant.*"): - # pass in multi-label data - metric.update(torch.rand(10, 5), torch.randint(0, 2, (10, 5))) diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index bca50867dcb44..44109d40b2efa 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -1,59 +1,11 @@ import pytest import torch -from torchmetrics.utilities.data import get_num_classes, to_categorical, to_onehot from pytorch_lightning import seed_everything from pytorch_lightning.metrics.functional.classification import dice_score from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve -def test_onehot(): - test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) - expected = torch.stack([ - torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]), - torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)]) - ]) - - assert test_tensor.shape == (2, 5) - assert expected.shape == (2, 10, 5) - - onehot_classes = to_onehot(test_tensor, num_classes=10) - onehot_no_classes = to_onehot(test_tensor) - - assert torch.allclose(onehot_classes, onehot_no_classes) - - assert onehot_classes.shape == expected.shape - assert onehot_no_classes.shape == expected.shape - - assert torch.allclose(expected.to(onehot_no_classes), onehot_no_classes) - assert torch.allclose(expected.to(onehot_classes), onehot_classes) - - -def test_to_categorical(): - test_tensor = torch.stack([ - torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]), - torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)]) - ]).to(torch.float) - - expected = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) - assert expected.shape == (2, 5) - assert test_tensor.shape == (2, 10, 5) - - result = to_categorical(test_tensor) - - assert result.shape == expected.shape - assert torch.allclose(result, expected.to(result.dtype)) - - -@pytest.mark.parametrize(['pred', 'target', 'num_classes', 'expected_num_classes'], [ - pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10), - pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), - pytest.param(torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), -]) -def test_get_num_classes(pred, target, num_classes, expected_num_classes): - assert get_num_classes(pred, target, num_classes) == expected_num_classes - - @pytest.mark.parametrize(['sample_weight', 'pos_label', "exp_shape"], [ pytest.param(1, 1., 42), pytest.param(None, 1., 42), From d6e5d58465d5c61236519514646f272698d5b820 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 17 Mar 2021 23:36:27 +0100 Subject: [PATCH 05/10] . --- tests/deprecated_api/test_remove_1-5_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/deprecated_api/test_remove_1-5_metrics.py b/tests/deprecated_api/test_remove_1-5_metrics.py index cede1bc67c51a..93f613e695755 100644 --- a/tests/deprecated_api/test_remove_1-5_metrics.py +++ b/tests/deprecated_api/test_remove_1-5_metrics.py @@ -58,7 +58,7 @@ def test_v1_5_metric_accuracy(): preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) target = torch.tensor([0, 0, 1, 1, 1]) with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert accuracy(preds=torch.tensor([0, 1]), target=torch.tensor([0, 1])) == torch.tensor(1.) + assert accuracy(preds, target) == torch.tensor(1.) Accuracy.__init__.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): From 85a81d8cdaa358bc43387d492f0ac90db976ea58 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 17 Mar 2021 23:39:01 +0100 Subject: [PATCH 06/10] chlog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index cc78de0f9c0c1..4d397079bb072 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,6 +77,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). [#6515](https://github.com/PyTorchLightning/pytorch-lightning/pull/6515), + [#6572](https://github.com/PyTorchLightning/pytorch-lightning/pull/6572), + ) From c2f8a7a301960db06a8ccbe7d2063ac7a1f610c0 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 18 Mar 2021 00:25:08 +0100 Subject: [PATCH 07/10] roc --- .../metrics/classification/roc.py | 128 ++---------------- pytorch_lightning/metrics/functional/roc.py | 125 +---------------- tests/metrics/classification/test_roc.py | 99 -------------- .../test_remove_1-5_metrics.py | 12 +- 4 files changed, 25 insertions(+), 339 deletions(-) delete mode 100644 tests/metrics/classification/test_roc.py rename tests/{deprecated_api => metrics}/test_remove_1-5_metrics.py (90%) diff --git a/pytorch_lightning/metrics/classification/roc.py b/pytorch_lightning/metrics/classification/roc.py index 30ca0b4fe6925..2298287a53ff7 100644 --- a/pytorch_lightning/metrics/classification/roc.py +++ b/pytorch_lightning/metrics/classification/roc.py @@ -11,79 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional -import torch -from torchmetrics import Metric +from torchmetrics import ROC as _ROC -from pytorch_lightning.metrics.functional.roc import _roc_compute, _roc_update -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.deprecation import deprecated -class ROC(Metric): - """ - Computes the Receiver Operating Characteristic (ROC). Works for both - binary and multiclass problems. In the case of multiclass, the values will - be calculated based on a one-vs-the-rest approach. - - Forward accepts - - - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor - with probabilities, where C is the number of classes. - - - ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels - - Args: - num_classes: integer with number of classes. Not nessesary to provide - for binary problems. - pos_label: integer determining the positive class. Default is ``None`` - which for binary problem is translate to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range [0,num_classes-1] - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Example (binary case): - - >>> from pytorch_lightning.metrics import ROC - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 1, 1]) - >>> roc = ROC(pos_label=1) - >>> fpr, tpr, thresholds = roc(pred, target) - >>> fpr - tensor([0., 0., 0., 0., 1.]) - >>> tpr - tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]) - >>> thresholds - tensor([4, 3, 2, 1, 0]) - - Example (multiclass case): - - >>> from pytorch_lightning.metrics import ROC - >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05], - ... [0.05, 0.75, 0.05, 0.05], - ... [0.05, 0.05, 0.75, 0.05], - ... [0.05, 0.05, 0.05, 0.75]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> roc = ROC(num_classes=4) - >>> fpr, tpr, thresholds = roc(pred, target) - >>> fpr - [tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000])] - >>> tpr - [tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.])] - >>> thresholds # doctest: +NORMALIZE_WHITESPACE - [tensor([1.7500, 0.7500, 0.0500]), - tensor([1.7500, 0.7500, 0.0500]), - tensor([1.7500, 0.7500, 0.0500]), - tensor([1.7500, 0.7500, 0.0500])] - - """ +class ROC(_ROC): + @deprecated(target=_ROC, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, num_classes: Optional[int] = None, @@ -92,56 +29,9 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - ) - - self.num_classes = num_classes - self.pos_label = pos_label - - self.add_state("preds", default=[], dist_reduce_fx=None) - self.add_state("target", default=[], dist_reduce_fx=None) - - rank_zero_warn( - 'Metric `ROC` will save all targets and predictions in buffer.' - ' For large datasets this may lead to large memory footprint.' - ) - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values """ - preds, target, num_classes, pos_label = _roc_update(preds, target, self.num_classes, self.pos_label) - self.preds.append(preds) - self.target.append(target) - self.num_classes = num_classes - self.pos_label = pos_label + This implementation refers to :class:`~torchmetrics.ROC`. - def compute( - self - ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], - List[torch.Tensor]]]: - """ - Compute the receiver operating characteristic - - Returns: - 3-element tuple containing - - fpr: - tensor with false positive rates. - If multiclass, this is a list of such tensors, one for each class. - tpr: - tensor with true positive rates. - If multiclass, this is a list of such tensors, one for each class. - thresholds: - thresholds used for computing false- and true postive rates - """ - preds = torch.cat(self.preds, dim=0) - target = torch.cat(self.target, dim=0) - return _roc_compute(preds, target, self.num_classes, self.pos_label) + .. deprecated:: + Use :class:`~torchmetrics.ROC`. Will be removed in v1.5.0. + """ \ No newline at end of file diff --git a/pytorch_lightning/metrics/functional/roc.py b/pytorch_lightning/metrics/functional/roc.py index 030c974365807..768d11f5dcc3f 100644 --- a/pytorch_lightning/metrics/functional/roc.py +++ b/pytorch_lightning/metrics/functional/roc.py @@ -14,69 +14,12 @@ from typing import List, Optional, Sequence, Tuple, Union import torch +from torchmetrics.functional import roc as _roc -from pytorch_lightning.metrics.functional.precision_recall_curve import ( - _binary_clf_curve, - _precision_recall_curve_update, -) - - -def _roc_update( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor, int, int]: - return _precision_recall_curve_update(preds, target, num_classes, pos_label) - - -def _roc_compute( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: int, - pos_label: int, - sample_weights: Optional[Sequence] = None, -) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], - List[torch.Tensor]]]: - - if num_classes == 1: - fps, tps, thresholds = _binary_clf_curve( - preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label - ) - # Add an extra threshold position - # to make sure that the curve starts at (0, 0) - tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps]) - fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps]) - thresholds = torch.cat([thresholds[0][None] + 1, thresholds]) - - if fps[-1] <= 0: - raise ValueError("No negative samples in targets, false positive value should be meaningless") - fpr = fps / fps[-1] - - if tps[-1] <= 0: - raise ValueError("No positive samples in targets, true positive value should be meaningless") - tpr = tps / tps[-1] - - return fpr, tpr, thresholds - - # Recursively call per class - fpr, tpr, thresholds = [], [], [] - for c in range(num_classes): - preds_c = preds[:, c] - res = roc( - preds=preds_c, - target=target, - num_classes=1, - pos_label=c, - sample_weights=sample_weights, - ) - fpr.append(res[0]) - tpr.append(res[1]) - thresholds.append(res[2]) - - return fpr, tpr, thresholds +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_roc, ver_deprecate="1.3.0", ver_remove="1.5.0") def roc( preds: torch.Tensor, target: torch.Tensor, @@ -84,64 +27,8 @@ def roc( pos_label: Optional[int] = None, sample_weights: Optional[Sequence] = None, ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], - List[torch.Tensor]]]: + List[torch.Tensor]],]: """ - Computes the Receiver Operating Characteristic (ROC). - - Args: - preds: predictions from model (logits or probabilities) - target: ground truth values - num_classes: integer with number of classes. Not nessesary to provide - for binary problems. - pos_label: integer determining the positive class. Default is ``None`` - which for binary problem is translate to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range [0,num_classes-1] - sample_weights: sample weights for each data point - - Returns: - 3-element tuple containing - - fpr: - tensor with false positive rates. - If multiclass, this is a list of such tensors, one for each class. - tpr: - tensor with true positive rates. - If multiclass, this is a list of such tensors, one for each class. - thresholds: - thresholds used for computing false- and true postive rates - - Example (binary case): - - >>> from pytorch_lightning.metrics.functional import roc - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 1, 1]) - >>> fpr, tpr, thresholds = roc(pred, target, pos_label=1) - >>> fpr - tensor([0., 0., 0., 0., 1.]) - >>> tpr - tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]) - >>> thresholds - tensor([4, 3, 2, 1, 0]) - - Example (multiclass case): - - >>> from pytorch_lightning.metrics.functional import roc - >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05], - ... [0.05, 0.75, 0.05, 0.05], - ... [0.05, 0.05, 0.75, 0.05], - ... [0.05, 0.05, 0.05, 0.75]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> fpr, tpr, thresholds = roc(pred, target, num_classes=4) - >>> fpr - [tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000])] - >>> tpr - [tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.])] - >>> thresholds # doctest: +NORMALIZE_WHITESPACE - [tensor([1.7500, 0.7500, 0.0500]), - tensor([1.7500, 0.7500, 0.0500]), - tensor([1.7500, 0.7500, 0.0500]), - tensor([1.7500, 0.7500, 0.0500])] + .. deprecated:: + Use :func:`torchmetrics.functional.roc`. Will be removed in v1.5.0. """ - preds, target, num_classes, pos_label = _roc_update(preds, target, num_classes, pos_label) - return _roc_compute(preds, target, num_classes, pos_label, sample_weights) diff --git a/tests/metrics/classification/test_roc.py b/tests/metrics/classification/test_roc.py deleted file mode 100644 index 46a23322ca1c0..0000000000000 --- a/tests/metrics/classification/test_roc.py +++ /dev/null @@ -1,99 +0,0 @@ -from functools import partial - -import numpy as np -import pytest -import torch -from sklearn.metrics import roc_curve as sk_roc_curve - -from pytorch_lightning.metrics.classification.roc import ROC -from pytorch_lightning.metrics.functional.roc import roc -from tests.metrics.classification.inputs import _input_binary_prob -from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.metrics.utils import MetricTester, NUM_CLASSES - -torch.manual_seed(42) - - -def _sk_roc_curve(y_true, probas_pred, num_classes=1): - """ Adjusted comparison function that can also handles multiclass """ - if num_classes == 1: - return sk_roc_curve(y_true, probas_pred, drop_intermediate=False) - - fpr, tpr, thresholds = [], [], [] - for i in range(num_classes): - y_true_temp = np.zeros_like(y_true) - y_true_temp[y_true == i] = 1 - res = sk_roc_curve(y_true_temp, probas_pred[:, i], drop_intermediate=False) - fpr.append(res[0]) - tpr.append(res[1]) - thresholds.append(res[2]) - return fpr, tpr, thresholds - - -def _sk_roc_binary_prob(preds, target, num_classes=1): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -def _sk_roc_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.view(-1).numpy() - - return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -def _sk_roc_multidim_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.view(-1).numpy() - return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES), - ] -) -class TestROC(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_roc(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=ROC, - sk_metric=partial(sk_metric, num_classes=num_classes), - dist_sync_on_step=dist_sync_on_step, - metric_args={"num_classes": num_classes} - ) - - def test_roc_functional(self, preds, target, sk_metric, num_classes): - self.run_functional_metric_test( - preds, - target, - metric_functional=roc, - sk_metric=partial(sk_metric, num_classes=num_classes), - metric_args={"num_classes": num_classes}, - ) - - -@pytest.mark.parametrize(['pred', 'target', 'expected_tpr', 'expected_fpr'], [ - pytest.param([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]), - pytest.param([1, 0], [0, 1], [0, 0, 1], [0, 1, 1]), - pytest.param([1, 1], [1, 0], [0, 1], [0, 1]), - pytest.param([1, 0], [1, 0], [0, 1, 1], [0, 0, 1]), - pytest.param([0.5, 0.5], [0, 1], [0, 1], [0, 1]), -]) -def test_roc_curve(pred, target, expected_tpr, expected_fpr): - fpr, tpr, thresh = roc(torch.tensor(pred), torch.tensor(target)) - - assert fpr.shape == tpr.shape - assert fpr.size(0) == thresh.size(0) - assert torch.allclose(fpr, torch.tensor(expected_fpr).to(fpr)) - assert torch.allclose(tpr, torch.tensor(expected_tpr).to(tpr)) diff --git a/tests/deprecated_api/test_remove_1-5_metrics.py b/tests/metrics/test_remove_1-5_metrics.py similarity index 90% rename from tests/deprecated_api/test_remove_1-5_metrics.py rename to tests/metrics/test_remove_1-5_metrics.py index 93f613e695755..3a22be2936a53 100644 --- a/tests/deprecated_api/test_remove_1-5_metrics.py +++ b/tests/metrics/test_remove_1-5_metrics.py @@ -16,8 +16,8 @@ import pytest import torch -from pytorch_lightning.metrics import Accuracy, AUC, AUROC, MetricCollection -from pytorch_lightning.metrics.functional import auc, auroc +from pytorch_lightning.metrics import Accuracy, AUC, AUROC, MetricCollection, ROC +from pytorch_lightning.metrics.functional import auc, auroc, roc from pytorch_lightning.metrics.functional.accuracy import accuracy from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot @@ -70,6 +70,10 @@ def test_v1_5_metric_auc_auroc(): with pytest.deprecated_call(match='It will be removed in v1.5.0'): AUC() + ROC.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + ROC() + AUROC.__init__.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): AUROC() @@ -82,6 +86,10 @@ def test_v1_5_metric_auc_auroc(): preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) target = torch.tensor([0, 0, 1, 1, 1]) + roc.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert roc(preds, target, pos_label=1) == torch.tensor(0.5) + auroc.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert auroc(preds, target, pos_label=1) == torch.tensor(0.5) From b3ef98d67df626afeb67f5242a44bcbb9dfb8d80 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 18 Mar 2021 00:27:17 +0100 Subject: [PATCH 08/10] format --- pytorch_lightning/metrics/functional/roc.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/metrics/functional/roc.py b/pytorch_lightning/metrics/functional/roc.py index 768d11f5dcc3f..1ca534eb6a5be 100644 --- a/pytorch_lightning/metrics/functional/roc.py +++ b/pytorch_lightning/metrics/functional/roc.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import List, Optional, Sequence, Tuple, Union -import torch +from torch import Tensor from torchmetrics.functional import roc as _roc from pytorch_lightning.utilities.deprecation import deprecated @@ -21,13 +21,12 @@ @deprecated(target=_roc, ver_deprecate="1.3.0", ver_remove="1.5.0") def roc( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, num_classes: Optional[int] = None, pos_label: Optional[int] = None, sample_weights: Optional[Sequence] = None, -) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], - List[torch.Tensor]],]: +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: """ .. deprecated:: Use :func:`torchmetrics.functional.roc`. Will be removed in v1.5.0. From ddbf406de0757da3f82d6b4b526c1c7c7a1ea249 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 18 Mar 2021 00:50:00 +0100 Subject: [PATCH 09/10] fix --- tests/metrics/test_remove_1-5_metrics.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/metrics/test_remove_1-5_metrics.py b/tests/metrics/test_remove_1-5_metrics.py index 3a22be2936a53..f284a9d85bc47 100644 --- a/tests/metrics/test_remove_1-5_metrics.py +++ b/tests/metrics/test_remove_1-5_metrics.py @@ -55,10 +55,10 @@ def test_v1_5_metrics_collection(): def test_v1_5_metric_accuracy(): accuracy.warned = False - preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) + preds = torch.tensor([0, 0, 1, 0, 1]) target = torch.tensor([0, 0, 1, 1, 1]) with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert accuracy(preds, target) == torch.tensor(1.) + assert accuracy(preds, target) == torch.tensor(0.8) Accuracy.__init__.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): @@ -84,12 +84,17 @@ def test_v1_5_metric_auc_auroc(): with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert auc(x, y) == torch.tensor(4.) - preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) - target = torch.tensor([0, 0, 1, 1, 1]) + preds = torch.tensor([0, 1, 2, 3]) + target = torch.tensor([0, 1, 1, 1]) roc.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert roc(preds, target, pos_label=1) == torch.tensor(0.5) + fpr, tpr, thresholds = roc(preds, target, pos_label=1) + assert torch.equal(fpr, torch.tensor([0., 0., 0., 0., 1.])) + assert torch.allclose(tpr, torch.tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]), atol=1e-4) + assert torch.equal(thresholds, torch.tensor([4, 3, 2, 1, 0])) + preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) + target = torch.tensor([0, 0, 1, 1, 1]) auroc.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert auroc(preds, target, pos_label=1) == torch.tensor(0.5) + assert auroc(preds, target) == torch.tensor(0.5) From f86111bc2b9e19c21f3ab21b5c0c088b622f645c Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 18 Mar 2021 01:00:36 +0100 Subject: [PATCH 10/10] flake8 --- pytorch_lightning/metrics/classification/roc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/classification/roc.py b/pytorch_lightning/metrics/classification/roc.py index 2298287a53ff7..5850913f61ed9 100644 --- a/pytorch_lightning/metrics/classification/roc.py +++ b/pytorch_lightning/metrics/classification/roc.py @@ -34,4 +34,4 @@ def __init__( .. deprecated:: Use :class:`~torchmetrics.ROC`. Will be removed in v1.5.0. - """ \ No newline at end of file + """