From 4a3f9069cb8fabf588f08361225d129cc52c933e Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 11 Dec 2020 22:11:21 +0100 Subject: [PATCH] add back compatibility for deprecated metrics 1/n (#5067) * add back compatibility for metrics * tests * Add deprecated metric utility functions back to functional (#5062) * add back *deprecated* metric utility functions to functional * pep * pep * suggestions * move Co-authored-by: Jirka Borovec * more * fix * import * docs * tests * fix Co-authored-by: Teddy Koker --- CHANGELOG.md | 4 +- .../metrics/functional/__init__.py | 5 +- .../metrics/functional/classification.py | 217 +++++++++++++++++- .../metrics/functional/reduction.py | 35 +++ tests/test_deprecated.py | 57 +++++ 5 files changed, 306 insertions(+), 12 deletions(-) create mode 100644 pytorch_lightning/metrics/functional/reduction.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e188256ec600..f078349ef3665 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -85,9 +85,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- Removed `multiclass_roc` and `multiclass_precision_recall_curve`, use `roc` and `precision_recall_curve` instead ([#4549](https://github.com/PyTorchLightning/pytorch-lightning/pull/4549)) - Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903)) -- WandbLogger does not force wandb `reinit` arg to True anymore and creates a run only when needed ([#4648](https://github.com/PyTorchLightning/pytorch-lightning/pull/4648)) +- `WandbLogger` does not force wandb `reinit` arg to True anymore and creates a run only when needed ([#4648](https://github.com/PyTorchLightning/pytorch-lightning/pull/4648)) - Changed `automatic_optimization` to be a model attribute ([#4602](https://github.com/PyTorchLightning/pytorch-lightning/pull/4602)) - Changed `Simple Profiler` report to order by percentage time spent + num calls ([#4880](https://github.com/PyTorchLightning/pytorch-lightning/pull/4880)) - Simplify optimization Logic ([#4984](https://github.com/PyTorchLightning/pytorch-lightning/pull/4984)) @@ -105,6 +104,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed - Removed `reorder` parameter of the `auc` metric ([#5004](https://github.com/PyTorchLightning/pytorch-lightning/pull/5004)) +- Removed `multiclass_roc` and `multiclass_precision_recall_curve`, use `roc` and `precision_recall_curve` instead ([#4549](https://github.com/PyTorchLightning/pytorch-lightning/pull/4549)) ### Fixed diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index e13242e40b0ac..fe9c93525aa69 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -17,13 +17,16 @@ auc, auroc, dice_score, + get_num_classes, + iou, multiclass_auroc, precision, precision_recall, recall, stat_scores, stat_scores_multiple_classes, - iou, + to_categorical, + to_onehot, ) from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # TODO: unify metrics between class and functional, add below diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 4fd0b9258f4a1..cc4318ac63969 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -16,12 +16,73 @@ import torch -from pytorch_lightning.metrics.functional.roc import roc -from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve -from pytorch_lightning.metrics.utils import to_categorical, get_num_classes, reduce, class_reduce +from pytorch_lightning.metrics.functional.average_precision import average_precision as __ap +from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve, precision_recall_curve as __prc +from pytorch_lightning.metrics.functional.roc import roc as __roc +from pytorch_lightning.metrics.utils import ( + to_categorical as __tc, + to_onehot as __to, + get_num_classes as __gnc, + reduce, + class_reduce, +) from pytorch_lightning.utilities import rank_zero_warn +def to_onehot( + tensor: torch.Tensor, + num_classes: Optional[int] = None, +) -> torch.Tensor: + """ + Converts a dense label tensor to one-hot format + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.to_onehot` + """ + rank_zero_warn( + "This `to_onehot` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.utils import to_onehot`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __to(tensor, num_classes) + + +def to_categorical( + tensor: torch.Tensor, + argmax_dim: int = 1 +) -> torch.Tensor: + """ + Converts a tensor of probabilities to a dense label tensor + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.to_categorical` + + """ + rank_zero_warn( + "This `to_categorical` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.utils import to_categorical`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __tc(tensor) + + +def get_num_classes( + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, +) -> int: + """ + Calculates the number of classes for a given prediction and target tensor. + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.get_num_classes` + + """ + rank_zero_warn( + "This `get_num_classes` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.utils import get_num_classes`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __gnc(pred,target, num_classes) + + def stat_scores( pred: torch.Tensor, target: torch.Tensor, @@ -333,8 +394,79 @@ def recall( num_classes=num_classes, class_reduction=class_reduction)[1] +# todo: remove in 1.3 +def roc( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary. + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc` + """ + rank_zero_warn( + "This `multiclass_roc` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.functional.roc import roc`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __roc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) + + +# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py +def _roc( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary. + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc` + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 1, 1]) + >>> fpr, tpr, thresholds = _roc(x, y) + >>> fpr + tensor([0., 0., 0., 0., 1.]) + >>> tpr + tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]) + >>> thresholds + tensor([4, 3, 2, 1, 0]) + + """ + rank_zero_warn( + "This `multiclass_roc` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.functional.roc import roc`." + " It will be removed in v1.3.0", DeprecationWarning + ) + fps, tps, thresholds = _binary_clf_curve(pred, target, sample_weights=sample_weight, pos_label=pos_label) + + # Add an extra threshold position + # to make sure that the curve starts at (0, 0) + tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps]) + fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps]) + thresholds = torch.cat([thresholds[0][None] + 1, thresholds]) + + if fps[-1] <= 0: + raise ValueError("No negative samples in targets, false positive value should be meaningless") + + fpr = fps / fps[-1] + + if tps[-1] <= 0: + raise ValueError("No positive samples in targets, true positive value should be meaningless") + + tpr = tps / tps[-1] + + return fpr, tpr, thresholds + + # TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py -def __multiclass_roc( +def multiclass_roc( pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None, @@ -343,7 +475,7 @@ def __multiclass_roc( """ Computes the Receiver Operating Characteristic (ROC) for multiclass predictors. - .. warning:: Deprecated + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc` Args: pred: estimated probabilities @@ -362,19 +494,24 @@ def __multiclass_roc( ... [0.05, 0.05, 0.85, 0.05], ... [0.05, 0.05, 0.05, 0.85]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> __multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE + >>> multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE ((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), (tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])), (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500]))) """ + rank_zero_warn( + "This `multiclass_roc` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.functional.roc import roc`." + " It will be removed in v1.3.0", DeprecationWarning + ) num_classes = get_num_classes(pred, target, num_classes) class_roc_vals = [] for c in range(num_classes): pred_c = pred[:, c] - class_roc_vals.append(roc(preds=pred_c, target=target, sample_weights=sample_weight, pos_label=c, num_classes=1)) + class_roc_vals.append(_roc(pred=pred_c, target=target, sample_weight=sample_weight, pos_label=c)) return tuple(class_roc_vals) @@ -472,7 +609,7 @@ def auroc( @auc_decorator() def _auroc(pred, target, sample_weight, pos_label): - return roc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label, num_classes=1) + return _roc(pred, target, sample_weight, pos_label) return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label) @@ -525,7 +662,7 @@ def multiclass_auroc( @multiclass_auc_decorator() def _multiclass_auroc(pred, target, sample_weight, num_classes): - return __multiclass_roc(pred=pred, target=target, sample_weight=sample_weight, num_classes=num_classes) + return multiclass_roc(pred, target, sample_weight, num_classes) class_aurocs = _multiclass_auroc(pred=pred, target=target, sample_weight=sample_weight, @@ -672,3 +809,65 @@ def iou( ]) return reduce(scores, reduction=reduction) + + +# todo: remove in 1.3 +def precision_recall_curve( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., +): + """ + Computes precision-recall pairs for different thresholds. + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve` + """ + rank_zero_warn( + "This `precision_recall_curve` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __prc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) + + +# todo: remove in 1.3 +def multiclass_precision_recall_curve( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + num_classes: Optional[int] = None, +): + """ + Computes precision-recall pairs for different thresholds given a multiclass scores. + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve` + """ + rank_zero_warn( + "This `multiclass_precision_recall_curve` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve`." + " It will be removed in v1.3.0", DeprecationWarning + ) + if num_classes is None: + num_classes = get_num_classes(pred, target, num_classes) + return __prc(preds=pred, target=target, sample_weights=sample_weight, num_classes=num_classes) + + +# todo: remove in 1.3 +def average_precision( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., +): + """ + Compute average precision from prediction scores. + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.average_precision.average_precision` + """ + rank_zero_warn( + "This `average_precision` was deprecated in v1.1.0 in favor of" + " `pytorch_lightning.metrics.functional.average_precision import average_precision`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __ap(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) diff --git a/pytorch_lightning/metrics/functional/reduction.py b/pytorch_lightning/metrics/functional/reduction.py new file mode 100644 index 0000000000000..c116b16d363a9 --- /dev/null +++ b/pytorch_lightning/metrics/functional/reduction.py @@ -0,0 +1,35 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from pytorch_lightning.metrics.utils import reduce as __reduce, class_reduce as __cr +from pytorch_lightning.utilities import rank_zero_warn + + +def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: + rank_zero_warn( + "This `reduce` was deprecated in v1.1.0 in favor of" + " `pytorch_lightning.metrics.utils import reduce`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __reduce(to_reduce=to_reduce, reduction=reduction) + + +def class_reduce(num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, class_reduction: str = 'none'): + rank_zero_warn( + "This `class_reduce` was deprecated in v1.1.0 in favor of" + " `pytorch_lightning.metrics.utils import class_reduce`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __cr(num=num, denom=denom, weights=weights, class_reduction=class_reduction) diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index f549de1f4d71e..58384fed0cd4f 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -51,6 +51,63 @@ def __init__(self, hparams): DeprecatedHparamsModel({}) +def test_tbd_remove_in_v1_3_0_metrics(): + with pytest.deprecated_call(match='will be removed in v1.3'): + from pytorch_lightning.metrics.functional.classification import to_onehot + to_onehot(torch.tensor([1, 2, 3])) + + with pytest.deprecated_call(match='will be removed in v1.3'): + from pytorch_lightning.metrics.functional.classification import to_categorical + to_categorical(torch.tensor([[0.2, 0.5], [0.9, 0.1]])) + + with pytest.deprecated_call(match='will be removed in v1.3'): + from pytorch_lightning.metrics.functional.classification import get_num_classes + get_num_classes(pred=torch.tensor([0, 1]), target=torch.tensor([1, 1])) + + x_binary = torch.tensor([0, 1, 2, 3]) + y_binary = torch.tensor([0, 1, 2, 3]) + + with pytest.deprecated_call(match='will be removed in v1.3'): + from pytorch_lightning.metrics.functional.classification import roc + roc(pred=x_binary, target=y_binary) + + with pytest.deprecated_call(match='will be removed in v1.3'): + from pytorch_lightning.metrics.functional.classification import _roc + _roc(pred=x_binary, target=y_binary) + + x_multy = torch.tensor([[0.85, 0.05, 0.05, 0.05], + [0.05, 0.85, 0.05, 0.05], + [0.05, 0.05, 0.85, 0.05], + [0.05, 0.05, 0.05, 0.85]]) + y_multy = torch.tensor([0, 1, 3, 2]) + + with pytest.deprecated_call(match='will be removed in v1.3'): + from pytorch_lightning.metrics.functional.classification import multiclass_roc + multiclass_roc(pred=x_multy, target=y_multy) + + with pytest.deprecated_call(match='will be removed in v1.3'): + from pytorch_lightning.metrics.functional.classification import average_precision + average_precision(pred=x_binary, target=y_binary) + + with pytest.deprecated_call(match='will be removed in v1.3'): + from pytorch_lightning.metrics.functional.classification import precision_recall_curve + precision_recall_curve(pred=x_binary, target=y_binary) + + with pytest.deprecated_call(match='will be removed in v1.3'): + from pytorch_lightning.metrics.functional.classification import multiclass_precision_recall_curve + multiclass_precision_recall_curve(pred=x_multy, target=y_multy) + + with pytest.deprecated_call(match='will be removed in v1.3'): + from pytorch_lightning.metrics.functional.reduction import reduce + reduce(torch.tensor([0, 1, 1, 0]), 'sum') + + with pytest.deprecated_call(match='will be removed in v1.3'): + from pytorch_lightning.metrics.functional.reduction import class_reduce + class_reduce(torch.randint(1, 10, (50,)).float(), + torch.randint(10, 20, (50,)).float(), + torch.randint(1, 100, (50,)).float()) + + def test_tbd_remove_in_v1_2_0(): with pytest.deprecated_call(match='will be removed in v1.2'): checkpoint_cb = ModelCheckpoint(filepath='.')