diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 854acc255cb5b..c8e1afe46f870 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -84,8 +84,6 @@ def get_num_classes( return __gnc(pred,target, num_classes) - - def stat_scores( pred: torch.Tensor, target: torch.Tensor, @@ -448,9 +446,15 @@ def roc( sample_weight: Optional[Sequence] = None, pos_label: int = 1., ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary. + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc` + """ rank_zero_warn( - "This `multiclass_roc` was deprecated in v1.1.0 in favor of `pytorch_lightning.metrics.functional.roc.roc`." - "It will be removed in v1.3.0", DeprecationWarning + "This `multiclass_roc` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.functional.roc import roc`." + " It will be removed in v1.3.0", DeprecationWarning ) return __roc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) @@ -465,16 +469,7 @@ def _roc( """ Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary. - .. warning:: Deprecated - - Args: - pred: estimated probabilities - target: ground-truth labels - sample_weight: sample weights - pos_label: the label for the positive class - - Return: - false-positive rate (fpr), true-positive rate (tpr), thresholds + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc` Example: @@ -490,8 +485,9 @@ def _roc( """ rank_zero_warn( - "This `multiclass_roc` was deprecated in v1.1.0 in favor of `pytorch_lightning.metrics.functional.roc.roc`." - "It will be removed in v1.3.0", DeprecationWarning + "This `multiclass_roc` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.functional.roc import roc`." + " It will be removed in v1.3.0", DeprecationWarning ) fps, tps, thresholds = _binary_clf_curve( pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label @@ -526,7 +522,7 @@ def multiclass_roc( """ Computes the Receiver Operating Characteristic (ROC) for multiclass predictors. - .. warning:: Deprecated + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc` Args: pred: estimated probabilities @@ -869,6 +865,11 @@ def precision_recall_curve( sample_weight: Optional[Sequence] = None, pos_label: int = 1., ): + """ + Computes precision-recall pairs for different thresholds. + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve` + """ rank_zero_warn( "This `precision_recall_curve` was deprecated in v1.1.0 in favor of" " `from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve`." @@ -884,6 +885,11 @@ def multiclass_precision_recall_curve( sample_weight: Optional[Sequence] = None, num_classes: Optional[int] = None, ): + """ + Computes precision-recall pairs for different thresholds given a multiclass scores. + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve` + """ rank_zero_warn( "This `multiclass_precision_recall_curve` was deprecated in v1.1.0 in favor of" " `from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve`." @@ -901,9 +907,14 @@ def average_precision( sample_weight: Optional[Sequence] = None, pos_label: int = 1., ): + """ + Compute average precision from prediction scores. + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.average_precision.average_precision` + """ rank_zero_warn( "This `average_precision` was deprecated in v1.1.0 in favor of" - " `pytorch_lightning.metrics.functional.average_precision import precision_recall_curve`." + " `pytorch_lightning.metrics.functional.average_precision import average_precision`." " It will be removed in v1.3.0", DeprecationWarning ) return __ap(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) diff --git a/pytorch_lightning/metrics/functional/reduction.py b/pytorch_lightning/metrics/functional/reduction.py index bc1425b71d507..c116b16d363a9 100644 --- a/pytorch_lightning/metrics/functional/reduction.py +++ b/pytorch_lightning/metrics/functional/reduction.py @@ -32,4 +32,4 @@ def class_reduce(num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, " `pytorch_lightning.metrics.utils import class_reduce`." " It will be removed in v1.3.0", DeprecationWarning ) - return __cr(num=num, denom=denom, weights=weights, class_reduction=class_reduction) \ No newline at end of file + return __cr(num=num, denom=denom, weights=weights, class_reduction=class_reduction)