Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Dec 10, 2020
1 parent c8473e2 commit cdc6fc7
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 19 deletions.
47 changes: 29 additions & 18 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ def get_num_classes(
return __gnc(pred,target, num_classes)




def stat_scores(
pred: torch.Tensor,
target: torch.Tensor,
Expand Down Expand Up @@ -448,9 +446,15 @@ def roc(
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc`
"""
rank_zero_warn(
"This `multiclass_roc` was deprecated in v1.1.0 in favor of `pytorch_lightning.metrics.functional.roc.roc`."
"It will be removed in v1.3.0", DeprecationWarning
"This `multiclass_roc` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.functional.roc import roc`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __roc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label)

Expand All @@ -465,16 +469,7 @@ def _roc(
"""
Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.
.. warning:: Deprecated
Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weights
pos_label: the label for the positive class
Return:
false-positive rate (fpr), true-positive rate (tpr), thresholds
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc`
Example:
Expand All @@ -490,8 +485,9 @@ def _roc(
"""
rank_zero_warn(
"This `multiclass_roc` was deprecated in v1.1.0 in favor of `pytorch_lightning.metrics.functional.roc.roc`."
"It will be removed in v1.3.0", DeprecationWarning
"This `multiclass_roc` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.functional.roc import roc`."
" It will be removed in v1.3.0", DeprecationWarning
)
fps, tps, thresholds = _binary_clf_curve(
pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label
Expand Down Expand Up @@ -526,7 +522,7 @@ def multiclass_roc(
"""
Computes the Receiver Operating Characteristic (ROC) for multiclass predictors.
.. warning:: Deprecated
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc`
Args:
pred: estimated probabilities
Expand Down Expand Up @@ -869,6 +865,11 @@ def precision_recall_curve(
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
):
"""
Computes precision-recall pairs for different thresholds.
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve`
"""
rank_zero_warn(
"This `precision_recall_curve` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve`."
Expand All @@ -884,6 +885,11 @@ def multiclass_precision_recall_curve(
sample_weight: Optional[Sequence] = None,
num_classes: Optional[int] = None,
):
"""
Computes precision-recall pairs for different thresholds given a multiclass scores.
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve`
"""
rank_zero_warn(
"This `multiclass_precision_recall_curve` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve`."
Expand All @@ -901,9 +907,14 @@ def average_precision(
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
):
"""
Compute average precision from prediction scores.
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.average_precision.average_precision`
"""
rank_zero_warn(
"This `average_precision` was deprecated in v1.1.0 in favor of"
" `pytorch_lightning.metrics.functional.average_precision import precision_recall_curve`."
" `pytorch_lightning.metrics.functional.average_precision import average_precision`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __ap(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label)
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/functional/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ def class_reduce(num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor,
" `pytorch_lightning.metrics.utils import class_reduce`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __cr(num=num, denom=denom, weights=weights, class_reduction=class_reduction)
return __cr(num=num, denom=denom, weights=weights, class_reduction=class_reduction)

0 comments on commit cdc6fc7

Please sign in to comment.