From ee58904cbc28ce470d4775e91dbed94c2b2ffc25 Mon Sep 17 00:00:00 2001 From: oguz-hanoglu Date: Thu, 2 Nov 2023 10:55:04 +0000 Subject: [PATCH 01/10] fix typo of specicity --- src/torchmetrics/functional/classification/__init__.py | 4 ++-- .../functional/classification/specificity_sensitivity.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index 069fc3625ad..af4892496fe 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -119,7 +119,7 @@ binary_specificity_at_sensitivity, multiclass_specificity_at_sensitivity, multilabel_specificity_at_sensitivity, - specicity_at_sensitivity, + specificity_at_sensitivity, ) from torchmetrics.functional.classification.stat_scores import ( binary_stat_scores, @@ -211,7 +211,7 @@ "binary_specificity_at_sensitivity", "multiclass_specificity_at_sensitivity", "multilabel_specificity_at_sensitivity", - "specicity_at_sensitivity", + "specificity_at_sensitivity", "binary_stat_scores", "multiclass_stat_scores", "multilabel_stat_scores", diff --git a/src/torchmetrics/functional/classification/specificity_sensitivity.py b/src/torchmetrics/functional/classification/specificity_sensitivity.py index a44948f8570..96ac34d17bb 100644 --- a/src/torchmetrics/functional/classification/specificity_sensitivity.py +++ b/src/torchmetrics/functional/classification/specificity_sensitivity.py @@ -403,7 +403,7 @@ def multilabel_specificity_at_sensitivity( return _multilabel_specificity_at_sensitivity_compute(state, num_labels, thresholds, ignore_index, min_sensitivity) -def specicity_at_sensitivity( +def specificity_at_sensitivity( preds: Tensor, target: Tensor, task: Literal["binary", "multiclass", "multilabel"], @@ -414,7 +414,7 @@ def specicity_at_sensitivity( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: - r"""Compute the highest possible specicity value given the minimum sensitivity thresholds provided. + r"""Compute the highest possible specificity value given the minimum sensitivity thresholds provided. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the find the specificity for a given sensitivity level. From 54307c2f5280671cc9efb7fa111d2cc60b9d0738 Mon Sep 17 00:00:00 2001 From: oguz-hanoglu Date: Thu, 2 Nov 2023 11:45:52 +0000 Subject: [PATCH 02/10] __init__.py's organized ordered similarly, duplicates removed. --- src/torchmetrics/classification/__init__.py | 31 +++++++++---------- .../functional/classification/__init__.py | 4 +-- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 684f0f2ae9f..079119a6f0d 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -117,18 +117,6 @@ ) __all__ = [ - "BinaryConfusionMatrix", - "ConfusionMatrix", - "MulticlassConfusionMatrix", - "MultilabelConfusionMatrix", - "PrecisionRecallCurve", - "BinaryPrecisionRecallCurve", - "MulticlassPrecisionRecallCurve", - "MultilabelPrecisionRecallCurve", - "BinaryStatScores", - "MulticlassStatScores", - "MultilabelStatScores", - "StatScores", "Accuracy", "BinaryAccuracy", "MulticlassAccuracy", @@ -147,6 +135,10 @@ "BinaryCohenKappa", "CohenKappa", "MulticlassCohenKappa", + "BinaryConfusionMatrix", + "ConfusionMatrix", + "MulticlassConfusionMatrix", + "MultilabelConfusionMatrix", "Dice", "ExactMatch", "MulticlassExactMatch", @@ -184,16 +176,21 @@ "MultilabelRecall", "Precision", "Recall", + "BinaryPrecisionRecallCurve", + "MulticlassPrecisionRecallCurve", + "MultilabelPrecisionRecallCurve", + "PrecisionRecallCurve", "MultilabelCoverageError", "MultilabelRankingAveragePrecision", "MultilabelRankingLoss", + "RecallAtFixedPrecision", "BinaryRecallAtFixedPrecision", "MulticlassRecallAtFixedPrecision", "MultilabelRecallAtFixedPrecision", - "ROC", "BinaryROC", "MulticlassROC", "MultilabelROC", + "ROC", "BinarySpecificity", "MulticlassSpecificity", "MultilabelSpecificity", @@ -201,12 +198,12 @@ "BinarySpecificityAtSensitivity", "MulticlassSpecificityAtSensitivity", "MultilabelSpecificityAtSensitivity", - "BinaryPrecisionAtFixedRecall", "SpecificityAtSensitivity", - "MulticlassPrecisionAtFixedRecall", - "MultilabelPrecisionAtFixedRecall", + "BinaryStatScores", + "MulticlassStatScores", + "MultilabelStatScores", + "StatScores", "PrecisionAtFixedRecall", - "RecallAtFixedPrecision", "BinaryPrecisionAtFixedRecall", "MulticlassPrecisionAtFixedRecall", "MultilabelPrecisionAtFixedRecall", diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index af4892496fe..919b7faf9e6 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -165,8 +165,6 @@ "multilabel_fbeta_score", "binary_fairness", "binary_groups_stat_rates", - "demographic_parity", - "equal_opportunity", "binary_hamming_distance", "hamming_distance", "multiclass_hamming_distance", @@ -219,4 +217,6 @@ "binary_precision_at_fixed_recall", "multilabel_precision_at_fixed_recall", "multiclass_precision_at_fixed_recall", + "demographic_parity", + "equal_opportunity", ] From 36be5326b598dd61d029b01a1de26a4cfa5b873b Mon Sep 17 00:00:00 2001 From: oguz-hanoglu Date: Thu, 23 Nov 2023 12:33:28 +0000 Subject: [PATCH 03/10] keeps the old one around with deprecation warning --- .../classification/specificity_sensitivity.py | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/src/torchmetrics/functional/classification/specificity_sensitivity.py b/src/torchmetrics/functional/classification/specificity_sensitivity.py index 96ac34d17bb..30a4184e3ed 100644 --- a/src/torchmetrics/functional/classification/specificity_sensitivity.py +++ b/src/torchmetrics/functional/classification/specificity_sensitivity.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import List, Optional, Tuple, Union import torch @@ -403,6 +404,56 @@ def multilabel_specificity_at_sensitivity( return _multilabel_specificity_at_sensitivity_compute(state, num_labels, thresholds, ignore_index, min_sensitivity) +# create specicity_at_sensitivity that calls specificity_at_sensitivity +def specicity_at_sensitivity( + preds: Tensor, + target: Tensor, + task: Literal["binary", "multiclass", "multilabel"], + min_sensitivity: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + r"""Compute the highest possible specificity value given the minimum sensitivity thresholds provided. + + This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and + the find the specificity for a given sensitivity level. + + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :func:`~torchmetrics.functional.classification.binary_specificity_at_sensitivity`, + :func:`~torchmetrics.functional.classification.multiclass_specificity_at_sensitivity` and + :func:`~torchmetrics.functional.classification.multilabel_specificity_at_sensitivity` for the specific details of + each argument influence and examples. + + """ + warnings.warn( + "This method has will be removed in 2.0.0. Use `specificity_at_sensitivity` instead.", + DeprecationWarning, + stacklevel=1, + ) + task = ClassificationTask.from_str(task) + if task == ClassificationTask.BINARY: + return binary_specificity_at_sensitivity( # type: ignore + preds, target, min_sensitivity, thresholds, ignore_index, validate_args + ) + if task == ClassificationTask.MULTICLASS: + if not isinstance(num_classes, int): + raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") + return multiclass_specificity_at_sensitivity( # type: ignore + preds, target, num_classes, min_sensitivity, thresholds, ignore_index, validate_args + ) + if task == ClassificationTask.MULTILABEL: + if not isinstance(num_labels, int): + raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") + return multilabel_specificity_at_sensitivity( # type: ignore + preds, target, num_labels, min_sensitivity, thresholds, ignore_index, validate_args + ) + raise ValueError(f"Not handled value: {task}") + + def specificity_at_sensitivity( preds: Tensor, target: Tensor, From cc60ed2ddd8757346c116d99fde7167215074274 Mon Sep 17 00:00:00 2001 From: oguz-hanoglu Date: Thu, 23 Nov 2023 12:38:55 +0000 Subject: [PATCH 04/10] __init__.py also includes the version with typo --- src/torchmetrics/functional/classification/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index 919b7faf9e6..514cef8091d 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -119,6 +119,7 @@ binary_specificity_at_sensitivity, multiclass_specificity_at_sensitivity, multilabel_specificity_at_sensitivity, + specicity_at_sensitivity, specificity_at_sensitivity, ) from torchmetrics.functional.classification.stat_scores import ( @@ -209,6 +210,7 @@ "binary_specificity_at_sensitivity", "multiclass_specificity_at_sensitivity", "multilabel_specificity_at_sensitivity", + "specicity_at_sensitivity", "specificity_at_sensitivity", "binary_stat_scores", "multiclass_stat_scores", From efe141587ca002f7657e98c2eb168a8ad24e8af3 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 28 Nov 2023 17:10:11 +0100 Subject: [PATCH 05/10] Apply suggestions from code review --- .../classification/specificity_sensitivity.py | 29 +++++++------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/src/torchmetrics/functional/classification/specificity_sensitivity.py b/src/torchmetrics/functional/classification/specificity_sensitivity.py index 30a4184e3ed..222cf75415b 100644 --- a/src/torchmetrics/functional/classification/specificity_sensitivity.py +++ b/src/torchmetrics/functional/classification/specificity_sensitivity.py @@ -434,24 +434,17 @@ def specicity_at_sensitivity( DeprecationWarning, stacklevel=1, ) - task = ClassificationTask.from_str(task) - if task == ClassificationTask.BINARY: - return binary_specificity_at_sensitivity( # type: ignore - preds, target, min_sensitivity, thresholds, ignore_index, validate_args - ) - if task == ClassificationTask.MULTICLASS: - if not isinstance(num_classes, int): - raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") - return multiclass_specificity_at_sensitivity( # type: ignore - preds, target, num_classes, min_sensitivity, thresholds, ignore_index, validate_args - ) - if task == ClassificationTask.MULTILABEL: - if not isinstance(num_labels, int): - raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") - return multilabel_specificity_at_sensitivity( # type: ignore - preds, target, num_labels, min_sensitivity, thresholds, ignore_index, validate_args - ) - raise ValueError(f"Not handled value: {task}") + return specificity_at_sensitivity( + preds=preds, + target=target, + task=task, + min_sensitivity=min_sensitivity, + thresholds=thresholds, + num_classes=num_classes, + num_labels=num_labels, + ignore_index=ignore_index, + validate_args=validate_args + ) def specificity_at_sensitivity( From 90de3b12ea161de28b50d371fcbd1bbc03a201a9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Nov 2023 16:12:42 +0000 Subject: [PATCH 06/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../classification/specificity_sensitivity.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/torchmetrics/functional/classification/specificity_sensitivity.py b/src/torchmetrics/functional/classification/specificity_sensitivity.py index 222cf75415b..cf176632158 100644 --- a/src/torchmetrics/functional/classification/specificity_sensitivity.py +++ b/src/torchmetrics/functional/classification/specificity_sensitivity.py @@ -435,15 +435,15 @@ def specicity_at_sensitivity( stacklevel=1, ) return specificity_at_sensitivity( - preds=preds, - target=target, - task=task, - min_sensitivity=min_sensitivity, - thresholds=thresholds, - num_classes=num_classes, - num_labels=num_labels, - ignore_index=ignore_index, - validate_args=validate_args + preds=preds, + target=target, + task=task, + min_sensitivity=min_sensitivity, + thresholds=thresholds, + num_classes=num_classes, + num_labels=num_labels, + ignore_index=ignore_index, + validate_args=validate_args, ) From 006e9e41c7d166d01e6ea03cd5a86160a506d00d Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 28 Nov 2023 17:16:02 +0100 Subject: [PATCH 07/10] changelog --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 72549469d59..f07aeb13493 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,12 +33,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Changed minimum supported Pytorch version from 1.8 to 1.10 ([#2145](https://github.com/Lightning-AI/torchmetrics/pull/2145)) +- Changed x-/y-axis order for `PrecisionRecallCurve` to be consistent with scikit-learn ([#2183](https://github.com/Lightning-AI/torchmetrics/pull/2183)) + + ### Deprecated - Deprecated `metric._update_called` ([#2141](https://github.com/Lightning-AI/torchmetrics/pull/2141)) -- Changed x-/y-axis order for `PrecisionRecallCurve` to be consistent with scikit-learn ([#2183](https://github.com/Lightning-AI/torchmetrics/pull/2183)) +- Deprecated `specicity_at_sensitivity` in favour of `specificity_at_sensitivity` ([#2199](https://github.com/Lightning-AI/torchmetrics/pull/2199)) ### Removed From e3cf3f4f2b4fc628983764e75e521ed2c86c8ebf Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 28 Nov 2023 17:16:26 +0100 Subject: [PATCH 08/10] Update src/torchmetrics/functional/classification/specificity_sensitivity.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- .../functional/classification/specificity_sensitivity.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/torchmetrics/functional/classification/specificity_sensitivity.py b/src/torchmetrics/functional/classification/specificity_sensitivity.py index cf176632158..edc0b9b226e 100644 --- a/src/torchmetrics/functional/classification/specificity_sensitivity.py +++ b/src/torchmetrics/functional/classification/specificity_sensitivity.py @@ -404,7 +404,6 @@ def multilabel_specificity_at_sensitivity( return _multilabel_specificity_at_sensitivity_compute(state, num_labels, thresholds, ignore_index, min_sensitivity) -# create specicity_at_sensitivity that calls specificity_at_sensitivity def specicity_at_sensitivity( preds: Tensor, target: Tensor, From f20eb8c090e3e4ade613300243b60323ce00d3f7 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 28 Nov 2023 17:19:59 +0100 Subject: [PATCH 09/10] add doc warning --- .../classification/specificity_sensitivity.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/torchmetrics/functional/classification/specificity_sensitivity.py b/src/torchmetrics/functional/classification/specificity_sensitivity.py index edc0b9b226e..15200ea0e61 100644 --- a/src/torchmetrics/functional/classification/specificity_sensitivity.py +++ b/src/torchmetrics/functional/classification/specificity_sensitivity.py @@ -417,15 +417,9 @@ def specicity_at_sensitivity( ) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the highest possible specificity value given the minimum sensitivity thresholds provided. - This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and - the find the specificity for a given sensitivity level. - - This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the - ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of - :func:`~torchmetrics.functional.classification.binary_specificity_at_sensitivity`, - :func:`~torchmetrics.functional.classification.multiclass_specificity_at_sensitivity` and - :func:`~torchmetrics.functional.classification.multilabel_specificity_at_sensitivity` for the specific details of - each argument influence and examples. + .. warning:: + This function was deprecated in v1.3.0 of Torchmetrics and will be removed in v2.0.0. + Use `specificity_at_sensitivity` instead. """ warnings.warn( From 3d84a9e58ae768d4f8b194f57bf4c245d8d63a1a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Nov 2023 16:21:25 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../functional/classification/specificity_sensitivity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/classification/specificity_sensitivity.py b/src/torchmetrics/functional/classification/specificity_sensitivity.py index 15200ea0e61..d85b47eb453 100644 --- a/src/torchmetrics/functional/classification/specificity_sensitivity.py +++ b/src/torchmetrics/functional/classification/specificity_sensitivity.py @@ -418,7 +418,7 @@ def specicity_at_sensitivity( r"""Compute the highest possible specificity value given the minimum sensitivity thresholds provided. .. warning:: - This function was deprecated in v1.3.0 of Torchmetrics and will be removed in v2.0.0. + This function was deprecated in v1.3.0 of Torchmetrics and will be removed in v2.0.0. Use `specificity_at_sensitivity` instead. """