Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prune deprecated metrics for 1.3 #6161

Merged
merged 7 commits into from
Feb 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))


### Changed

Expand All @@ -24,6 +26,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed deprecated Trainer argument `enable_pl_optimizer` and `automatic_optimization` ([#6163](https://github.com/PyTorchLightning/pytorch-lightning/pull/6163))


- Removed deprecated metrics ([#6161](https://github.com/PyTorchLightning/pytorch-lightning/pull/6161))
* from `pytorch_lightning.metrics.functional.classification` removed `to_onehot`, `to_categorical`, `get_num_classes`, `roc`, `multiclass_roc`, `average_precision`, `precision_recall_curve`, `multiclass_precision_recall_curve`
* from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce`


### Fixed

- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))
Expand Down Expand Up @@ -93,7 +100,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `Trainer` flag to activate Stochastic Weight Averaging (SWA) `Trainer(stochastic_weight_avg=True)` ([#6038](https://github.com/PyTorchLightning/pytorch-lightning/pull/6038))
- Added DeepSpeed integration ([#5954](https://github.com/PyTorchLightning/pytorch-lightning/pull/5954),
[#6042](https://github.com/PyTorchLightning/pytorch-lightning/pull/6042))
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))

### Changed

Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
multiclass_auroc,
stat_scores_multiple_classes,
to_categorical,
to_onehot,
)
from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # noqa: F401
from pytorch_lightning.metrics.functional.explained_variance import explained_variance # noqa: F401
Expand Down
257 changes: 10 additions & 247 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,70 +18,11 @@

from pytorch_lightning.metrics.functional.auc import auc as __auc
from pytorch_lightning.metrics.functional.auroc import auroc as __auroc
from pytorch_lightning.metrics.functional.average_precision import average_precision as __ap
from pytorch_lightning.metrics.functional.iou import iou as __iou
from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve
from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve as __prc
from pytorch_lightning.metrics.functional.roc import roc as __roc
from pytorch_lightning.metrics.utils import class_reduce
from pytorch_lightning.metrics.utils import get_num_classes as __gnc
from pytorch_lightning.metrics.utils import reduce
from pytorch_lightning.metrics.utils import to_categorical as __tc
from pytorch_lightning.metrics.utils import to_onehot as __to
from pytorch_lightning.metrics.utils import class_reduce, get_num_classes, reduce, to_categorical
from pytorch_lightning.utilities import rank_zero_warn


def to_onehot(
tensor: torch.Tensor,
num_classes: Optional[int] = None,
) -> torch.Tensor:
"""
Converts a dense label tensor to one-hot format

.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.to_onehot`
"""
rank_zero_warn(
"This `to_onehot` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.utils import to_onehot`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __to(tensor, num_classes)


def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
"""
Converts a tensor of probabilities to a dense label tensor

.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.to_categorical`

"""
rank_zero_warn(
"This `to_categorical` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.utils import to_categorical`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __tc(tensor)


def get_num_classes(
pred: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
) -> int:
"""
Calculates the number of classes for a given prediction and target tensor.

.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.get_num_classes`

"""
rank_zero_warn(
"This `get_num_classes` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.utils import get_num_classes`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __gnc(pred, target, num_classes)


def stat_scores(
pred: torch.Tensor,
target: torch.Tensor,
Expand Down Expand Up @@ -122,6 +63,7 @@ def stat_scores(
return tp, fp, tn, fn, sup


# todo: remove in 1.4
def stat_scores_multiple_classes(
pred: torch.Tensor,
target: torch.Tensor,
Expand Down Expand Up @@ -210,6 +152,7 @@ def _confmat_normalize(cm):
return cm


# todo: remove in 1.4
def precision_recall(
pred: torch.Tensor,
target: torch.Tensor,
Expand Down Expand Up @@ -268,6 +211,7 @@ def precision_recall(
return precision, recall


# todo: remove in 1.4
def precision(
pred: torch.Tensor,
target: torch.Tensor,
Expand Down Expand Up @@ -311,6 +255,7 @@ def precision(
return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[0]


# todo: remove in 1.4
def recall(
pred: torch.Tensor,
target: torch.Tensor,
Expand Down Expand Up @@ -353,128 +298,7 @@ def recall(
return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[1]


# todo: remove in 1.3
def roc(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.

.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc`
"""
rank_zero_warn(
"This `multiclass_roc` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.functional.roc import roc`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __roc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label)


# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py
def _roc(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.

.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc`

Example:

>>> x = torch.tensor([0, 1, 2, 3])
>>> y = torch.tensor([0, 1, 1, 1])
>>> fpr, tpr, thresholds = _roc(x, y)
>>> fpr
tensor([0., 0., 0., 0., 1.])
>>> tpr
tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000])
>>> thresholds
tensor([4, 3, 2, 1, 0])

"""
rank_zero_warn(
"This `multiclass_roc` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.functional.roc import roc`."
" It will be removed in v1.3.0", DeprecationWarning
)
fps, tps, thresholds = _binary_clf_curve(pred, target, sample_weights=sample_weight, pos_label=pos_label)

# Add an extra threshold position
# to make sure that the curve starts at (0, 0)
tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
thresholds = torch.cat([thresholds[0][None] + 1, thresholds])

if fps[-1] <= 0:
raise ValueError("No negative samples in targets, false positive value should be meaningless")

fpr = fps / fps[-1]

if tps[-1] <= 0:
raise ValueError("No positive samples in targets, true positive value should be meaningless")

tpr = tps / tps[-1]

return fpr, tpr, thresholds


# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py
def multiclass_roc(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
num_classes: Optional[int] = None,
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Computes the Receiver Operating Characteristic (ROC) for multiclass predictors.

.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc`

Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weights
num_classes: number of classes (default: None, computes automatically from data)

Return:
returns roc for each class.
Number of classes, false-positive rate (fpr), true-positive rate (tpr), thresholds

Example:

>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
... [0.05, 0.85, 0.05, 0.05],
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE
((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])))
"""
rank_zero_warn(
"This `multiclass_roc` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.functional.roc import roc`."
" It will be removed in v1.3.0", DeprecationWarning
)
num_classes = get_num_classes(pred, target, num_classes)

class_roc_vals = []
for c in range(num_classes):
pred_c = pred[:, c]

class_roc_vals.append(_roc(pred=pred_c, target=target, sample_weight=sample_weight, pos_label=c))

return tuple(class_roc_vals)


# todo: remove in 1.4
def auc(
x: torch.Tensor,
y: torch.Tensor,
Expand Down Expand Up @@ -508,6 +332,7 @@ def auc(
return __auc(x, y)


# todo: remove in 1.4
def auc_decorator() -> Callable:
rank_zero_warn("This `auc_decorator` was deprecated in v1.2.0." " It will be removed in v1.4.0", DeprecationWarning)

Expand All @@ -524,6 +349,7 @@ def new_func(*args, **kwargs) -> torch.Tensor:
return wrapper


# todo: remove in 1.4
def multiclass_auc_decorator() -> Callable:
rank_zero_warn(
"This `multiclass_auc_decorator` was deprecated in v1.2.0."
Expand All @@ -546,6 +372,7 @@ def new_func(*args, **kwargs) -> torch.Tensor:
return wrapper


# todo: remove in 1.4
def auroc(
pred: torch.Tensor,
target: torch.Tensor,
Expand Down Expand Up @@ -588,6 +415,7 @@ def auroc(
)


# todo: remove in 1.4
def multiclass_auroc(
pred: torch.Tensor,
target: torch.Tensor,
Expand Down Expand Up @@ -767,68 +595,3 @@ def iou(
num_classes=num_classes,
reduction=reduction
)


# todo: remove in 1.3
def precision_recall_curve(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
):
"""
Computes precision-recall pairs for different thresholds.

.. warning :: Deprecated in favor of
:func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve`
"""
rank_zero_warn(
"This `precision_recall_curve` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __prc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label)


# todo: remove in 1.3
def multiclass_precision_recall_curve(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
num_classes: Optional[int] = None,
):
"""
Computes precision-recall pairs for different thresholds given a multiclass scores.

.. warning :: Deprecated in favor of
:func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve`
"""
rank_zero_warn(
"This `multiclass_precision_recall_curve` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve`."
" It will be removed in v1.3.0", DeprecationWarning
)
if num_classes is None:
num_classes = get_num_classes(pred, target, num_classes)
return __prc(preds=pred, target=target, sample_weights=sample_weight, num_classes=num_classes)


# todo: remove in 1.3
def average_precision(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
):
"""
Compute average precision from prediction scores.

.. warning :: Deprecated in favor of
:func:`~pytorch_lightning.metrics.functional.average_precision.average_precision`
"""
rank_zero_warn(
"This `average_precision` was deprecated in v1.1.0 in favor of"
" `pytorch_lightning.metrics.functional.average_precision import average_precision`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __ap(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label)
3 changes: 1 addition & 2 deletions pytorch_lightning/metrics/functional/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
import torch

from pytorch_lightning.metrics.functional.confusion_matrix import _confusion_matrix_update
from pytorch_lightning.metrics.functional.reduction import reduce
from pytorch_lightning.metrics.utils import get_num_classes
from pytorch_lightning.metrics.utils import get_num_classes, reduce


def _iou_from_confmat(
Expand Down
Loading