diff --git a/.azure/gpu-pipeline.yml b/.azure/gpu-pipeline.yml index c071b72ff01..8f97521b477 100644 --- a/.azure/gpu-pipeline.yml +++ b/.azure/gpu-pipeline.yml @@ -32,7 +32,7 @@ jobs: agents: 'azure-jirka-spot' #maxParallel: '2' # how long to run the job before automatically cancelling - timeoutInMinutes: "55" + timeoutInMinutes: "65" # how much time to give 'run always even if cancelled tasks' before stopping them cancelTimeoutInMinutes: "2" diff --git a/.github/actions/unittesting/action.yml b/.github/actions/unittesting/action.yml index e69fce887b9..cf45d95e5e6 100644 --- a/.github/actions/unittesting/action.yml +++ b/.github/actions/unittesting/action.yml @@ -57,7 +57,7 @@ runs: - name: Unittests working-directory: ./tests - run: python -m pytest ${{ inputs.dirs }} --cov=torchmetrics --junitxml="$PYTEST_ARTEFACT.xml" --durations=50 ${{ inputs.test-timeout }} + run: python -m pytest -v --maxfail=5 ${{ inputs.dirs }} --cov=torchmetrics --junitxml="$PYTEST_ARTEFACT.xml" --durations=50 ${{ inputs.test-timeout }} shell: ${{ inputs.shell-type }} - name: Upload pytest results diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml index 0fd16ca43d2..6cb556f2042 100644 --- a/.github/workflows/ci_test-conda.yml +++ b/.github/workflows/ci_test-conda.yml @@ -77,9 +77,11 @@ jobs: python ./.github/assistant.py prune-packages requirements/detection.txt torchvision # import of PILLOW_VERSION which they recently removed in v9.0 in favor of __version__ pip install -q "Pillow<9.0" # It messes with torchvision - pip install -e . -r requirements/devel.txt -f https://download.pytorch.org/whl/cpu/torch_stable.html + pip install -e . -r requirements/devel.txt "torch==${{ matrix.pytorch-version }}.*" -f $TORCH_URL pip list python -c "from torch import __version__ as ver; assert '.'.join(ver.split('.')[:2]) == '${{ matrix.pytorch-version }}', ver" + env: + TORCH_URL: https://download.pytorch.org/whl/cpu/torch_stable.html - name: DocTests working-directory: ./src diff --git a/.github/workflows/code-format.yml b/.github/workflows/code-format.yml index 41340d8138e..003302733b0 100644 --- a/.github/workflows/code-format.yml +++ b/.github/workflows/code-format.yml @@ -2,8 +2,7 @@ name: Code formatting # see: https://help.github.com/en/actions/reference/events-that-trigger-workflows on: # Trigger the workflow on push or pull request, but only for the master branch - push: - branches: [master, "release/*"] + push: {} pull_request: branches: [master, "release/*"] diff --git a/CHANGELOG.md b/CHANGELOG.md index 3cac9ed61cf..6662ec8fa14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Classification refactor ( + [#1054](https://github.com/Lightning-AI/metrics/pull/1054), + [#1143](https://github.com/Lightning-AI/metrics/pull/1143), + [#1145](https://github.com/Lightning-AI/metrics/pull/1145), + [#1151](https://github.com/Lightning-AI/metrics/pull/1151), + [#1159](https://github.com/Lightning-AI/metrics/pull/1159), + [#1163](https://github.com/Lightning-AI/metrics/pull/1163), + [#1167](https://github.com/Lightning-AI/metrics/pull/1167), + [#1175](https://github.com/Lightning-AI/metrics/pull/1175), + [#1189](https://github.com/Lightning-AI/metrics/pull/1189), + [#1197](https://github.com/Lightning-AI/metrics/pull/1197), + [#1215](https://github.com/Lightning-AI/metrics/pull/1215), + [#1195](https://github.com/Lightning-AI/metrics/pull/1195) +) + - Changed update in `FID` metric to be done in a online fashion to save memory ([#1199](https://github.com/PyTorchLightning/metrics/pull/1199)) diff --git a/docs/source/classification/accuracy.rst b/docs/source/classification/accuracy.rst index 4ca5becef90..9c22eab6b71 100644 --- a/docs/source/classification/accuracy.rst +++ b/docs/source/classification/accuracy.rst @@ -13,8 +13,44 @@ ________________ .. autoclass:: torchmetrics.Accuracy :noindex: +BinaryAccuracy +^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinaryAccuracy + :noindex: + +MulticlassAccuracy +^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassAccuracy + :noindex: + +MultilabelAccuracy +^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MultilabelAccuracy + :noindex: + Functional Interface ____________________ -.. autofunction:: torchmetrics.functional.accuracy +.. autofunction:: torchmetrics.functional.classification.accuracy + :noindex: + +binary_accuracy +^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_accuracy + :noindex: + +multiclass_accuracy +^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_accuracy + :noindex: + +multilabel_accuracy +^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multilabel_accuracy :noindex: diff --git a/docs/source/classification/auroc.rst b/docs/source/classification/auroc.rst index 6c0a2342b8b..99f3ee83862 100644 --- a/docs/source/classification/auroc.rst +++ b/docs/source/classification/auroc.rst @@ -15,8 +15,44 @@ ________________ .. autoclass:: torchmetrics.AUROC :noindex: +BinaryAUROC +^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinaryAUROC + :noindex: + +MulticlassAUROC +^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassAUROC + :noindex: + +MultilabelAUROC +^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MultilabelAUROC + :noindex: + Functional Interface ____________________ .. autofunction:: torchmetrics.functional.auroc :noindex: + +binary_auroc +^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_auroc + :noindex: + +multiclass_auroc +^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_auroc + :noindex: + +multilabel_auroc +^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multilabel_auroc + :noindex: diff --git a/docs/source/classification/average_precision.rst b/docs/source/classification/average_precision.rst index 2061c284400..f75428832e1 100644 --- a/docs/source/classification/average_precision.rst +++ b/docs/source/classification/average_precision.rst @@ -13,8 +13,44 @@ ________________ .. autoclass:: torchmetrics.AveragePrecision :noindex: +BinaryAveragePrecision +^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinaryAveragePrecision + :noindex: + +MulticlassAveragePrecision +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassAveragePrecision + :noindex: + +MultilabelAveragePrecision +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MultilabelAveragePrecision + :noindex: + Functional Interface ____________________ .. autofunction:: torchmetrics.functional.average_precision :noindex: + +binary_average_precision +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_average_precision + :noindex: + +multiclass_average_precision +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_average_precision + :noindex: + +multilabel_average_precision +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multilabel_average_precision + :noindex: diff --git a/docs/source/classification/calibration_error.rst b/docs/source/classification/calibration_error.rst index b10cb834f5f..baa2e5e0b1b 100644 --- a/docs/source/classification/calibration_error.rst +++ b/docs/source/classification/calibration_error.rst @@ -15,8 +15,32 @@ ________________ .. autoclass:: torchmetrics.CalibrationError :noindex: +BinaryCalibrationError +^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinaryCalibrationError + :noindex: + +MulticlassCalibrationError +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassCalibrationError + :noindex: + Functional Interface ____________________ .. autofunction:: torchmetrics.functional.calibration_error :noindex: + +binary_calibration_error +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_calibration_error + :noindex: + +multiclass_calibration_error +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_calibration_error + :noindex: diff --git a/docs/source/classification/cohen_kappa.rst b/docs/source/classification/cohen_kappa.rst index 41127fa4187..e5f924e71d2 100644 --- a/docs/source/classification/cohen_kappa.rst +++ b/docs/source/classification/cohen_kappa.rst @@ -12,11 +12,43 @@ Cohen Kappa Module Interface ________________ +CohenKappa +^^^^^^^^^^ + .. autoclass:: torchmetrics.CohenKappa :noindex: +BinaryCohenKappa +^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinaryCohenKappa + :noindex: + :exclude-members: update, compute + +MulticlassCohenKappa +^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassCohenKappa + :noindex: + :exclude-members: update, compute + Functional Interface ____________________ +cohen_kappa +^^^^^^^^^^^ + .. autofunction:: torchmetrics.functional.cohen_kappa :noindex: + +binary_cohen_kappa +^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_cohen_kappa + :noindex: + +multiclass_cohen_kappa +^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_cohen_kappa + :noindex: diff --git a/docs/source/classification/confusion_matrix.rst b/docs/source/classification/confusion_matrix.rst index a1cc43fdfe9..78d1edb6250 100644 --- a/docs/source/classification/confusion_matrix.rst +++ b/docs/source/classification/confusion_matrix.rst @@ -12,11 +12,53 @@ Confusion Matrix Module Interface ________________ +ConfusionMatrix +^^^^^^^^^^^^^^^ + .. autoclass:: torchmetrics.ConfusionMatrix :noindex: +BinaryConfusionMatrix +^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinaryConfusionMatrix + :noindex: + +MulticlassConfusionMatrix +^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassConfusionMatrix + :noindex: + +MultilabelConfusionMatrix +^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MultilabelConfusionMatrix + :noindex: + Functional Interface ____________________ +confusion_matrix +^^^^^^^^^^^^^^^^ + .. autofunction:: torchmetrics.functional.confusion_matrix :noindex: + +binary_confusion_matrix +^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_confusion_matrix + :noindex: + +multiclass_confusion_matrix +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_confusion_matrix + :noindex: + +multilabel_confusion_matrix +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multilabel_confusion_matrix + :noindex: diff --git a/docs/source/classification/coverage_error.rst b/docs/source/classification/coverage_error.rst index b1d159658d6..16db100c474 100644 --- a/docs/source/classification/coverage_error.rst +++ b/docs/source/classification/coverage_error.rst @@ -13,8 +13,14 @@ ________________ .. autoclass:: torchmetrics.CoverageError :noindex: +.. autoclass:: torchmetrics.classification.MultilabelCoverageError + :noindex: + Functional Interface ____________________ .. autofunction:: torchmetrics.functional.coverage_error :noindex: + +.. autofunction:: torchmetrics.functional.classification.multilabel_coverage_error + :noindex: diff --git a/docs/source/classification/exact_match.rst b/docs/source/classification/exact_match.rst new file mode 100644 index 00000000000..c3a9000d4c5 --- /dev/null +++ b/docs/source/classification/exact_match.rst @@ -0,0 +1,26 @@ +.. customcarditem:: + :header: Exact Match + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: Classification + +########### +Exact Match +########### + +Module Interface +________________ + +MultilabelExactMatch +^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MultilabelExactMatch + :noindex: + +Functional Interface +____________________ + +multilabel_exact_match +^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multilabel_exact_match + :noindex: diff --git a/docs/source/classification/f1_score.rst b/docs/source/classification/f1_score.rst index 7122faf2453..3bf41a63df5 100644 --- a/docs/source/classification/f1_score.rst +++ b/docs/source/classification/f1_score.rst @@ -1,20 +1,62 @@ .. customcarditem:: - :header: F1 Score + :header: F-1 Score :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg :tags: Classification -######## -F1 Score -######## +######### +F-1 Score +######### Module Interface ________________ +F1Score +^^^^^^^ + .. autoclass:: torchmetrics.F1Score :noindex: +BinaryF1Score +^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinaryF1Score + :noindex: + +MulticlassF1Score +^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassF1Score + :noindex: + +MultilabelF1Score +^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MultilabelF1Score + :noindex: + Functional Interface ____________________ +f1_score +^^^^^^^^ + .. autofunction:: torchmetrics.functional.f1_score :noindex: + +binary_f1_score +^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_f1_score + :noindex: + +multiclass_f1_score +^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_f1_score + :noindex: + +multilabel_f1_score +^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multilabel_f1_score + :noindex: diff --git a/docs/source/classification/fbeta_score.rst b/docs/source/classification/fbeta_score.rst index 05ab8618a34..a00eac98905 100644 --- a/docs/source/classification/fbeta_score.rst +++ b/docs/source/classification/fbeta_score.rst @@ -1,22 +1,64 @@ .. customcarditem:: - :header: FBeta Score + :header: F-Beta Score :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg :tags: Classification .. include:: ../links.rst -########### -FBeta Score -########### +############ +F-Beta Score +############ Module Interface ________________ +FBetaScore +^^^^^^^^^^ + .. autoclass:: torchmetrics.FBetaScore :noindex: +BinaryFBetaScore +^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinaryFBetaScore + :noindex: + +MulticlassFBetaScore +^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassFBetaScore + :noindex: + +MultilabelFBetaScore +^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MultilabelFBetaScore + :noindex: + Functional Interface ____________________ +fbeta_score +^^^^^^^^^^^ + .. autofunction:: torchmetrics.functional.fbeta_score :noindex: + +binary_fbeta_score +^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_fbeta_score + :noindex: + +multiclass_fbeta_score +^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_fbeta_score + :noindex: + +multilabel_fbeta_score +^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multilabel_fbeta_score + :noindex: diff --git a/docs/source/classification/hamming_distance.rst b/docs/source/classification/hamming_distance.rst index 4af52d20492..29cc56b765d 100644 --- a/docs/source/classification/hamming_distance.rst +++ b/docs/source/classification/hamming_distance.rst @@ -10,11 +10,53 @@ Hamming Distance Module Interface ________________ +HammingDistance +^^^^^^^^^^^^^^^ + .. autoclass:: torchmetrics.HammingDistance :noindex: +BinaryHammingDistance +^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinaryHammingDistance + :noindex: + +MulticlassHammingDistance +^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassHammingDistance + :noindex: + +MultilabelHammingDistance +^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MultilabelHammingDistance + :noindex: + Functional Interface ____________________ +hamming_distance +^^^^^^^^^^^^^^^^ + .. autofunction:: torchmetrics.functional.hamming_distance :noindex: + +binary_hamming_distance +^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_hamming_distance + :noindex: + +multiclass_hamming_distance +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_hamming_distance + :noindex: + +multilabel_hamming_distance +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multilabel_hamming_distance + :noindex: diff --git a/docs/source/classification/hinge_loss.rst b/docs/source/classification/hinge_loss.rst index 4aa5562ab6a..5a0b992f2a3 100644 --- a/docs/source/classification/hinge_loss.rst +++ b/docs/source/classification/hinge_loss.rst @@ -13,8 +13,32 @@ ________________ .. autoclass:: torchmetrics.HingeLoss :noindex: +BinaryHingeLoss +^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinaryHingeLoss + :noindex: + +MulticlassHingeLoss +^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassHingeLoss + :noindex: + Functional Interface ____________________ .. autofunction:: torchmetrics.functional.hinge_loss :noindex: + +binary_hinge_loss +^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_hinge_loss + :noindex: + +multiclass_hinge_loss +^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_hinge_loss + :noindex: diff --git a/docs/source/classification/jaccard_index.rst b/docs/source/classification/jaccard_index.rst index b37e264992d..2fb3f3332e0 100644 --- a/docs/source/classification/jaccard_index.rst +++ b/docs/source/classification/jaccard_index.rst @@ -10,11 +10,57 @@ Jaccard Index Module Interface ________________ +CohenKappa +^^^^^^^^^^ + .. autoclass:: torchmetrics.JaccardIndex :noindex: +BinaryJaccardIndex +^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinaryJaccardIndex + :noindex: + :exclude-members: update, compute + +MulticlassJaccardIndex +^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassJaccardIndex + :noindex: + :exclude-members: update, compute + +MultilabelJaccardIndex +^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MultilabelJaccardIndex + :noindex: + :exclude-members: update, compute + + Functional Interface ____________________ +jaccard_index +^^^^^^^^^^^^^ + .. autofunction:: torchmetrics.functional.jaccard_index :noindex: + +binary_jaccard_index +^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_jaccard_index + :noindex: + +multiclass_jaccard_index +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_jaccard_index + :noindex: + +multilabel_jaccard_index +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multilabel_jaccard_index + :noindex: diff --git a/docs/source/classification/label_ranking_average_precision.rst b/docs/source/classification/label_ranking_average_precision.rst index c9b9aa8282a..32f1b0867b5 100644 --- a/docs/source/classification/label_ranking_average_precision.rst +++ b/docs/source/classification/label_ranking_average_precision.rst @@ -13,8 +13,15 @@ ________________ .. autoclass:: torchmetrics.LabelRankingAveragePrecision :noindex: +.. autoclass:: torchmetrics.classification.MultilabelRankingAveragePrecision + :noindex: + + Functional Interface ____________________ .. autofunction:: torchmetrics.functional.label_ranking_average_precision :noindex: + +.. autofunction:: torchmetrics.functional.classification.multilabel_ranking_average_precision + :noindex: diff --git a/docs/source/classification/label_ranking_loss.rst b/docs/source/classification/label_ranking_loss.rst index 766f657ba12..168b2c80ceb 100644 --- a/docs/source/classification/label_ranking_loss.rst +++ b/docs/source/classification/label_ranking_loss.rst @@ -13,8 +13,15 @@ ________________ .. autoclass:: torchmetrics.LabelRankingLoss :noindex: + +.. autoclass:: torchmetrics.classification.MultilabelRankingLoss + :noindex: + Functional Interface ____________________ .. autofunction:: torchmetrics.functional.label_ranking_loss :noindex: + +.. autofunction:: torchmetrics.functional.classification.multilabel_ranking_loss + :noindex: diff --git a/docs/source/classification/matthews_corr_coef.rst b/docs/source/classification/matthews_corr_coef.rst index 1cfe173c24c..7b29686766b 100644 --- a/docs/source/classification/matthews_corr_coef.rst +++ b/docs/source/classification/matthews_corr_coef.rst @@ -5,18 +5,64 @@ .. include:: ../links.rst -#################### -Matthews Corr. Coef. -#################### +################################ +Matthews Correlation Coefficient +################################ Module Interface ________________ +MatthewsCorrCoef +^^^^^^^^^^^^^^^^ + .. autoclass:: torchmetrics.MatthewsCorrCoef :noindex: +BinaryMatthewsCorrCoef +^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinaryMatthewsCorrCoef + :noindex: + :exclude-members: update, compute + +MulticlassMatthewsCorrCoef +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassMatthewsCorrCoef + :noindex: + :exclude-members: update, compute + +MultilabelMatthewsCorrCoef +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MultilabelMatthewsCorrCoef + :noindex: + :exclude-members: update, compute + + Functional Interface ____________________ +matthews_corrcoef +^^^^^^^^^^^^^^^^^ + .. autofunction:: torchmetrics.functional.matthews_corrcoef :noindex: + +binary_matthews_corrcoef +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_matthews_corrcoef + :noindex: + +multiclass_matthews_corrcoef +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_matthews_corrcoef + :noindex: + +multilabel_matthews_corrcoef +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multilabel_matthews_corrcoef + :noindex: diff --git a/docs/source/classification/precision.rst b/docs/source/classification/precision.rst index d73eed7303b..5115c78bcc2 100644 --- a/docs/source/classification/precision.rst +++ b/docs/source/classification/precision.rst @@ -3,6 +3,8 @@ :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg :tags: Classification +.. include:: ../links.rst + ######### Precision ######### @@ -13,8 +15,44 @@ ________________ .. autoclass:: torchmetrics.Precision :noindex: +BinaryPrecision +^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinaryPrecision + :noindex: + +MulticlassPrecision +^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassPrecision + :noindex: + +MultilabelPrecision +^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MultilabelPrecision + :noindex: + Functional Interface ____________________ .. autofunction:: torchmetrics.functional.precision :noindex: + +binary_precision +^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_precision + :noindex: + +multiclass_precision +^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_precision + :noindex: + +multilabel_precision +^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multilabel_precision + :noindex: diff --git a/docs/source/classification/precision_recall_curve.rst b/docs/source/classification/precision_recall_curve.rst index bd457727374..dbe0421ed12 100644 --- a/docs/source/classification/precision_recall_curve.rst +++ b/docs/source/classification/precision_recall_curve.rst @@ -13,8 +13,44 @@ ________________ .. autoclass:: torchmetrics.PrecisionRecallCurve :noindex: +BinaryPrecisionRecallCurve +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinaryPrecisionRecallCurve + :noindex: + +MulticlassPrecisionRecallCurve +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassPrecisionRecallCurve + :noindex: + +MultilabelPrecisionRecallCurve +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MultilabelPrecisionRecallCurve + :noindex: + Functional Interface ____________________ .. autofunction:: torchmetrics.functional.precision_recall_curve :noindex: + +binary_precision_recall_curve +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_precision_recall_curve + :noindex: + +multiclass_precision_recall_curve +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_precision_recall_curve + :noindex: + +multilabel_precision_recall_curve +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multilabel_precision_recall_curve + :noindex: diff --git a/docs/source/classification/recall.rst b/docs/source/classification/recall.rst index 37d9b3dcc3f..771fe4ccfb6 100644 --- a/docs/source/classification/recall.rst +++ b/docs/source/classification/recall.rst @@ -13,8 +13,44 @@ ________________ .. autoclass:: torchmetrics.Recall :noindex: +BinaryRecall +^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinaryRecall + :noindex: + +MulticlassRecall +^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassRecall + :noindex: + +MultilabelRecall +^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MultilabelRecall + :noindex: + Functional Interface ____________________ .. autofunction:: torchmetrics.functional.recall :noindex: + +binary_recall +^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_recall + :noindex: + +multiclass_recall +^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_recall + :noindex: + +multilabel_recall +^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multilabel_recall + :noindex: diff --git a/docs/source/classification/recall_at_fixed_precision.rst b/docs/source/classification/recall_at_fixed_precision.rst new file mode 100644 index 00000000000..7a43aa23064 --- /dev/null +++ b/docs/source/classification/recall_at_fixed_precision.rst @@ -0,0 +1,50 @@ +.. customcarditem:: + :header: Recall At Fixed Precision + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: Classification + +######################### +Recall At Fixed Precision +######################### + +Module Interface +________________ + +BinaryRecallAtFixedPrecision +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinaryRecallAtFixedPrecision + :noindex: + +MulticlassRecallAtFixedPrecision +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassRecallAtFixedPrecision + :noindex: + +MultilabelRecallAtFixedPrecision +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MultilabelRecallAtFixedPrecision + :noindex: + +Functional Interface +____________________ + +binary_recall_at_fixed_precision +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_recall_at_fixed_precision + :noindex: + +multiclass_recall_at_fixed_precision +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_recall_at_fixed_precision + :noindex: + +multilabel_recall_at_fixed_precision +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multilabel_recall_at_fixed_precision + :noindex: diff --git a/docs/source/classification/roc.rst b/docs/source/classification/roc.rst index 6b3aaea4add..eef65dd845e 100644 --- a/docs/source/classification/roc.rst +++ b/docs/source/classification/roc.rst @@ -13,8 +13,44 @@ ________________ .. autoclass:: torchmetrics.ROC :noindex: +BinaryROC +^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinaryROC + :noindex: + +MulticlassROC +^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassROC + :noindex: + +MultilabelROC +^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MultilabelROC + :noindex: + Functional Interface ____________________ .. autofunction:: torchmetrics.functional.roc :noindex: + +binary_roc +^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_roc + :noindex: + +multiclass_roc +^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_roc + :noindex: + +multilabel_roc +^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multilabel_roc + :noindex: diff --git a/docs/source/classification/specificity.rst b/docs/source/classification/specificity.rst index 4d0aef5eda4..19299e0c863 100644 --- a/docs/source/classification/specificity.rst +++ b/docs/source/classification/specificity.rst @@ -13,8 +13,45 @@ ________________ .. autoclass:: torchmetrics.Specificity :noindex: +BinarySpecificity +^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinarySpecificity + :noindex: + +MulticlassSpecificity +^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassSpecificity + :noindex: + +MultilabelSpecificity +^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MultilabelSpecificity + :noindex: + + Functional Interface ____________________ .. autofunction:: torchmetrics.functional.specificity :noindex: + +binary_specificity +^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_specificity + :noindex: + +multiclass_specificity +^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_specificity + :noindex: + +multilabel_specificity +^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multilabel_specificity + :noindex: diff --git a/docs/source/classification/stat_scores.rst b/docs/source/classification/stat_scores.rst index 809c3106948..000027dbbea 100644 --- a/docs/source/classification/stat_scores.rst +++ b/docs/source/classification/stat_scores.rst @@ -12,11 +12,56 @@ Stat Scores Module Interface ________________ +StatScores +^^^^^^^^^^ + .. autoclass:: torchmetrics.StatScores :noindex: +BinaryStatScores +^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinaryStatScores + :noindex: + :exclude-members: update, compute + +MulticlassStatScores +^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassStatScores + :noindex: + :exclude-members: update, compute + +MultilabelStatScores +^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MultilabelStatScores + :noindex: + :exclude-members: update, compute + Functional Interface ____________________ +stat_scores +^^^^^^^^^^^ + .. autofunction:: torchmetrics.functional.stat_scores :noindex: + +binary_stat_scores +^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_stat_scores + :noindex: + +multiclass_stat_scores +^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_stat_scores + :noindex: + +multilabel_stat_scores +^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multilabel_stat_scores + :noindex: diff --git a/docs/source/pages/overview.rst b/docs/source/pages/overview.rst index 2438703e267..7c01ccda043 100644 --- a/docs/source/pages/overview.rst +++ b/docs/source/pages/overview.rst @@ -183,7 +183,7 @@ the following limitations: - :ref:`image/peak_signal_noise_ratio:Peak Signal-to-Noise Ratio (PSNR)` - :ref:`image/structural_similarity:Structural Similarity Index Measure (SSIM)` - - :ref:`classification/kl_divergence:KL Divergence` + - :ref:`regression/kl_divergence:KL Divergence` You can always check the precision/dtype of the metric by checking the `.dtype` property. diff --git a/docs/source/classification/kl_divergence.rst b/docs/source/regression/kl_divergence.rst similarity index 100% rename from docs/source/classification/kl_divergence.rst rename to docs/source/regression/kl_divergence.rst diff --git a/requirements/classification_test.txt b/requirements/classification_test.txt new file mode 100644 index 00000000000..318026a665c --- /dev/null +++ b/requirements/classification_test.txt @@ -0,0 +1 @@ +netcal # calibration_error diff --git a/requirements/devel.txt b/requirements/devel.txt index 757c79a82ae..05aa5e61b9a 100644 --- a/requirements/devel.txt +++ b/requirements/devel.txt @@ -15,3 +15,4 @@ -r text_test.txt -r audio_test.txt -r detection_test.txt +-r classification_test.txt diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index 483a4d786b7..bb77854d816 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -39,7 +39,6 @@ HammingDistance, HingeLoss, JaccardIndex, - KLDivergence, LabelRankingAveragePrecision, LabelRankingLoss, MatthewsCorrCoef, @@ -63,6 +62,7 @@ from torchmetrics.regression import ( # noqa: E402 CosineSimilarity, ExplainedVariance, + KLDivergence, MeanAbsoluteError, MeanAbsolutePercentageError, MeanSquaredError, diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 70ae4d5179c..07fce8e7bb6 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -11,30 +11,106 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.classification.accuracy import Accuracy # noqa: F401 +from torchmetrics.classification.confusion_matrix import ( # noqa: F401 isort:skip + BinaryConfusionMatrix, + ConfusionMatrix, + MulticlassConfusionMatrix, + MultilabelConfusionMatrix, +) +from torchmetrics.classification.precision_recall_curve import ( # noqa: F401 isort:skip + PrecisionRecallCurve, + BinaryPrecisionRecallCurve, + MulticlassPrecisionRecallCurve, + MultilabelPrecisionRecallCurve, +) +from torchmetrics.classification.stat_scores import ( # noqa: F401 isort:skip + BinaryStatScores, + MulticlassStatScores, + MultilabelStatScores, + StatScores, +) + +from torchmetrics.classification.accuracy import ( # noqa: F401 + Accuracy, + BinaryAccuracy, + MulticlassAccuracy, + MultilabelAccuracy, +) from torchmetrics.classification.auc import AUC # noqa: F401 -from torchmetrics.classification.auroc import AUROC # noqa: F401 -from torchmetrics.classification.avg_precision import AveragePrecision # noqa: F401 +from torchmetrics.classification.auroc import AUROC, BinaryAUROC, MulticlassAUROC, MultilabelAUROC # noqa: F401 +from torchmetrics.classification.average_precision import ( # noqa: F401 + AveragePrecision, + BinaryAveragePrecision, + MulticlassAveragePrecision, + MultilabelAveragePrecision, +) from torchmetrics.classification.binned_precision_recall import BinnedAveragePrecision # noqa: F401 from torchmetrics.classification.binned_precision_recall import BinnedPrecisionRecallCurve # noqa: F401 from torchmetrics.classification.binned_precision_recall import BinnedRecallAtFixedPrecision # noqa: F401 -from torchmetrics.classification.calibration_error import CalibrationError # noqa: F401 -from torchmetrics.classification.cohen_kappa import CohenKappa # noqa: F401 -from torchmetrics.classification.confusion_matrix import ConfusionMatrix # noqa: F401 +from torchmetrics.classification.calibration_error import ( # noqa: F401 + BinaryCalibrationError, + CalibrationError, + MulticlassCalibrationError, +) +from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, CohenKappa, MulticlassCohenKappa # noqa: F401 from torchmetrics.classification.dice import Dice # noqa: F401 -from torchmetrics.classification.f_beta import F1Score, FBetaScore # noqa: F401 -from torchmetrics.classification.hamming import HammingDistance # noqa: F401 -from torchmetrics.classification.hinge import HingeLoss # noqa: F401 -from torchmetrics.classification.jaccard import JaccardIndex # noqa: F401 -from torchmetrics.classification.kl_divergence import KLDivergence # noqa: F401 -from torchmetrics.classification.matthews_corrcoef import MatthewsCorrCoef # noqa: F401 -from torchmetrics.classification.precision_recall import Precision, Recall # noqa: F401 -from torchmetrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401 +from torchmetrics.classification.exact_match import MultilabelExactMatch # noqa: F401 +from torchmetrics.classification.f_beta import ( # noqa: F401 + BinaryF1Score, + BinaryFBetaScore, + F1Score, + FBetaScore, + MulticlassF1Score, + MulticlassFBetaScore, + MultilabelF1Score, + MultilabelFBetaScore, +) +from torchmetrics.classification.hamming import ( # noqa: F401 + BinaryHammingDistance, + HammingDistance, + MulticlassHammingDistance, + MultilabelHammingDistance, +) +from torchmetrics.classification.hinge import BinaryHingeLoss, HingeLoss, MulticlassHingeLoss # noqa: F401 +from torchmetrics.classification.jaccard import ( # noqa: F401 + BinaryJaccardIndex, + JaccardIndex, + MulticlassJaccardIndex, + MultilabelJaccardIndex, +) +from torchmetrics.classification.matthews_corrcoef import ( # noqa: F401 + BinaryMatthewsCorrCoef, + MatthewsCorrCoef, + MulticlassMatthewsCorrCoef, + MultilabelMatthewsCorrCoef, +) +from torchmetrics.classification.precision_recall import ( # noqa: F401 + BinaryPrecision, + BinaryRecall, + MulticlassPrecision, + MulticlassRecall, + MultilabelPrecision, + MultilabelRecall, + Precision, + Recall, +) from torchmetrics.classification.ranking import ( # noqa: F401 CoverageError, LabelRankingAveragePrecision, LabelRankingLoss, + MultilabelCoverageError, + MultilabelRankingAveragePrecision, + MultilabelRankingLoss, +) +from torchmetrics.classification.recall_at_fixed_precision import ( # noqa: F401 + BinaryRecallAtFixedPrecision, + MulticlassRecallAtFixedPrecision, + MultilabelRecallAtFixedPrecision, +) +from torchmetrics.classification.roc import ROC, BinaryROC, MulticlassROC, MultilabelROC # noqa: F401 +from torchmetrics.classification.specificity import ( # noqa: F401 + BinarySpecificity, + MulticlassSpecificity, + MultilabelSpecificity, + Specificity, ) -from torchmetrics.classification.roc import ROC # noqa: F401 -from torchmetrics.classification.specificity import Specificity # noqa: F401 -from torchmetrics.classification.stat_scores import StatScores # noqa: F401 diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 0c717acee82..074be2d8b5b 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -13,23 +13,326 @@ # limitations under the License. from typing import Any, Optional +import torch from torch import Tensor, tensor +from typing_extensions import Literal from torchmetrics.functional.classification.accuracy import ( _accuracy_compute, + _accuracy_reduce, _accuracy_update, _check_subset_validity, _mode, _subset_accuracy_compute, _subset_accuracy_update, ) +from torchmetrics.metric import Metric from torchmetrics.utilities.enums import AverageMethod, DataType +from torchmetrics.utilities.prints import rank_zero_warn -from torchmetrics.classification.stat_scores import StatScores # isort:skip +from torchmetrics.classification.stat_scores import ( # isort:skip + StatScores, + BinaryStatScores, + MulticlassStatScores, + MultilabelStatScores, +) + + +class BinaryAccuracy(BinaryStatScores): + r"""Computes `Accuracy`_ for binary tasks: + + .. math:: + \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) + + Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a + tensor of predictions. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` + is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + + Example (preds is int tensor): + >>> from torchmetrics.classification import BinaryAccuracy + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> metric = BinaryAccuracy() + >>> metric(preds, target) + tensor(0.6667) + + Example (preds is float tensor): + >>> from torchmetrics.classification import BinaryAccuracy + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> metric = BinaryAccuracy() + >>> metric(preds, target) + tensor(0.6667) + + Example (multidim tensors): + >>> from torchmetrics.classification import BinaryAccuracy + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> metric = BinaryAccuracy(multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.3333, 0.1667]) + """ + is_differentiable = False + higher_is_better = True + full_state_update: bool = False + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _accuracy_reduce(tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average) + + +class MulticlassAccuracy(MulticlassStatScores): + r"""Computes `Accuracy`_ for multiclass tasks: + + .. math:: + \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) + + Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a + tensor of predictions. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.classification import MulticlassAccuracy + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> metric = MulticlassAccuracy(num_classes=3) + >>> metric(preds, target) + tensor(0.8333) + >>> metric = MulticlassAccuracy(num_classes=3, average=None) + >>> metric(preds, target) + tensor([0.5000, 1.0000, 1.0000]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MulticlassAccuracy + >>> target = target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> metric = MulticlassAccuracy(num_classes=3) + >>> metric(preds, target) + tensor(0.8333) + >>> metric = MulticlassAccuracy(num_classes=3, average=None) + >>> metric(preds, target) + tensor([0.5000, 1.0000, 1.0000]) + + Example (multidim tensors): + >>> from torchmetrics.classification import MulticlassAccuracy + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> metric = MulticlassAccuracy(num_classes=3, multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.5000, 0.2778]) + >>> metric = MulticlassAccuracy(num_classes=3, multidim_average='samplewise', average=None) + >>> metric(preds, target) + tensor([[1.0000, 0.0000, 0.5000], + [0.0000, 0.3333, 0.5000]]) + """ + is_differentiable = False + higher_is_better = True + full_state_update: bool = False + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _accuracy_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average) + + +class MultilabelAccuracy(MultilabelStatScores): + r"""Computes `Accuracy`_ for multilabel tasks: + + .. math:: + \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) + + Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a + tensor of predictions. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.classification import MultilabelAccuracy + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> metric = MultilabelAccuracy(num_labels=3) + >>> metric(preds, target) + tensor(0.6667) + >>> metric = MultilabelAccuracy(num_labels=3, average=None) + >>> metric(preds, target) + tensor([1.0000, 0.5000, 0.5000]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MultilabelAccuracy + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> metric = MultilabelAccuracy(num_labels=3) + >>> metric(preds, target) + tensor(0.6667) + >>> metric = MultilabelAccuracy(num_labels=3, average=None) + >>> metric(preds, target) + tensor([1.0000, 0.5000, 0.5000]) + + Example (multidim tensors): + >>> from torchmetrics.classification import MultilabelAccuracy + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> metric = MultilabelAccuracy(num_labels=3, multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.3333, 0.1667]) + >>> metric = MultilabelAccuracy(num_labels=3, multidim_average='samplewise', average=None) + >>> metric(preds, target) + tensor([[0.5000, 0.5000, 0.0000], + [0.0000, 0.0000, 0.5000]]) + + """ + is_differentiable = False + higher_is_better = True + full_state_update: bool = False + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _accuracy_reduce( + tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True + ) class Accuracy(StatScores): r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + Computes Accuracy_: .. math:: @@ -160,16 +463,64 @@ class Accuracy(StatScores): correct: Tensor total: Tensor + def __new__( + cls, + threshold: float = 0.5, + num_classes: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + top_k: Optional[int] = None, + multiclass: Optional[bool] = None, + subset_accuracy: bool = False, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + if task is not None: + assert multidim_average is not None + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinaryAccuracy(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassAccuracy(num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelAccuracy(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + def __init__( self, threshold: float = 0.5, num_classes: Optional[int] = None, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, multiclass: Optional[bool] = None, subset_accuracy: bool = False, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, **kwargs: Any, ) -> None: allowed_average = ["micro", "macro", "weighted", "samples", "none", None] diff --git a/src/torchmetrics/classification/auc.py b/src/torchmetrics/classification/auc.py index 975ad64dd4e..6d6167e1bed 100644 --- a/src/torchmetrics/classification/auc.py +++ b/src/torchmetrics/classification/auc.py @@ -15,9 +15,9 @@ from torch import Tensor -from torchmetrics.functional.classification.auc import _auc_compute, _auc_update from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn +from torchmetrics.utilities.compute import _auc_compute, _auc_format_inputs from torchmetrics.utilities.data import dim_zero_cat @@ -28,6 +28,9 @@ class AUC(Metric): Forward accepts two input tensors that should be 1D and have the same number of elements + .. note:: + This metric has been deprecated in v0.10 and will be removed in v0.11. + Args: reorder: AUC expects its first input to be sorted. If this is not the case, setting this argument to ``True`` will use a stable sorting algorithm to @@ -47,6 +50,11 @@ def __init__( **kwargs: Any, ) -> None: super().__init__(**kwargs) + rank_zero_warn( + "`torchmetrics.classification.AUC` has been deprecated in v0.10 and will be removed in v0.11." + "A functional version is still available in `torchmetrics.utilities.compute`", + DeprecationWarning, + ) self.reorder = reorder @@ -65,7 +73,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore preds: Predictions from model (probabilities, or labels) target: Ground truth labels """ - x, y = _auc_update(preds, target) + x, y = _auc_format_inputs(preds, target) self.x.append(x) self.y.append(y) diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index 4a754c7c849..b1d45b6298f 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -11,12 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional +from typing import Any, List, Optional, Union import torch from torch import Tensor +from typing_extensions import Literal -from torchmetrics.functional.classification.auroc import _auroc_compute, _auroc_update +from torchmetrics.classification.precision_recall_curve import ( + BinaryPrecisionRecallCurve, + MulticlassPrecisionRecallCurve, + MultilabelPrecisionRecallCurve, +) +from torchmetrics.functional.classification.auroc import ( + _auroc_compute, + _auroc_update, + _binary_auroc_arg_validation, + _binary_auroc_compute, + _multiclass_auroc_arg_validation, + _multiclass_auroc_compute, + _multilabel_auroc_arg_validation, + _multilabel_auroc_compute, +) from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat @@ -24,8 +39,301 @@ from torchmetrics.utilities.imports import _TORCH_LOWER_1_6 +class BinaryAUROC(BinaryPrecisionRecallCurve): + r""" + Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for binary tasks. The AUROC score + summarizes the ROC curve into an single number that describes the performance of a model for multiple + thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 + corresponds to random guessing. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + max_fpr: If not ``None``, calculates standardized partial AUC over the range ``[0, max_fpr]``. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + A single scalar with the auroc score + + Example: + >>> from torchmetrics.classification import BinaryAUROC + >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> metric = BinaryAUROC(thresholds=None) + >>> metric(preds, target) + tensor(0.5000) + >>> metric = BinaryAUROC(thresholds=5) + >>> metric(preds, target) + tensor(0.5000) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + max_fpr: Optional[float] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs) + if validate_args: + _binary_auroc_arg_validation(max_fpr, thresholds, ignore_index) + self.max_fpr = max_fpr + + def compute(self) -> Tensor: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _binary_auroc_compute(state, self.thresholds, self.max_fpr) + + +class MulticlassAUROC(MulticlassPrecisionRecallCurve): + r""" + Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multiclass tasks. The AUROC score + summarizes the ROC curve into an single number that describes the performance of a model for multiple + thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 + corresponds to random guessing. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Args: + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over classes. Should be one of the following: + + - ``macro``: Calculate score for each class and average them + - ``weighted``: Calculates score for each class and computes weighted average using their support + - ``"none"`` or ``None``: Calculates score for each class and applies no reduction + + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with auroc score per class. + If `average="macro"|"weighted"` then a single scalar is returned. + + Example: + >>> from torchmetrics.classification import MulticlassAUROC + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> metric = MulticlassAUROC(num_classes=5, average="macro", thresholds=None) + >>> metric(preds, target) + tensor(0.5333) + >>> metric = MulticlassAUROC(num_classes=5, average=None, thresholds=None) + >>> metric(preds, target) + tensor([1.0000, 1.0000, 0.3333, 0.3333, 0.0000]) + >>> metric = MulticlassAUROC(num_classes=5, average="macro", thresholds=5) + >>> metric(preds, target) + tensor(0.5333) + >>> metric = MulticlassAUROC(num_classes=5, average=None, thresholds=5) + >>> metric(preds, target) + tensor([1.0000, 1.0000, 0.3333, 0.3333, 0.0000]) + + """ + + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_classes=num_classes, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs + ) + if validate_args: + _multiclass_auroc_arg_validation(num_classes, average, thresholds, ignore_index) + self.average = average + self.validate_args = validate_args + + def compute(self) -> Tensor: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _multiclass_auroc_compute(state, self.num_classes, self.average, self.thresholds) + + +class MultilabelAUROC(MultilabelPrecisionRecallCurve): + r""" + Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multilabel tasks. The AUROC score + summarizes the ROC curve into an single number that describes the performance of a model for multiple + thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 + corresponds to random guessing. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Args: + num_labels: Integer specifing the number of labels + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum score over all labels + - ``macro``: Calculate score for each label and average them + - ``weighted``: Calculates score for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates score for each label and applies no reduction + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with auroc score per class. + If `average="micro|macro"|"weighted"` then a single scalar is returned. + + Example: + >>> from torchmetrics.classification import MultilabelAUROC + >>> preds = torch.tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = torch.tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> metric = MultilabelAUROC(num_labels=3, average="macro", thresholds=None) + >>> metric(preds, target) + tensor(0.6528) + >>> metric = MultilabelAUROC(num_labels=3, average=None, thresholds=None) + >>> metric(preds, target) + tensor([0.6250, 0.5000, 0.8333]) + >>> metric = MultilabelAUROC(num_labels=3, average="macro", thresholds=5) + >>> metric(preds, target) + tensor(0.6528) + >>> metric = MultilabelAUROC(num_labels=3, average=None, thresholds=5) + >>> metric(preds, target) + tensor([0.6250, 0.5000, 0.8333]) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_labels=num_labels, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs + ) + if validate_args: + _multilabel_auroc_arg_validation(num_labels, average, thresholds, ignore_index) + self.average = average + self.validate_args = validate_args + + def compute(self) -> Tensor: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _multilabel_auroc_compute(state, self.num_labels, self.average, self.thresholds, self.ignore_index) + + class AUROC(Metric): - r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_). + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + + Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_). Works for both binary, multilabel and multiclass problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach. @@ -103,12 +411,54 @@ class AUROC(Metric): preds: List[Tensor] target: List[Tensor] + def __new__( + cls, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + max_fpr: Optional[float] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + if task is not None: + kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryAUROC(max_fpr, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassAUROC(num_classes, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelAUROC(num_labels, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + def __init__( self, num_classes: Optional[int] = None, pos_label: Optional[int] = None, - average: Optional[str] = "macro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", max_fpr: Optional[float] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py new file mode 100644 index 00000000000..03510463b4c --- /dev/null +++ b/src/torchmetrics/classification/average_precision.py @@ -0,0 +1,482 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional, Union + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.classification.precision_recall_curve import ( + BinaryPrecisionRecallCurve, + MulticlassPrecisionRecallCurve, + MultilabelPrecisionRecallCurve, +) +from torchmetrics.functional.classification.average_precision import ( + _average_precision_compute, + _average_precision_update, + _binary_average_precision_compute, + _multiclass_average_precision_arg_validation, + _multiclass_average_precision_compute, + _multilabel_average_precision_arg_validation, + _multilabel_average_precision_compute, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities import rank_zero_warn +from torchmetrics.utilities.data import dim_zero_cat + + +class BinaryAveragePrecision(BinaryPrecisionRecallCurve): + r""" + Computes the average precision (AP) score for binary tasks. The AP score summarizes a precision-recall curve + as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold + as weight: + + .. math:: + AP = \sum{n} (R_n - R_{n-1}) P_n + + where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. This value is + equivalent to the area under the precision-recall curve (AUPRC). + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + A single scalar with the average precision score + + Example: + >>> from torchmetrics.classification import BinaryAveragePrecision + >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> metric = BinaryAveragePrecision(thresholds=None) + >>> metric(preds, target) + tensor(0.5833) + >>> metric = BinaryAveragePrecision(thresholds=5) + >>> metric(preds, target) + tensor(0.6667) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def compute(self) -> Tensor: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _binary_average_precision_compute(state, self.thresholds) + + +class MulticlassAveragePrecision(MulticlassPrecisionRecallCurve): + r""" + Computes the average precision (AP) score for binary tasks. The AP score summarizes a precision-recall curve + as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold + as weight: + + .. math:: + AP = \sum{n} (R_n - R_{n-1}) P_n + + where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. This value is + equivalent to the area under the precision-recall curve (AUPRC). + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Args: + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over classes. Should be one of the following: + + - ``macro``: Calculate score for each class and average them + - ``weighted``: Calculates score for each class and computes weighted average using their support + - ``"none"`` or ``None``: Calculates score for each class and applies no reduction + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with AP score per class. + If `average="macro"|"weighted"` then a single scalar is returned. + + Example: + >>> from torchmetrics.classification import MulticlassAveragePrecision + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> metric = MulticlassAveragePrecision(num_classes=5, average="macro", thresholds=None) + >>> metric(preds, target) + tensor(0.6250) + >>> metric = MulticlassAveragePrecision(num_classes=5, average=None, thresholds=None) + >>> metric(preds, target) + tensor([1.0000, 1.0000, 0.2500, 0.2500, nan]) + >>> metric = MulticlassAveragePrecision(num_classes=5, average="macro", thresholds=5) + >>> metric(preds, target) + tensor(0.5000) + >>> metric = MulticlassAveragePrecision(num_classes=5, average=None, thresholds=5) + >>> metric(preds, target) + tensor([1.0000, 1.0000, 0.2500, 0.2500, -0.0000]) + + """ + + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_classes=num_classes, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs + ) + if validate_args: + _multiclass_average_precision_arg_validation(num_classes, average, thresholds, ignore_index) + self.average = average + self.validate_args = validate_args + + def compute(self) -> Tensor: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _multiclass_average_precision_compute(state, self.num_classes, self.average, self.thresholds) + + +class MultilabelAveragePrecision(MultilabelPrecisionRecallCurve): + r""" + Computes the average precision (AP) score for binary tasks. The AP score summarizes a precision-recall curve + as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold + as weight: + + .. math:: + AP = \sum{n} (R_n - R_{n-1}) P_n + + where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. This value is + equivalent to the area under the precision-recall curve (AUPRC). + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Args: + num_labels: Integer specifing the number of labels + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum score over all labels + - ``macro``: Calculate score for each label and average them + - ``weighted``: Calculates score for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates score for each label and applies no reduction + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with AP score per class. + If `average="micro|macro"|"weighted"` then a single scalar is returned. + + Example: + >>> from torchmetrics.classification import MultilabelAveragePrecision + >>> preds = torch.tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = torch.tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> metric = MultilabelAveragePrecision(num_labels=3, average="macro", thresholds=None) + >>> metric(preds, target) + tensor(0.7500) + >>> metric = MultilabelAveragePrecision(num_labels=3, average=None, thresholds=None) + >>> metric(preds, target) + tensor([0.7500, 0.5833, 0.9167]) + >>> metric = MultilabelAveragePrecision(num_labels=3, average="macro", thresholds=5) + >>> metric(preds, target) + tensor(0.7778) + >>> metric = MultilabelAveragePrecision(num_labels=3, average=None, thresholds=5) + >>> metric(preds, target) + tensor([0.7500, 0.6667, 0.9167]) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_labels=num_labels, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs + ) + if validate_args: + _multilabel_average_precision_arg_validation(num_labels, average, thresholds, ignore_index) + self.average = average + self.validate_args = validate_args + + def compute(self) -> Tensor: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _multilabel_average_precision_compute( + state, self.num_labels, self.average, self.thresholds, self.ignore_index + ) + + +class AveragePrecision(Metric): + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes the average precision score, which summarises the precision recall curve into one number. Works for + both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one- + vs-the-rest approach. + + Forward accepts + + - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor + with probabilities, where C is the number of classes. + + - ``target`` (long tensor): ``(N, ...)`` with integer labels + + Args: + num_classes: integer with number of classes. Not nessesary to provide + for binary problems. + pos_label: integer determining the positive class. Default is ``None`` + which for binary problem is translated to 1. For multiclass problems + this argument should not be set as we iteratively change it in the + range ``[0, num_classes-1]`` + average: + defines the reduction that is applied in the case of multiclass and multilabel input. + Should be one of the following: + + - ``'macro'`` [default]: Calculate the metric for each class separately, and average the + metrics across classes (with equal weights for each class). + - ``'micro'``: Calculate the metric globally, across all samples and classes. Cannot be + used with multiclass input. + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics across classes, weighting each class by its support. + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (binary case): + >>> from torchmetrics import AveragePrecision + >>> pred = torch.tensor([0, 0.1, 0.8, 0.4]) + >>> target = torch.tensor([0, 1, 1, 1]) + >>> average_precision = AveragePrecision(pos_label=1) + >>> average_precision(pred, target) + tensor(1.) + + Example (multiclass case): + >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> average_precision = AveragePrecision(num_classes=5, average=None) + >>> average_precision(pred, target) + [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] + """ + + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + preds: List[Tensor] + target: List[Tensor] + + def __new__( + cls, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + if task is not None: + kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryAveragePrecision(**kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassAveragePrecision(num_classes, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelAveragePrecision(num_labels, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + + def __init__( + self, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + self.num_classes = num_classes + self.pos_label = pos_label + allowed_average = ("micro", "macro", "weighted", "none", None) + if average not in allowed_average: + raise ValueError(f"Expected argument `average` to be one of {allowed_average}" f" but got {average}") + self.average = average + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + rank_zero_warn( + "Metric `AveragePrecision` will save all targets and predictions in buffer." + " For large datasets this may lead to large memory footprint." + ) + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + preds, target, num_classes, pos_label = _average_precision_update( + preds, target, self.num_classes, self.pos_label, self.average + ) + self.preds.append(preds) + self.target.append(target) + self.num_classes = num_classes + self.pos_label = pos_label + + def compute(self) -> Union[Tensor, List[Tensor]]: + """Compute the average precision score. + + Returns: + tensor with average precision. If multiclass return list of such tensors, one for each class + """ + preds = dim_zero_cat(self.preds) + target = dim_zero_cat(self.target) + if not self.num_classes: + raise ValueError(f"`num_classes` bas to be positive number, but got {self.num_classes}") + return _average_precision_compute(preds, target, self.num_classes, self.pos_label, self.average) diff --git a/src/torchmetrics/classification/avg_precision.py b/src/torchmetrics/classification/avg_precision.py deleted file mode 100644 index 6cf94d13cd4..00000000000 --- a/src/torchmetrics/classification/avg_precision.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, List, Optional, Union - -import torch -from torch import Tensor - -from torchmetrics.functional.classification.average_precision import ( - _average_precision_compute, - _average_precision_update, -) -from torchmetrics.metric import Metric -from torchmetrics.utilities import rank_zero_warn -from torchmetrics.utilities.data import dim_zero_cat - - -class AveragePrecision(Metric): - """Computes the average precision score, which summarises the precision recall curve into one number. Works for - both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one- - vs-the-rest approach. - - Forward accepts - - - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor - with probabilities, where C is the number of classes. - - - ``target`` (long tensor): ``(N, ...)`` with integer labels - - Args: - num_classes: integer with number of classes. Not nessesary to provide - for binary problems. - pos_label: integer determining the positive class. Default is ``None`` - which for binary problem is translated to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range ``[0, num_classes-1]`` - average: - defines the reduction that is applied in the case of multiclass and multilabel input. - Should be one of the following: - - - ``'macro'`` [default]: Calculate the metric for each class separately, and average the - metrics across classes (with equal weights for each class). - - ``'micro'``: Calculate the metric globally, across all samples and classes. Cannot be - used with multiclass input. - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics across classes, weighting each class by its support. - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - - Example (binary case): - >>> from torchmetrics import AveragePrecision - >>> pred = torch.tensor([0, 0.1, 0.8, 0.4]) - >>> target = torch.tensor([0, 1, 1, 1]) - >>> average_precision = AveragePrecision(pos_label=1) - >>> average_precision(pred, target) - tensor(1.) - - Example (multiclass case): - >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], - ... [0.05, 0.75, 0.05, 0.05, 0.05], - ... [0.05, 0.05, 0.75, 0.05, 0.05], - ... [0.05, 0.05, 0.05, 0.75, 0.05]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> average_precision = AveragePrecision(num_classes=5, average=None) - >>> average_precision(pred, target) - [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] - """ - - is_differentiable: bool = False - higher_is_better: Optional[bool] = None - full_state_update: bool = False - preds: List[Tensor] - target: List[Tensor] - - def __init__( - self, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, - average: Optional[str] = "macro", - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - - self.num_classes = num_classes - self.pos_label = pos_label - allowed_average = ("micro", "macro", "weighted", "none", None) - if average not in allowed_average: - raise ValueError(f"Expected argument `average` to be one of {allowed_average}" f" but got {average}") - self.average = average - - self.add_state("preds", default=[], dist_reduce_fx="cat") - self.add_state("target", default=[], dist_reduce_fx="cat") - - rank_zero_warn( - "Metric `AveragePrecision` will save all targets and predictions in buffer." - " For large datasets this may lead to large memory footprint." - ) - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values - """ - preds, target, num_classes, pos_label = _average_precision_update( - preds, target, self.num_classes, self.pos_label, self.average - ) - self.preds.append(preds) - self.target.append(target) - self.num_classes = num_classes - self.pos_label = pos_label - - def compute(self) -> Union[Tensor, List[Tensor]]: - """Compute the average precision score. - - Returns: - tensor with average precision. If multiclass return list of such tensors, one for each class - """ - preds = dim_zero_cat(self.preds) - target = dim_zero_cat(self.target) - if not self.num_classes: - raise ValueError(f"`num_classes` bas to be positive number, but got {self.num_classes}") - return _average_precision_compute(preds, target, self.num_classes, self.pos_label, self.average) diff --git a/src/torchmetrics/classification/binned_precision_recall.py b/src/torchmetrics/classification/binned_precision_recall.py index 21be67c0475..d7253527ae3 100644 --- a/src/torchmetrics/classification/binned_precision_recall.py +++ b/src/torchmetrics/classification/binned_precision_recall.py @@ -19,6 +19,7 @@ from torchmetrics.functional.classification.average_precision import _average_precision_compute_with_precision_recall from torchmetrics.metric import Metric from torchmetrics.utilities.data import METRIC_EPS, to_onehot +from torchmetrics.utilities.prints import rank_zero_warn def _recall_at_precision( @@ -49,6 +50,10 @@ class BinnedPrecisionRecallCurve(Metric): Computation is performed in constant-memory by computing precision and recall for ``thresholds`` buckets/thresholds (evenly distributed between 0 and 1). + .. warn: + This metric has been deprecated in v0.10 and will be removed in v0.11. + Instead use `PrecisionRecallCurve` metric with the `thresholds` argument set accordingly. + Forward accepts - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor @@ -122,6 +127,11 @@ def __init__( thresholds: Union[int, Tensor, List[float]] = 100, **kwargs: Any, ) -> None: + rank_zero_warn( + "Metric `BinnedPrecisionRecallCurve` has been deprecated in v0.10 and will be completly removed in v0.11." + " Instead, use the refactored version of `PrecisionRecallCurve` by specifying the `thresholds` argument.", + DeprecationWarning, + ) super().__init__(**kwargs) self.num_classes = num_classes @@ -187,6 +197,10 @@ class BinnedAveragePrecision(BinnedPrecisionRecallCurve): Computation is performed in constant-memory by computing precision and recall for ``thresholds`` buckets/thresholds (evenly distributed between 0 and 1). + .. warn: + This metric has been deprecated in v0.10 and will be removed in v0.11. + Instead use `AveragePrecision` metric with the `thresholds` argument set accordingly. + Forward accepts - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor @@ -225,6 +239,19 @@ class BinnedAveragePrecision(BinnedPrecisionRecallCurve): [tensor(1.0000), tensor(1.0000), tensor(0.2500), tensor(0.2500), tensor(-0.)] """ + def __init__( + self, + num_classes: int, + thresholds: Union[int, Tensor, List[float]] = 100, + **kwargs: Any, + ) -> None: + rank_zero_warn( + "Metric `BinnedAveragePrecision` has been deprecated in v0.10 and will be completly removed in v0.11." + " Instead, use the refactored version of `AveragePrecision` by specifying the `thresholds` argument.", + DeprecationWarning, + ) + super().__init__(num_classes=num_classes, thresholds=thresholds, **kwargs) + def compute(self) -> Union[List[Tensor], Tensor]: # type: ignore precisions, recalls, _ = super().compute() return _average_precision_compute_with_precision_recall(precisions, recalls, self.num_classes, average=None) @@ -236,6 +263,10 @@ class BinnedRecallAtFixedPrecision(BinnedPrecisionRecallCurve): Computation is performed in constant-memory by computing precision and recall for ``thresholds`` buckets/thresholds (evenly distributed between 0 and 1). + .. warn: + This metric has been deprecated in v0.10 and will be removed in v0.11. + Instead use `RecallAtFixedPrecision` metric with the `thresholds` argument set accordingly. + Forward accepts - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor @@ -283,6 +314,11 @@ def __init__( thresholds: Union[int, Tensor, List[float]] = 100, **kwargs: Any, ) -> None: + rank_zero_warn( + "Metric `BinnedRecallAtFixedPrecision` has been deprecated in v0.10 and will be completly removed in v0.11." + " Instead, use the refactored version of `RecallAtFixedPrecision` by specifying the `thresholds` argument.", + DeprecationWarning, + ) super().__init__(num_classes=num_classes, thresholds=thresholds, **kwargs) self.min_precision = min_precision diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index f9b7aa23dde..92f6b7dc64e 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -11,18 +11,227 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List +from typing import Any, List, Optional import torch from torch import Tensor +from typing_extensions import Literal -from torchmetrics.functional.classification.calibration_error import _ce_compute, _ce_update +from torchmetrics.functional.classification.calibration_error import ( + _binary_calibration_error_arg_validation, + _binary_calibration_error_tensor_validation, + _binary_calibration_error_update, + _binary_confusion_matrix_format, + _ce_compute, + _ce_update, + _multiclass_calibration_error_arg_validation, + _multiclass_calibration_error_tensor_validation, + _multiclass_calibration_error_update, + _multiclass_confusion_matrix_format, +) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.prints import rank_zero_warn + + +class BinaryCalibrationError(Metric): + r"""`Computes the Top-label Calibration Error`_ for binary tasks. The expected calibration error can be used to + quantify how well a given model is calibrated e.g. how well the predicted output probabilities of the model matches + the actual probabilities of the ground truth distribution. + + Three different norms are implemented, each corresponding to variations on the calibration error metric. + + .. math:: + \text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)} + + .. math:: + \text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)} + + .. math:: + \text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)} + + Where :math:`p_i` is the top-1 prediction accuracy in bin :math:`i`, :math:`c_i` is the average confidence of + predictions in bin :math:`i`, and :math:`b_i` is the fraction of data points in bin :math:`i`. Bins are constructed + in an uniform way in the [0,1] range. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + n_bins: Number of bins to use when computing the metric. + norm: Norm used to compare empirical and expected probability bins. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> from torchmetrics.classification import BinaryCalibrationError + >>> preds = torch.tensor([0.25, 0.25, 0.55, 0.75, 0.75]) + >>> target = torch.tensor([0, 0, 1, 1, 1]) + >>> metric = BinaryCalibrationError(n_bins=2, norm='l1') + >>> metric(preds, target) + tensor(0.2900) + >>> metric = BinaryCalibrationError(n_bins=2, norm='l2') + >>> metric(preds, target) + tensor(0.2918) + >>> metric = BinaryCalibrationError(n_bins=2, norm='max') + >>> metric(preds, target) + tensor(0.3167) + """ + is_differentiable: bool = False + higher_is_better: bool = False + full_state_update: bool = False + + def __init__( + self, + n_bins: int = 15, + norm: Literal["l1", "l2", "max"] = "l1", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _binary_calibration_error_arg_validation(n_bins, norm, ignore_index) + self.validate_args = validate_args + self.n_bins = n_bins + self.norm = norm + self.ignore_index = ignore_index + self.add_state("confidences", [], dist_reduce_fx="cat") + self.add_state("accuracies", [], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _binary_calibration_error_tensor_validation(preds, target, self.ignore_index) + preds, target = _binary_confusion_matrix_format( + preds, target, threshold=0.0, ignore_index=self.ignore_index, convert_to_labels=False + ) + confidences, accuracies = _binary_calibration_error_update(preds, target) + self.confidences.append(confidences) + self.accuracies.append(accuracies) + + def compute(self) -> Tensor: + confidences = dim_zero_cat(self.confidences) + accuracies = dim_zero_cat(self.accuracies) + return _ce_compute(confidences, accuracies, self.n_bins, norm=self.norm) + + +class MulticlassCalibrationError(Metric): + r"""`Computes the Top-label Calibration Error`_ for multiclass tasks. The expected calibration error can be used to + quantify how well a given model is calibrated e.g. how well the predicted output probabilities of the model matches + the actual probabilities of the ground truth distribution. + + Three different norms are implemented, each corresponding to variations on the calibration error metric. + + .. math:: + \text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)} + + .. math:: + \text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)} + + .. math:: + \text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)} + + Where :math:`p_i` is the top-1 prediction accuracy in bin :math:`i`, :math:`c_i` is the average confidence of + predictions in bin :math:`i`, and :math:`b_i` is the fraction of data points in bin :math:`i`. Bins are constructed + in an uniform way in the [0,1] range. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of classes + n_bins: Number of bins to use when computing the metric. + norm: Norm used to compare empirical and expected probability bins. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> from torchmetrics.classification import MulticlassCalibrationError + >>> preds = torch.tensor([[0.25, 0.20, 0.55], + ... [0.55, 0.05, 0.40], + ... [0.10, 0.30, 0.60], + ... [0.90, 0.05, 0.05]]) + >>> target = torch.tensor([0, 1, 2, 0]) + >>> metric = MulticlassCalibrationError(num_classes=3, n_bins=3, norm='l1') + >>> metric(preds, target) + tensor(0.2000) + >>> metric = MulticlassCalibrationError(num_classes=3, n_bins=3, norm='l2') + >>> metric(preds, target) + tensor(0.2082) + >>> metric = MulticlassCalibrationError(num_classes=3, n_bins=3, norm='max') + >>> metric(preds, target) + tensor(0.2333) + """ + is_differentiable: bool = False + higher_is_better: bool = False + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + n_bins: int = 15, + norm: Literal["l1", "l2", "max"] = "l1", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multiclass_calibration_error_arg_validation(num_classes, n_bins, norm, ignore_index) + self.validate_args = validate_args + self.num_classes = num_classes + self.n_bins = n_bins + self.norm = norm + self.ignore_index = ignore_index + self.add_state("confidences", [], dist_reduce_fx="cat") + self.add_state("accuracies", [], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _multiclass_calibration_error_tensor_validation(preds, target, self.num_classes, self.ignore_index) + preds, target = _multiclass_confusion_matrix_format( + preds, target, ignore_index=self.ignore_index, convert_to_labels=False + ) + confidences, accuracies = _multiclass_calibration_error_update(preds, target) + self.confidences.append(confidences) + self.accuracies.append(accuracies) + + def compute(self) -> Tensor: + confidences = dim_zero_cat(self.confidences) + accuracies = dim_zero_cat(self.accuracies) + return _ce_compute(confidences, accuracies, self.n_bins, norm=self.norm) class CalibrationError(Metric): - r"""`Computes the Top-label Calibration Error`_ + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + `Computes the Top-label Calibration Error`_ Three different norms are implemented, each corresponding to variations on the calibration error metric. L1 norm (Expected Calibration Error) @@ -62,6 +271,37 @@ class CalibrationError(Metric): confidences: List[Tensor] accuracies: List[Tensor] + def __new__( + cls, + n_bins: int = 15, + norm: str = "l1", + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + if task is not None: + kwargs.update(dict(n_bins=n_bins, norm=norm, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryCalibrationError(**kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassCalibrationError(num_classes, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + def __init__( self, n_bins: int = 15, diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index 7146ac14480..149890515b8 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -15,13 +15,183 @@ import torch from torch import Tensor +from typing_extensions import Literal -from torchmetrics.functional.classification.cohen_kappa import _cohen_kappa_compute, _cohen_kappa_update +from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix +from torchmetrics.functional.classification.cohen_kappa import ( + _binary_cohen_kappa_arg_validation, + _cohen_kappa_compute, + _cohen_kappa_reduce, + _cohen_kappa_update, + _multiclass_cohen_kappa_arg_validation, +) from torchmetrics.metric import Metric +from torchmetrics.utilities.prints import rank_zero_warn + + +class BinaryCohenKappa(BinaryConfusionMatrix): + r"""Calculates `Cohen's kappa score`_ that measures inter-annotator agreement for binary + tasks. It is defined as + + .. math:: + \kappa = (p_o - p_e) / (1 - p_e) + + where :math:`p_o` is the empirical probability of agreement and :math:`p_e` is + the expected agreement when both annotators assign labels randomly. Note that + :math:`p_e` is estimated using a per-annotator empirical prior over the + class labels. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + weights: Weighting type to calculate the score. Choose from: + + - ``None`` or ``'none'``: no weighting + - ``'linear'``: linear weighting + - ``'quadratic'``: quadratic weighting + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics.classification import BinaryCohenKappa + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> metric = BinaryCohenKappa() + >>> metric(preds, target) + tensor(0.5000) + + Example (preds is float tensor): + >>> from torchmetrics.classification import BinaryCohenKappa + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) + >>> metric = BinaryCohenKappa() + >>> metric(preds, target) + tensor(0.5000) + + """ + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + def __init__( + self, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + weights: Optional[Literal["linear", "quadratic", "none"]] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(threshold, ignore_index, normalize=None, validate_args=False, **kwargs) + if validate_args: + _binary_cohen_kappa_arg_validation(threshold, ignore_index, weights) + self.weights = weights + self.validate_args = validate_args + + def compute(self) -> Tensor: + return _cohen_kappa_reduce(self.confmat, self.weights) + + +class MulticlassCohenKappa(MulticlassConfusionMatrix): + r"""Calculates `Cohen's kappa score`_ that measures inter-annotator agreement for multiclass + tasks. It is defined as + + .. math:: + \kappa = (p_o - p_e) / (1 - p_e) + + where :math:`p_o` is the empirical probability of agreement and :math:`p_e` is + the expected agreement when both annotators assign labels randomly. Note that + :math:`p_e` is estimated using a per-annotator empirical prior over the + class labels. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of classes + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + weights: Weighting type to calculate the score. Choose from: + + - ``None`` or ``'none'``: no weighting + - ``'linear'``: linear weighting + - ``'quadratic'``: quadratic weighting + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (pred is integer tensor): + >>> from torchmetrics.classification import MulticlassCohenKappa + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> metric = MulticlassCohenKappa(num_classes=3) + >>> metric(preds, target) + tensor(0.6364) + + Example (pred is float tensor): + >>> from torchmetrics.classification import MulticlassCohenKappa + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> metric = MulticlassCohenKappa(num_classes=3) + >>> metric(preds, target) + tensor(0.6364) + + """ + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + ignore_index: Optional[int] = None, + weights: Optional[Literal["linear", "quadratic", "none"]] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(num_classes, ignore_index, normalize=None, validate_args=False, **kwargs) + if validate_args: + _multiclass_cohen_kappa_arg_validation(num_classes, ignore_index, weights) + self.weights = weights + self.validate_args = validate_args + + def compute(self) -> Tensor: + return _cohen_kappa_reduce(self.confmat, self.weights) class CohenKappa(Metric): - r"""Calculates `Cohen's kappa score`_ that measures inter-annotator agreement. It is defined as + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Calculates `Cohen's kappa score`_ that measures inter-annotator agreement. It is defined as .. math:: \kappa = (p_o - p_e) / (1 - p_e) @@ -72,6 +242,37 @@ class labels. full_state_update: bool = False confmat: Tensor + def __new__( + cls, + num_classes: Optional[int] = None, + weights: Optional[str] = None, + threshold: float = 0.5, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + if task is not None: + kwargs.update(dict(weights=weights, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryCohenKappa(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassCohenKappa(num_classes, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + def __init__( self, num_classes: int, diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index a847b04044b..89909c2d45c 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -15,13 +15,319 @@ import torch from torch import Tensor +from typing_extensions import Literal -from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_compute, _confusion_matrix_update +from torchmetrics.functional.classification.confusion_matrix import ( + _binary_confusion_matrix_arg_validation, + _binary_confusion_matrix_compute, + _binary_confusion_matrix_format, + _binary_confusion_matrix_tensor_validation, + _binary_confusion_matrix_update, + _confusion_matrix_compute, + _confusion_matrix_update, + _multiclass_confusion_matrix_arg_validation, + _multiclass_confusion_matrix_compute, + _multiclass_confusion_matrix_format, + _multiclass_confusion_matrix_tensor_validation, + _multiclass_confusion_matrix_update, + _multilabel_confusion_matrix_arg_validation, + _multilabel_confusion_matrix_compute, + _multilabel_confusion_matrix_format, + _multilabel_confusion_matrix_tensor_validation, + _multilabel_confusion_matrix_update, +) from torchmetrics.metric import Metric +from torchmetrics.utilities.prints import rank_zero_warn + + +class BinaryConfusionMatrix(Metric): + r""" + Computes the `confusion matrix`_ for binary tasks. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics.classification import BinaryConfusionMatrix + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> metric = BinaryConfusionMatrix() + >>> metric(preds, target) + tensor([[2, 0], + [1, 1]]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import BinaryConfusionMatrix + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) + >>> metric = BinaryConfusionMatrix() + >>> metric(preds, target) + tensor([[2, 0], + [1, 1]]) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize) + self.threshold = threshold + self.ignore_index = ignore_index + self.normalize = normalize + self.validate_args = validate_args + + self.add_state("confmat", torch.zeros(2, 2, dtype=torch.long), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + """ + if self.validate_args: + _binary_confusion_matrix_tensor_validation(preds, target, self.ignore_index) + preds, target = _binary_confusion_matrix_format(preds, target, self.threshold, self.ignore_index) + confmat = _binary_confusion_matrix_update(preds, target) + self.confmat += confmat + + def compute(self) -> Tensor: + """Computes confusion matrix. + + Returns an [2,2] matrix. + """ + return _binary_confusion_matrix_compute(self.confmat, self.normalize) + + +class MulticlassConfusionMatrix(Metric): + r""" + Computes the `confusion matrix`_ for multiclass tasks. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of classes + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (pred is integer tensor): + >>> from torchmetrics.classification import MulticlassConfusionMatrix + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> metric = MulticlassConfusionMatrix(num_classes=3) + >>> metric(preds, target) + tensor([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]]) + + Example (pred is float tensor): + >>> from torchmetrics.classification import MulticlassConfusionMatrix + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> metric = MulticlassConfusionMatrix(num_classes=3) + >>> metric(preds, target) + tensor([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]]) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["none", "true", "pred", "all"]] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize) + self.num_classes = num_classes + self.ignore_index = ignore_index + self.normalize = normalize + self.validate_args = validate_args + + self.add_state("confmat", torch.zeros(num_classes, num_classes, dtype=torch.long), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + """ + if self.validate_args: + _multiclass_confusion_matrix_tensor_validation(preds, target, self.num_classes, self.ignore_index) + preds, target = _multiclass_confusion_matrix_format(preds, target, self.ignore_index) + confmat = _multiclass_confusion_matrix_update(preds, target, self.num_classes) + self.confmat += confmat + + def compute(self) -> Tensor: + """Computes confusion matrix. + + Returns an [num_classes, num_classes] matrix. + """ + return _multiclass_confusion_matrix_compute(self.confmat, self.normalize) + + +class MultilabelConfusionMatrix(Metric): + r""" + Computes the `confusion matrix`_ for multilabel tasks. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics.classification import MultilabelConfusionMatrix + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> metric = MultilabelConfusionMatrix(num_labels=3) + >>> metric(preds, target) + tensor([[[1, 0], [0, 1]], + [[1, 0], [1, 0]], + [[0, 1], [0, 1]]]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MultilabelConfusionMatrix + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> metric = MultilabelConfusionMatrix(num_labels=3) + >>> metric(preds, target) + tensor([[[1, 0], [0, 1]], + [[1, 0], [1, 0]], + [[0, 1], [0, 1]]]) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["none", "true", "pred", "all"]] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multilabel_confusion_matrix_arg_validation(num_labels, threshold, ignore_index, normalize) + self.num_labels = num_labels + self.threshold = threshold + self.ignore_index = ignore_index + self.normalize = normalize + self.validate_args = validate_args + + self.add_state("confmat", torch.zeros(num_labels, 2, 2, dtype=torch.long), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + """ + if self.validate_args: + _multilabel_confusion_matrix_tensor_validation(preds, target, self.num_labels, self.ignore_index) + preds, target = _multilabel_confusion_matrix_format( + preds, target, self.num_labels, self.threshold, self.ignore_index + ) + confmat = _multilabel_confusion_matrix_update(preds, target, self.num_labels) + self.confmat += confmat + + def compute(self) -> Tensor: + """Computes confusion matrix. + + Returns an [num_labels,2,2] matrix. + """ + return _multilabel_confusion_matrix_compute(self.confmat, self.normalize) class ConfusionMatrix(Metric): - r"""Computes the `confusion matrix`_. + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes the `confusion matrix`_. Works with binary, multiclass, and multilabel data. Accepts probabilities or logits from a model output or integer class values in prediction. Works with multi-dimensional preds and target, but it should be noted that @@ -90,6 +396,42 @@ class ConfusionMatrix(Metric): full_state_update: bool = False confmat: Tensor + def __new__( + cls, + num_classes: Optional[int] = None, + normalize: Optional[str] = None, + threshold: float = 0.5, + multilabel: bool = False, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + if task is not None: + kwargs.update(dict(normalize=normalize, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryConfusionMatrix(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassConfusionMatrix(num_classes, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelConfusionMatrix(num_labels, threshold, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + def __init__( self, num_classes: int, diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index 7f4684a27ee..8a8b58736da 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -14,6 +14,7 @@ from typing import Any, Optional from torch import Tensor +from typing_extensions import Literal from torchmetrics.classification.stat_scores import StatScores from torchmetrics.functional.classification.dice import _dice_compute @@ -124,7 +125,7 @@ def __init__( zero_division: int = 0, num_classes: Optional[int] = None, threshold: float = 0.5, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = "global", ignore_index: Optional[int] = None, top_k: Optional[int] = None, diff --git a/src/torchmetrics/classification/exact_match.py b/src/torchmetrics/classification/exact_match.py new file mode 100644 index 00000000000..5577b6d9a83 --- /dev/null +++ b/src/torchmetrics/classification/exact_match.py @@ -0,0 +1,164 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.classification.exact_match import ( + _multilabel_exact_scores_compute, + _multilabel_exact_scores_update, +) +from torchmetrics.functional.classification.stat_scores import ( + _multilabel_stat_scores_arg_validation, + _multilabel_stat_scores_format, + _multilabel_stat_scores_tensor_validation, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat + + +class MultilabelExactMatch(Metric): + r"""Computes Exact match (also known as subset accuracy) for multilabel tasks. Exact Match is a stricter + version of accuracy where all labels have to match exactly for the sample to be correctly classified. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.classification import MultilabelExactMatch + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> metric = MultilabelExactMatch(num_labels=3) + >>> metric(preds, target) + tensor(0.5000) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MultilabelExactMatch + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> metric = MultilabelExactMatch(num_labels=3) + >>> metric(preds, target) + tensor(0.5000) + + Example (multidim tensors): + >>> from torchmetrics.classification import MultilabelExactMatch + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> metric = MultilabelExactMatch(num_labels=3, multidim_average='samplewise') + >>> metric(preds, target) + tensor([0., 0.]) + """ + + is_differentiable = False + higher_is_better = True + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multilabel_stat_scores_arg_validation( + num_labels, threshold, average=None, multidim_average=multidim_average, ignore_index=ignore_index + ) + self.num_labels = num_labels + self.threshold = threshold + self.multidim_average = multidim_average + self.ignore_index = ignore_index + self.validate_args = validate_args + + self.add_state( + "correct", + torch.zeros(1, dtype=torch.long) if self.multidim_average == "global" else [], + dist_reduce_fx="sum" if self.multidim_average == "global" else "cat", + ) + self.add_state( + "total", + torch.zeros(1, dtype=torch.long), + dist_reduce_fx="sum" if self.multidim_average == "global" else "mean", + ) + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + """ + if self.validate_args: + _multilabel_stat_scores_tensor_validation( + preds, target, self.num_labels, self.multidim_average, self.ignore_index + ) + preds, target = _multilabel_stat_scores_format( + preds, target, self.num_labels, self.threshold, self.ignore_index + ) + correct, total = _multilabel_exact_scores_update(preds, target, self.num_labels, self.multidim_average) + if self.multidim_average == "samplewise": + self.correct.append(correct) + self.total = total + else: + self.correct += correct + self.total += total + + def compute(self) -> Tensor: + correct = dim_zero_cat(self.correct) + return _multilabel_exact_scores_compute(correct, self.total) diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 88eb2eeff38..40fd9c2859f 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -13,15 +13,719 @@ # limitations under the License. from typing import Any, Optional +import torch from torch import Tensor +from typing_extensions import Literal -from torchmetrics.classification.stat_scores import StatScores -from torchmetrics.functional.classification.f_beta import _fbeta_compute +from torchmetrics.classification.stat_scores import ( + BinaryStatScores, + MulticlassStatScores, + MultilabelStatScores, + StatScores, +) +from torchmetrics.functional.classification.f_beta import ( + _binary_fbeta_score_arg_validation, + _fbeta_compute, + _fbeta_reduce, + _multiclass_fbeta_score_arg_validation, + _multilabel_fbeta_score_arg_validation, +) +from torchmetrics.metric import Metric from torchmetrics.utilities.enums import AverageMethod +from torchmetrics.utilities.prints import rank_zero_warn + + +class BinaryFBetaScore(BinaryStatScores): + r""" + Computes `F-score`_ metric for binary tasks: + + .. math:: + F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} + {(\beta^2 * \text{precision}) + \text{recall}} + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + beta: Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` + is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + + Example (preds is int tensor): + >>> from torchmetrics.classification import BinaryFBetaScore + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> metric = BinaryFBetaScore(beta=2.0) + >>> metric(preds, target) + tensor(0.6667) + + Example (preds is float tensor): + >>> from torchmetrics.classification import BinaryFBetaScore + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> metric = BinaryFBetaScore(beta=2.0) + >>> metric(preds, target) + tensor(0.6667) + + Example (multidim tensors): + >>> from torchmetrics.classification import BinaryFBetaScore + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> metric = BinaryFBetaScore(beta=2.0, multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.5882, 0.0000]) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = False + + def __init__( + self, + beta: float, + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + threshold=threshold, + multidim_average=multidim_average, + ignore_index=ignore_index, + validate_args=False, + **kwargs, + ) + if validate_args: + _binary_fbeta_score_arg_validation(beta, threshold, multidim_average, ignore_index) + self.validate_args = validate_args + self.beta = beta + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _fbeta_reduce(tp, fp, tn, fn, self.beta, average="binary", multidim_average=self.multidim_average) + + +class MulticlassFBetaScore(MulticlassStatScores): + r""" + Computes `F-score`_ metric for multiclass tasks: + + .. math:: + F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} + {(\beta^2 * \text{precision}) + \text{recall}} + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + beta: Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + top_k: + + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.classification import MulticlassFBetaScore + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> metric = MulticlassFBetaScore(beta=2.0, num_classes=3) + >>> metric(preds, target) + tensor(0.7963) + >>> metric = MulticlassFBetaScore(beta=2.0, num_classes=3, average=None) + >>> metric(preds, target) + tensor([0.5556, 0.8333, 1.0000]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MulticlassFBetaScore + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> metric = MulticlassFBetaScore(beta=2.0, num_classes=3) + >>> metric(preds, target) + tensor(0.7963) + >>> metric = MulticlassFBetaScore(beta=2.0, num_classes=3, average=None) + >>> metric(preds, target) + tensor([0.5556, 0.8333, 1.0000]) + + Example (multidim tensors): + >>> from torchmetrics.classification import MulticlassFBetaScore + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> metric = MulticlassFBetaScore(beta=2.0, num_classes=3, multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.4697, 0.2706]) + >>> metric = MulticlassFBetaScore(beta=2.0, num_classes=3, multidim_average='samplewise', average=None) + >>> metric(preds, target) + tensor([[0.9091, 0.0000, 0.5000], + [0.0000, 0.3571, 0.4545]]) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = False + + def __init__( + self, + beta: float, + num_classes: int, + top_k: int = 1, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_classes=num_classes, + top_k=top_k, + average=average, + multidim_average=multidim_average, + ignore_index=ignore_index, + validate_args=False, + **kwargs, + ) + if validate_args: + _multiclass_fbeta_score_arg_validation(beta, num_classes, top_k, average, multidim_average, ignore_index) + self.validate_args = validate_args + self.beta = beta + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _fbeta_reduce(tp, fp, tn, fn, self.beta, average=self.average, multidim_average=self.multidim_average) + + +class MultilabelFBetaScore(MultilabelStatScores): + r""" + Computes `F-score`_ metric for multilabel tasks: + + .. math:: + F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} + {(\beta^2 * \text{precision}) + \text{recall}} + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + beta: Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.classification import MultilabelFBetaScore + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> metric = MultilabelFBetaScore(beta=2.0, num_labels=3) + >>> metric(preds, target) + tensor(0.6111) + >>> metric = MultilabelFBetaScore(beta=2.0, num_labels=3, average=None) + >>> metric(preds, target) + tensor([1.0000, 0.0000, 0.8333]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MultilabelFBetaScore + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> metric = MultilabelFBetaScore(beta=2.0, num_labels=3) + >>> metric(preds, target) + tensor(0.6111) + >>> metric = MultilabelFBetaScore(beta=2.0, num_labels=3, average=None) + >>> metric(preds, target) + tensor([1.0000, 0.0000, 0.8333]) + + Example (multidim tensors): + >>> from torchmetrics.classification import MultilabelFBetaScore + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> metric = MultilabelFBetaScore(num_labels=3, beta=2.0, multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.5556, 0.0000]) + >>> metric = MultilabelFBetaScore(num_labels=3, beta=2.0, multidim_average='samplewise', average=None) + >>> metric(preds, target) + tensor([[0.8333, 0.8333, 0.0000], + [0.0000, 0.0000, 0.0000]]) + + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = False + + def __init__( + self, + beta: float, + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_labels=num_labels, + threshold=threshold, + average=average, + multidim_average=multidim_average, + ignore_index=ignore_index, + validate_args=False, + **kwargs, + ) + if validate_args: + _multilabel_fbeta_score_arg_validation(beta, num_labels, threshold, average, multidim_average, ignore_index) + self.validate_args = validate_args + self.beta = beta + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _fbeta_reduce(tp, fp, tn, fn, self.beta, average=self.average, multidim_average=self.multidim_average) + + +class BinaryF1Score(BinaryFBetaScore): + r""" + Computes F-1 score for binary tasks: + + .. math:: + F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}} + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` + is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + + Example (preds is int tensor): + >>> from torchmetrics.classification import BinaryF1Score + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> metric = BinaryF1Score() + >>> metric(preds, target) + tensor(0.6667) + + Example (preds is float tensor): + >>> from torchmetrics.classification import BinaryF1Score + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> metric = BinaryF1Score() + >>> metric(preds, target) + tensor(0.6667) + + Example (multidim tensors): + >>> from torchmetrics.classification import BinaryF1Score + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> metric = BinaryF1Score(multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.5000, 0.0000]) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = False + + def __init__( + self, + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + beta=1.0, + threshold=threshold, + multidim_average=multidim_average, + ignore_index=ignore_index, + validate_args=validate_args, + **kwargs, + ) + + +class MulticlassF1Score(MulticlassFBetaScore): + r""" + Computes F-1 score for multiclass tasks: + + .. math:: + F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}} + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.classification import MulticlassF1Score + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> metric = MulticlassF1Score(num_classes=3) + >>> metric(preds, target) + tensor(0.7778) + >>> metric = MulticlassF1Score(num_classes=3, average=None) + >>> metric(preds, target) + tensor([0.6667, 0.6667, 1.0000]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MulticlassF1Score + >>> target = target = torch.tensor([2, 1, 0, 0]) + >>> preds = preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> metric = MulticlassF1Score(num_classes=3) + >>> metric(preds, target) + tensor(0.7778) + >>> metric = MulticlassF1Score(num_classes=3, average=None) + >>> metric(preds, target) + tensor([0.6667, 0.6667, 1.0000]) + + Example (multidim tensors): + >>> from torchmetrics.classification import MulticlassF1Score + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> metric = MulticlassF1Score(num_classes=3, multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.4333, 0.2667]) + >>> metric = MulticlassF1Score(num_classes=3, multidim_average='samplewise', average=None) + >>> metric(preds, target) + tensor([[0.8000, 0.0000, 0.5000], + [0.0000, 0.4000, 0.4000]]) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + top_k: int = 1, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + beta=1.0, + num_classes=num_classes, + top_k=top_k, + average=average, + multidim_average=multidim_average, + ignore_index=ignore_index, + validate_args=validate_args, + **kwargs, + ) + + +class MultilabelF1Score(MultilabelFBetaScore): + r""" + Computes F-1 score for multilabel tasks: + + .. math:: + F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}} + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)``` + + Example (preds is int tensor): + >>> from torchmetrics.classification import MultilabelF1Score + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> metric = MultilabelF1Score(num_labels=3) + >>> metric(preds, target) + tensor(0.5556) + >>> metric = MultilabelF1Score(num_labels=3, average=None) + >>> metric(preds, target) + tensor([1.0000, 0.0000, 0.6667]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MultilabelF1Score + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> metric = MultilabelF1Score(num_labels=3) + >>> metric(preds, target) + tensor(0.5556) + >>> metric = MultilabelF1Score(num_labels=3, average=None) + >>> metric(preds, target) + tensor([1.0000, 0.0000, 0.6667]) + + Example (multidim tensors): + >>> from torchmetrics.classification import MultilabelF1Score + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> metric = MultilabelF1Score(num_labels=3, multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.4444, 0.0000]) + >>> metric = MultilabelF1Score(num_labels=3, multidim_average='samplewise', average=None) + >>> metric(preds, target) + tensor([[0.6667, 0.6667, 0.0000], + [0.0000, 0.0000, 0.0000]]) + + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + beta=1.0, + num_labels=num_labels, + threshold=threshold, + average=average, + multidim_average=multidim_average, + ignore_index=ignore_index, + validate_args=validate_args, + **kwargs, + ) class FBetaScore(StatScores): - r"""Computes `F-score`_, specifically: + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes `F-score`_, specifically: .. math:: F_\beta = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} @@ -120,12 +824,56 @@ class FBetaScore(StatScores): """ full_state_update: bool = False + def __new__( + cls, + num_classes: Optional[int] = None, + beta: float = 1.0, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + top_k: Optional[int] = None, + multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + if task is not None: + assert multidim_average is not None + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinaryFBetaScore(beta, threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassFBetaScore(beta, num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelFBetaScore(beta, num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + def __init__( self, num_classes: Optional[int] = None, beta: float = 1.0, threshold: float = 0.5, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, @@ -161,7 +909,15 @@ def compute(self) -> Tensor: class F1Score(FBetaScore): - """Computes F1 metric. + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes F1 metric. F1 metrics correspond to a harmonic mean of the precision and recall scores. Works with binary, multiclass, and multilabel data. Accepts logits or probabilities from a model @@ -251,11 +1007,54 @@ class F1Score(FBetaScore): higher_is_better: bool = True full_state_update: bool = False + def __new__( + cls, + num_classes: Optional[int] = None, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + top_k: Optional[int] = None, + multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + if task is not None: + assert multidim_average is not None + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinaryF1Score(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassF1Score(num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelF1Score(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + def __init__( self, num_classes: Optional[int] = None, threshold: float = 0.5, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index be9a0bf430e..143bea2035a 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -11,17 +11,323 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional import torch from torch import Tensor, tensor +from typing_extensions import Literal -from torchmetrics.functional.classification.hamming import _hamming_distance_compute, _hamming_distance_update +from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores +from torchmetrics.functional.classification.hamming import ( + _hamming_distance_compute, + _hamming_distance_reduce, + _hamming_distance_update, +) from torchmetrics.metric import Metric +from torchmetrics.utilities.prints import rank_zero_warn + + +class BinaryHammingDistance(BinaryStatScores): + r"""Computes the average `Hamming distance`_ (also known as Hamming loss) for binary tasks: + + .. math:: + \text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il}) + + Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, + and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that + tensor. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` + is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + + Example (preds is int tensor): + >>> from torchmetrics.classification import BinaryHammingDistance + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> metric = BinaryHammingDistance() + >>> metric(preds, target) + tensor(0.3333) + + Example (preds is float tensor): + >>> from torchmetrics.classification import BinaryHammingDistance + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> metric = BinaryHammingDistance() + >>> metric(preds, target) + tensor(0.3333) + + Example (multidim tensors): + >>> from torchmetrics.classification import BinaryHammingDistance + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> metric = BinaryHammingDistance(multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.6667, 0.8333]) + """ + + is_differentiable: bool = False + higher_is_better: bool = False + full_state_update: bool = False + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _hamming_distance_reduce(tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average) + + +class MulticlassHammingDistance(MulticlassStatScores): + r"""Computes the average `Hamming distance`_ (also known as Hamming loss) for multiclass tasks: + + .. math:: + \text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il}) + + Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, + and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that + tensor. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.classification import MulticlassHammingDistance + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> metric = MulticlassHammingDistance(num_classes=3) + >>> metric(preds, target) + tensor(0.1667) + >>> metric = MulticlassHammingDistance(num_classes=3, average=None) + >>> metric(preds, target) + tensor([0.5000, 0.0000, 0.0000]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MulticlassHammingDistance + >>> target = target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> metric = MulticlassHammingDistance(num_classes=3) + >>> metric(preds, target) + tensor(0.1667) + >>> metric = MulticlassHammingDistance(num_classes=3, average=None) + >>> metric(preds, target) + tensor([0.5000, 0.0000, 0.0000]) + + Example (multidim tensors): + >>> from torchmetrics.classification import MulticlassHammingDistance + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> metric = MulticlassHammingDistance(num_classes=3, multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.5000, 0.7222]) + >>> metric = MulticlassHammingDistance(num_classes=3, multidim_average='samplewise', average=None) + >>> metric(preds, target) + tensor([[0.0000, 1.0000, 0.5000], + [1.0000, 0.6667, 0.5000]]) + """ + + is_differentiable: bool = False + higher_is_better: bool = False + full_state_update: bool = False + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _hamming_distance_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average) + + +class MultilabelHammingDistance(MultilabelStatScores): + r"""Computes the average `Hamming distance`_ (also known as Hamming loss) for multilabel tasks: + + .. math:: + \text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il}) + + Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, + and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that + tensor. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.classification import MultilabelHammingDistance + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> metric = MultilabelHammingDistance(num_labels=3) + >>> metric(preds, target) + tensor(0.3333) + >>> metric = MultilabelHammingDistance(num_labels=3, average=None) + >>> metric(preds, target) + tensor([0.0000, 0.5000, 0.5000]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MultilabelHammingDistance + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> metric = MultilabelHammingDistance(num_labels=3) + >>> metric(preds, target) + tensor(0.3333) + >>> metric = MultilabelHammingDistance(num_labels=3, average=None) + >>> metric(preds, target) + tensor([0.0000, 0.5000, 0.5000]) + + Example (multidim tensors): + >>> from torchmetrics.classification import MultilabelHammingDistance + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> metric = MultilabelHammingDistance(num_labels=3, multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.6667, 0.8333]) + >>> metric = MultilabelHammingDistance(num_labels=3, multidim_average='samplewise', average=None) + >>> metric(preds, target) + tensor([[0.5000, 0.5000, 1.0000], + [1.0000, 1.0000, 0.5000]]) + + """ + + is_differentiable: bool = False + higher_is_better: bool = False + full_state_update: bool = False + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _hamming_distance_reduce( + tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True + ) class HammingDistance(Metric): - r"""Computes the average `Hamming distance`_ (also known as Hamming loss) between targets and predictions: + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes the average `Hamming distance`_ (also known as Hamming loss) between targets and predictions: .. math:: \text{Hamming distance} = \frac{1}{N \cdot L}\sum_i^N \sum_l^L 1(y_{il} \neq \hat{y_{il}}) @@ -62,6 +368,47 @@ class HammingDistance(Metric): correct: Tensor total: Tensor + def __new__( + cls, + threshold: float = 0.5, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + top_k: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + if task is not None: + assert multidim_average is not None + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinaryHammingDistance(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassHammingDistance(num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelHammingDistance(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + def __init__( self, threshold: float = 0.5, diff --git a/src/torchmetrics/classification/hinge.py b/src/torchmetrics/classification/hinge.py index 9fb128a6f73..c5598be5d8c 100644 --- a/src/torchmetrics/classification/hinge.py +++ b/src/torchmetrics/classification/hinge.py @@ -13,14 +13,208 @@ # limitations under the License. from typing import Any, Optional, Union -from torch import Tensor, tensor +import torch +from torch import Tensor +from typing_extensions import Literal -from torchmetrics.functional.classification.hinge import MulticlassMode, _hinge_compute, _hinge_update +from torchmetrics.functional.classification.hinge import ( + MulticlassMode, + _binary_confusion_matrix_format, + _binary_hinge_loss_arg_validation, + _binary_hinge_loss_tensor_validation, + _binary_hinge_loss_update, + _hinge_compute, + _hinge_loss_compute, + _hinge_update, + _multiclass_confusion_matrix_format, + _multiclass_hinge_loss_arg_validation, + _multiclass_hinge_loss_tensor_validation, + _multiclass_hinge_loss_update, +) from torchmetrics.metric import Metric +from torchmetrics.utilities.prints import rank_zero_warn + + +class BinaryHingeLoss(Metric): + r"""Computes the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs) for binary tasks. It is + defined as: + + .. math:: + \text{Hinge loss} = \max(0, 1 - y \times \hat{y}) + + Where :math:`y \in {-1, 1}` is the target, and :math:`\hat{y} \in \mathbb{R}` is the prediction. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + squared: + If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> from torchmetrics.classification import BinaryHingeLoss + >>> preds = torch.tensor([0.25, 0.25, 0.55, 0.75, 0.75]) + >>> target = torch.tensor([0, 0, 1, 1, 1]) + >>> metric = BinaryHingeLoss() + >>> metric(preds, target) + tensor(0.6900) + >>> metric = BinaryHingeLoss(squared=True) + >>> metric(preds, target) + tensor(0.6905) + """ + is_differentiable: bool = True + higher_is_better: bool = False + full_state_update: bool = False + + def __init__( + self, + squared: bool = False, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _binary_hinge_loss_arg_validation(squared, ignore_index) + self.validate_args = validate_args + self.squared = squared + self.ignore_index = ignore_index + + self.add_state("measures", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _binary_hinge_loss_tensor_validation(preds, target, self.ignore_index) + preds, target = _binary_confusion_matrix_format( + preds, target, threshold=0.0, ignore_index=self.ignore_index, convert_to_labels=False + ) + measures, total = _binary_hinge_loss_update(preds, target, self.squared) + self.measures += measures + self.total += total + + def compute(self) -> Tensor: + return _hinge_loss_compute(self.measures, self.total) + + +class MulticlassHingeLoss(Metric): + r"""Computes the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs) for multiclass tasks + + The metric can be computed in two ways. Either, the definition by Crammer and Singer is used: + + .. math:: + \text{Hinge loss} = \max\left(0, 1 - \hat{y}_y + \max_{i \ne y} (\hat{y}_i)\right) + + Where :math:`y \in {0, ..., \mathrm{C}}` is the target class (where :math:`\mathrm{C}` is the number of classes), + and :math:`\hat{y} \in \mathbb{R}^\mathrm{C}` is the predicted output per class. Alternatively, the metric can + also be computed in one-vs-all approach, where each class is valued against all other classes in a binary fashion. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of classes + squared: + If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss. + multiclass_mode: + Determines how to compute the metric + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> from torchmetrics.classification import MulticlassHingeLoss + >>> preds = torch.tensor([[0.25, 0.20, 0.55], + ... [0.55, 0.05, 0.40], + ... [0.10, 0.30, 0.60], + ... [0.90, 0.05, 0.05]]) + >>> target = torch.tensor([0, 1, 2, 0]) + >>> metric = MulticlassHingeLoss(num_classes=3) + >>> metric(preds, target) + tensor(0.9125) + >>> metric = MulticlassHingeLoss(num_classes=3, squared=True) + >>> metric(preds, target) + tensor(1.1131) + >>> metric = MulticlassHingeLoss(num_classes=3, multiclass_mode='one-vs-all') + >>> metric(preds, target) + tensor([0.8750, 1.1250, 1.1000]) + """ + is_differentiable: bool = True + higher_is_better: bool = False + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + squared: bool = False, + multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multiclass_hinge_loss_arg_validation(num_classes, squared, multiclass_mode, ignore_index) + self.validate_args = validate_args + self.num_classes = num_classes + self.squared = squared + self.multiclass_mode = multiclass_mode + self.ignore_index = ignore_index + + self.add_state( + "measures", + default=torch.tensor(0.0) + if self.multiclass_mode == "crammer-singer" + else torch.zeros( + num_classes, + ), + dist_reduce_fx="sum", + ) + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _multiclass_hinge_loss_tensor_validation(preds, target, self.num_classes, self.ignore_index) + preds, target = _multiclass_confusion_matrix_format(preds, target, self.ignore_index, convert_to_labels=False) + measures, total = _multiclass_hinge_loss_update(preds, target, self.squared, self.multiclass_mode) + self.measures += measures + self.total += total + + def compute(self) -> Tensor: + return _hinge_loss_compute(self.measures, self.total) class HingeLoss(Metric): - r"""Computes the mean `Hinge loss`_, typically used for Support Vector Machines (SVMs). + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes the mean `Hinge loss`_, typically used for Support Vector Machines (SVMs). In the binary case it is defined as: @@ -93,6 +287,38 @@ class HingeLoss(Metric): measure: Tensor total: Tensor + def __new__( + cls, + squared: bool = False, + multiclass_mode: Literal["crammer-singer", "one-vs-all"] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + if task is not None: + kwargs.update(dict(ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryHingeLoss(squared, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert multiclass_mode is not None + return MulticlassHingeLoss(num_classes, squared, multiclass_mode, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + def __init__( self, squared: bool = False, @@ -101,8 +327,8 @@ def __init__( ) -> None: super().__init__(**kwargs) - self.add_state("measure", default=tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + self.add_state("measure", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") if multiclass_mode not in (None, MulticlassMode.CRAMMER_SINGER, MulticlassMode.ONE_VS_ALL): raise ValueError( diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index 354a73c762d..b89ac7cfd29 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -15,13 +15,257 @@ import torch from torch import Tensor +from typing_extensions import Literal +from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix from torchmetrics.classification.confusion_matrix import ConfusionMatrix -from torchmetrics.functional.classification.jaccard import _jaccard_from_confmat +from torchmetrics.functional.classification.jaccard import ( + _jaccard_from_confmat, + _jaccard_index_reduce, + _multiclass_jaccard_index_arg_validation, + _multilabel_jaccard_index_arg_validation, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.prints import rank_zero_warn + + +class BinaryJaccardIndex(BinaryConfusionMatrix): + r"""Calculates the Jaccard index for binary tasks. The `Jaccard index`_ (also known as + the intersetion over union or jaccard similarity coefficient) is an statistic that can be + used to determine the similarity and diversity of a sample set. It is defined as the size + of the intersection divided by the union of the sample sets: + + .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics.classification import BinaryJaccardIndex + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> metric = BinaryJaccardIndex() + >>> metric(preds, target) + tensor(0.5000) + + Example (preds is float tensor): + >>> from torchmetrics.classification import BinaryJaccardIndex + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) + >>> metric = BinaryJaccardIndex() + >>> metric(preds, target) + tensor(0.5000) + """ + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + def __init__( + self, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + threshold=threshold, ignore_index=ignore_index, normalize=None, validate_args=validate_args, **kwargs + ) + + def compute(self) -> Tensor: + return _jaccard_index_reduce(self.confmat, average="binary") + + +class MulticlassJaccardIndex(MulticlassConfusionMatrix): + r"""Calculates the Jaccard index for multiclass tasks. The `Jaccard index`_ (also known as + the intersetion over union or jaccard similarity coefficient) is an statistic that can be + used to determine the similarity and diversity of a sample set. It is defined as the size + of the intersection divided by the union of the sample sets: + + .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of classes + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (pred is integer tensor): + >>> from torchmetrics.classification import MulticlassJaccardIndex + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> metric = MulticlassJaccardIndex(num_classes=3) + >>> metric(preds, target) + tensor(0.6667) + + Example (pred is float tensor): + >>> from torchmetrics.classification import MulticlassJaccardIndex + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> metric = MulticlassJaccardIndex(num_classes=3) + >>> metric(preds, target) + tensor(0.6667) + """ + + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_classes=num_classes, ignore_index=ignore_index, normalize=None, validate_args=False, **kwargs + ) + if validate_args: + _multiclass_jaccard_index_arg_validation(num_classes, ignore_index, average) + self.validate_args = validate_args + self.average = average + + def compute(self) -> Tensor: + return _jaccard_index_reduce(self.confmat, average=self.average) + + +class MultilabelJaccardIndex(MultilabelConfusionMatrix): + r"""Calculates the Jaccard index for multilabel tasks. The `Jaccard index`_ (also known as + the intersetion over union or jaccard similarity coefficient) is an statistic that can be + used to determine the similarity and diversity of a sample set. It is defined as the size + of the intersection divided by the union of the sample sets: + + .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics.classification import MultilabelJaccardIndex + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> metric = MultilabelJaccardIndex(num_labels=3) + >>> metric(preds, target) + tensor(0.5000) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MultilabelJaccardIndex + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> metric = MultilabelJaccardIndex(num_labels=3) + >>> metric(preds, target) + tensor(0.5000) + + """ + + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_labels=num_labels, + threshold=threshold, + ignore_index=ignore_index, + normalize=None, + validate_args=False, + **kwargs, + ) + if validate_args: + _multilabel_jaccard_index_arg_validation(num_labels, threshold, ignore_index, average) + self.validate_args = validate_args + self.average = average + + def compute(self) -> Tensor: + return _jaccard_index_reduce(self.confmat, average=self.average) class JaccardIndex(ConfusionMatrix): - r"""Computes Intersection over union, or `Jaccard index`_: + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes Intersection over union, or `Jaccard index`_: .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} @@ -80,10 +324,47 @@ class JaccardIndex(ConfusionMatrix): higher_is_better: bool = True full_state_update: bool = False + def __new__( + cls, + num_classes: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, + absent_score: float = 0.0, + threshold: float = 0.5, + multilabel: bool = False, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + if task is not None: + kwargs.update(dict(ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryJaccardIndex(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassJaccardIndex(num_classes, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelJaccardIndex(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + def __init__( self, num_classes: int, - average: Optional[str] = "macro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", ignore_index: Optional[int] = None, absent_score: float = 0.0, threshold: float = 0.5, diff --git a/src/torchmetrics/classification/kl_divergence.py b/src/torchmetrics/classification/kl_divergence.py index 0dbb17f0e92..b6144bf6edf 100644 --- a/src/torchmetrics/classification/kl_divergence.py +++ b/src/torchmetrics/classification/kl_divergence.py @@ -13,16 +13,13 @@ # limitations under the License. from typing import Any -import torch -from torch import Tensor from typing_extensions import Literal -from torchmetrics.functional.classification.kl_divergence import _kld_compute, _kld_update -from torchmetrics.metric import Metric -from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.regression.kl_divergence import KLDivergence as _KLDivergence +from torchmetrics.utilities.prints import rank_zero_warn -class KLDivergence(Metric): +class KLDivergence(_KLDivergence): r"""Computes the `KL divergence`_: .. math:: @@ -46,6 +43,8 @@ class KLDivergence(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + .. note:: + This metric have been moved to the regression package in v0.10 and this version will be removed in v0.11. Raises: TypeError: @@ -65,10 +64,6 @@ class KLDivergence(Metric): tensor(0.0853) """ - is_differentiable: bool = True - higher_is_better: bool = False - full_state_update: bool = False - total: Tensor def __init__( self, @@ -76,30 +71,9 @@ def __init__( reduction: Literal["mean", "sum", "none", None] = "mean", **kwargs: Any, ) -> None: - super().__init__(**kwargs) - if not isinstance(log_prob, bool): - raise TypeError(f"Expected argument `log_prob` to be bool but got {log_prob}") - self.log_prob = log_prob - - allowed_reduction = ["mean", "sum", "none", None] - if reduction not in allowed_reduction: - raise ValueError(f"Expected argument `reduction` to be one of {allowed_reduction} but got {reduction}") - self.reduction = reduction - - if self.reduction in ["mean", "sum"]: - self.add_state("measures", torch.tensor(0.0), dist_reduce_fx="sum") - else: - self.add_state("measures", [], dist_reduce_fx="cat") - self.add_state("total", torch.tensor(0), dist_reduce_fx="sum") - - def update(self, p: Tensor, q: Tensor) -> None: # type: ignore - measures, total = _kld_update(p, q, self.log_prob) - if self.reduction is None or self.reduction == "none": - self.measures.append(measures) - else: - self.measures += measures.sum() - self.total += total - - def compute(self) -> Tensor: - measures = dim_zero_cat(self.measures) if self.reduction is None or self.reduction == "none" else self.measures - return _kld_compute(measures, self.total, self.reduction) + super().__init__(log_prob, reduction, **kwargs) + rank_zero_warn( + "`torchmetrics.classification.KLDivergence` have been moved to `torchmetrics.regression.KLDivergence`" + " from v0.10 and this version will be removed in v0.11. Please update import paths.", + DeprecationWarning, + ) diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index db11dd40306..7c290f8740b 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -11,21 +11,230 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional import torch from torch import Tensor +from typing_extensions import Literal +from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix from torchmetrics.functional.classification.matthews_corrcoef import ( _matthews_corrcoef_compute, + _matthews_corrcoef_reduce, _matthews_corrcoef_update, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.prints import rank_zero_warn + + +class BinaryMatthewsCorrCoef(BinaryConfusionMatrix): + r""" + Calculates `Matthews correlation coefficient`_ for binary tasks. This metric measures + the general correlation or quality of a classification. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics.classification import BinaryMatthewsCorrCoef + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> metric = BinaryMatthewsCorrCoef() + >>> metric(preds, target) + tensor(0.5774) + + Example (preds is float tensor): + >>> from torchmetrics.classification import BinaryMatthewsCorrCoef + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) + >>> metric = BinaryMatthewsCorrCoef() + >>> metric(preds, target) + tensor(0.5774) + + """ + + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + def __init__( + self, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(threshold, ignore_index, normalize=None, validate_args=validate_args, **kwargs) + + def compute(self) -> Tensor: + return _matthews_corrcoef_reduce(self.confmat) + + +class MulticlassMatthewsCorrCoef(MulticlassConfusionMatrix): + r"""Calculates `Matthews correlation coefficient`_ for multiclass tasks. This metric measures + the general correlation or quality of a classification. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of classes + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (pred is integer tensor): + >>> from torchmetrics.classification import MulticlassMatthewsCorrCoef + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> metric = MulticlassMatthewsCorrCoef(num_classes=3) + >>> metric(preds, target) + tensor(0.7000) + + Example (pred is float tensor): + >>> from torchmetrics.classification import MulticlassMatthewsCorrCoef + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> metric = MulticlassMatthewsCorrCoef(num_classes=3) + >>> metric(preds, target) + tensor(0.7000) + + """ + + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(num_classes, ignore_index, normalize=None, validate_args=validate_args, **kwargs) + + def compute(self) -> Tensor: + return _matthews_corrcoef_reduce(self.confmat) + + +class MultilabelMatthewsCorrCoef(MultilabelConfusionMatrix): + r"""Calculates `Matthews correlation coefficient`_ for multilabel tasks. This metric measures + the general correlation or quality of a classification. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics.classification import MultilabelMatthewsCorrCoef + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> metric = MultilabelMatthewsCorrCoef(num_labels=3) + >>> metric(preds, target) + tensor(0.3333) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MultilabelMatthewsCorrCoef + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> metric = MultilabelMatthewsCorrCoef(num_labels=3) + >>> metric(preds, target) + tensor(0.3333) + + """ + + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(num_labels, threshold, ignore_index, normalize=None, validate_args=validate_args, **kwargs) + + def compute(self) -> Tensor: + return _matthews_corrcoef_reduce(self.confmat) class MatthewsCorrCoef(Metric): - r"""Calculates `Matthews correlation coefficient`_ that measures the general correlation - or quality of a classification. + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Calculates `Matthews correlation coefficient`_ that measures the general correlation + or quality of a classification. In the binary case it is defined as: @@ -69,6 +278,40 @@ class MatthewsCorrCoef(Metric): full_state_update: bool = False confmat: Tensor + def __new__( + cls, + num_classes: int, + threshold: float = 0.5, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + if task is not None: + kwargs.update(dict(ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryMatthewsCorrCoef(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassMatthewsCorrCoef(num_classes, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelMatthewsCorrCoef(num_labels, threshold, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + def __init__( self, num_classes: int, diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index 38737148b6a..e5b605a741f 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -13,15 +13,608 @@ # limitations under the License. from typing import Any, Optional +import torch from torch import Tensor - -from torchmetrics.classification.stat_scores import StatScores -from torchmetrics.functional.classification.precision_recall import _precision_compute, _recall_compute +from typing_extensions import Literal + +from torchmetrics.classification.stat_scores import ( + BinaryStatScores, + MulticlassStatScores, + MultilabelStatScores, + StatScores, +) +from torchmetrics.functional.classification.precision_recall import ( + _precision_compute, + _precision_recall_reduce, + _recall_compute, +) +from torchmetrics.metric import Metric from torchmetrics.utilities.enums import AverageMethod +from torchmetrics.utilities.prints import rank_zero_warn + + +class BinaryPrecision(BinaryStatScores): + r"""Computes `Precision`_ for binary tasks: + + .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} + + Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and + false positives respecitively. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` + is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + + Example (preds is int tensor): + >>> from torchmetrics.classification import BinaryPrecision + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> metric = BinaryPrecision() + >>> metric(preds, target) + tensor(0.6667) + + Example (preds is float tensor): + >>> from torchmetrics.classification import BinaryPrecision + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> metric = BinaryPrecision() + >>> metric(preds, target) + tensor(0.6667) + + Example (multidim tensors): + >>> from torchmetrics.classification import BinaryPrecision + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> metric = BinaryPrecision(multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.4000, 0.0000]) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = False + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _precision_recall_reduce( + "precision", tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average + ) + + +class MulticlassPrecision(MulticlassStatScores): + r"""Computes `Precision`_ for multiclass tasks + + .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} + + Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and + false positives respecitively. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.classification import MulticlassPrecision + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> metric = MulticlassPrecision(num_classes=3) + >>> metric(preds, target) + tensor(0.8333) + >>> metric = MulticlassPrecision(num_classes=3, average=None) + >>> metric(preds, target) + tensor([1.0000, 0.5000, 1.0000]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MulticlassPrecision + >>> target = target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> metric = MulticlassPrecision(num_classes=3) + >>> metric(preds, target) + tensor(0.8333) + >>> metric = MulticlassPrecision(num_classes=3, average=None) + >>> metric(preds, target) + tensor([1.0000, 0.5000, 1.0000]) + + Example (multidim tensors): + >>> from torchmetrics.classification import MulticlassPrecision + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> metric = MulticlassPrecision(num_classes=3, multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.3889, 0.2778]) + >>> metric = MulticlassPrecision(num_classes=3, multidim_average='samplewise', average=None) + >>> metric(preds, target) + tensor([[0.6667, 0.0000, 0.5000], + [0.0000, 0.5000, 0.3333]]) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = False + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _precision_recall_reduce( + "precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average + ) + + +class MultilabelPrecision(MultilabelStatScores): + r"""Computes `Precision`_ for multilabel tasks + + .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} + + Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and + false positives respecitively. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.classification import MultilabelPrecision + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> metric = MultilabelPrecision(num_labels=3) + >>> metric(preds, target) + tensor(0.5000) + >>> metric = MultilabelPrecision(num_labels=3, average=None) + >>> metric(preds, target) + tensor([1.0000, 0.0000, 0.5000]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MultilabelPrecision + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> metric = MultilabelPrecision(num_labels=3) + >>> metric(preds, target) + tensor(0.5000) + >>> metric = MultilabelPrecision(num_labels=3, average=None) + >>> metric(preds, target) + tensor([1.0000, 0.0000, 0.5000]) + + Example (multidim tensors): + >>> from torchmetrics.classification import MultilabelPrecision + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> metric = MultilabelPrecision(num_labels=3, multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.3333, 0.0000]) + >>> metric = MultilabelPrecision(num_labels=3, multidim_average='samplewise', average=None) + >>> metric(preds, target) + tensor([[0.5000, 0.5000, 0.0000], + [0.0000, 0.0000, 0.0000]]) + + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = False + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _precision_recall_reduce( + "precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average + ) + + +class BinaryRecall(BinaryStatScores): + r"""Computes `Recall`_ for binary tasks: + + .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} + + Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and + false negatives respecitively. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` + is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + + Example (preds is int tensor): + >>> from torchmetrics.classification import BinaryRecall + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> metric = BinaryRecall() + >>> metric(preds, target) + tensor(0.6667) + + Example (preds is float tensor): + >>> from torchmetrics.classification import BinaryRecall + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> metric = BinaryRecall() + >>> metric(preds, target) + tensor(0.6667) + + Example (multidim tensors): + >>> from torchmetrics.classification import BinaryRecall + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> metric = BinaryRecall(multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.6667, 0.0000]) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = False + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _precision_recall_reduce( + "recall", tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average + ) + + +class MulticlassRecall(MulticlassStatScores): + r"""Computes `Recall`_ for multiclass tasks: + + .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} + + Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and + false negatives respecitively. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.classification import MulticlassRecall + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> metric = MulticlassRecall(num_classes=3) + >>> metric(preds, target) + tensor(0.8333) + >>> metric = MulticlassRecall(num_classes=3, average=None) + >>> metric(preds, target) + tensor([0.5000, 1.0000, 1.0000]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MulticlassRecall + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> metric = MulticlassRecall(num_classes=3) + >>> metric(preds, target) + tensor(0.8333) + >>> metric = MulticlassRecall(num_classes=3, average=None) + >>> metric(preds, target) + tensor([0.5000, 1.0000, 1.0000]) + + Example (multidim tensors): + >>> from torchmetrics.classification import MulticlassRecall + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> metric = MulticlassRecall(num_classes=3, multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.5000, 0.2778]) + >>> metric = MulticlassRecall(num_classes=3, multidim_average='samplewise', average=None) + >>> metric(preds, target) + tensor([[1.0000, 0.0000, 0.5000], + [0.0000, 0.3333, 0.5000]]) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = False + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _precision_recall_reduce( + "recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average + ) + + +class MultilabelRecall(MultilabelStatScores): + r"""Computes `Recall`_ for multilabel tasks: + + .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} + + Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and + false negatives respecitively. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.classification import MultilabelRecall + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> metric = MultilabelRecall(num_labels=3) + >>> metric(preds, target) + tensor(0.6667) + >>> metric = MultilabelRecall(num_labels=3, average=None) + >>> metric(preds, target) + tensor([1., 0., 1.]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MultilabelRecall + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> metric = MultilabelRecall(num_labels=3) + >>> metric(preds, target) + tensor(0.6667) + >>> metric = MultilabelRecall(num_labels=3, average=None) + >>> metric(preds, target) + tensor([1., 0., 1.]) + + Example (multidim tensors): + >>> from torchmetrics.classification import MultilabelRecall + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> metric = MultilabelRecall(num_labels=3, multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.6667, 0.0000]) + >>> metric = MultilabelRecall(num_labels=3, multidim_average='samplewise', average=None) + >>> metric(preds, target) + tensor([[1., 1., 0.], + [0., 0., 0.]]) + + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = False + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _precision_recall_reduce( + "recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average + ) class Precision(StatScores): - r"""Computes `Precision`_: + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes `Precision`_: .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} @@ -113,11 +706,54 @@ class Precision(StatScores): higher_is_better = True full_state_update: bool = False + def __new__( + cls, + threshold: float = 0.5, + num_classes: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + top_k: Optional[int] = None, + multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + if task is not None: + assert multidim_average is not None + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinaryPrecision(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassPrecision(num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelPrecision(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + def __init__( self, num_classes: Optional[int] = None, threshold: float = 0.5, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, @@ -160,7 +796,15 @@ def compute(self) -> Tensor: class Recall(StatScores): - r"""Computes `Recall`_: + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes `Recall`_: .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} @@ -252,11 +896,54 @@ class Recall(StatScores): higher_is_better: bool = True full_state_update: bool = False + def __new__( + cls, + threshold: float = 0.5, + num_classes: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + top_k: Optional[int] = None, + multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + if task is not None: + assert multidim_average is not None + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinaryRecall(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassRecall(num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelRecall(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + def __init__( self, num_classes: Optional[int] = None, threshold: float = 0.5, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index ee4e29aecbc..8a2619b662e 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -15,8 +15,25 @@ import torch from torch import Tensor +from typing_extensions import Literal from torchmetrics.functional.classification.precision_recall_curve import ( + _adjust_threshold_arg, + _binary_precision_recall_curve_arg_validation, + _binary_precision_recall_curve_compute, + _binary_precision_recall_curve_format, + _binary_precision_recall_curve_tensor_validation, + _binary_precision_recall_curve_update, + _multiclass_precision_recall_curve_arg_validation, + _multiclass_precision_recall_curve_compute, + _multiclass_precision_recall_curve_format, + _multiclass_precision_recall_curve_tensor_validation, + _multiclass_precision_recall_curve_update, + _multilabel_precision_recall_curve_arg_validation, + _multilabel_precision_recall_curve_compute, + _multilabel_precision_recall_curve_format, + _multilabel_precision_recall_curve_tensor_validation, + _multilabel_precision_recall_curve_update, _precision_recall_curve_compute, _precision_recall_curve_update, ) @@ -25,8 +42,386 @@ from torchmetrics.utilities.data import dim_zero_cat +class BinaryPrecisionRecallCurve(Metric): + r""" + Computes the precision-recall curve for binary tasks. The curve consist of multiple pairs of precision and + recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + (tuple): a tuple of 3 tensors containing: + + - precision: an 1d tensor of size (n_thresholds+1, ) with precision values + - recall: an 1d tensor of size (n_thresholds+1, ) with recall values + - thresholds: an 1d tensor of size (n_thresholds, ) with increasing threshold values + + Example: + >>> from torchmetrics.classification import BinaryPrecisionRecallCurve + >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> metric = BinaryPrecisionRecallCurve(thresholds=None) + >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE + (tensor([0.6667, 0.5000, 0.0000, 1.0000]), + tensor([1.0000, 0.5000, 0.0000, 0.0000]), + tensor([0.5000, 0.7000, 0.8000])) + >>> metric = BinaryPrecisionRecallCurve(thresholds=5) + >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE + (tensor([0.5000, 0.6667, 0.6667, 0.0000, 0.0000, 1.0000]), + tensor([1., 1., 1., 0., 0., 0.]), + tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) + + self.ignore_index = ignore_index + self.validate_args = validate_args + + if thresholds is None: + self.thresholds = thresholds + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + else: + self.register_buffer("thresholds", _adjust_threshold_arg(thresholds)) + self.add_state("confmat", default=torch.zeros(thresholds, 2, 2, dtype=torch.long), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _binary_precision_recall_curve_tensor_validation(preds, target, self.ignore_index) + preds, target, _ = _binary_precision_recall_curve_format(preds, target, self.thresholds, self.ignore_index) + state = _binary_precision_recall_curve_update(preds, target, self.thresholds) + if isinstance(state, Tensor): + self.confmat += state + else: + self.preds.append(state[0]) + self.target.append(state[1]) + + def compute(self) -> Tuple[Tensor, Tensor, Tensor]: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _binary_precision_recall_curve_compute(state, self.thresholds) + + +class MulticlassPrecisionRecallCurve(Metric): + r""" + Computes the precision-recall curve for multiclass tasks. The curve consist of multiple pairs of precision and + recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Args: + num_classes: Integer specifing the number of classes + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + (tuple): a tuple of either 3 tensors or 3 lists containing + + - precision: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) + with precision values (length may differ between classes). If `thresholds` is set to something else, + then a single 2d tensor of size (n_classes, n_thresholds+1) with precision values is returned. + - recall: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) + with recall values (length may differ between classes). If `thresholds` is set to something else, + then a single 2d tensor of size (n_classes, n_thresholds+1) with recall values is returned. + - thresholds: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds, ) + with increasing threshold values (length may differ between classes). If `threshold` is set to something else, + then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all classes. + + Example: + >>> from torchmetrics.classification import MulticlassPrecisionRecallCurve + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> metric = MulticlassPrecisionRecallCurve(num_classes=5, thresholds=None) + >>> precision, recall, thresholds = metric(preds, target) + >>> precision # doctest: +NORMALIZE_WHITESPACE + [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), + tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] + >>> recall + [tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])] + >>> thresholds + [tensor(0.7500), tensor(0.7500), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor(0.0500)] + >>> metric = MulticlassPrecisionRecallCurve(num_classes=5, thresholds=5) + >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE + (tensor([[0.2500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000], + [0.2500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000], + [0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000], + [0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]), + tensor([[1., 1., 1., 1., 0., 0.], + [1., 1., 1., 1., 0., 0.], + [1., 0., 0., 0., 0., 0.], + [1., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0.]]), + tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) + + self.num_classes = num_classes + self.ignore_index = ignore_index + self.validate_args = validate_args + + if thresholds is None: + self.thresholds = thresholds + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + else: + self.register_buffer("thresholds", _adjust_threshold_arg(thresholds)) + self.add_state( + "confmat", default=torch.zeros(thresholds, num_classes, 2, 2, dtype=torch.long), dist_reduce_fx="sum" + ) + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _multiclass_precision_recall_curve_tensor_validation(preds, target, self.num_classes, self.ignore_index) + preds, target, _ = _multiclass_precision_recall_curve_format( + preds, target, self.num_classes, self.thresholds, self.ignore_index + ) + state = _multiclass_precision_recall_curve_update(preds, target, self.num_classes, self.thresholds) + if isinstance(state, Tensor): + self.confmat += state + else: + self.preds.append(state[0]) + self.target.append(state[1]) + + def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _multiclass_precision_recall_curve_compute(state, self.num_classes, self.thresholds) + + +class MultilabelPrecisionRecallCurve(Metric): + r""" + Computes the precision-recall curve for multilabel tasks. The curve consist of multiple pairs of precision and + recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of either 3 tensors or 3 lists containing + + - precision: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) + with precision values (length may differ between labels). If `thresholds` is set to something else, + then a single 2d tensor of size (n_labels, n_thresholds+1) with precision values is returned. + - recall: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) + with recall values (length may differ between labels). If `thresholds` is set to something else, + then a single 2d tensor of size (n_labels, n_thresholds+1) with recall values is returned. + - thresholds: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds, ) + with increasing threshold values (length may differ between labels). If `threshold` is set to something else, + then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all labels. + + Example: + >>> from torchmetrics.classification import MultilabelPrecisionRecallCurve + >>> preds = torch.tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = torch.tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> metric = MultilabelPrecisionRecallCurve(num_labels=3, thresholds=None) + >>> precision, recall, thresholds = metric(preds, target) + >>> precision # doctest: +NORMALIZE_WHITESPACE + [tensor([0.5000, 0.5000, 1.0000, 1.0000]), tensor([0.6667, 0.5000, 0.0000, 1.0000]), + tensor([0.7500, 1.0000, 1.0000, 1.0000])] + >>> recall # doctest: +NORMALIZE_WHITESPACE + [tensor([1.0000, 0.5000, 0.5000, 0.0000]), tensor([1.0000, 0.5000, 0.0000, 0.0000]), + tensor([1.0000, 0.6667, 0.3333, 0.0000])] + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + [tensor([0.0500, 0.4500, 0.7500]), tensor([0.5500, 0.6500, 0.7500]), + tensor([0.0500, 0.3500, 0.7500])] + >>> metric = MultilabelPrecisionRecallCurve(num_labels=3, thresholds=5) + >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE + (tensor([[0.5000, 0.5000, 1.0000, 1.0000, 0.0000, 1.0000], + [0.5000, 0.6667, 0.6667, 0.0000, 0.0000, 1.0000], + [0.7500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000]]), + tensor([[1.0000, 0.5000, 0.5000, 0.5000, 0.0000, 0.0000], + [1.0000, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000], + [1.0000, 0.6667, 0.3333, 0.3333, 0.0000, 0.0000]]), + tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) + + self.num_labels = num_labels + self.ignore_index = ignore_index + self.validate_args = validate_args + + if thresholds is None: + self.thresholds = thresholds + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + else: + self.register_buffer("thresholds", _adjust_threshold_arg(thresholds)) + self.add_state( + "confmat", default=torch.zeros(thresholds, num_labels, 2, 2, dtype=torch.long), dist_reduce_fx="sum" + ) + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _multilabel_precision_recall_curve_tensor_validation(preds, target, self.num_labels, self.ignore_index) + preds, target, _ = _multilabel_precision_recall_curve_format( + preds, target, self.num_labels, self.thresholds, self.ignore_index + ) + state = _multilabel_precision_recall_curve_update(preds, target, self.num_labels, self.thresholds) + if isinstance(state, Tensor): + self.confmat += state + else: + self.preds.append(state[0]) + self.target.append(state[1]) + + def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _multilabel_precision_recall_curve_compute(state, self.num_labels, self.thresholds, self.ignore_index) + + class PrecisionRecallCurve(Metric): - """Computes precision-recall pairs for different thresholds. Works for both binary and multiclass problems. In + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes precision-recall pairs for different thresholds. Works for both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach. Forward accepts @@ -81,6 +476,41 @@ class PrecisionRecallCurve(Metric): preds: List[Tensor] target: List[Tensor] + def __new__( + cls, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + if task is not None: + kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryPrecisionRecallCurve(**kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassPrecisionRecallCurve(num_classes, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelPrecisionRecallCurve(num_labels, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + def __init__( self, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/ranking.py b/src/torchmetrics/classification/ranking.py index d0bbe68bec1..3b7aba9ba34 100644 --- a/src/torchmetrics/classification/ranking.py +++ b/src/torchmetrics/classification/ranking.py @@ -23,10 +23,228 @@ _label_ranking_average_precision_update, _label_ranking_loss_compute, _label_ranking_loss_update, + _multilabel_confusion_matrix_arg_validation, + _multilabel_confusion_matrix_format, + _multilabel_coverage_error_update, + _multilabel_ranking_average_precision_update, + _multilabel_ranking_loss_update, + _multilabel_ranking_tensor_validation, + _ranking_reduce, ) from torchmetrics.metric import Metric +class MultilabelCoverageError(Metric): + """Computes multilabel coverage error [1]. The score measure how far we need to go through the ranked scores to + cover all true labels. The best value is equal to the average number of labels in the target tensor per sample. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_labels: Integer specifing the number of labels + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example: + >>> from torchmetrics.classification import MultilabelCoverageError + >>> _ = torch.manual_seed(42) + >>> preds = torch.rand(10, 5) + >>> target = torch.randint(2, (10, 5)) + >>> metric = MultilabelCoverageError(num_labels=5) + >>> metric(preds, target) + tensor(3.9000) + + References: + [1] Tsoumakas, G., Katakis, I., & Vlahavas, I. (2010). Mining multi-label data. In Data mining and + knowledge discovery handbook (pp. 667-685). Springer US. + """ + + higher_is_better: bool = False + is_differentiable: bool = False + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multilabel_confusion_matrix_arg_validation(num_labels, threshold=0.0, ignore_index=ignore_index) + self.validate_args = validate_args + self.num_labels = num_labels + self.ignore_index = ignore_index + self.add_state("measure", torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", torch.tensor(0.0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _multilabel_ranking_tensor_validation(preds, target, self.num_labels, self.ignore_index) + preds, target = _multilabel_confusion_matrix_format( + preds, target, self.num_labels, threshold=0.0, ignore_index=self.ignore_index, should_threshold=False + ) + measure, n_elements = _multilabel_coverage_error_update(preds, target) + self.measure += measure + self.total += n_elements + + def compute(self) -> Tensor: + return _ranking_reduce(self.measure, self.total) + + +class MultilabelRankingAveragePrecision(Metric): + """Computes label ranking average precision score for multilabel data [1]. The score is the average over each + ground truth label assigned to each sample of the ratio of true vs. total labels with lower score. Best score + is 1. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_labels: Integer specifing the number of labels + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example: + >>> from torchmetrics.classification import MultilabelRankingAveragePrecision + >>> _ = torch.manual_seed(42) + >>> preds = torch.rand(10, 5) + >>> target = torch.randint(2, (10, 5)) + >>> metric = MultilabelRankingAveragePrecision(num_labels=5) + >>> metric(preds, target) + tensor(0.7744) + + References: + [1] Tsoumakas, G., Katakis, I., & Vlahavas, I. (2010). Mining multi-label data. In Data mining and + knowledge discovery handbook (pp. 667-685). Springer US. + """ + + higher_is_better: bool = True + is_differentiable: bool = False + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multilabel_confusion_matrix_arg_validation(num_labels, threshold=0.0, ignore_index=ignore_index) + self.validate_args = validate_args + self.num_labels = num_labels + self.ignore_index = ignore_index + self.add_state("measure", torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", torch.tensor(0.0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _multilabel_ranking_tensor_validation(preds, target, self.num_labels, self.ignore_index) + preds, target = _multilabel_confusion_matrix_format( + preds, target, self.num_labels, threshold=0.0, ignore_index=self.ignore_index, should_threshold=False + ) + measure, n_elements = _multilabel_ranking_average_precision_update(preds, target) + self.measure += measure + self.total += n_elements + + def compute(self) -> Tensor: + return _ranking_reduce(self.measure, self.total) + + +class MultilabelRankingLoss(Metric): + """Computes the label ranking loss for multilabel data [1]. The score is corresponds to the average number of + label pairs that are incorrectly ordered given some predictions weighted by the size of the label set and the + number of labels not in the label set. The best score is 0. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example: + >>> from torchmetrics.classification import MultilabelRankingLoss + >>> _ = torch.manual_seed(42) + >>> preds = torch.rand(10, 5) + >>> target = torch.randint(2, (10, 5)) + >>> metric = MultilabelRankingLoss(num_labels=5) + >>> metric(preds, target) + tensor(0.4167) + + References: + [1] Tsoumakas, G., Katakis, I., & Vlahavas, I. (2010). Mining multi-label data. In Data mining and + knowledge discovery handbook (pp. 667-685). Springer US. + """ + + higher_is_better: bool = False + is_differentiable: bool = False + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multilabel_confusion_matrix_arg_validation(num_labels, threshold=0.0, ignore_index=ignore_index) + self.validate_args = validate_args + self.num_labels = num_labels + self.ignore_index = ignore_index + self.add_state("measure", torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", torch.tensor(0.0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _multilabel_ranking_tensor_validation(preds, target, self.num_labels, self.ignore_index) + preds, target = _multilabel_confusion_matrix_format( + preds, target, self.num_labels, threshold=0.0, ignore_index=self.ignore_index, should_threshold=False + ) + measure, n_elements = _multilabel_ranking_loss_update(preds, target) + self.measure += measure + self.total += n_elements + + def compute(self) -> Tensor: + return _ranking_reduce(self.measure, self.total) + + class CoverageError(Metric): """Computes multilabel coverage error [1]. The score measure how far we need to go through the ranked scores to cover all true labels. The best value is equal to the average number of labels in the target tensor per sample. diff --git a/src/torchmetrics/classification/recall_at_fixed_precision.py b/src/torchmetrics/classification/recall_at_fixed_precision.py new file mode 100644 index 00000000000..b742f6146c1 --- /dev/null +++ b/src/torchmetrics/classification/recall_at_fixed_precision.py @@ -0,0 +1,299 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional, Tuple, Union + +import torch +from torch import Tensor + +from torchmetrics.classification.precision_recall_curve import ( + BinaryPrecisionRecallCurve, + MulticlassPrecisionRecallCurve, + MultilabelPrecisionRecallCurve, +) +from torchmetrics.functional.classification.recall_at_fixed_precision import ( + _binary_recall_at_fixed_precision_arg_validation, + _binary_recall_at_fixed_precision_compute, + _multiclass_recall_at_fixed_precision_arg_compute, + _multiclass_recall_at_fixed_precision_arg_validation, + _multilabel_recall_at_fixed_precision_arg_compute, + _multilabel_recall_at_fixed_precision_arg_validation, +) +from torchmetrics.utilities.data import dim_zero_cat + + +class BinaryRecallAtFixedPrecision(BinaryPrecisionRecallCurve): + r""" + Computes the higest possible recall value given the minimum precision thresholds provided. This is done by + first calculating the precision-recall curve for different thresholds and the find the recall for a given + precision level. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + min_precision: float value specifying minimum precision threshold. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + (tuple): a tuple of 2 tensors containing: + + - recall: an scalar tensor with the maximum recall for the given precision level + - threshold: an scalar tensor with the corresponding threshold level + + Example: + >>> from torchmetrics.classification import BinaryRecallAtFixedPrecision + >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> metric = BinaryRecallAtFixedPrecision(min_precision=0.5, thresholds=None) + >>> metric(preds, target) + (tensor(1.), tensor(0.5000)) + >>> metric = BinaryRecallAtFixedPrecision(min_precision=0.5, thresholds=5) + >>> metric(preds, target) + (tensor(1.), tensor(0.5000)) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + min_precision: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(thresholds, ignore_index, validate_args=False, **kwargs) + if validate_args: + _binary_recall_at_fixed_precision_arg_validation(min_precision, thresholds, ignore_index) + self.validate_args = validate_args + self.min_precision = min_precision + + def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _binary_recall_at_fixed_precision_compute(state, self.thresholds, self.min_precision) + + +class MulticlassRecallAtFixedPrecision(MulticlassPrecisionRecallCurve): + r""" + Computes the higest possible recall value given the minimum precision thresholds provided. This is done by + first calculating the precision-recall curve for different thresholds and the find the recall for a given + precision level. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Args: + num_classes: Integer specifing the number of classes + min_precision: float value specifying minimum precision threshold. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + (tuple): a tuple of either 2 tensors or 2 lists containing + + - recall: an 1d tensor of size (n_classes, ) with the maximum recall for the given precision level per class + - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class + + Example: + >>> from torchmetrics.classification import MulticlassRecallAtFixedPrecision + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> metric = MulticlassRecallAtFixedPrecision(num_classes=5, min_precision=0.5, thresholds=None) + >>> metric(preds, target) + (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06])) + >>> metric = MulticlassRecallAtFixedPrecision(num_classes=5, min_precision=0.5, thresholds=5) + >>> metric(preds, target) + (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06])) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + min_precision: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_classes=num_classes, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs + ) + if validate_args: + _multiclass_recall_at_fixed_precision_arg_validation(num_classes, min_precision, thresholds, ignore_index) + self.validate_args = validate_args + self.min_precision = min_precision + + def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _multiclass_recall_at_fixed_precision_arg_compute( + state, self.num_classes, self.thresholds, self.min_precision + ) + + +class MultilabelRecallAtFixedPrecision(MultilabelPrecisionRecallCurve): + r""" + Computes the higest possible recall value given the minimum precision thresholds provided. This is done by + first calculating the precision-recall curve for different thresholds and the find the recall for a given + precision level. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Args: + num_labels: Integer specifing the number of labels + min_precision: float value specifying minimum precision threshold. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + (tuple): a tuple of either 2 tensors or 2 lists containing + + - recall: an 1d tensor of size (n_classes, ) with the maximum recall for the given precision level per class + - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class + + Example: + >>> from torchmetrics.classification import MultilabelRecallAtFixedPrecision + >>> preds = torch.tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = torch.tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> metric = MultilabelRecallAtFixedPrecision(num_labels=3, min_precision=0.5, thresholds=None) + >>> metric(preds, target) + (tensor([1., 1., 1.]), tensor([0.0500, 0.5500, 0.0500])) + >>> metric = MultilabelRecallAtFixedPrecision(num_labels=3, min_precision=0.5, thresholds=5) + >>> metric(preds, target) + (tensor([1., 1., 1.]), tensor([0.0000, 0.5000, 0.0000])) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + min_precision: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_labels=num_labels, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs + ) + if validate_args: + _multilabel_recall_at_fixed_precision_arg_validation(num_labels, min_precision, thresholds, ignore_index) + self.validate_args = validate_args + self.min_precision = min_precision + + def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _multilabel_recall_at_fixed_precision_arg_compute( + state, self.num_labels, self.thresholds, self.ignore_index, self.min_precision + ) diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index 7682dd758ac..1bb1a3f7dfb 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -15,15 +15,309 @@ import torch from torch import Tensor +from typing_extensions import Literal -from torchmetrics.functional.classification.roc import _roc_compute, _roc_update +from torchmetrics.classification.precision_recall_curve import ( + BinaryPrecisionRecallCurve, + MulticlassPrecisionRecallCurve, + MultilabelPrecisionRecallCurve, +) +from torchmetrics.functional.classification.roc import ( + _binary_roc_compute, + _multiclass_roc_compute, + _multilabel_roc_compute, + _roc_compute, + _roc_update, +) from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat +class BinaryROC(BinaryPrecisionRecallCurve): + r""" + Computes the Receiver Operating Characteristic (ROC) for binary tasks. The curve consist of multiple + pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, + such that the tradeoff between the two values can be seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Note that outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which + are sorted in reversed order during their calculation, such that they are monotome increasing. + + Args: + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + (tuple): a tuple of 3 tensors containing: + + - fpr: an 1d tensor of size (n_thresholds+1, ) with false positive rate values + - tpr: an 1d tensor of size (n_thresholds+1, ) with true positive rate values + - thresholds: an 1d tensor of size (n_thresholds, ) with decreasing threshold values + + Example: + >>> from torchmetrics.classification import BinaryROC + >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> metric = BinaryROC(thresholds=None) + >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE + (tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]), + tensor([0.0000, 0.0000, 0.5000, 1.0000, 1.0000]), + tensor([1.0000, 0.8000, 0.7000, 0.5000, 0.0000])) + >>> metric = BinaryROC(thresholds=5) + >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE + (tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]), + tensor([0., 0., 1., 1., 1.]), + tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000])) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def compute(self) -> Tuple[Tensor, Tensor, Tensor]: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _binary_roc_compute(state, self.thresholds) + + +class MulticlassROC(MulticlassPrecisionRecallCurve): + r""" + Computes the Receiver Operating Characteristic (ROC) for binary tasks. The curve consist of multiple + pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, + such that the tradeoff between the two values can be seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Note that outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which + are sorted in reversed order during their calculation, such that they are monotome increasing. + + Args: + num_classes: Integer specifing the number of classes + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + (tuple): a tuple of either 3 tensors or 3 lists containing + + - fpr: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) + with false positive rate values (length may differ between classes). If `thresholds` is set to something else, + then a single 2d tensor of size (n_classes, n_thresholds+1) with false positive rate values is returned. + - tpr: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) + with true positive rate values (length may differ between classes). If `thresholds` is set to something else, + then a single 2d tensor of size (n_classes, n_thresholds+1) with true positive rate values is returned. + - thresholds: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds, ) + with decreasing threshold values (length may differ between classes). If `threshold` is set to something else, + then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all classes. + + Example: + >>> from torchmetrics.classification import MulticlassROC + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> metric = MulticlassROC(num_classes=5, thresholds=None) + >>> fpr, tpr, thresholds = metric(preds, target) + >>> fpr # doctest: +NORMALIZE_WHITESPACE + [tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), + tensor([0.0000, 0.3333, 1.0000]), tensor([0., 1.])] + >>> tpr + [tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0., 0.])] + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + [tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500]), + tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.0500])] + >>> metric = MulticlassROC(num_classes=5, thresholds=5) + >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE + (tensor([[0.0000, 0.0000, 0.0000, 0.0000, 1.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 1.0000], + [0.0000, 0.3333, 0.3333, 0.3333, 1.0000], + [0.0000, 0.3333, 0.3333, 0.3333, 1.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]), + tensor([[0., 1., 1., 1., 1.], + [0., 1., 1., 1., 1.], + [0., 0., 0., 0., 1.], + [0., 0., 0., 0., 1.], + [0., 0., 0., 0., 0.]]), + tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000])) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def compute(self) -> Tuple[Tensor, Tensor, Tensor]: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _multiclass_roc_compute(state, self.num_classes, self.thresholds) + + +class MultilabelROC(MultilabelPrecisionRecallCurve): + r""" + Computes the Receiver Operating Characteristic (ROC) for binary tasks. The curve consist of multiple + pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, + such that the tradeoff between the two values can be seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Note that outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which + are sorted in reversed order during their calculation, such that they are monotome increasing. + + Args: + num_labels: Integer specifing the number of labels + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + (tuple): a tuple of either 3 tensors or 3 lists containing + + - fpr: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) + with false positive rate values (length may differ between labels). If `thresholds` is set to something else, + then a single 2d tensor of size (n_labels, n_thresholds+1) with false positive rate values is returned. + - tpr: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) + with true positive rate values (length may differ between labels). If `thresholds` is set to something else, + then a single 2d tensor of size (n_labels, n_thresholds+1) with true positive rate values is returned. + - thresholds: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds, ) + with decreasing threshold values (length may differ between labels). If `threshold` is set to something else, + then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all labels. + + Example: + >>> from torchmetrics.classification import MultilabelROC + >>> preds = torch.tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = torch.tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> metric = MultilabelROC(num_labels=3, thresholds=None) + >>> fpr, tpr, thresholds = metric(preds, target) + >>> fpr # doctest: +NORMALIZE_WHITESPACE + [tensor([0.0000, 0.0000, 0.5000, 1.0000]), + tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]), + tensor([0., 0., 0., 1.])] + >>> tpr # doctest: +NORMALIZE_WHITESPACE + [tensor([0.0000, 0.5000, 0.5000, 1.0000]), + tensor([0.0000, 0.0000, 0.5000, 1.0000, 1.0000]), + tensor([0.0000, 0.3333, 0.6667, 1.0000])] + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + [tensor([1.0000, 0.7500, 0.4500, 0.0500]), + tensor([1.0000, 0.7500, 0.6500, 0.5500, 0.0500]), + tensor([1.0000, 0.7500, 0.3500, 0.0500])] + >>> metric = MultilabelROC(num_labels=3, thresholds=5) + >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE + (tensor([[0.0000, 0.0000, 0.0000, 0.5000, 1.0000], + [0.0000, 0.5000, 0.5000, 0.5000, 1.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]), + tensor([[0.0000, 0.5000, 0.5000, 0.5000, 1.0000], + [0.0000, 0.0000, 1.0000, 1.0000, 1.0000], + [0.0000, 0.3333, 0.3333, 0.6667, 1.0000]]), + tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000])) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def compute(self) -> Tuple[Tensor, Tensor, Tensor]: + if self.thresholds is None: + state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] + else: + state = self.confmat + return _multilabel_roc_compute(state, self.num_labels, self.thresholds, self.ignore_index) + + class ROC(Metric): - """Computes the Receiver Operating Characteristic (ROC). Works for both binary, multiclass and multilabel + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes the Receiver Operating Characteristic (ROC). Works for both binary, multiclass and multilabel problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach. Forward accepts @@ -106,6 +400,41 @@ class ROC(Metric): preds: List[Tensor] target: List[Tensor] + def __new__( + cls, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + if task is not None: + kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryROC(**kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassROC(num_classes, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelROC(num_labels, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + def __init__( self, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index 56057a9cfbd..5f66c9122e7 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -15,14 +15,295 @@ import torch from torch import Tensor +from typing_extensions import Literal -from torchmetrics.classification.stat_scores import StatScores -from torchmetrics.functional.classification.specificity import _specificity_compute +from torchmetrics.classification.stat_scores import ( + BinaryStatScores, + MulticlassStatScores, + MultilabelStatScores, + StatScores, +) +from torchmetrics.functional.classification.specificity import _specificity_compute, _specificity_reduce +from torchmetrics.metric import Metric from torchmetrics.utilities.enums import AverageMethod +from torchmetrics.utilities.prints import rank_zero_warn + + +class BinarySpecificity(BinaryStatScores): + r"""Computes `Specificity`_ for binary tasks: + + .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and + false positives respecitively. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` + is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + + Example (preds is int tensor): + >>> from torchmetrics.classification import BinarySpecificity + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> metric = BinarySpecificity() + >>> metric(preds, target) + tensor(0.6667) + + Example (preds is float tensor): + >>> from torchmetrics.classification import BinarySpecificity + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> metric = BinarySpecificity() + >>> metric(preds, target) + tensor(0.6667) + + Example (multidim tensors): + >>> from torchmetrics.classification import BinarySpecificity + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> metric = BinarySpecificity(multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.0000, 0.3333]) + """ + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _specificity_reduce(tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average) + + +class MulticlassSpecificity(MulticlassStatScores): + r"""Computes `Specificity`_ for multiclass tasks: + + .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and + false positives respecitively. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.classification import MulticlassSpecificity + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> metric = MulticlassSpecificity(num_classes=3) + >>> metric(preds, target) + tensor(0.8889) + >>> metric = MulticlassSpecificity(num_classes=3, average=None) + >>> metric(preds, target) + tensor([1.0000, 0.6667, 1.0000]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MulticlassSpecificity + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> metric = MulticlassSpecificity(num_classes=3) + >>> metric(preds, target) + tensor(0.8889) + >>> metric = MulticlassSpecificity(num_classes=3, average=None) + >>> metric(preds, target) + tensor([1.0000, 0.6667, 1.0000]) + + Example (multidim tensors): + >>> from torchmetrics.classification import MulticlassSpecificity + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> metric = MulticlassSpecificity(num_classes=3, multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.7500, 0.6556]) + >>> metric = MulticlassSpecificity(num_classes=3, multidim_average='samplewise', average=None) + >>> metric(preds, target) + tensor([[0.7500, 0.7500, 0.7500], + [0.8000, 0.6667, 0.5000]]) + """ + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _specificity_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average) + + +class MultilabelSpecificity(MultilabelStatScores): + r"""Computes `Specificity`_ for multilabel tasks + + .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and + false positives respecitively. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + - If ``multidim_average`` is set to ``global`` + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + - If ``multidim_average`` is set to ``samplewise`` + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.classification import MultilabelSpecificity + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> metric = MultilabelSpecificity(num_labels=3) + >>> metric(preds, target) + tensor(0.6667) + >>> metric = MultilabelSpecificity(num_labels=3, average=None) + >>> metric(preds, target) + tensor([1., 1., 0.]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MultilabelSpecificity + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> metric = MultilabelSpecificity(num_labels=3) + >>> metric(preds, target) + tensor(0.6667) + >>> metric = MultilabelSpecificity(num_labels=3, average=None) + >>> metric(preds, target) + tensor([1., 1., 0.]) + + Example (multidim tensors): + >>> from torchmetrics.classification import MultilabelSpecificity + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> metric = MultilabelSpecificity(num_labels=3, multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.0000, 0.3333]) + >>> metric = MultilabelSpecificity(num_labels=3, multidim_average='samplewise', average=None) + >>> metric(preds, target) + tensor([[0., 0., 0.], + [0., 0., 1.]]) + """ + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _specificity_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average) class Specificity(StatScores): - r"""Computes `Specificity`_: + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes `Specificity`_: .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}} @@ -115,11 +396,54 @@ class Specificity(StatScores): higher_is_better: bool = True full_state_update: bool = False + def __new__( + cls, + num_classes: Optional[int] = None, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + top_k: Optional[int] = None, + multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + if task is not None: + assert multidim_average is not None + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinarySpecificity(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassSpecificity(num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelSpecificity(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + def __init__( self, num_classes: Optional[int] = None, threshold: float = 0.5, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index eca2150d63b..fe7b2eacaa4 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -11,18 +11,495 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Optional, Tuple, Union import torch from torch import Tensor - -from torchmetrics.functional.classification.stat_scores import _stat_scores_compute, _stat_scores_update +from typing_extensions import Literal + +from torchmetrics.functional.classification.stat_scores import ( + _binary_stat_scores_arg_validation, + _binary_stat_scores_compute, + _binary_stat_scores_format, + _binary_stat_scores_tensor_validation, + _binary_stat_scores_update, + _multiclass_stat_scores_arg_validation, + _multiclass_stat_scores_compute, + _multiclass_stat_scores_format, + _multiclass_stat_scores_tensor_validation, + _multiclass_stat_scores_update, + _multilabel_stat_scores_arg_validation, + _multilabel_stat_scores_compute, + _multilabel_stat_scores_format, + _multilabel_stat_scores_tensor_validation, + _multilabel_stat_scores_update, + _stat_scores_compute, + _stat_scores_update, +) from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod +from torchmetrics.utilities.prints import rank_zero_warn + + +class _AbstractStatScores(Metric): + # define common functions + def _create_state(self, size: int, multidim_average: str) -> None: + """Initialize the states for the different statistics.""" + default: Union[Callable[[], list], Callable[[], Tensor]] + if multidim_average == "samplewise": + default = lambda: [] + dist_reduce_fx = "cat" + else: + default = lambda: torch.zeros(size, dtype=torch.long) + dist_reduce_fx = "sum" + self.add_state("tp", default(), dist_reduce_fx=dist_reduce_fx) + self.add_state("fp", default(), dist_reduce_fx=dist_reduce_fx) + self.add_state("tn", default(), dist_reduce_fx=dist_reduce_fx) + self.add_state("fn", default(), dist_reduce_fx=dist_reduce_fx) + + def _update_state(self, tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor) -> None: + """Update states depending on multidim_average argument.""" + if self.multidim_average == "samplewise": + self.tp.append(tp) + self.fp.append(fp) + self.tn.append(tn) + self.fn.append(fn) + else: + self.tp += tp + self.fp += fp + self.tn += tn + self.fn += fn + + def _final_state(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Final aggregation in case of list states.""" + tp = dim_zero_cat(self.tp) + fp = dim_zero_cat(self.fp) + tn = dim_zero_cat(self.tn) + fn = dim_zero_cat(self.fn) + return tp, fp, tn, fn + + +class BinaryStatScores(_AbstractStatScores): + r""" + Computes the number of true positives, false positives, true negatives, false negatives and the support + for binary tasks. Related to `Type I and Type II errors`_. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics.classification import BinaryStatScores + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> metric = BinaryStatScores() + >>> metric(preds, target) + tensor([2, 1, 2, 1, 3]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import BinaryStatScores + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> metric = BinaryStatScores() + >>> metric(preds, target) + tensor([2, 1, 2, 1, 3]) + + Example (multidim tensors): + >>> from torchmetrics.classification import BinaryStatScores + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> metric = BinaryStatScores(multidim_average='samplewise') + >>> metric(preds, target) + tensor([[2, 3, 0, 1, 3], + [0, 2, 1, 3, 3]]) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super(_AbstractStatScores, self).__init__(**kwargs) + if validate_args: + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) + self.threshold = threshold + self.multidim_average = multidim_average + self.ignore_index = ignore_index + self.validate_args = validate_args + + self._create_state(1, multidim_average) + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + """ + if self.validate_args: + _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index) + preds, target = _binary_stat_scores_format(preds, target, self.threshold, self.ignore_index) + tp, fp, tn, fn = _binary_stat_scores_update(preds, target, self.multidim_average) + self._update_state(tp, fp, tn, fn) + + def compute(self) -> Tensor: + """Computes the final statistics. + + Returns: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape + depends on the ``multidim_average`` parameter: + + - If ``multidim_average`` is set to ``global``, the shape will be ``(5,)`` + - If ``multidim_average`` is set to ``samplewise``, the shape will be ``(N, 5)`` + """ + tp, fp, tn, fn = self._final_state() + return _binary_stat_scores_compute(tp, fp, tn, fn, self.multidim_average) + + +class MulticlassStatScores(_AbstractStatScores): + r""" + Computes the number of true positives, false positives, true negatives, false negatives and the support + for multiclass tasks. Related to `Type I and Type II errors`_. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics.classification import MulticlassStatScores + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> metric = MulticlassStatScores(num_classes=3, average='micro') + >>> metric(preds, target) + tensor([3, 1, 7, 1, 4]) + >>> metric = MulticlassStatScores(num_classes=3, average=None) + >>> metric(preds, target) + tensor([[1, 0, 2, 1, 2], + [1, 1, 2, 0, 1], + [1, 0, 3, 0, 1]]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MulticlassStatScores + >>> target = target = torch.tensor([2, 1, 0, 0]) + >>> preds = preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> metric = MulticlassStatScores(num_classes=3, average='micro') + >>> metric(preds, target) + tensor([3, 1, 7, 1, 4]) + >>> metric = MulticlassStatScores(num_classes=3, average=None) + >>> metric(preds, target) + tensor([[1, 0, 2, 1, 2], + [1, 1, 2, 0, 1], + [1, 0, 3, 0, 1]]) + + Example (multidim tensors): + >>> from torchmetrics.classification import MulticlassStatScores + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> metric = MulticlassStatScores(num_classes=3, multidim_average="samplewise", average='micro') + >>> metric(preds, target) + tensor([[3, 3, 9, 3, 6], + [2, 4, 8, 4, 6]]) + >>> metric = MulticlassStatScores(num_classes=3, multidim_average="samplewise", average=None) + >>> metric(preds, target) + tensor([[[2, 1, 3, 0, 2], + [0, 1, 3, 2, 2], + [1, 1, 3, 1, 2]], + [[0, 1, 4, 1, 1], + [1, 1, 2, 2, 3], + [1, 2, 2, 1, 2]]]) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + top_k: int = 1, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super(_AbstractStatScores, self).__init__(**kwargs) + if validate_args: + _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + self.num_classes = num_classes + self.top_k = top_k + self.average = average + self.multidim_average = multidim_average + self.ignore_index = ignore_index + self.validate_args = validate_args + + self._create_state(num_classes, multidim_average) + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + """ + if self.validate_args: + _multiclass_stat_scores_tensor_validation( + preds, target, self.num_classes, self.multidim_average, self.ignore_index + ) + preds, target = _multiclass_stat_scores_format(preds, target, self.top_k) + tp, fp, tn, fn = _multiclass_stat_scores_update( + preds, target, self.num_classes, self.top_k, self.multidim_average, self.ignore_index + ) + self._update_state(tp, fp, tn, fn) + + def compute(self) -> Tensor: + """Computes the final statistics. + + Returns: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape + depends on ``average`` and ``multidim_average`` parameters: + + - If ``multidim_average`` is set to ``global`` + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)`` + - If ``average=None/'none'``, the shape will be ``(C, 5)`` + - If ``multidim_average`` is set to ``samplewise`` + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)`` + - If ``average=None/'none'``, the shape will be ``(N, C, 5)`` + """ + tp, fp, tn, fn = self._final_state() + return _multiclass_stat_scores_compute(tp, fp, tn, fn, self.average, self.multidim_average) + + +class MultilabelStatScores(_AbstractStatScores): + r""" + Computes the number of true positives, false positives, true negatives, false negatives and the support + for multilabel tasks. Related to `Type I and Type II errors`_. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics.classification import MultilabelStatScores + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> metric = MultilabelStatScores(num_labels=3, average='micro') + >>> metric(preds, target) + tensor([2, 1, 2, 1, 3]) + >>> metric = MultilabelStatScores(num_labels=3, average=None) + >>> metric(preds, target) + tensor([[1, 0, 1, 0, 1], + [0, 0, 1, 1, 1], + [1, 1, 0, 0, 1]]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MultilabelStatScores + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> metric = MultilabelStatScores(num_labels=3, average='micro') + >>> metric(preds, target) + tensor([2, 1, 2, 1, 3]) + >>> metric = MultilabelStatScores(num_labels=3, average=None) + >>> metric(preds, target) + tensor([[1, 0, 1, 0, 1], + [0, 0, 1, 1, 1], + [1, 1, 0, 0, 1]]) + + Example (multidim tensors): + >>> from torchmetrics.classification import MultilabelStatScores + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> metric = MultilabelStatScores(num_labels=3, multidim_average='samplewise', average='micro') + >>> metric(preds, target) + tensor([[2, 3, 0, 1, 3], + [0, 2, 1, 3, 3]]) + >>> metric = MultilabelStatScores(num_labels=3, multidim_average='samplewise', average=None) + >>> metric(preds, target) + tensor([[[1, 1, 0, 0, 1], + [1, 1, 0, 0, 1], + [0, 1, 0, 1, 1]], + [[0, 0, 0, 2, 2], + [0, 2, 0, 0, 0], + [0, 0, 1, 1, 1]]]) + + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super(_AbstractStatScores, self).__init__(**kwargs) + if validate_args: + _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) + self.num_labels = num_labels + self.threshold = threshold + self.average = average + self.multidim_average = multidim_average + self.ignore_index = ignore_index + self.validate_args = validate_args + + self._create_state(num_labels, multidim_average) + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + """ + if self.validate_args: + _multilabel_stat_scores_tensor_validation( + preds, target, self.num_labels, self.multidim_average, self.ignore_index + ) + preds, target = _multilabel_stat_scores_format( + preds, target, self.num_labels, self.threshold, self.ignore_index + ) + tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, self.multidim_average) + self._update_state(tp, fp, tn, fn) + + def compute(self) -> Tensor: + """Computes the final statistics. + + Returns: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape + depends on ``average`` and ``multidim_average`` parameters: + + - If ``multidim_average`` is set to ``global`` + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)`` + - If ``average=None/'none'``, the shape will be ``(C, 5)`` + - If ``multidim_average`` is set to ``samplewise`` + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)`` + - If ``average=None/'none'``, the shape will be ``(N, C, 5)`` + """ + tp, fp, tn, fn = self._final_state() + return _multilabel_stat_scores_compute(tp, fp, tn, fn, self.average, self.multidim_average) class StatScores(Metric): - r"""Computes the number of true positives, false positives, true negatives, false negatives. + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes the number of true positives, false positives, true negatives, false negatives. Related to `Type I and Type II errors`_ and the `confusion matrix`_. The reduction method (how the statistics are aggregated) is controlled by the @@ -119,6 +596,49 @@ class StatScores(Metric): # tn: Union[Tensor, List[Tensor]] # fn: Union[Tensor, List[Tensor]] + def __new__( + cls, + num_classes: Optional[int] = None, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + top_k: Optional[int] = None, + multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + if task is not None: + assert multidim_average is not None + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinaryStatScores(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassStatScores(num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelStatScores(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + return super().__new__(cls) + def __init__( self, threshold: float = 0.5, @@ -128,10 +648,14 @@ def __init__( ignore_index: Optional[int] = None, mdmc_reduce: Optional[str] = None, multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, **kwargs: Any, ) -> None: super().__init__(**kwargs) - self.reduce = reduce self.mdmc_reduce = mdmc_reduce self.num_classes = num_classes @@ -176,7 +700,6 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore preds: Predictions from model (probabilities, logits or labels) target: Ground truth values """ - tp, fp, tn, fn = _stat_scores_update( preds, target, diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index f66d55ba459..8e71ed4de63 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -26,7 +26,6 @@ from torchmetrics.functional.classification.hamming import hamming_distance from torchmetrics.functional.classification.hinge import hinge_loss from torchmetrics.functional.classification.jaccard import jaccard_index -from torchmetrics.functional.classification.kl_divergence import kl_divergence from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve @@ -54,6 +53,7 @@ from torchmetrics.functional.pairwise.manhattan import pairwise_manhattan_distance from torchmetrics.functional.regression.cosine_similarity import cosine_similarity from torchmetrics.functional.regression.explained_variance import explained_variance +from torchmetrics.functional.regression.kl_divergence import kl_divergence from torchmetrics.functional.regression.log_mse import mean_squared_log_error from torchmetrics.functional.regression.mae import mean_absolute_error from torchmetrics.functional.regression.mape import mean_absolute_percentage_error diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index 70f777b56e0..82932c0d6e3 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -11,27 +11,116 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.functional.classification.accuracy import accuracy # noqa: F401 +from torchmetrics.functional.classification.accuracy import ( # noqa: F401 + accuracy, + binary_accuracy, + multiclass_accuracy, + multilabel_accuracy, +) from torchmetrics.functional.classification.auc import auc # noqa: F401 -from torchmetrics.functional.classification.auroc import auroc # noqa: F401 -from torchmetrics.functional.classification.average_precision import average_precision # noqa: F401 -from torchmetrics.functional.classification.calibration_error import calibration_error # noqa: F401 -from torchmetrics.functional.classification.cohen_kappa import cohen_kappa # noqa: F401 -from torchmetrics.functional.classification.confusion_matrix import confusion_matrix # noqa: F401 +from torchmetrics.functional.classification.auroc import ( # noqa: F401 + auroc, + binary_auroc, + multiclass_auroc, + multilabel_auroc, +) +from torchmetrics.functional.classification.average_precision import ( # noqa: F401 + average_precision, + binary_average_precision, + multiclass_average_precision, + multilabel_average_precision, +) +from torchmetrics.functional.classification.calibration_error import ( # noqa: F401 + binary_calibration_error, + calibration_error, + multiclass_calibration_error, +) +from torchmetrics.functional.classification.cohen_kappa import ( # noqa: F401 + binary_cohen_kappa, + cohen_kappa, + multiclass_cohen_kappa, +) +from torchmetrics.functional.classification.confusion_matrix import ( # noqa: F401 + binary_confusion_matrix, + confusion_matrix, + multiclass_confusion_matrix, + multilabel_confusion_matrix, +) from torchmetrics.functional.classification.dice import dice, dice_score # noqa: F401 -from torchmetrics.functional.classification.f_beta import f1_score, fbeta_score # noqa: F401 -from torchmetrics.functional.classification.hamming import hamming_distance # noqa: F401 -from torchmetrics.functional.classification.hinge import hinge_loss # noqa: F401 -from torchmetrics.functional.classification.jaccard import jaccard_index # noqa: F401 -from torchmetrics.functional.classification.kl_divergence import kl_divergence # noqa: F401 -from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef # noqa: F401 -from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall # noqa: F401 -from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve # noqa: F401 +from torchmetrics.functional.classification.exact_match import multilabel_exact_match # noqa: F401 +from torchmetrics.functional.classification.f_beta import ( # noqa: F401 + binary_f1_score, + binary_fbeta_score, + f1_score, + fbeta_score, + multiclass_f1_score, + multiclass_fbeta_score, + multilabel_f1_score, + multilabel_fbeta_score, +) +from torchmetrics.functional.classification.hamming import ( # noqa: F401 + binary_hamming_distance, + hamming_distance, + multiclass_hamming_distance, + multilabel_hamming_distance, +) +from torchmetrics.functional.classification.hinge import ( # noqa: F401 + binary_hinge_loss, + hinge_loss, + multiclass_hinge_loss, +) +from torchmetrics.functional.classification.jaccard import ( # noqa: F401 + binary_jaccard_index, + jaccard_index, + multiclass_jaccard_index, + multilabel_jaccard_index, +) +from torchmetrics.functional.classification.matthews_corrcoef import ( # noqa: F401 + binary_matthews_corrcoef, + matthews_corrcoef, + multiclass_matthews_corrcoef, + multilabel_matthews_corrcoef, +) +from torchmetrics.functional.classification.precision_recall import ( # noqa: F401 + binary_precision, + binary_recall, + multiclass_precision, + multiclass_recall, + multilabel_precision, + multilabel_recall, + precision, + precision_recall, + recall, +) +from torchmetrics.functional.classification.precision_recall_curve import ( # noqa: F401 + binary_precision_recall_curve, + multiclass_precision_recall_curve, + multilabel_precision_recall_curve, + precision_recall_curve, +) from torchmetrics.functional.classification.ranking import ( # noqa: F401 coverage_error, label_ranking_average_precision, label_ranking_loss, + multilabel_coverage_error, + multilabel_ranking_average_precision, + multilabel_ranking_loss, +) +from torchmetrics.functional.classification.recall_at_fixed_precision import ( # noqa: F401 + binary_recall_at_fixed_precision, + multiclass_recall_at_fixed_precision, + multilabel_recall_at_fixed_precision, +) +from torchmetrics.functional.classification.roc import binary_roc, multiclass_roc, multilabel_roc, roc # noqa: F401 +from torchmetrics.functional.classification.specificity import ( # noqa: F401 + binary_specificity, + multiclass_specificity, + multilabel_specificity, + specificity, +) +from torchmetrics.functional.classification.stat_scores import ( # noqa: F401 + binary_stat_scores, + multiclass_stat_scores, + multilabel_stat_scores, + stat_scores, ) -from torchmetrics.functional.classification.roc import roc # noqa: F401 -from torchmetrics.functional.classification.specificity import specificity # noqa: F401 -from torchmetrics.functional.classification.stat_scores import stat_scores # noqa: F401 diff --git a/src/torchmetrics/functional/classification/accuracy.py b/src/torchmetrics/functional/classification/accuracy.py index 888f044f96d..290eddad55e 100644 --- a/src/torchmetrics/functional/classification/accuracy.py +++ b/src/torchmetrics/functional/classification/accuracy.py @@ -15,10 +15,374 @@ import torch from torch import Tensor, tensor - -from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores, _stat_scores_update +from typing_extensions import Literal + +from torchmetrics.functional.classification.stat_scores import ( + _binary_stat_scores_arg_validation, + _binary_stat_scores_format, + _binary_stat_scores_tensor_validation, + _binary_stat_scores_update, + _multiclass_stat_scores_arg_validation, + _multiclass_stat_scores_format, + _multiclass_stat_scores_tensor_validation, + _multiclass_stat_scores_update, + _multilabel_stat_scores_arg_validation, + _multilabel_stat_scores_format, + _multilabel_stat_scores_tensor_validation, + _multilabel_stat_scores_update, + _reduce_stat_scores, + _stat_scores_update, +) from torchmetrics.utilities.checks import _check_classification_inputs, _input_format_classification, _input_squeeze +from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.enums import AverageMethod, DataType, MDMCAverageMethod +from torchmetrics.utilities.prints import rank_zero_warn + + +def _accuracy_reduce( + tp: Tensor, + fp: Tensor, + tn: Tensor, + fn: Tensor, + average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]], + multidim_average: Literal["global", "samplewise"] = "global", + multilabel: bool = False, +) -> Tensor: + """Reduce classification statistics into accuracy score + Args: + tp: number of true positives + fp: number of false positives + tn: number of true negatives + fn: number of false negatives + normalize: normalization method. + - `"true"` will divide by the sum of the column dimension. + - `"pred"` will divide by the sum of the row dimension. + - `"all"` will divide by the sum of the full matrix + - `"none"` or `None` will apply no reduction + multilabel: bool indicating if reduction is for multilabel tasks + + Returns: + Accuracy score + """ + if average == "binary": + return _safe_divide(tp + tn, tp + tn + fp + fn) + elif average == "micro": + tp = tp.sum(dim=0 if multidim_average == "global" else 1) + fn = fn.sum(dim=0 if multidim_average == "global" else 1) + if multilabel: + fp = fp.sum(dim=0 if multidim_average == "global" else 1) + tn = tn.sum(dim=0 if multidim_average == "global" else 1) + return _safe_divide(tp + tn, tp + tn + fp + fn) + return _safe_divide(tp, tp + fn) + else: + if multilabel: + score = _safe_divide(tp + tn, tp + tn + fp + fn) + else: + score = _safe_divide(tp, tp + fn) + if average is None or average == "none": + return score + if average == "weighted": + weights = tp + fn + else: + weights = torch.ones_like(score) + return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1) + + +def binary_accuracy( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes `Accuracy`_ for binary tasks: + + .. math:: + \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) + + Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a + tensor of predictions. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` + is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import binary_accuracy + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> binary_accuracy(preds, target) + tensor(0.6667) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import binary_accuracy + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> binary_accuracy(preds, target) + tensor(0.6667) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import binary_accuracy + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> binary_accuracy(preds, target, multidim_average='samplewise') + tensor([0.3333, 0.1667]) + """ + if validate_args: + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) + _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) + preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) + tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) + return _accuracy_reduce(tp, fp, tn, fn, average="binary", multidim_average=multidim_average) + + +def multiclass_accuracy( + preds: Tensor, + target: Tensor, + num_classes: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + top_k: int = 1, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes `Accuracy`_ for multiclass tasks: + + .. math:: + \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) + + Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a + tensor of predictions. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import multiclass_accuracy + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> multiclass_accuracy(preds, target, num_classes=3) + tensor(0.8333) + >>> multiclass_accuracy(preds, target, num_classes=3, average=None) + tensor([0.5000, 1.0000, 1.0000]) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import multiclass_accuracy + >>> target = target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> multiclass_accuracy(preds, target, num_classes=3) + tensor(0.8333) + >>> multiclass_accuracy(preds, target, num_classes=3, average=None) + tensor([0.5000, 1.0000, 1.0000]) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import multiclass_accuracy + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> multiclass_accuracy(preds, target, num_classes=3, multidim_average='samplewise') + tensor([0.5000, 0.2778]) + >>> multiclass_accuracy(preds, target, num_classes=3, multidim_average='samplewise', average=None) + tensor([[1.0000, 0.0000, 0.5000], + [0.0000, 0.3333, 0.5000]]) + """ + if validate_args: + _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) + preds, target = _multiclass_stat_scores_format(preds, target, top_k) + tp, fp, tn, fn = _multiclass_stat_scores_update(preds, target, num_classes, top_k, multidim_average, ignore_index) + return _accuracy_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average) + + +def multilabel_accuracy( + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes `Accuracy`_ for multilabel tasks: + + .. math:: + \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) + + Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a + tensor of predictions. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import multilabel_accuracy + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_accuracy(preds, target, num_labels=3) + tensor(0.6667) + >>> multilabel_accuracy(preds, target, num_labels=3, average=None) + tensor([1.0000, 0.5000, 0.5000]) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import multilabel_accuracy + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_accuracy(preds, target, num_labels=3) + tensor(0.6667) + >>> multilabel_accuracy(preds, target, num_labels=3, average=None) + tensor([1.0000, 0.5000, 0.5000]) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import multilabel_accuracy + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> multilabel_accuracy(preds, target, num_labels=3, multidim_average='samplewise') + tensor([0.3333, 0.1667]) + >>> multilabel_accuracy(preds, target, num_labels=3, multidim_average='samplewise', average=None) + tensor([[0.5000, 0.5000, 0.0000], + [0.0000, 0.0000, 0.5000]]) + + """ + if validate_args: + _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) + _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) + preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) + tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) + return _accuracy_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True) def _check_subset_validity(mode: DataType) -> bool: @@ -258,7 +622,7 @@ def _subset_accuracy_compute(correct: Tensor, total: Tensor) -> Tensor: def accuracy( preds: Tensor, target: Tensor, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = "global", threshold: float = 0.5, top_k: Optional[int] = None, @@ -266,8 +630,20 @@ def accuracy( num_classes: Optional[int] = None, multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, ) -> Tensor: - r"""Computes `Accuracy`_ + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes `Accuracy`_ .. math:: \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) @@ -390,6 +766,34 @@ def accuracy( >>> accuracy(preds, target, top_k=2) tensor(0.6667) """ + if task is not None: + assert multidim_average is not None + if task == "binary": + return binary_accuracy(preds, target, threshold, multidim_average, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_accuracy( + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + ) + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_accuracy( + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + ) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") diff --git a/src/torchmetrics/functional/classification/auc.py b/src/torchmetrics/functional/classification/auc.py index 7c439cdefe0..4d62697c9ac 100644 --- a/src/torchmetrics/functional/classification/auc.py +++ b/src/torchmetrics/functional/classification/auc.py @@ -11,99 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple - import torch from torch import Tensor - -def _auc_update(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: - """Updates and returns variables required to compute area under the curve. Checks if the 2 input tenser have - the same number of elements and if they are 1d. - - Args: - x: x-coordinates - y: y-coordinates - """ - - if x.ndim > 1: - x = x.squeeze() - - if y.ndim > 1: - y = y.squeeze() - - if x.ndim > 1 or y.ndim > 1: - raise ValueError( - f"Expected both `x` and `y` tensor to be 1d, but got tensors with dimension {x.ndim} and {y.ndim}" - ) - if x.numel() != y.numel(): - raise ValueError( - f"Expected the same number of elements in `x` and `y` tensor but received {x.numel()} and {y.numel()}" - ) - return x, y - - -def _auc_compute_without_check(x: Tensor, y: Tensor, direction: float) -> Tensor: - """Computes area under the curve using the trapezoidal rule. Assumes increasing or decreasing order of `x`. - - Args: - x: x-coordinates, must be either increasing or decreasing - y: y-coordinates - direction: 1 if increaing, -1 if decreasing - - Example: - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 2, 2]) - >>> x, y = _auc_update(x, y) - >>> _auc_compute_without_check(x, y, direction=1.0) - tensor(4.) - """ - - with torch.no_grad(): - auc_: Tensor = torch.trapz(y, x) * direction - return auc_ - - -def _auc_compute(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor: - """Computes area under the curve using the trapezoidal rule. Checks for increasing or decreasing order of `x`. - - Args: - x: x-coordinates, must be either increasing or decreasing - y: y-coordinates - reorder: if True, will reorder the arrays to make it either increasing or decreasing - - Example: - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 2, 2]) - >>> x, y = _auc_update(x, y) - >>> _auc_compute(x, y) - tensor(4.) - >>> _auc_compute(x, y, reorder=True) - tensor(4.) - """ - - with torch.no_grad(): - if reorder: - # TODO: include stable=True arg when pytorch v1.9 is released - x, x_idx = torch.sort(x) - y = y[x_idx] - - dx = x[1:] - x[:-1] - if (dx < 0).any(): - if (dx <= 0).all(): - direction = -1.0 - else: - raise ValueError( - "The `x` tensor is neither increasing or decreasing. Try setting the reorder argument to `True`." - ) - else: - direction = 1.0 - return _auc_compute_without_check(x, y, direction) +from torchmetrics.utilities.compute import auc as _auc +from torchmetrics.utilities.prints import rank_zero_warn def auc(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor: """Computes Area Under the Curve (AUC) using the trapezoidal rule. + .. note:: + This metric have been moved to `torchmetrics.utilities.compute` in v0.10 this version will be removed in v0.11. + Args: x: x-coordinates, must be either increasing or decreasing y: y-coordinates @@ -129,5 +49,9 @@ def auc(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor: >>> auc(x, y, reorder=True) tensor(4.) """ - x, y = _auc_update(x, y) - return _auc_compute(x, y, reorder=reorder) + rank_zero_warn( + "`torchmetrics.functional.auc` has been move to `torchmetrics.utilities.compute` in v0.10" + " and will be removed in v0.11.", + DeprecationWarning, + ) + return _auc(x, y, reorder=reorder) diff --git a/src/torchmetrics/functional/classification/auroc.py b/src/torchmetrics/functional/classification/auroc.py index ecc1c7eac72..2b51a04075e 100644 --- a/src/torchmetrics/functional/classification/auroc.py +++ b/src/torchmetrics/functional/classification/auroc.py @@ -12,17 +12,420 @@ # See the License for the specific language governing permissions and # limitations under the License. import warnings -from typing import Optional, Sequence, Tuple +from typing import List, Optional, Sequence, Tuple, Union import torch from torch import Tensor, tensor - -from torchmetrics.functional.classification.auc import _auc_compute_without_check -from torchmetrics.functional.classification.roc import roc +from typing_extensions import Literal + +from torchmetrics.functional.classification.precision_recall_curve import ( + _binary_precision_recall_curve_arg_validation, + _binary_precision_recall_curve_format, + _binary_precision_recall_curve_tensor_validation, + _binary_precision_recall_curve_update, + _multiclass_precision_recall_curve_arg_validation, + _multiclass_precision_recall_curve_format, + _multiclass_precision_recall_curve_tensor_validation, + _multiclass_precision_recall_curve_update, + _multilabel_precision_recall_curve_arg_validation, + _multilabel_precision_recall_curve_format, + _multilabel_precision_recall_curve_tensor_validation, + _multilabel_precision_recall_curve_update, +) +from torchmetrics.functional.classification.roc import ( + _binary_roc_compute, + _multiclass_roc_compute, + _multilabel_roc_compute, + roc, +) from torchmetrics.utilities.checks import _input_format_classification +from torchmetrics.utilities.compute import _auc_compute_without_check, _safe_divide from torchmetrics.utilities.data import _bincount from torchmetrics.utilities.enums import AverageMethod, DataType from torchmetrics.utilities.imports import _TORCH_LOWER_1_6 +from torchmetrics.utilities.prints import rank_zero_warn + + +def _reduce_auroc( + fpr: Union[Tensor, List[Tensor]], + tpr: Union[Tensor, List[Tensor]], + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + weights: Optional[Tensor] = None, +) -> Tensor: + """Utility function for reducing multiple average precision score into one number.""" + if isinstance(fpr, Tensor): + res = _auc_compute_without_check(fpr, tpr, 1.0, axis=1) + else: + res = [_auc_compute_without_check(x, y, 1.0) for x, y in zip(fpr, tpr)] + res = torch.stack(res) + if average is None or average == "none": + return res + if torch.isnan(res).any(): + rank_zero_warn( + f"Average precision score for one or more classes was `nan`. Ignoring these classes in {average}-average", + UserWarning, + ) + idx = ~torch.isnan(res) + if average == "macro": + return res[idx].mean() + elif average == "weighted" and weights is not None: + weights = _safe_divide(weights[idx], weights[idx].sum()) + return (res[idx] * weights).sum() + else: + raise ValueError("Received an incompatible combinations of inputs to make reduction.") + + +def _binary_auroc_arg_validation( + max_fpr: Optional[float] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) + if max_fpr is not None: + if not isinstance(max_fpr, float) and 0 < max_fpr <= 1: + raise ValueError(f"Arguments `max_fpr` should be a float in range (0, 1], but got: {max_fpr}") + if _TORCH_LOWER_1_6: + raise RuntimeError( + "`max_fpr` argument requires `torch.bucketize` which" " is not available below PyTorch version 1.6" + ) + + +def _binary_auroc_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + thresholds: Optional[Tensor], + max_fpr: Optional[float] = None, + pos_label: int = 1, +) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor]]: + fpr, tpr, _ = _binary_roc_compute(state, thresholds, pos_label) + if max_fpr is None or max_fpr == 1: + return _auc_compute_without_check(fpr, tpr, 1.0) + + _device = fpr.device if isinstance(fpr, Tensor) else fpr[0].device + max_area: Tensor = tensor(max_fpr, device=_device) + # Add a single point at max_fpr and interpolate its tpr value + stop = torch.bucketize(max_area, fpr, out_int32=True, right=True) + weight = (max_area - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1]) + interp_tpr: Tensor = torch.lerp(tpr[stop - 1], tpr[stop], weight) + tpr = torch.cat([tpr[:stop], interp_tpr.view(1)]) + fpr = torch.cat([fpr[:stop], max_area.view(1)]) + + # Compute partial AUC + partial_auc = _auc_compute_without_check(fpr, tpr, 1.0) + + # McClish correction: standardize result to be 0.5 if non-discriminant and 1 if maximal + min_area: Tensor = 0.5 * max_area**2 + return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area)) + + +def binary_auroc( + preds: Tensor, + target: Tensor, + max_fpr: Optional[float] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tuple[Tensor, Tensor, Tensor]: + r""" + Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for binary tasks. The AUROC score + summarizes the ROC curve into an single number that describes the performance of a model for multiple + thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 + corresponds to random guessing. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + max_fpr: If not ``None``, calculates standardized partial AUC over the range ``[0, max_fpr]``. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + A single scalar with the auroc score + + Example: + >>> from torchmetrics.functional.classification import binary_auroc + >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> binary_auroc(preds, target, thresholds=None) + tensor(0.5000) + >>> binary_auroc(preds, target, thresholds=5) + tensor(0.5000) + """ + if validate_args: + _binary_auroc_arg_validation(max_fpr, thresholds, ignore_index) + _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index) + preds, target, thresholds = _binary_precision_recall_curve_format(preds, target, thresholds, ignore_index) + state = _binary_precision_recall_curve_update(preds, target, thresholds) + return _binary_auroc_compute(state, thresholds, max_fpr) + + +def _multiclass_auroc_arg_validation( + num_classes: int, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) + allowed_average = ("macro", "weighted", "none", None) + if average not in allowed_average: + raise ValueError(f"Expected argument `average` to be one of {allowed_average} but got {average}") + + +def _multiclass_auroc_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_classes: int, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + thresholds: Optional[Tensor] = None, +) -> Tensor: + fpr, tpr, _ = _multiclass_roc_compute(state, num_classes, thresholds) + return _reduce_auroc( + fpr, + tpr, + average, + weights=_bincount(state[1], minlength=num_classes).float() if thresholds is None else state[0][:, 1, :].sum(-1), + ) + + +def multiclass_auroc( + preds: Tensor, + target: Tensor, + num_classes: int, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multiclass tasks. The AUROC score + summarizes the ROC curve into an single number that describes the performance of a model for multiple + thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 + corresponds to random guessing. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over classes. Should be one of the following: + + - ``macro``: Calculate score for each class and average them + - ``weighted``: Calculates score for each class and computes weighted average using their support + - ``"none"`` or ``None``: Calculates score for each class and applies no reduction + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with auroc score per class. + If `average="macro"|"weighted"` then a single scalar is returned. + + Example: + >>> from torchmetrics.functional.classification import multiclass_auroc + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> multiclass_auroc(preds, target, num_classes=5, average="macro", thresholds=None) + tensor(0.5333) + >>> multiclass_auroc(preds, target, num_classes=5, average=None, thresholds=None) + tensor([1.0000, 1.0000, 0.3333, 0.3333, 0.0000]) + >>> multiclass_auroc(preds, target, num_classes=5, average="macro", thresholds=5) + tensor(0.5333) + >>> multiclass_auroc(preds, target, num_classes=5, average=None, thresholds=5) + tensor([1.0000, 1.0000, 0.3333, 0.3333, 0.0000]) + + """ + if validate_args: + _multiclass_auroc_arg_validation(num_classes, average, thresholds, ignore_index) + _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index) + preds, target, thresholds = _multiclass_precision_recall_curve_format( + preds, target, num_classes, thresholds, ignore_index + ) + state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds) + return _multiclass_auroc_compute(state, num_classes, average, thresholds) + + +def _multilabel_auroc_arg_validation( + num_labels: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]], + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) + allowed_average = ("micro", "macro", "weighted", "none", None) + if average not in allowed_average: + raise ValueError(f"Expected argument `average` to be one of {allowed_average} but got {average}") + + +def _multilabel_auroc_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_labels: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]], + thresholds: Optional[Tensor], + ignore_index: Optional[int] = None, +) -> Union[Tuple[Tensor, Tensor, Tensor], Tensor]: + if average == "micro": + if isinstance(state, Tensor) and thresholds is not None: + return _binary_auroc_compute(state.sum(1), thresholds, max_fpr=None) + else: + preds = state[0].flatten() + target = state[1].flatten() + if ignore_index is not None: + idx = target == ignore_index + preds = preds[~idx] + target = target[~idx] + return _binary_auroc_compute((preds, target), thresholds, max_fpr=None) + + else: + fpr, tpr, _ = _multilabel_roc_compute(state, num_labels, thresholds, ignore_index) + return _reduce_auroc( + fpr, + tpr, + average, + weights=(state[1] == 1).sum(dim=0).float() if thresholds is None else state[0][:, 1, :].sum(-1), + ) + + +def multilabel_auroc( + preds: Tensor, + target: Tensor, + num_labels: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + r""" + Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multilabel tasks. The AUROC score + summarizes the ROC curve into an single number that describes the performance of a model for multiple + thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 + corresponds to random guessing. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum score over all labels + - ``macro``: Calculate score for each label and average them + - ``weighted``: Calculates score for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates score for each label and applies no reduction + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with auroc score per class. + If `average="micro|macro"|"weighted"` then a single scalar is returned. + + Example: + >>> from torchmetrics.functional.classification import multilabel_auroc + >>> preds = torch.tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = torch.tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> multilabel_auroc(preds, target, num_labels=3, average="macro", thresholds=None) + tensor(0.6528) + >>> multilabel_auroc(preds, target, num_labels=3, average=None, thresholds=None) + tensor([0.6250, 0.5000, 0.8333]) + >>> multilabel_auroc(preds, target, num_labels=3, average="macro", thresholds=5) + tensor(0.6528) + >>> multilabel_auroc(preds, target, num_labels=3, average=None, thresholds=5) + tensor([0.6250, 0.5000, 0.8333]) + """ + if validate_args: + _multilabel_auroc_arg_validation(num_labels, average, thresholds, ignore_index) + _multilabel_precision_recall_curve_tensor_validation(preds, target, num_labels, ignore_index) + preds, target, thresholds = _multilabel_precision_recall_curve_format( + preds, target, num_labels, thresholds, ignore_index + ) + state = _multilabel_precision_recall_curve_update(preds, target, num_labels, thresholds) + return _multilabel_auroc_compute(state, num_labels, average, thresholds, ignore_index) def _auroc_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, DataType]: @@ -198,11 +601,24 @@ def auroc( target: Tensor, num_classes: Optional[int] = None, pos_label: Optional[int] = None, - average: Optional[str] = "macro", + average: Optional[Literal["macro", "weighted", "none"]] = "macro", max_fpr: Optional[float] = None, sample_weights: Optional[Sequence] = None, -) -> Tensor: - """Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) For non-binary input, if the ``preds`` and ``target`` tensor have the same size the input will be interpretated as multilabel and if ``preds`` have one @@ -225,7 +641,6 @@ def auroc( range [0,num_classes-1] average: - - ``'micro'`` computes metric globally. Only works for multilabel problems - ``'macro'`` computes metric for each class and uniformly averages them - ``'weighted'`` computes metric for each class and does a weighted-average, where each class is weighted by their support (accounts for class imbalance) @@ -265,5 +680,27 @@ def auroc( >>> auroc(preds, target, num_classes=3) tensor(0.7778) """ + if task is not None: + if task == "binary": + return binary_auroc(preds, target, max_fpr, thresholds, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + return multiclass_auroc(preds, target, num_classes, average, thresholds, ignore_index, validate_args) + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_auroc(preds, target, num_labels, average, thresholds, ignore_index, validate_args) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + preds, target, mode = _auroc_update(preds, target) return _auroc_compute(preds, target, mode, num_classes, pos_label, average, max_fpr, sample_weights) diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index 26e47a0633b..00b1d05dc3e 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -16,12 +16,393 @@ import torch from torch import Tensor +from typing_extensions import Literal from torchmetrics.functional.classification.precision_recall_curve import ( + _binary_precision_recall_curve_arg_validation, + _binary_precision_recall_curve_compute, + _binary_precision_recall_curve_format, + _binary_precision_recall_curve_tensor_validation, + _binary_precision_recall_curve_update, + _multiclass_precision_recall_curve_arg_validation, + _multiclass_precision_recall_curve_compute, + _multiclass_precision_recall_curve_format, + _multiclass_precision_recall_curve_tensor_validation, + _multiclass_precision_recall_curve_update, + _multilabel_precision_recall_curve_arg_validation, + _multilabel_precision_recall_curve_compute, + _multilabel_precision_recall_curve_format, + _multilabel_precision_recall_curve_tensor_validation, + _multilabel_precision_recall_curve_update, _precision_recall_curve_compute, _precision_recall_curve_update, ) +from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.data import _bincount +from torchmetrics.utilities.prints import rank_zero_warn + + +def _reduce_average_precision( + precision: Union[Tensor, List[Tensor]], + recall: Union[Tensor, List[Tensor]], + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + weights: Optional[Tensor] = None, +) -> Tensor: + """Utility function for reducing multiple average precision score into one number.""" + res = [] + if isinstance(precision, Tensor) and isinstance(recall, Tensor): + res = -torch.sum((recall[:, 1:] - recall[:, :-1]) * precision[:, :-1], 1) + else: + for p, r in zip(precision, recall): + res.append(-torch.sum((r[1:] - r[:-1]) * p[:-1])) + res = torch.stack(res) + if average is None or average == "none": + return res + if torch.isnan(res).any(): + rank_zero_warn( + f"Average precision score for one or more classes was `nan`. Ignoring these classes in {average}-average", + UserWarning, + ) + idx = ~torch.isnan(res) + if average == "macro": + return res[idx].mean() + elif average == "weighted" and weights is not None: + weights = _safe_divide(weights[idx], weights[idx].sum()) + return (res[idx] * weights).sum() + else: + raise ValueError("Received an incompatible combinations of inputs to make reduction.") + + +def _binary_average_precision_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + thresholds: Optional[Tensor], +) -> Tensor: + precision, recall, _ = _binary_precision_recall_curve_compute(state, thresholds) + return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1]) + + +def binary_average_precision( + preds: Tensor, + target: Tensor, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes the average precision (AP) score for binary tasks. The AP score summarizes a precision-recall curve + as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold + as weight: + + .. math:: + AP = \sum{n} (R_n - R_{n-1}) P_n + + where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. This value is + equivalent to the area under the precision-recall curve (AUPRC). + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + A single scalar with the average precision score + + Example: + >>> from torchmetrics.functional.classification import binary_average_precision + >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> binary_average_precision(preds, target, thresholds=None) + tensor(0.5833) + >>> binary_average_precision(preds, target, thresholds=5) + tensor(0.6667) + """ + + if validate_args: + _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) + _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index) + preds, target, thresholds = _binary_precision_recall_curve_format(preds, target, thresholds, ignore_index) + state = _binary_precision_recall_curve_update(preds, target, thresholds) + return _binary_average_precision_compute(state, thresholds) + + +def _multiclass_average_precision_arg_validation( + num_classes: int, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) + allowed_average = ("macro", "weighted", "none", None) + if average not in allowed_average: + raise ValueError(f"Expected argument `average` to be one of {allowed_average} but got {average}") + + +def _multiclass_average_precision_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_classes: int, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + thresholds: Optional[Tensor] = None, +) -> Tensor: + precision, recall, _ = _multiclass_precision_recall_curve_compute(state, num_classes, thresholds) + return _reduce_average_precision( + precision, + recall, + average, + weights=_bincount(state[1], minlength=num_classes).float() if thresholds is None else state[0][:, 1, :].sum(-1), + ) + + +def multiclass_average_precision( + preds: Tensor, + target: Tensor, + num_classes: int, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes the average precision (AP) score for binary tasks. The AP score summarizes a precision-recall curve + as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold + as weight: + + .. math:: + AP = \sum{n} (R_n - R_{n-1}) P_n + + where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. This value is + equivalent to the area under the precision-recall curve (AUPRC). + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over classes. Should be one of the following: + + - ``macro``: Calculate score for each class and average them + - ``weighted``: Calculates score for each class and computes weighted average using their support + - ``"none"`` or ``None``: Calculates score for each class and applies no reduction + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with AP score per class. + If `average="macro"|"weighted"` then a single scalar is returned. + + Example: + >>> from torchmetrics.functional.classification import multiclass_average_precision + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> multiclass_average_precision(preds, target, num_classes=5, average="macro", thresholds=None) + tensor(0.6250) + >>> multiclass_average_precision(preds, target, num_classes=5, average=None, thresholds=None) + tensor([1.0000, 1.0000, 0.2500, 0.2500, nan]) + >>> multiclass_average_precision(preds, target, num_classes=5, average="macro", thresholds=5) + tensor(0.5000) + >>> multiclass_average_precision(preds, target, num_classes=5, average=None, thresholds=5) + tensor([1.0000, 1.0000, 0.2500, 0.2500, -0.0000]) + + """ + if validate_args: + _multiclass_average_precision_arg_validation(num_classes, average, thresholds, ignore_index) + _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index) + preds, target, thresholds = _multiclass_precision_recall_curve_format( + preds, target, num_classes, thresholds, ignore_index + ) + state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds) + return _multiclass_average_precision_compute(state, num_classes, average, thresholds) + + +def _multilabel_average_precision_arg_validation( + num_labels: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]], + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) + allowed_average = ("micro", "macro", "weighted", "none", None) + if average not in allowed_average: + raise ValueError(f"Expected argument `average` to be one of {allowed_average} but got {average}") + + +def _multilabel_average_precision_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_labels: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]], + thresholds: Optional[Tensor], + ignore_index: Optional[int] = None, +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + if average == "micro": + if isinstance(state, Tensor) and thresholds is not None: + state = state.sum(1) + else: + preds, target = state[0].flatten(), state[1].flatten() + if ignore_index is not None: + idx = target == ignore_index + preds = preds[~idx] + target = target[~idx] + state = [preds, target] + return _binary_average_precision_compute(state, thresholds) + else: + precision, recall, _ = _multilabel_precision_recall_curve_compute(state, num_labels, thresholds, ignore_index) + return _reduce_average_precision( + precision, + recall, + average, + weights=(state[1] == 1).sum(dim=0).float() if thresholds is None else state[0][:, 1, :].sum(-1), + ) + + +def multilabel_average_precision( + preds: Tensor, + target: Tensor, + num_labels: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes the average precision (AP) score for binary tasks. The AP score summarizes a precision-recall curve + as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold + as weight: + + .. math:: + AP = \sum{n} (R_n - R_{n-1}) P_n + + where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. This value is + equivalent to the area under the precision-recall curve (AUPRC). + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum score over all labels + - ``macro``: Calculate score for each label and average them + - ``weighted``: Calculates score for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates score for each label and applies no reduction + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with AP score per class. + If `average="micro|macro"|"weighted"` then a single scalar is returned. + + Example: + >>> from torchmetrics.functional.classification import multilabel_average_precision + >>> preds = torch.tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = torch.tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> multilabel_average_precision(preds, target, num_labels=3, average="macro", thresholds=None) + tensor(0.7500) + >>> multilabel_average_precision(preds, target, num_labels=3, average=None, thresholds=None) + tensor([0.7500, 0.5833, 0.9167]) + >>> multilabel_average_precision(preds, target, num_labels=3, average="macro", thresholds=5) + tensor(0.7778) + >>> multilabel_average_precision(preds, target, num_labels=3, average=None, thresholds=5) + tensor([0.7500, 0.6667, 0.9167]) + """ + if validate_args: + _multilabel_average_precision_arg_validation(num_labels, average, thresholds, ignore_index) + _multilabel_precision_recall_curve_tensor_validation(preds, target, num_labels, ignore_index) + preds, target, thresholds = _multilabel_precision_recall_curve_format( + preds, target, num_labels, thresholds, ignore_index + ) + state = _multilabel_precision_recall_curve_update(preds, target, num_labels, thresholds) + return _multilabel_average_precision_compute(state, num_labels, average, thresholds, ignore_index) def _average_precision_update( @@ -98,7 +479,7 @@ def _average_precision_compute( if preds.ndim == target.ndim and target.ndim > 1: weights = target.sum(dim=0).float() else: - weights = _bincount(target, minlength=num_classes).float() + weights = _bincount(target, minlength=max(num_classes, 2)).float() weights = weights / torch.sum(weights) else: weights = None @@ -178,9 +559,22 @@ def average_precision( target: Tensor, num_classes: Optional[int] = None, pos_label: Optional[int] = None, - average: Optional[str] = "macro", + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, ) -> Union[List[Tensor], Tensor]: - """Computes the average precision score. + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes the average precision score. Args: preds: predictions from model (logits or probabilities) @@ -196,8 +590,6 @@ def average_precision( - ``'macro'`` [default]: Calculate the metric for each class separately, and average the metrics across classes (with equal weights for each class). - - ``'micro'``: Calculate the metric globally, across all samples and classes. Cannot be - used with multiclass input. - ``'weighted'``: Calculate the metric for each class separately, and average the metrics across classes, weighting each class by its support. - ``'none'`` or ``None``: Calculate the metric for each class separately, and return @@ -223,5 +615,30 @@ def average_precision( >>> average_precision(pred, target, num_classes=5, average=None) [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] """ + if task is not None: + if task == "binary": + return binary_average_precision(preds, target, thresholds, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + return multiclass_average_precision( + preds, target, num_classes, average, thresholds, ignore_index, validate_args + ) + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_average_precision( + preds, target, num_labels, average, thresholds, ignore_index, validate_args + ) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes, pos_label, average) return _average_precision_compute(preds, target, num_classes, pos_label, average) diff --git a/src/torchmetrics/functional/classification/calibration_error.py b/src/torchmetrics/functional/classification/calibration_error.py index 5f08cf73400..6bf276f15b2 100644 --- a/src/torchmetrics/functional/classification/calibration_error.py +++ b/src/torchmetrics/functional/classification/calibration_error.py @@ -11,14 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Optional, Tuple, Union import torch from torch import Tensor - +from typing_extensions import Literal + +from torchmetrics.functional.classification.confusion_matrix import ( + _binary_confusion_matrix_format, + _binary_confusion_matrix_tensor_validation, + _multiclass_confusion_matrix_format, + _multiclass_confusion_matrix_tensor_validation, +) from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import DataType from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_8 +from torchmetrics.utilities.prints import rank_zero_warn def _binning_with_loop( @@ -61,6 +69,7 @@ def _binning_bucketize( Returns: tuple with binned accuracy, binned confidence and binned probabilities """ + accuracies = accuracies.to(dtype=confidences.dtype) acc_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype) conf_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype) count_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype) @@ -82,7 +91,7 @@ def _binning_bucketize( def _ce_compute( confidences: Tensor, accuracies: Tensor, - bin_boundaries: Tensor, + bin_boundaries: Union[Tensor, int], norm: str = "l1", debias: bool = False, ) -> Tensor: @@ -102,13 +111,17 @@ def _ce_compute( Returns: Tensor: Calibration error scalar. """ + if isinstance(bin_boundaries, int): + bin_boundaries = torch.linspace(0, 1, bin_boundaries + 1, dtype=torch.float, device=confidences.device) + if norm not in {"l1", "l2", "max"}: raise ValueError(f"Norm {norm} is not supported. Please select from l1, l2, or max. ") - if _TORCH_GREATER_EQUAL_1_8: - acc_bin, conf_bin, prop_bin = _binning_bucketize(confidences, accuracies, bin_boundaries) - else: - acc_bin, conf_bin, prop_bin = _binning_with_loop(confidences, accuracies, bin_boundaries) + with torch.no_grad(): + if _TORCH_GREATER_EQUAL_1_8: + acc_bin, conf_bin, prop_bin = _binning_bucketize(confidences, accuracies, bin_boundaries) + else: + acc_bin, conf_bin, prop_bin = _binning_with_loop(confidences, accuracies, bin_boundaries) if norm == "l1": ce = torch.sum(torch.abs(acc_bin - conf_bin) * prop_bin) @@ -126,6 +139,214 @@ def _ce_compute( return ce +def _binary_calibration_error_arg_validation( + n_bins: int, + norm: Literal["l1", "l2", "max"] = "l1", + ignore_index: Optional[int] = None, +) -> None: + if not isinstance(n_bins, int) or n_bins < 1: + raise ValueError(f"Expected argument `n_bins` to be an integer larger than 0, but got {n_bins}") + allowed_norm = ("l1", "l2", "max") + if norm not in allowed_norm: + raise ValueError(f"Expected argument `norm` to be one of {allowed_norm}, but got {norm}.") + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + + +def _binary_calibration_error_tensor_validation( + preds: Tensor, target: Tensor, ignore_index: Optional[int] = None +) -> None: + _binary_confusion_matrix_tensor_validation(preds, target, ignore_index) + if not preds.is_floating_point(): + raise ValueError( + "Expected argument `preds` to be floating tensor with probabilities/logits" + f" but got tensor with dtype {preds.dtype}" + ) + + +def _binary_calibration_error_update(preds: Tensor, target: Tensor) -> Tensor: + confidences, accuracies = preds, target + return confidences, accuracies + + +def binary_calibration_error( + preds: Tensor, + target: Tensor, + n_bins: int = 15, + norm: Literal["l1", "l2", "max"] = "l1", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""`Computes the Top-label Calibration Error`_ for binary tasks. The expected calibration error can be used to + quantify how well a given model is calibrated e.g. how well the predicted output probabilities of the model matches + the actual probabilities of the ground truth distribution. + + Three different norms are implemented, each corresponding to variations on the calibration error metric. + + .. math:: + \text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)} + + .. math:: + \text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)} + + .. math:: + \text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)} + + Where :math:`p_i` is the top-1 prediction accuracy in bin :math:`i`, :math:`c_i` is the average confidence of + predictions in bin :math:`i`, and :math:`b_i` is the fraction of data points in bin :math:`i`. Bins are constructed + in an uniform way in the [0,1] range. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + n_bins: Number of bins to use when computing the metric. + norm: Norm used to compare empirical and expected probability bins. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example: + >>> from torchmetrics.functional.classification import binary_calibration_error + >>> preds = torch.tensor([0.25, 0.25, 0.55, 0.75, 0.75]) + >>> target = torch.tensor([0, 0, 1, 1, 1]) + >>> binary_calibration_error(preds, target, n_bins=2, norm='l1') + tensor(0.2900) + >>> binary_calibration_error(preds, target, n_bins=2, norm='l2') + tensor(0.2918) + >>> binary_calibration_error(preds, target, n_bins=2, norm='max') + tensor(0.3167) + """ + if validate_args: + _binary_calibration_error_arg_validation(n_bins, norm, ignore_index) + _binary_calibration_error_tensor_validation(preds, target, ignore_index) + preds, target = _binary_confusion_matrix_format( + preds, target, threshold=0.0, ignore_index=ignore_index, convert_to_labels=False + ) + confidences, accuracies = _binary_calibration_error_update(preds, target) + return _ce_compute(confidences, accuracies, n_bins, norm) + + +def _multiclass_calibration_error_arg_validation( + num_classes: int, + n_bins: int, + norm: Literal["l1", "l2", "max"] = "l1", + ignore_index: Optional[int] = None, +) -> None: + if not isinstance(num_classes, int) or num_classes < 2: + raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") + if not isinstance(n_bins, int) or n_bins < 1: + raise ValueError(f"Expected argument `n_bins` to be an integer larger than 0, but got {n_bins}") + allowed_norm = ("l1", "l2", "max") + if norm not in allowed_norm: + raise ValueError(f"Expected argument `norm` to be one of {allowed_norm}, but got {norm}.") + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + + +def _multiclass_calibration_error_tensor_validation( + preds: Tensor, target: Tensor, num_classes: int, ignore_index: Optional[int] = None +) -> None: + _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) + if not preds.is_floating_point(): + raise ValueError( + "Expected argument `preds` to be floating tensor with probabilities/logits" + f" but got tensor with dtype {preds.dtype}" + ) + + +def _multiclass_calibration_error_update( + preds: Tensor, + target: Tensor, +) -> Tensor: + if not torch.all((0 <= preds) * (preds <= 1)): + preds = preds.softmax(1) + confidences, predictions = preds.max(dim=1) + accuracies = predictions.eq(target) + return confidences.float(), accuracies.float() + + +def multiclass_calibration_error( + preds: Tensor, + target: Tensor, + num_classes: int, + n_bins: int = 15, + norm: Literal["l1", "l2", "max"] = "l1", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""`Computes the Top-label Calibration Error`_ for multiclass tasks. The expected calibration error can be used to + quantify how well a given model is calibrated e.g. how well the predicted output probabilities of the model matches + the actual probabilities of the ground truth distribution. + + Three different norms are implemented, each corresponding to variations on the calibration error metric. + + .. math:: + \text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)} + + .. math:: + \text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)} + + .. math:: + \text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)} + + Where :math:`p_i` is the top-1 prediction accuracy in bin :math:`i`, :math:`c_i` is the average confidence of + predictions in bin :math:`i`, and :math:`b_i` is the fraction of data points in bin :math:`i`. Bins are constructed + in an uniform way in the [0,1] range. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + n_bins: Number of bins to use when computing the metric. + norm: Norm used to compare empirical and expected probability bins. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example: + >>> from torchmetrics.functional.classification import multiclass_calibration_error + >>> preds = torch.tensor([[0.25, 0.20, 0.55], + ... [0.55, 0.05, 0.40], + ... [0.10, 0.30, 0.60], + ... [0.90, 0.05, 0.05]]) + >>> target = torch.tensor([0, 1, 2, 0]) + >>> multiclass_calibration_error(preds, target, num_classes=3, n_bins=3, norm='l1') + tensor(0.2000) + >>> multiclass_calibration_error(preds, target, num_classes=3, n_bins=3, norm='l2') + tensor(0.2082) + >>> multiclass_calibration_error(preds, target, num_classes=3, n_bins=3, norm='max') + tensor(0.2333) + """ + if validate_args: + _multiclass_calibration_error_arg_validation(num_classes, n_bins, norm, ignore_index) + _multiclass_calibration_error_tensor_validation(preds, target, num_classes, ignore_index) + preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index, convert_to_labels=False) + confidences, accuracies = _multiclass_calibration_error_update(preds, target) + return _ce_compute(confidences, accuracies, n_bins, norm) + + def _ce_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: """Given a predictions and targets tensor, computes the confidences of the top-1 prediction and records their correctness. @@ -165,8 +386,25 @@ def _ce_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: return confidences.float(), accuracies.float() -def calibration_error(preds: Tensor, target: Tensor, n_bins: int = 15, norm: str = "l1") -> Tensor: - r"""`Computes the Top-label Calibration Error`_ +def calibration_error( + preds: Tensor, + target: Tensor, + n_bins: int = 15, + norm: Literal["l1", "l2", "max"] = "l1", + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + `Computes the Top-label Calibration Error`_ Three different norms are implemented, each corresponding to variations on the calibration error metric. @@ -199,6 +437,24 @@ def calibration_error(preds: Tensor, target: Tensor, n_bins: int = 15, norm: str norm: Norm used to compare empirical and expected probability bins. Defaults to "l1", or Expected Calibration Error. """ + if task is not None: + assert norm is not None + if task == "binary": + return binary_calibration_error(preds, target, n_bins, norm, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + return multiclass_calibration_error(preds, target, num_classes, n_bins, norm, ignore_index, validate_args) + raise ValueError(f"Expected argument `task` to either be `'binary'`, `'multiclass'` but got {task}") + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + if norm not in ("l1", "l2", "max"): raise ValueError(f"Norm {norm} is not supported. Please select from l1, l2, or max. ") diff --git a/src/torchmetrics/functional/classification/cohen_kappa.py b/src/torchmetrics/functional/classification/cohen_kappa.py index 01623e08fab..0c3962a7d74 100644 --- a/src/torchmetrics/functional/classification/cohen_kappa.py +++ b/src/torchmetrics/functional/classification/cohen_kappa.py @@ -15,8 +15,222 @@ import torch from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.classification.confusion_matrix import ( + _binary_confusion_matrix_arg_validation, + _binary_confusion_matrix_format, + _binary_confusion_matrix_tensor_validation, + _binary_confusion_matrix_update, + _confusion_matrix_compute, + _confusion_matrix_update, + _multiclass_confusion_matrix_arg_validation, + _multiclass_confusion_matrix_format, + _multiclass_confusion_matrix_tensor_validation, + _multiclass_confusion_matrix_update, +) +from torchmetrics.utilities.prints import rank_zero_warn + + +def _cohen_kappa_reduce(confmat: Tensor, weights: Optional[Literal["linear", "quadratic", "none"]] = None) -> Tensor: + """Reduce an un-normalized confusion matrix of shape (n_classes, n_classes) into the cohen kappa score.""" + confmat = confmat.float() if not confmat.is_floating_point() else confmat + n_classes = confmat.shape[0] + sum0 = confmat.sum(dim=0, keepdim=True) + sum1 = confmat.sum(dim=1, keepdim=True) + expected = sum1 @ sum0 / sum0.sum() # outer product + + if weights is None or weights == "none": + w_mat = torch.ones_like(confmat).flatten() + w_mat[:: n_classes + 1] = 0 + w_mat = w_mat.reshape(n_classes, n_classes) + elif weights in ("linear", "quadratic"): + w_mat = torch.zeros_like(confmat) + w_mat += torch.arange(n_classes, dtype=w_mat.dtype, device=w_mat.device) + if weights == "linear": + w_mat = torch.abs(w_mat - w_mat.T) + else: + w_mat = torch.pow(w_mat - w_mat.T, 2.0) + else: + raise ValueError( + f"Received {weights} for argument ``weights`` but should be either" " None, 'linear' or 'quadratic'" + ) + k = torch.sum(w_mat * confmat) / torch.sum(w_mat * expected) + return 1 - k + + +def _binary_cohen_kappa_arg_validation( + threshold: float = 0.5, + ignore_index: Optional[int] = None, + weights: Optional[Literal["linear", "quadratic", "none"]] = None, +) -> None: + """Validate non tensor input. + + - ``threshold`` has to be a float in the [0,1] range + - ``ignore_index`` has to be None or int + - ``weights`` has to be "linear" | "quadratic" | "none" | None + """ + _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize=None) + allowed_weights = ("linear", "quadratic", "none", None) + if weights not in allowed_weights: + raise ValueError(f"Expected argument `weight` to be one of {allowed_weights}, but got {weights}.") + + +def binary_cohen_kappa( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + weights: Optional[Literal["linear", "quadratic", "none"]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Calculates `Cohen's kappa score`_ that measures inter-annotator agreement for binary + tasks. It is defined as + + .. math:: + \kappa = (p_o - p_e) / (1 - p_e) + + where :math:`p_o` is the empirical probability of agreement and :math:`p_e` is + the expected agreement when both annotators assign labels randomly. Note that + :math:`p_e` is estimated using a per-annotator empirical prior over the + class labels. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + threshold: Threshold for transforming probability to binary (0,1) predictions + weights: Weighting type to calculate the score. Choose from: + + - ``None`` or ``'none'``: no weighting + - ``'linear'``: linear weighting + - ``'quadratic'``: quadratic weighting + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import binary_cohen_kappa + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> binary_cohen_kappa(preds, target) + tensor(0.5000) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import binary_cohen_kappa + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) + >>> binary_cohen_kappa(preds, target) + tensor(0.5000) + + """ + if validate_args: + _binary_cohen_kappa_arg_validation(threshold, ignore_index, weights) + _binary_confusion_matrix_tensor_validation(preds, target, ignore_index) + preds, target = _binary_confusion_matrix_format(preds, target, threshold, ignore_index) + confmat = _binary_confusion_matrix_update(preds, target) + return _cohen_kappa_reduce(confmat, weights) + + +def _multiclass_cohen_kappa_arg_validation( + num_classes: int, + ignore_index: Optional[int] = None, + weights: Optional[Literal["linear", "quadratic", "none"]] = None, +) -> None: + """Validate non tensor input. + + - ``num_classes`` has to be a int larger than 1 + - ``ignore_index`` has to be None or int + - ``weights`` has to be "linear" | "quadratic" | "none" | None + """ + _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize=None) + allowed_weights = ("linear", "quadratic", "none", None) + if weights not in allowed_weights: + raise ValueError(f"Expected argument `weight` to be one of {allowed_weights}, but got {weights}.") + + +def multiclass_cohen_kappa( + preds: Tensor, + target: Tensor, + num_classes: int, + weights: Optional[Literal["linear", "quadratic", "none"]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Calculates `Cohen's kappa score`_ that measures inter-annotator agreement for multiclass + tasks. It is defined as + + .. math:: + \kappa = (p_o - p_e) / (1 - p_e) + + where :math:`p_o` is the empirical probability of agreement and :math:`p_e` is + the expected agreement when both annotators assign labels randomly. Note that + :math:`p_e` is estimated using a per-annotator empirical prior over the + class labels. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + weights: Weighting type to calculate the score. Choose from: + + - ``None`` or ``'none'``: no weighting + - ``'linear'``: linear weighting + - ``'quadratic'``: quadratic weighting + + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (pred is integer tensor): + >>> from torchmetrics.functional.classification import multiclass_cohen_kappa + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> multiclass_cohen_kappa(preds, target, num_classes=3) + tensor(0.6364) + + Example (pred is float tensor): + >>> from torchmetrics.functional.classification import multiclass_cohen_kappa + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> multiclass_cohen_kappa(preds, target, num_classes=3) + tensor(0.6364) + + """ + if validate_args: + _multiclass_cohen_kappa_arg_validation(num_classes, ignore_index, weights) + _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) + preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index) + confmat = _multiclass_confusion_matrix_update(preds, target, num_classes) + return _cohen_kappa_reduce(confmat, weights) -from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_compute, _confusion_matrix_update _cohen_kappa_update = _confusion_matrix_update @@ -71,10 +285,22 @@ def cohen_kappa( preds: Tensor, target: Tensor, num_classes: int, - weights: Optional[str] = None, + weights: Optional[Literal["linear", "quadratic", "none"]] = None, threshold: float = 0.5, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, ) -> Tensor: - r"""Calculates `Cohen's kappa score`_ that measures inter-annotator agreement. + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + + Calculates `Cohen's kappa score`_ that measures inter-annotator agreement. It is defined as @@ -106,5 +332,23 @@ class labels. >>> cohen_kappa(preds, target, num_classes=2) tensor(0.5000) """ + if task is not None: + if task == "binary": + return binary_cohen_kappa(preds, target, threshold, weights, ignore_index, validate_args) + if task == "multiclass": + return multiclass_cohen_kappa(preds, target, num_classes, weights, ignore_index, validate_args) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + confmat = _cohen_kappa_update(preds, target, num_classes, threshold) return _cohen_kappa_compute(confmat, weights) diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index 362276dc146..d6d98a24b4b 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -11,15 +11,588 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Tuple import torch from torch import Tensor +from typing_extensions import Literal -from torchmetrics.utilities import rank_zero_warn -from torchmetrics.utilities.checks import _input_format_classification -from torchmetrics.utilities.data import _bincount +from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification +from torchmetrics.utilities.data import _bincount, _movedim from torchmetrics.utilities.enums import DataType +from torchmetrics.utilities.prints import rank_zero_warn + + +def _confusion_matrix_reduce( + confmat: Tensor, normalize: Optional[Literal["true", "pred", "all", "none"]] = None +) -> Tensor: + """Reduce an un-normalized confusion matrix + Args: + confmat: un-normalized confusion matrix + normalize: normalization method. + - `"true"` will divide by the sum of the column dimension. + - `"pred"` will divide by the sum of the row dimension. + - `"all"` will divide by the sum of the full matrix + - `"none"` or `None` will apply no reduction + + Returns: + Normalized confusion matrix + """ + allowed_normalize = ("true", "pred", "all", "none", None) + if normalize not in allowed_normalize: + raise ValueError(f"Argument `normalize` needs to one of the following: {allowed_normalize}") + if normalize is not None and normalize != "none": + confmat = confmat.float() if not confmat.is_floating_point() else confmat + if normalize == "true": + confmat = confmat / confmat.sum(axis=-1, keepdim=True) + elif normalize == "pred": + confmat = confmat / confmat.sum(axis=-2, keepdim=True) + elif normalize == "all": + confmat = confmat / confmat.sum(axis=[-2, -1], keepdim=True) + + nan_elements = confmat[torch.isnan(confmat)].nelement() + if nan_elements: + confmat[torch.isnan(confmat)] = 0 + rank_zero_warn(f"{nan_elements} NaN values found in confusion matrix have been replaced with zeros.") + return confmat + + +def _binary_confusion_matrix_arg_validation( + threshold: float = 0.5, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, +) -> None: + """Validate non tensor input. + + - ``threshold`` has to be a float in the [0,1] range + - ``ignore_index`` has to be None or int + - ``normalize`` has to be "true" | "pred" | "all" | "none" | None + """ + if not (isinstance(threshold, float) and (0 <= threshold <= 1)): + raise ValueError(f"Expected argument `threshold` to be a float in the [0,1] range, but got {threshold}.") + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + allowed_normalize = ("true", "pred", "all", "none", None) + if normalize not in allowed_normalize: + raise ValueError(f"Expected argument `normalize` to be one of {allowed_normalize}, but got {normalize}.") + + +def _binary_confusion_matrix_tensor_validation( + preds: Tensor, target: Tensor, ignore_index: Optional[int] = None +) -> None: + """Validate tensor input. + + - tensors have to be of same shape + - all values in target tensor that are not ignored have to be in {0, 1} + - if pred tensor is not floating point, then all values also have to be in {0, 1} + """ + # Check that they have same shape + _check_same_shape(preds, target) + + # Check that target only contains {0,1} values or value in ignore_index + unique_values = torch.unique(target) + if ignore_index is None: + check = torch.any((unique_values != 0) & (unique_values != 1)) + else: + check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index)) + if check: + raise RuntimeError( + f"Detected the following values in `target`: {unique_values} but expected only" + f" the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." + ) + + # If preds is label tensor, also check that it only contains {0,1} values + if not preds.is_floating_point(): + unique_values = torch.unique(preds) + if torch.any((unique_values != 0) & (unique_values != 1)): + raise RuntimeError( + f"Detected the following values in `preds`: {unique_values} but expected only" + " the following values [0,1] since preds is a label tensor." + ) + + +def _binary_confusion_matrix_format( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + convert_to_labels: bool = True, +) -> Tuple[Tensor, Tensor]: + """Convert all input to label format. + + - Remove all datapoints that should be ignored + - If preds tensor is floating point, applies sigmoid if pred tensor not in [0,1] range + - If preds tensor is floating point, thresholds afterwards + """ + preds = preds.flatten() + target = target.flatten() + if ignore_index is not None: + idx = target != ignore_index + preds = preds[idx] + target = target[idx] + + if preds.is_floating_point(): + if not torch.all((0 <= preds) * (preds <= 1)): + # preds is logits, convert with sigmoid + preds = preds.sigmoid() + if convert_to_labels: + preds = preds > threshold + + return preds, target + + +def _binary_confusion_matrix_update(preds: Tensor, target: Tensor) -> Tensor: + """Computes the bins to update the confusion matrix with.""" + unique_mapping = (target * 2 + preds).to(torch.long) + bins = _bincount(unique_mapping, minlength=4) + return bins.reshape(2, 2) + + +def _binary_confusion_matrix_compute( + confmat: Tensor, normalize: Optional[Literal["true", "pred", "all", "none"]] = None +) -> Tensor: + """Reduces the confusion matrix to it's final form. + + Normalization technique can be chosen by ``normalize``. + """ + return _confusion_matrix_reduce(confmat, normalize) + + +def binary_confusion_matrix( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes the `confusion matrix`_ for binary tasks. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + threshold: Threshold for transforming probability to binary (0,1) predictions + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + A ``[2, 2]`` tensor + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import binary_confusion_matrix + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> binary_confusion_matrix(preds, target) + tensor([[2, 0], + [1, 1]]) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import binary_confusion_matrix + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) + >>> binary_confusion_matrix(preds, target) + tensor([[2, 0], + [1, 1]]) + """ + if validate_args: + _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize) + _binary_confusion_matrix_tensor_validation(preds, target, ignore_index) + preds, target = _binary_confusion_matrix_format(preds, target, threshold, ignore_index) + confmat = _binary_confusion_matrix_update(preds, target) + return _binary_confusion_matrix_compute(confmat, normalize) + + +def _multiclass_confusion_matrix_arg_validation( + num_classes: int, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, +) -> None: + """Validate non tensor input. + + - ``num_classes`` has to be a int larger than 1 + - ``ignore_index`` has to be None or int + - ``normalize`` has to be "true" | "pred" | "all" | "none" | None + """ + if not isinstance(num_classes, int) or num_classes < 2: + raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + allowed_normalize = ("true", "pred", "all", "none", None) + if normalize not in allowed_normalize: + raise ValueError(f"Expected argument `normalize` to be one of {allowed_normalize}, but got {normalize}.") + + +def _multiclass_confusion_matrix_tensor_validation( + preds: Tensor, target: Tensor, num_classes: int, ignore_index: Optional[int] = None +) -> None: + """Validate tensor input. + + - if target has one more dimension than preds, then all dimensions except for preds.shape[1] should match + exactly. preds.shape[1] should have size equal to number of classes + - if preds and target have same number of dims, then all dimensions should match + - all values in target tensor that are not ignored have to be {0, ..., num_classes - 1} + - if pred tensor is not floating point, then all values also have to be in {0, ..., num_classes - 1} + """ + if preds.ndim == target.ndim + 1: + if not preds.is_floating_point(): + raise ValueError("If `preds` have one dimension more than `target`, `preds` should be a float tensor.") + if preds.shape[1] != num_classes: + raise ValueError( + "If `preds` have one dimension more than `target`, `preds.shape[1]` should be" + " equal to number of classes." + ) + if preds.shape[2:] != target.shape[1:]: + raise ValueError( + "If `preds` have one dimension more than `target`, the shape of `preds` should be" + " (N, C, ...), and the shape of `target` should be (N, ...)." + ) + elif preds.ndim == target.ndim: + if preds.shape != target.shape: + raise ValueError( + "The `preds` and `target` should have the same shape,", + f" got `preds` with shape={preds.shape} and `target` with shape={target.shape}.", + ) + else: + raise ValueError( + "Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)" + " and `preds` should be (N, C, ...)." + ) + + num_unique_values = len(torch.unique(target)) + if ignore_index is None: + check = num_unique_values > num_classes + else: + check = num_unique_values > num_classes + 1 + if check: + raise RuntimeError( + "Detected more unique values in `target` than `num_classes`. Expected only " + f"{num_classes if ignore_index is None else num_classes + 1} but found " + f"{num_unique_values} in `target`." + ) + + if not preds.is_floating_point(): + num_unique_values = len(torch.unique(preds)) + if num_unique_values > num_classes: + raise RuntimeError( + "Detected more unique values in `preds` than `num_classes`. Expected only " + f"{num_classes} but found {num_unique_values} in `preds`." + ) + + +def _multiclass_confusion_matrix_format( + preds: Tensor, + target: Tensor, + ignore_index: Optional[int] = None, + convert_to_labels: bool = True, +) -> Tuple[Tensor, Tensor]: + """Convert all input to label format. + + - Applies argmax if preds have one more dimension than target + - Remove all datapoints that should be ignored + """ + # Apply argmax if we have one more dimension + if preds.ndim == target.ndim + 1 and convert_to_labels: + preds = preds.argmax(dim=1) + + if convert_to_labels: + preds = preds.flatten() + else: + preds = _movedim(preds, 1, -1).reshape(-1, preds.shape[1]) + target = target.flatten() + + if ignore_index is not None: + idx = target != ignore_index + preds = preds[idx] + target = target[idx] + + return preds, target + + +def _multiclass_confusion_matrix_update(preds: Tensor, target: Tensor, num_classes: int) -> Tensor: + """Computes the bins to update the confusion matrix with.""" + unique_mapping = (target * num_classes + preds).to(torch.long) + bins = _bincount(unique_mapping, minlength=num_classes**2) + return bins.reshape(num_classes, num_classes) + + +def _multiclass_confusion_matrix_compute( + confmat: Tensor, normalize: Optional[Literal["true", "pred", "all", "none"]] = None +) -> Tensor: + """Reduces the confusion matrix to it's final form. + + Normalization technique can be chosen by ``normalize``. + """ + return _confusion_matrix_reduce(confmat, normalize) + + +def multiclass_confusion_matrix( + preds: Tensor, + target: Tensor, + num_classes: int, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes the `confusion matrix`_ for multiclass tasks. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + A ``[num_classes, num_classes]`` tensor + + Example (pred is integer tensor): + >>> from torchmetrics.functional.classification import multiclass_confusion_matrix + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> multiclass_confusion_matrix(preds, target, num_classes=3) + tensor([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]]) + + Example (pred is float tensor): + >>> from torchmetrics.functional.classification import multiclass_confusion_matrix + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> multiclass_confusion_matrix(preds, target, num_classes=3) + tensor([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]]) + """ + if validate_args: + _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize) + _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) + preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index) + confmat = _multiclass_confusion_matrix_update(preds, target, num_classes) + return _multiclass_confusion_matrix_compute(confmat, normalize) + + +def _multilabel_confusion_matrix_arg_validation( + num_labels: int, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, +) -> None: + """Validate non tensor input. + + - ``num_labels`` should be an int larger than 1 + - ``threshold`` has to be a float in the [0,1] range + - ``ignore_index`` has to be None or int + - ``normalize`` has to be "true" | "pred" | "all" | "none" | None + """ + if not isinstance(num_labels, int) or num_labels < 2: + raise ValueError(f"Expected argument `num_labels` to be an integer larger than 1, but got {num_labels}") + if not (isinstance(threshold, float) and (0 <= threshold <= 1)): + raise ValueError(f"Expected argument `threshold` to be a float, but got {threshold}.") + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + allowed_normalize = ("true", "pred", "all", "none", None) + if normalize not in allowed_normalize: + raise ValueError(f"Expected argument `normalize` to be one of {allowed_normalize}, but got {normalize}.") + + +def _multilabel_confusion_matrix_tensor_validation( + preds: Tensor, target: Tensor, num_labels: int, ignore_index: Optional[int] = None +) -> None: + """Validate tensor input. + + - tensors have to be of same shape + - the second dimension of both tensors need to be equal to the number of labels + - all values in target tensor that are not ignored have to be in {0, 1} + - if pred tensor is not floating point, then all values also have to be in {0, 1} + """ + # Check that they have same shape + _check_same_shape(preds, target) + + if preds.shape[1] != num_labels: + raise ValueError( + "Expected both `target.shape[1]` and `preds.shape[1]` to be equal to the number of labels" + f" but got {preds.shape[1]} and expected {num_labels}" + ) + + # Check that target only contains [0,1] values or value in ignore_index + unique_values = torch.unique(target) + if ignore_index is None: + check = torch.any((unique_values != 0) & (unique_values != 1)) + else: + check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index)) + if check: + raise RuntimeError( + f"Detected the following values in `target`: {unique_values} but expected only" + f" the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." + ) + + # If preds is label tensor, also check that it only contains [0,1] values + if not preds.is_floating_point(): + unique_values = torch.unique(preds) + if torch.any((unique_values != 0) & (unique_values != 1)): + raise RuntimeError( + f"Detected the following values in `preds`: {unique_values} but expected only" + " the following values [0,1] since preds is a label tensor." + ) + + +def _multilabel_confusion_matrix_format( + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + should_threshold: bool = True, +) -> Tuple[Tensor, Tensor]: + """Convert all input to label format. + + - If preds tensor is floating point, applies sigmoid if pred tensor not in [0,1] range + - If preds tensor is floating point, thresholds afterwards + - Mask all elements that should be ignored with negative numbers for later filtration + """ + if preds.is_floating_point(): + if not torch.all((0 <= preds) * (preds <= 1)): + preds = preds.sigmoid() + if should_threshold: + preds = preds > threshold + preds = _movedim(preds, 1, -1).reshape(-1, num_labels) + target = _movedim(target, 1, -1).reshape(-1, num_labels) + + if ignore_index is not None: + preds = preds.clone() + target = target.clone() + # Make sure that when we map, it will always result in a negative number that we can filter away + # Each label correspond to a 2x2 matrix = 4 elements per label + idx = target == ignore_index + preds[idx] = -4 * num_labels + target[idx] = -4 * num_labels + + return preds, target + + +def _multilabel_confusion_matrix_update(preds: Tensor, target: Tensor, num_labels: int) -> Tensor: + """Computes the bins to update the confusion matrix with.""" + unique_mapping = ((2 * target + preds) + 4 * torch.arange(num_labels, device=preds.device)).flatten() + unique_mapping = unique_mapping[unique_mapping >= 0] + bins = _bincount(unique_mapping, minlength=4 * num_labels) + return bins.reshape(num_labels, 2, 2) + + +def _multilabel_confusion_matrix_compute( + confmat: Tensor, normalize: Optional[Literal["true", "pred", "all", "none"]] = None +) -> Tensor: + """Reduces the confusion matrix to it's final form. + + Normalization technique can be chosen by ``normalize``. + """ + return _confusion_matrix_reduce(confmat, normalize) + + +def multilabel_confusion_matrix( + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes the `confusion matrix`_ for multilabel tasks. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + A ``[num_labels, 2, 2]`` tensor + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import multilabel_confusion_matrix + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_confusion_matrix(preds, target, num_labels=3) + tensor([[[1, 0], [0, 1]], + [[1, 0], [1, 0]], + [[0, 1], [0, 1]]]) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import multilabel_confusion_matrix + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_confusion_matrix(preds, target, num_labels=3) + tensor([[[1, 0], [0, 1]], + [[1, 0], [1, 0]], + [[0, 1], [0, 1]]]) + """ + if validate_args: + _multilabel_confusion_matrix_arg_validation(num_labels, threshold, ignore_index, normalize) + _multilabel_confusion_matrix_tensor_validation(preds, target, num_labels, ignore_index) + preds, target = _multilabel_confusion_matrix_format(preds, target, num_labels, threshold, ignore_index) + confmat = _multilabel_confusion_matrix_update(preds, target, num_labels) + return _multilabel_confusion_matrix_compute(confmat, normalize) def _confusion_matrix_update( @@ -117,11 +690,23 @@ def confusion_matrix( preds: Tensor, target: Tensor, num_classes: int, - normalize: Optional[str] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, threshold: float = 0.5, multilabel: bool = False, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, ) -> Tensor: r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes the `confusion matrix`_. Works with binary, multiclass, and multilabel data. Accepts probabilities or logits from a model output or integer class values in prediction. Works with multi-dimensional preds and target, but it should be noted that @@ -182,5 +767,28 @@ def confusion_matrix( [[0, 1], [0, 1]]]) """ + if task is not None: + if task == "binary": + return binary_confusion_matrix(preds, target, threshold, normalize, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + return multiclass_confusion_matrix(preds, target, num_classes, normalize, ignore_index, validate_args) + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_confusion_matrix( + preds, target, num_labels, threshold, normalize, ignore_index, validate_args + ) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) confmat = _confusion_matrix_update(preds, target, num_classes, threshold, multilabel) return _confusion_matrix_compute(confmat, normalize) diff --git a/src/torchmetrics/functional/classification/exact_match.py b/src/torchmetrics/functional/classification/exact_match.py new file mode 100644 index 00000000000..36a024f3b4b --- /dev/null +++ b/src/torchmetrics/functional/classification/exact_match.py @@ -0,0 +1,135 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.classification.stat_scores import ( + _multilabel_stat_scores_arg_validation, + _multilabel_stat_scores_format, + _multilabel_stat_scores_tensor_validation, +) +from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.data import _movedim + + +def _multilabel_exact_scores_update( + preds: Tensor, target: Tensor, num_labels: int, multidim_average: Literal["global", "samplewise"] = "global" +) -> Tuple[Tensor, Tensor]: + """Computes the statistics.""" + if multidim_average == "global": + preds = _movedim(preds, 1, -1).reshape(-1, num_labels) + target = _movedim(target, 1, -1).reshape(-1, num_labels) + + correct = ((preds == target).sum(1) == num_labels).sum(dim=-1) + total = torch.tensor(preds.shape[0 if multidim_average == "global" else 2], device=correct.device) + return correct, total + + +def _multilabel_exact_scores_compute( + correct: Tensor, + total: Tensor, +) -> Tensor: + """Final reduction for exact match.""" + return _safe_divide(correct, total) + + +def multilabel_exact_match( + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes Exact match (also known as subset accuracy) for multilabel tasks. Exact Match is a stricter + version of accuracy where all labels have to match exactly for the sample to be correctly classified. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import multilabel_exact_match + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_exact_match(preds, target, num_labels=3) + tensor(0.5000) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import multilabel_exact_match + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_exact_match(preds, target, num_labels=3) + tensor(0.5000) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import multilabel_exact_match + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> multilabel_exact_match(preds, target, num_labels=3, multidim_average='samplewise') + tensor([0., 0.]) + + """ + average = None + if validate_args: + _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) + _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) + preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) + correct, total = _multilabel_exact_scores_update(preds, target, num_labels, multidim_average) + return _multilabel_exact_scores_compute(correct, total) diff --git a/src/torchmetrics/functional/classification/f_beta.py b/src/torchmetrics/functional/classification/f_beta.py index e523b4b533d..f5e7ae8a3be 100644 --- a/src/torchmetrics/functional/classification/f_beta.py +++ b/src/torchmetrics/functional/classification/f_beta.py @@ -15,16 +15,695 @@ import torch from torch import Tensor +from typing_extensions import Literal -from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores, _stat_scores_update +from torchmetrics.functional.classification.stat_scores import ( + _binary_stat_scores_arg_validation, + _binary_stat_scores_format, + _binary_stat_scores_tensor_validation, + _binary_stat_scores_update, + _multiclass_stat_scores_arg_validation, + _multiclass_stat_scores_format, + _multiclass_stat_scores_tensor_validation, + _multiclass_stat_scores_update, + _multilabel_stat_scores_arg_validation, + _multilabel_stat_scores_format, + _multilabel_stat_scores_tensor_validation, + _multilabel_stat_scores_update, + _reduce_stat_scores, + _stat_scores_update, +) +from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.enums import AverageMethod as AvgMethod from torchmetrics.utilities.enums import MDMCAverageMethod +from torchmetrics.utilities.prints import rank_zero_warn -def _safe_divide(num: Tensor, denom: Tensor) -> Tensor: - """prevent zero division.""" - denom[denom == 0.0] = 1 - return num / denom +def _fbeta_reduce( + tp: Tensor, + fp: Tensor, + tn: Tensor, + fn: Tensor, + beta: float, + average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]], + multidim_average: Literal["global", "samplewise"] = "global", +) -> Tensor: + beta2 = beta**2 + if average == "binary": + return _safe_divide((1 + beta2) * tp, (1 + beta2) * tp + beta2 * fn + fp) + elif average == "micro": + tp = tp.sum(dim=0 if multidim_average == "global" else 1) + fn = fn.sum(dim=0 if multidim_average == "global" else 1) + fp = fp.sum(dim=0 if multidim_average == "global" else 1) + return _safe_divide((1 + beta2) * tp, (1 + beta2) * tp + beta2 * fn + fp) + else: + fbeta_score = _safe_divide((1 + beta2) * tp, (1 + beta2) * tp + beta2 * fn + fp) + if average is None or average == "none": + return fbeta_score + if average == "weighted": + weights = tp + fn + else: + weights = torch.ones_like(fbeta_score) + return _safe_divide(weights * fbeta_score, weights.sum(-1, keepdim=True)).sum(-1) + + +def _binary_fbeta_score_arg_validation( + beta: float, + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, +) -> None: + if not (isinstance(beta, float) and beta > 0): + raise ValueError(f"Expected argument `beta` to be a float larger than 0, but got {beta}.") + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) + + +def binary_fbeta_score( + preds: Tensor, + target: Tensor, + beta: float, + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes `F-score`_ metric for binary tasks: + + .. math:: + F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} + {(\beta^2 * \text{precision}) + \text{recall}} + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + beta: Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` + is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import binary_fbeta_score + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> binary_fbeta_score(preds, target, beta=2.0) + tensor(0.6667) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import binary_fbeta_score + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> binary_fbeta_score(preds, target, beta=2.0) + tensor(0.6667) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import binary_fbeta_score + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> binary_fbeta_score(preds, target, beta=2.0, multidim_average='samplewise') + tensor([0.5882, 0.0000]) + """ + if validate_args: + _binary_fbeta_score_arg_validation(beta, threshold, multidim_average, ignore_index) + _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) + preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) + tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) + return _fbeta_reduce(tp, fp, tn, fn, beta, average="binary", multidim_average=multidim_average) + + +def _multiclass_fbeta_score_arg_validation( + beta: float, + num_classes: int, + top_k: int = 1, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, +) -> None: + if not (isinstance(beta, float) and beta > 0): + raise ValueError(f"Expected argument `beta` to be a float larger than 0, but got {beta}.") + _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + + +def multiclass_fbeta_score( + preds: Tensor, + target: Tensor, + beta: float, + num_classes: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + top_k: int = 1, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes `F-score`_ metric for multiclass tasks: + + .. math:: + F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} + {(\beta^2 * \text{precision}) + \text{recall}} + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + beta: Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import multiclass_fbeta_score + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3) + tensor(0.7963) + >>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3, average=None) + tensor([0.5556, 0.8333, 1.0000]) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import multiclass_fbeta_score + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3) + tensor(0.7963) + >>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3, average=None) + tensor([0.5556, 0.8333, 1.0000]) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import multiclass_fbeta_score + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3, multidim_average='samplewise') + tensor([0.4697, 0.2706]) + >>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3, multidim_average='samplewise', average=None) + tensor([[0.9091, 0.0000, 0.5000], + [0.0000, 0.3571, 0.4545]]) + """ + if validate_args: + _multiclass_fbeta_score_arg_validation(beta, num_classes, top_k, average, multidim_average, ignore_index) + _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) + preds, target = _multiclass_stat_scores_format(preds, target, top_k) + tp, fp, tn, fn = _multiclass_stat_scores_update(preds, target, num_classes, top_k, multidim_average, ignore_index) + return _fbeta_reduce(tp, fp, tn, fn, beta, average=average, multidim_average=multidim_average) + + +def _multilabel_fbeta_score_arg_validation( + beta: float, + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, +) -> None: + if not (isinstance(beta, float) and beta > 0): + raise ValueError(f"Expected argument `beta` to be a float larger than 0, but got {beta}.") + _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) + + +def multilabel_fbeta_score( + preds: Tensor, + target: Tensor, + beta: float, + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes `F-score`_ metric for multilabel tasks: + + .. math:: + F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} + {(\beta^2 * \text{precision}) + \text{recall}} + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + beta: Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import multilabel_fbeta_score + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_fbeta_score(preds, target, beta=2.0, num_labels=3) + tensor(0.6111) + >>> multilabel_fbeta_score(preds, target, beta=2.0, num_labels=3, average=None) + tensor([1.0000, 0.0000, 0.8333]) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import multilabel_fbeta_score + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_fbeta_score(preds, target, beta=2.0, num_labels=3) + tensor(0.6111) + >>> multilabel_fbeta_score(preds, target, beta=2.0, num_labels=3, average=None) + tensor([1.0000, 0.0000, 0.8333]) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import multilabel_fbeta_score + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> multilabel_fbeta_score(preds, target, num_labels=3, beta=2.0, multidim_average='samplewise') + tensor([0.5556, 0.0000]) + >>> multilabel_fbeta_score(preds, target, num_labels=3, beta=2.0, multidim_average='samplewise', average=None) + tensor([[0.8333, 0.8333, 0.0000], + [0.0000, 0.0000, 0.0000]]) + + """ + if validate_args: + _multilabel_fbeta_score_arg_validation(beta, num_labels, threshold, average, multidim_average, ignore_index) + _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) + preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) + tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) + return _fbeta_reduce(tp, fp, tn, fn, beta, average=average, multidim_average=multidim_average) + + +def binary_f1_score( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes F-1 score for binary tasks: + + .. math:: + F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}} + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` + is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import binary_f1_score + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> binary_f1_score(preds, target) + tensor(0.6667) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import binary_f1_score + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> binary_f1_score(preds, target) + tensor(0.6667) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import binary_f1_score + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> binary_f1_score(preds, target, multidim_average='samplewise') + tensor([0.5000, 0.0000]) + """ + return binary_fbeta_score( + preds=preds, + target=target, + beta=1.0, + threshold=threshold, + multidim_average=multidim_average, + ignore_index=ignore_index, + validate_args=validate_args, + ) + + +def multiclass_f1_score( + preds: Tensor, + target: Tensor, + num_classes: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + top_k: int = 1, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes F-1 score for multiclass tasks: + + .. math:: + F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}} + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import multiclass_f1_score + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> multiclass_f1_score(preds, target, num_classes=3) + tensor(0.7778) + >>> multiclass_f1_score(preds, target, num_classes=3, average=None) + tensor([0.6667, 0.6667, 1.0000]) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import multiclass_f1_score + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> multiclass_f1_score(preds, target, num_classes=3) + tensor(0.7778) + >>> multiclass_f1_score(preds, target, num_classes=3, average=None) + tensor([0.6667, 0.6667, 1.0000]) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import multiclass_f1_score + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> multiclass_f1_score(preds, target, num_classes=3, multidim_average='samplewise') + tensor([0.4333, 0.2667]) + >>> multiclass_f1_score(preds, target, num_classes=3, multidim_average='samplewise', average=None) + tensor([[0.8000, 0.0000, 0.5000], + [0.0000, 0.4000, 0.4000]]) + """ + return multiclass_fbeta_score( + preds=preds, + target=target, + beta=1.0, + num_classes=num_classes, + average=average, + top_k=top_k, + multidim_average=multidim_average, + ignore_index=ignore_index, + validate_args=validate_args, + ) + + +def multilabel_f1_score( + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes F-1 score for multilabel tasks: + + .. math:: + F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}} + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import multilabel_f1_score + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_f1_score(preds, target, num_labels=3) + tensor(0.5556) + >>> multilabel_f1_score(preds, target, num_labels=3, average=None) + tensor([1.0000, 0.0000, 0.6667]) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import multilabel_f1_score + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_f1_score(preds, target, num_labels=3) + tensor(0.5556) + >>> multilabel_f1_score(preds, target, num_labels=3, average=None) + tensor([1.0000, 0.0000, 0.6667]) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import multilabel_f1_score + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> multilabel_f1_score(preds, target, num_labels=3, multidim_average='samplewise') + tensor([0.4444, 0.0000]) + >>> multilabel_f1_score(preds, target, num_labels=3, multidim_average='samplewise', average=None) + tensor([[0.6667, 0.6667, 0.0000], + [0.0000, 0.0000, 0.0000]]) + + """ + return multilabel_fbeta_score( + preds=preds, + target=target, + beta=1.0, + num_labels=num_labels, + threshold=threshold, + average=average, + multidim_average=multidim_average, + ignore_index=ignore_index, + validate_args=validate_args, + ) def _fbeta_compute( @@ -112,15 +791,27 @@ def fbeta_score( preds: Tensor, target: Tensor, beta: float = 1.0, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, num_classes: Optional[int] = None, threshold: float = 0.5, top_k: Optional[int] = None, multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, ) -> Tensor: r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes f_beta metric. .. math:: @@ -215,6 +906,33 @@ def fbeta_score( tensor(0.3333) """ + if task is not None: + assert multidim_average is not None + if task == "binary": + return binary_fbeta_score(preds, target, beta, threshold, multidim_average, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_fbeta_score( + preds, target, beta, num_classes, average, top_k, multidim_average, ignore_index, validate_args + ) + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_fbeta_score( + preds, target, beta, num_labels, threshold, average, multidim_average, ignore_index, validate_args + ) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) allowed_average = list(AvgMethod) if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") @@ -248,15 +966,27 @@ def f1_score( preds: Tensor, target: Tensor, beta: float = 1.0, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, num_classes: Optional[int] = None, threshold: float = 0.5, top_k: Optional[int] = None, multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, ) -> Tensor: - """Computes F1 metric. F1 metrics correspond to a equally weighted average of the precision and recall scores. + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes F1 metric. F1 metrics correspond to a equally weighted average of the precision and recall scores. Works with binary, multiclass, and multilabel data. Accepts probabilities or logits from a model output or integer class values in prediction. @@ -349,6 +1079,33 @@ def f1_score( >>> f1_score(preds, target, num_classes=3) tensor(0.3333) """ + if task is not None: + assert multidim_average is not None + if task == "binary": + return binary_f1_score(preds, target, threshold, multidim_average, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_f1_score( + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + ) + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_f1_score( + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + ) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) return fbeta_score( preds, target, 1.0, average, mdmc_average, ignore_index, num_classes, threshold, top_k, multiclass ) diff --git a/src/torchmetrics/functional/classification/hamming.py b/src/torchmetrics/functional/classification/hamming.py index 44fedd577ae..b05f48aa871 100644 --- a/src/torchmetrics/functional/classification/hamming.py +++ b/src/torchmetrics/functional/classification/hamming.py @@ -11,12 +11,380 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch from torch import Tensor +from typing_extensions import Literal +from torchmetrics.functional.classification.stat_scores import ( + _binary_stat_scores_arg_validation, + _binary_stat_scores_format, + _binary_stat_scores_tensor_validation, + _binary_stat_scores_update, + _multiclass_stat_scores_arg_validation, + _multiclass_stat_scores_format, + _multiclass_stat_scores_tensor_validation, + _multiclass_stat_scores_update, + _multilabel_stat_scores_arg_validation, + _multilabel_stat_scores_format, + _multilabel_stat_scores_tensor_validation, + _multilabel_stat_scores_update, +) from torchmetrics.utilities.checks import _input_format_classification +from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.prints import rank_zero_warn + + +def _hamming_distance_reduce( + tp: Tensor, + fp: Tensor, + tn: Tensor, + fn: Tensor, + average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]], + multidim_average: Literal["global", "samplewise"] = "global", + multilabel: bool = False, +) -> Tensor: + """Reduce classification statistics into hamming distance + Args: + tp: number of true positives + fp: number of false positives + tn: number of true negatives + fn: number of false negatives + normalize: normalization method. + + - `"true"` will divide by the sum of the column dimension. + - `"pred"` will divide by the sum of the row dimension. + - `"all"` will divide by the sum of the full matrix + - `"none"` or `None` will apply no reduction + + multilabel: bool indicating if reduction is for multilabel tasks + + Returns: + Accuracy score + """ + if average == "binary": + return 1 - _safe_divide(tp + tn, tp + fp + tn + fn) + elif average == "micro": + tp = tp.sum(dim=0 if multidim_average == "global" else 1) + fn = fn.sum(dim=0 if multidim_average == "global" else 1) + if multilabel: + fp = fp.sum(dim=0 if multidim_average == "global" else 1) + tn = tn.sum(dim=0 if multidim_average == "global" else 1) + return 1 - _safe_divide(tp + tn, tp + tn + fp + fn) + return 1 - _safe_divide(tp, tp + fn) + else: + if multilabel: + score = 1 - _safe_divide(tp + tn, tp + tn + fp + fn) + else: + score = 1 - _safe_divide(tp, tp + fn) + if average is None or average == "none": + return score + if average == "weighted": + weights = tp + fn + else: + weights = torch.ones_like(score) + return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1) + + +def binary_hamming_distance( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes the average `Hamming distance`_ (also known as Hamming loss) for binary tasks: + + .. math:: + \text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il}) + + Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, + and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that + tensor. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` + is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import binary_hamming_distance + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> binary_hamming_distance(preds, target) + tensor(0.3333) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import binary_hamming_distance + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> binary_hamming_distance(preds, target) + tensor(0.3333) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import binary_hamming_distance + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> binary_hamming_distance(preds, target, multidim_average='samplewise') + tensor([0.6667, 0.8333]) + """ + if validate_args: + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) + _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) + preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) + tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) + return _hamming_distance_reduce(tp, fp, tn, fn, average="binary", multidim_average=multidim_average) + + +def multiclass_hamming_distance( + preds: Tensor, + target: Tensor, + num_classes: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + top_k: int = 1, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes the average `Hamming distance`_ (also known as Hamming loss) for multiclass tasks: + + .. math:: + \text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il}) + + Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, + and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that + tensor. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import multiclass_hamming_distance + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> multiclass_hamming_distance(preds, target, num_classes=3) + tensor(0.1667) + >>> multiclass_hamming_distance(preds, target, num_classes=3, average=None) + tensor([0.5000, 0.0000, 0.0000]) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import multiclass_hamming_distance + >>> target = target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> multiclass_hamming_distance(preds, target, num_classes=3) + tensor(0.1667) + >>> multiclass_hamming_distance(preds, target, num_classes=3, average=None) + tensor([0.5000, 0.0000, 0.0000]) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import multiclass_hamming_distance + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> multiclass_hamming_distance(preds, target, num_classes=3, multidim_average='samplewise') + tensor([0.5000, 0.7222]) + >>> multiclass_hamming_distance(preds, target, num_classes=3, multidim_average='samplewise', average=None) + tensor([[0.0000, 1.0000, 0.5000], + [1.0000, 0.6667, 0.5000]]) + """ + if validate_args: + _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) + preds, target = _multiclass_stat_scores_format(preds, target, top_k) + tp, fp, tn, fn = _multiclass_stat_scores_update(preds, target, num_classes, top_k, multidim_average, ignore_index) + return _hamming_distance_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average) + + +def multilabel_hamming_distance( + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes the average `Hamming distance`_ (also known as Hamming loss) for multilabel tasks: + + .. math:: + \text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il}) + + Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, + and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that + tensor. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import multilabel_hamming_distance + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_hamming_distance(preds, target, num_labels=3) + tensor(0.3333) + >>> multilabel_hamming_distance(preds, target, num_labels=3, average=None) + tensor([0.0000, 0.5000, 0.5000]) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import multilabel_hamming_distance + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_hamming_distance(preds, target, num_labels=3) + tensor(0.3333) + >>> multilabel_hamming_distance(preds, target, num_labels=3, average=None) + tensor([0.0000, 0.5000, 0.5000]) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import multilabel_hamming_distance + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> multilabel_hamming_distance(preds, target, num_labels=3, multidim_average='samplewise') + tensor([0.6667, 0.8333]) + >>> multilabel_hamming_distance(preds, target, num_labels=3, multidim_average='samplewise', average=None) + tensor([[0.5000, 0.5000, 1.0000], + [1.0000, 1.0000, 0.5000]]) + + """ + if validate_args: + _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) + _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) + preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) + tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) + return _hamming_distance_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True) def _hamming_distance_update( @@ -59,8 +427,27 @@ def _hamming_distance_compute(correct: Tensor, total: Union[int, Tensor]) -> Ten return 1 - correct.float() / total -def hamming_distance(preds: Tensor, target: Tensor, threshold: float = 0.5) -> Tensor: +def hamming_distance( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + top_k: int = 1, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + Computes the average `Hamming distance`_ (also known as Hamming loss) between targets and predictions: @@ -91,6 +478,33 @@ def hamming_distance(preds: Tensor, target: Tensor, threshold: float = 0.5) -> T >>> hamming_distance(preds, target) tensor(0.2500) """ + if task is not None: + assert multidim_average is not None + if task == "binary": + return binary_hamming_distance(preds, target, threshold, multidim_average, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_hamming_distance( + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + ) + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_hamming_distance( + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + ) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) correct, total = _hamming_distance_update(preds, target, threshold) return _hamming_distance_compute(correct, total) diff --git a/src/torchmetrics/functional/classification/hinge.py b/src/torchmetrics/functional/classification/hinge.py index 4d57dcc58f1..72440df27b2 100644 --- a/src/torchmetrics/functional/classification/hinge.py +++ b/src/torchmetrics/functional/classification/hinge.py @@ -15,10 +15,232 @@ import torch from torch import Tensor, tensor - +from typing_extensions import Literal + +from torchmetrics.functional.classification.confusion_matrix import ( + _binary_confusion_matrix_format, + _binary_confusion_matrix_tensor_validation, + _multiclass_confusion_matrix_format, + _multiclass_confusion_matrix_tensor_validation, +) from torchmetrics.utilities.checks import _input_squeeze from torchmetrics.utilities.data import to_onehot from torchmetrics.utilities.enums import DataType, EnumStr +from torchmetrics.utilities.prints import rank_zero_warn + + +def _hinge_loss_compute(measure: Tensor, total: Tensor) -> Tensor: + return measure / total + + +def _binary_hinge_loss_arg_validation(squared: bool, ignore_index: Optional[int] = None) -> None: + if not isinstance(squared, bool): + raise ValueError(f"Expected argument `squared` to be an bool but got {squared}") + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + + +def _binary_hinge_loss_tensor_validation(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None) -> None: + _binary_confusion_matrix_tensor_validation(preds, target, ignore_index) + if not preds.is_floating_point(): + raise ValueError( + "Expected argument `preds` to be floating tensor with probabilities/logits" + f" but got tensor with dtype {preds.dtype}" + ) + + +def _binary_hinge_loss_update( + preds: Tensor, + target: Tensor, + squared: bool, +) -> Tuple[Tensor, Tensor]: + + target = target.bool() + margin = torch.zeros_like(preds) + margin[target] = preds[target] + margin[~target] = -preds[~target] + + measures = 1 - margin + measures = torch.clamp(measures, 0) + + if squared: + measures = measures.pow(2) + + total = tensor(target.shape[0], device=target.device) + return measures.sum(dim=0), total + + +def binary_hinge_loss( + preds: Tensor, + target: Tensor, + squared: bool = False, + ignore_index: Optional[int] = None, + validate_args: bool = False, +) -> Tensor: + r"""Computes the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs) for binary tasks. It is + defined as: + + .. math:: + \text{Hinge loss} = \max(0, 1 - y \times \hat{y}) + + Where :math:`y \in {-1, 1}` is the target, and :math:`\hat{y} \in \mathbb{R}` is the prediction. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + squared: + If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example: + >>> from torchmetrics.functional.classification import binary_hinge_loss + >>> preds = torch.tensor([0.25, 0.25, 0.55, 0.75, 0.75]) + >>> target = torch.tensor([0, 0, 1, 1, 1]) + >>> binary_hinge_loss(preds, target) + tensor(0.6900) + >>> binary_hinge_loss(preds, target, squared=True) + tensor(0.6905) + """ + if validate_args: + _binary_hinge_loss_arg_validation(squared, ignore_index) + _binary_hinge_loss_tensor_validation(preds, target, ignore_index) + preds, target = _binary_confusion_matrix_format( + preds, target, threshold=0.0, ignore_index=ignore_index, convert_to_labels=False + ) + measures, total = _binary_hinge_loss_update(preds, target, squared) + return _hinge_loss_compute(measures, total) + + +def _multiclass_hinge_loss_arg_validation( + num_classes: int, + squared: bool = False, + multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer", + ignore_index: Optional[int] = None, +) -> None: + _binary_hinge_loss_arg_validation(squared, ignore_index) + if not isinstance(num_classes, int) or num_classes < 2: + raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") + allowed_mm = ("crammer-singer", "one-vs-all") + if multiclass_mode not in allowed_mm: + raise ValueError(f"Expected argument `multiclass_mode` to be one of {allowed_mm}, but got {multiclass_mode}.") + + +def _multiclass_hinge_loss_tensor_validation( + preds: Tensor, target: Tensor, num_classes: int, ignore_index: Optional[int] = None +) -> None: + _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) + if not preds.is_floating_point(): + raise ValueError( + "Expected argument `preds` to be floating tensor with probabilities/logits" + f" but got tensor with dtype {preds.dtype}" + ) + + +def _multiclass_hinge_loss_update( + preds: Tensor, + target: Tensor, + squared: bool, + multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer", +) -> Tuple[Tensor, Tensor]: + if not torch.all((0 <= preds) * (preds <= 1)): + preds = preds.softmax(1) + + target = to_onehot(target, max(2, preds.shape[1])).bool() + if multiclass_mode == "crammer-singer": + margin = preds[target] + margin -= torch.max(preds[~target].view(preds.shape[0], -1), dim=1)[0] + else: + target = target.bool() + margin = torch.zeros_like(preds) + margin[target] = preds[target] + margin[~target] = -preds[~target] + + measures = 1 - margin + measures = torch.clamp(measures, 0) + + if squared: + measures = measures.pow(2) + + total = tensor(target.shape[0], device=target.device) + return measures.sum(dim=0), total + + +def multiclass_hinge_loss( + preds: Tensor, + target: Tensor, + num_classes: int, + squared: bool = False, + multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer", + ignore_index: Optional[int] = None, + validate_args: bool = False, +) -> Tensor: + r"""Computes the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs) for multiclass tasks + + The metric can be computed in two ways. Either, the definition by Crammer and Singer is used: + + .. math:: + \text{Hinge loss} = \max\left(0, 1 - \hat{y}_y + \max_{i \ne y} (\hat{y}_i)\right) + + Where :math:`y \in {0, ..., \mathrm{C}}` is the target class (where :math:`\mathrm{C}` is the number of classes), + and :math:`\hat{y} \in \mathbb{R}^\mathrm{C}` is the predicted output per class. Alternatively, the metric can + also be computed in one-vs-all approach, where each class is valued against all other classes in a binary fashion. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + squared: + If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss. + multiclass_mode: + Determines how to compute the metric + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example: + >>> from torchmetrics.functional.classification import multiclass_hinge_loss + >>> preds = torch.tensor([[0.25, 0.20, 0.55], + ... [0.55, 0.05, 0.40], + ... [0.10, 0.30, 0.60], + ... [0.90, 0.05, 0.05]]) + >>> target = torch.tensor([0, 1, 2, 0]) + >>> multiclass_hinge_loss(preds, target, num_classes=3) + tensor(0.9125) + >>> multiclass_hinge_loss(preds, target, num_classes=3, squared=True) + tensor(1.1131) + >>> multiclass_hinge_loss(preds, target, num_classes=3, multiclass_mode='one-vs-all') + tensor([0.8750, 1.1250, 1.1000]) + """ + if validate_args: + _multiclass_hinge_loss_arg_validation(num_classes, squared, multiclass_mode, ignore_index) + _multiclass_hinge_loss_tensor_validation(preds, target, num_classes, ignore_index) + preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index, convert_to_labels=False) + measures, total = _multiclass_hinge_loss_update(preds, target, squared, multiclass_mode) + return _hinge_loss_compute(measures, total) class MulticlassMode(EnumStr): @@ -158,9 +380,20 @@ def hinge_loss( preds: Tensor, target: Tensor, squared: bool = False, - multiclass_mode: Optional[Union[str, MulticlassMode]] = None, + multiclass_mode: Optional[Literal["crammer-singer", "one-vs-all"]] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, ) -> Tensor: r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + Computes the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs). In the binary case it is defined as: @@ -227,5 +460,26 @@ def hinge_loss( >>> hinge_loss(preds, target, multiclass_mode="one-vs-all") tensor([2.2333, 1.5000, 1.2333]) """ + if task is not None: + if task == "binary": + return binary_hinge_loss(preds, target, squared, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + assert multiclass_mode is not None + return multiclass_hinge_loss( + preds, target, num_classes, squared, multiclass_mode, ignore_index, validate_args + ) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) measure, total = _hinge_update(preds, target, squared=squared, multiclass_mode=multiclass_mode) return _hinge_compute(measure, total) diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index e834d0691c5..6369a020951 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -15,8 +15,288 @@ import torch from torch import Tensor +from typing_extensions import Literal -from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update +from torchmetrics.functional.classification.confusion_matrix import ( + _binary_confusion_matrix_arg_validation, + _binary_confusion_matrix_format, + _binary_confusion_matrix_tensor_validation, + _binary_confusion_matrix_update, + _confusion_matrix_update, + _multiclass_confusion_matrix_arg_validation, + _multiclass_confusion_matrix_format, + _multiclass_confusion_matrix_tensor_validation, + _multiclass_confusion_matrix_update, + _multilabel_confusion_matrix_arg_validation, + _multilabel_confusion_matrix_format, + _multilabel_confusion_matrix_tensor_validation, + _multilabel_confusion_matrix_update, +) +from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.prints import rank_zero_warn + + +def _jaccard_index_reduce( + confmat: Tensor, + average: Optional[Literal["micro", "macro", "weighted", "none", "binary"]], +) -> Tensor: + """Perform reduction of an un-normalized confusion matrix into jaccard score. + + Args: + confmat: tensor with un-normalized confusionmatrix + average: reduction method + + - ``'binary'``: binary reduction, expects a 2x2 matrix + - ``'macro'``: Calculate the metric for each class separately, and average the + metrics across classes (with equal weights for each class). + - ``'micro'``: Calculate the metric globally, across all samples and classes. + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics across classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + """ + allowed_average = ["binary", "micro", "macro", "weighted", "none", None] + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + confmat = confmat.float() + if average == "binary": + return confmat[1, 1] / (confmat[0, 1] + confmat[1, 0] + confmat[1, 1]) + else: + if confmat.ndim == 3: # multilabel + num = confmat[:, 1, 1] + denom = confmat[:, 1, 1] + confmat[:, 0, 1] + confmat[:, 1, 0] + else: # multiclass + num = torch.diag(confmat) + denom = confmat.sum(0) + confmat.sum(1) - num + + if average == "micro": + num = num.sum() + denom = denom.sum() + + jaccard = _safe_divide(num, denom) + + if average is None or average == "none": + return jaccard + if average == "weighted": + weights = confmat[:, 1, 1] + confmat[:, 1, 0] if confmat.ndim == 3 else confmat.sum(1) + else: + weights = torch.ones_like(jaccard) + return ((weights * jaccard) / weights.sum()).sum() + + +def binary_jaccard_index( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Calculates the Jaccard index for binary tasks. The `Jaccard index`_ (also known as + the intersetion over union or jaccard similarity coefficient) is an statistic that can be + used to determine the similarity and diversity of a sample set. It is defined as the size + of the intersection divided by the union of the sample sets: + + .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import binary_jaccard_index + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> binary_jaccard_index(preds, target) + tensor(0.5000) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import binary_jaccard_index + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) + >>> binary_jaccard_index(preds, target) + tensor(0.5000) + """ + if validate_args: + _binary_confusion_matrix_arg_validation(threshold, ignore_index) + _binary_confusion_matrix_tensor_validation(preds, target, ignore_index) + preds, target = _binary_confusion_matrix_format(preds, target, threshold, ignore_index) + confmat = _binary_confusion_matrix_update(preds, target) + return _jaccard_index_reduce(confmat, average="binary") + + +def _multiclass_jaccard_index_arg_validation( + num_classes: int, + ignore_index: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = None, +) -> None: + _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index) + allowed_average = ("micro", "macro", "weighted", "none", None) + if average not in allowed_average: + raise ValueError(f"Expected argument `average` to be one of {allowed_average}, but got {average}.") + + +def multiclass_jaccard_index( + preds: Tensor, + target: Tensor, + num_classes: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Calculates the Jaccard index for multiclass tasks. The `Jaccard index`_ (also known as + the intersetion over union or jaccard similarity coefficient) is an statistic that can be + used to determine the similarity and diversity of a sample set. It is defined as the size + of the intersection divided by the union of the sample sets: + + .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (pred is integer tensor): + >>> from torchmetrics.functional.classification import multiclass_jaccard_index + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> multiclass_jaccard_index(preds, target, num_classes=3) + tensor(0.6667) + + Example (pred is float tensor): + >>> from torchmetrics.functional.classification import multiclass_jaccard_index + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> multiclass_jaccard_index(preds, target, num_classes=3) + tensor(0.6667) + """ + if validate_args: + _multiclass_jaccard_index_arg_validation(num_classes, ignore_index, average) + _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) + preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index) + confmat = _multiclass_confusion_matrix_update(preds, target, num_classes) + return _jaccard_index_reduce(confmat, average=average) + + +def _multilabel_jaccard_index_arg_validation( + num_labels: int, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", +) -> None: + _multilabel_confusion_matrix_arg_validation(num_labels, threshold, ignore_index) + allowed_average = ("micro", "macro", "weighted", "none", None) + if average not in allowed_average: + raise ValueError(f"Expected argument `average` to be one of {allowed_average}, but got {average}.") + + +def multilabel_jaccard_index( + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Calculates the Jaccard index for multilabel tasks. The `Jaccard index`_ (also known as + the intersetion over union or jaccard similarity coefficient) is an statistic that can be + used to determine the similarity and diversity of a sample set. It is defined as the size + of the intersection divided by the union of the sample sets: + + .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import multilabel_jaccard_index + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_jaccard_index(preds, target, num_labels=3) + tensor(0.5000) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import multilabel_jaccard_index + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_jaccard_index(preds, target, num_labels=3) + tensor(0.5000) + + """ + if validate_args: + _multilabel_jaccard_index_arg_validation(num_labels, threshold, ignore_index) + _multilabel_confusion_matrix_tensor_validation(preds, target, num_labels, ignore_index) + preds, target = _multilabel_confusion_matrix_format(preds, target, num_labels, threshold, ignore_index) + confmat = _multilabel_confusion_matrix_update(preds, target, num_labels) + return _jaccard_index_reduce(confmat, average=average) def _jaccard_from_confmat( @@ -95,12 +375,23 @@ def jaccard_index( preds: Tensor, target: Tensor, num_classes: int, - average: Optional[str] = "macro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", ignore_index: Optional[int] = None, absent_score: float = 0.0, threshold: float = 0.5, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + validate_args: bool = True, ) -> Tensor: - r"""Computes `Jaccard index`_ + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes `Jaccard index`_ .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} @@ -159,6 +450,26 @@ def jaccard_index( >>> jaccard_index(pred, target, num_classes=2) tensor(0.9660) """ - + if task is not None: + if task == "binary": + return binary_jaccard_index(preds, target, threshold, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + return multiclass_jaccard_index(preds, target, num_classes, average, ignore_index, validate_args) + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_jaccard_index(preds, target, num_labels, threshold, average, ignore_index, validate_args) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) confmat = _confusion_matrix_update(preds, target, num_classes, threshold) return _jaccard_from_confmat(confmat, num_classes, average, ignore_index, absent_score) diff --git a/src/torchmetrics/functional/classification/kl_divergence.py b/src/torchmetrics/functional/classification/kl_divergence.py index 4a4da4b8b07..69853be74f4 100644 --- a/src/torchmetrics/functional/classification/kl_divergence.py +++ b/src/torchmetrics/functional/classification/kl_divergence.py @@ -11,70 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Tuple - -import torch from torch import Tensor from typing_extensions import Literal -from torchmetrics.utilities.checks import _check_same_shape -from torchmetrics.utilities.compute import _safe_xlogy - - -def _kld_update(p: Tensor, q: Tensor, log_prob: bool) -> Tuple[Tensor, int]: - """Updates and returns KL divergence scores for each observation and the total number of observations. Checks - same shape and 2D nature of the input tensors else raises ValueError. - - Args: - p: data distribution with shape ``[N, d]`` - q: prior or approximate distribution with shape ``[N, d]`` - log_prob: bool indicating if input is log-probabilities or probabilities. If given as probabilities, - will normalize to make sure the distributes sum to 1 - """ - _check_same_shape(p, q) - if p.ndim != 2 or q.ndim != 2: - raise ValueError(f"Expected both p and q distribution to be 2D but got {p.ndim} and {q.ndim} respectively") - - total = p.shape[0] - if log_prob: - measures = torch.sum(p.exp() * (p - q), axis=-1) - else: - p = p / p.sum(axis=-1, keepdim=True) - q = q / q.sum(axis=-1, keepdim=True) - measures = _safe_xlogy(p, p / q).sum(axis=-1) - - return measures, total - - -def _kld_compute(measures: Tensor, total: Tensor, reduction: Literal["mean", "sum", "none", None] = "mean") -> Tensor: - """Computes the KL divergenece based on the type of reduction. - - Args: - measures: Tensor of KL divergence scores for each observation - total: Number of observations - reduction: - Determines how to reduce over the ``N``/batch dimension: - - - ``'mean'`` [default]: Averages score across samples - - ``'sum'``: Sum score across samples - - ``'none'`` or ``None``: Returns score per sample - - Example: - >>> p = torch.tensor([[0.36, 0.48, 0.16]]) - >>> q = torch.tensor([[1/3, 1/3, 1/3]]) - >>> measures, total = _kld_update(p, q, log_prob=False) - >>> _kld_compute(measures, total) - tensor(0.0853) - """ - - if reduction == "sum": - return measures.sum() - if reduction == "mean": - return measures.sum() / total - if reduction is None or reduction == "none": - return measures - return measures / total +from torchmetrics.functional.regression.kl_divergence import kl_divergence as _kl_divergence +from torchmetrics.utilities.prints import rank_zero_warn def kl_divergence( @@ -89,6 +30,9 @@ def kl_divergence( over data and :math:`Q` is often a prior or approximation of :math:`P`. It should be noted that the KL divergence is a non-symetrical metric i.e. :math:`D_{KL}(P||Q) \neq D_{KL}(Q||P)`. + .. note:: + This metric have been moved to the regression package in v0.10 and this version will be removed in v0.11. + Args: p: data distribution with shape ``[N, d]`` q: prior or approximate distribution with shape ``[N, d]`` @@ -108,5 +52,10 @@ def kl_divergence( >>> kl_divergence(p, q) tensor(0.0853) """ - measures, total = _kld_update(p, q, log_prob) - return _kld_compute(measures, total, reduction) + rank_zero_warn( + "`torchmetrics.functional.classification.kl_divergence` have been moved to" + "`torchmetrics.functional.regression.kl_divergence` from v0.10 and this version will be removed in v0.11." + "Please update import paths.", + DeprecationWarning, + ) + return _kl_divergence(p, q, log_prob, reduction) diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index d7956ffab6d..f2ac89a2aab 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -11,10 +11,229 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import torch from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.classification.confusion_matrix import ( + _binary_confusion_matrix_arg_validation, + _binary_confusion_matrix_format, + _binary_confusion_matrix_tensor_validation, + _binary_confusion_matrix_update, + _confusion_matrix_update, + _multiclass_confusion_matrix_arg_validation, + _multiclass_confusion_matrix_format, + _multiclass_confusion_matrix_tensor_validation, + _multiclass_confusion_matrix_update, + _multilabel_confusion_matrix_arg_validation, + _multilabel_confusion_matrix_format, + _multilabel_confusion_matrix_tensor_validation, + _multilabel_confusion_matrix_update, +) +from torchmetrics.utilities.prints import rank_zero_warn + + +def _matthews_corrcoef_reduce(confmat: Tensor) -> Tensor: + """Reduce an un-normalized confusion matrix of shape (n_classes, n_classes) into the matthews corrcoef + score.""" + # convert multilabel into binary + confmat = confmat.sum(0) if confmat.ndim == 3 else confmat + + tk = confmat.sum(dim=-1).float() + pk = confmat.sum(dim=-2).float() + c = torch.trace(confmat).float() + s = confmat.sum().float() + + cov_ytyp = c * s - sum(tk * pk) + cov_ypyp = s**2 - sum(pk * pk) + cov_ytyt = s**2 - sum(tk * tk) + + denom = cov_ypyp * cov_ytyt + if denom == 0: + return torch.tensor(0, dtype=confmat.dtype, device=confmat.device) + else: + return cov_ytyp / torch.sqrt(denom) + + +def binary_matthews_corrcoef( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Calculates `Matthews correlation coefficient`_ for binary tasks. This metric measures + the general correlation or quality of a classification. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import binary_matthews_corrcoef + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> binary_matthews_corrcoef(preds, target) + tensor(0.5774) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import binary_matthews_corrcoef + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) + >>> binary_matthews_corrcoef(preds, target) + tensor(0.5774) + + """ + if validate_args: + _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize=None) + _binary_confusion_matrix_tensor_validation(preds, target, ignore_index) + preds, target = _binary_confusion_matrix_format(preds, target, threshold, ignore_index) + confmat = _binary_confusion_matrix_update(preds, target) + return _matthews_corrcoef_reduce(confmat) + + +def multiclass_matthews_corrcoef( + preds: Tensor, + target: Tensor, + num_classes: int, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Calculates `Matthews correlation coefficient`_ for multiclass tasks. This metric measures + the general correlation or quality of a classification. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of classes + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (pred is integer tensor): + >>> from torchmetrics.functional.classification import multiclass_matthews_corrcoef + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> multiclass_matthews_corrcoef(preds, target, num_classes=3) + tensor(0.7000) + + Example (pred is float tensor): + >>> from torchmetrics.functional.classification import multiclass_matthews_corrcoef + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> multiclass_matthews_corrcoef(preds, target, num_classes=3) + tensor(0.7000) + + """ + if validate_args: + _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize=None) + _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) + preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index) + confmat = _multiclass_confusion_matrix_update(preds, target, num_classes) + return _matthews_corrcoef_reduce(confmat) + + +def multilabel_matthews_corrcoef( + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Calculates `Matthews correlation coefficient`_ for multilabel tasks. This metric measures + the general correlation or quality of a classification. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import multilabel_matthews_corrcoef + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_matthews_corrcoef(preds, target, num_labels=3) + tensor(0.3333) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import multilabel_matthews_corrcoef + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_matthews_corrcoef(preds, target, num_labels=3) + tensor(0.3333) + + """ + if validate_args: + _multilabel_confusion_matrix_arg_validation(num_labels, threshold, ignore_index, normalize=None) + _multilabel_confusion_matrix_tensor_validation(preds, target, num_labels, ignore_index) + preds, target = _multilabel_confusion_matrix_format(preds, target, num_labels, threshold, ignore_index) + confmat = _multilabel_confusion_matrix_update(preds, target, num_labels) + return _matthews_corrcoef_reduce(confmat) -from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update _matthews_corrcoef_update = _confusion_matrix_update @@ -53,8 +272,19 @@ def matthews_corrcoef( target: Tensor, num_classes: int, threshold: float = 0.5, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, ) -> Tensor: r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + Calculates `Matthews correlation coefficient`_ that measures the general correlation or quality of a classification. In the binary case it is defined as: @@ -82,5 +312,26 @@ def matthews_corrcoef( tensor(0.5774) """ + if task is not None: + if task == "binary": + return binary_matthews_corrcoef(preds, target, threshold, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + return multiclass_matthews_corrcoef(preds, target, num_classes, ignore_index, validate_args) + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_matthews_corrcoef(preds, target, num_labels, threshold, ignore_index, validate_args) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) confmat = _matthews_corrcoef_update(preds, target, num_classes, threshold) return _matthews_corrcoef_compute(confmat) diff --git a/src/torchmetrics/functional/classification/precision_recall.py b/src/torchmetrics/functional/classification/precision_recall.py index 573814c7f44..a4696bb7a06 100644 --- a/src/torchmetrics/functional/classification/precision_recall.py +++ b/src/torchmetrics/functional/classification/precision_recall.py @@ -15,9 +15,643 @@ import torch from torch import Tensor - -from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores, _stat_scores_update +from typing_extensions import Literal + +from torchmetrics.functional.classification.stat_scores import ( + _binary_stat_scores_arg_validation, + _binary_stat_scores_format, + _binary_stat_scores_tensor_validation, + _binary_stat_scores_update, + _multiclass_stat_scores_arg_validation, + _multiclass_stat_scores_format, + _multiclass_stat_scores_tensor_validation, + _multiclass_stat_scores_update, + _multilabel_stat_scores_arg_validation, + _multilabel_stat_scores_format, + _multilabel_stat_scores_tensor_validation, + _multilabel_stat_scores_update, + _reduce_stat_scores, + _stat_scores_update, +) +from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod +from torchmetrics.utilities.prints import rank_zero_warn + + +def _precision_recall_reduce( + stat: Literal["precision", "recall"], + tp: Tensor, + fp: Tensor, + tn: Tensor, + fn: Tensor, + average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]], + multidim_average: Literal["global", "samplewise"] = "global", +) -> Tensor: + different_stat = fp if stat == "precision" else fn # this is what differs between the two scores + if average == "binary": + return _safe_divide(tp, tp + different_stat) + elif average == "micro": + tp = tp.sum(dim=0 if multidim_average == "global" else 1) + fn = fn.sum(dim=0 if multidim_average == "global" else 1) + different_stat = different_stat.sum(dim=0 if multidim_average == "global" else 1) + return _safe_divide(tp, tp + different_stat) + else: + score = _safe_divide(tp, tp + different_stat) + if average is None or average == "none": + return score + if average == "weighted": + weights = tp + fn + else: + weights = torch.ones_like(score) + return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1) + + +def binary_precision( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes `Precision`_ for binary tasks: + + .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} + + Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and + false positives respecitively. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` + is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import binary_precision + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> binary_precision(preds, target) + tensor(0.6667) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import binary_precision + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> binary_precision(preds, target) + tensor(0.6667) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import binary_precision + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> binary_precision(preds, target, multidim_average='samplewise') + tensor([0.4000, 0.0000]) + """ + if validate_args: + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) + _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) + preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) + tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) + return _precision_recall_reduce("precision", tp, fp, tn, fn, average="binary", multidim_average=multidim_average) + + +def multiclass_precision( + preds: Tensor, + target: Tensor, + num_classes: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + top_k: int = 1, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes `Precision`_ for multiclass tasks + + .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} + + Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and + false positives respecitively. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import multiclass_precision + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> multiclass_precision(preds, target, num_classes=3) + tensor(0.8333) + >>> multiclass_precision(preds, target, num_classes=3, average=None) + tensor([1.0000, 0.5000, 1.0000]) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import multiclass_precision + >>> target = target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> multiclass_precision(preds, target, num_classes=3) + tensor(0.8333) + >>> multiclass_precision(preds, target, num_classes=3, average=None) + tensor([1.0000, 0.5000, 1.0000]) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import multiclass_precision + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> multiclass_precision(preds, target, num_classes=3, multidim_average='samplewise') + tensor([0.3889, 0.2778]) + >>> multiclass_precision(preds, target, num_classes=3, multidim_average='samplewise', average=None) + tensor([[0.6667, 0.0000, 0.5000], + [0.0000, 0.5000, 0.3333]]) + """ + if validate_args: + _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) + preds, target = _multiclass_stat_scores_format(preds, target, top_k) + tp, fp, tn, fn = _multiclass_stat_scores_update(preds, target, num_classes, top_k, multidim_average, ignore_index) + return _precision_recall_reduce("precision", tp, fp, tn, fn, average=average, multidim_average=multidim_average) + + +def multilabel_precision( + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes `Precision`_ for multilabel tasks + + .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} + + Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and + false positives respecitively. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import multilabel_precision + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_precision(preds, target, num_labels=3) + tensor(0.5000) + >>> multilabel_precision(preds, target, num_labels=3, average=None) + tensor([1.0000, 0.0000, 0.5000]) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import multilabel_precision + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_precision(preds, target, num_labels=3) + tensor(0.5000) + >>> multilabel_precision(preds, target, num_labels=3, average=None) + tensor([1.0000, 0.0000, 0.5000]) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import multilabel_precision + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> multilabel_precision(preds, target, num_labels=3, multidim_average='samplewise') + tensor([0.3333, 0.0000]) + >>> multilabel_precision(preds, target, num_labels=3, multidim_average='samplewise', average=None) + tensor([[0.5000, 0.5000, 0.0000], + [0.0000, 0.0000, 0.0000]]) + + """ + if validate_args: + _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) + _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) + preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) + tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) + return _precision_recall_reduce("precision", tp, fp, tn, fn, average=average, multidim_average=multidim_average) + + +def binary_recall( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes `Recall`_ for binary tasks: + + .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} + + Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and + false negatives respecitively. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` + is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import binary_recall + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> binary_recall(preds, target) + tensor(0.6667) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import binary_recall + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> binary_recall(preds, target) + tensor(0.6667) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import binary_recall + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> binary_recall(preds, target, multidim_average='samplewise') + tensor([0.6667, 0.0000]) + """ + if validate_args: + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) + _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) + preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) + tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) + return _precision_recall_reduce("recall", tp, fp, tn, fn, average="binary", multidim_average=multidim_average) + + +def multiclass_recall( + preds: Tensor, + target: Tensor, + num_classes: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + top_k: int = 1, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes `Recall`_ for multiclass tasks: + + .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} + + Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and + false negatives respecitively. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import multiclass_recall + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> multiclass_recall(preds, target, num_classes=3) + tensor(0.8333) + >>> multiclass_recall(preds, target, num_classes=3, average=None) + tensor([0.5000, 1.0000, 1.0000]) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import multiclass_recall + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> multiclass_recall(preds, target, num_classes=3) + tensor(0.8333) + >>> multiclass_recall(preds, target, num_classes=3, average=None) + tensor([0.5000, 1.0000, 1.0000]) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import multiclass_recall + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> multiclass_recall(preds, target, num_classes=3, multidim_average='samplewise') + tensor([0.5000, 0.2778]) + >>> multiclass_recall(preds, target, num_classes=3, multidim_average='samplewise', average=None) + tensor([[1.0000, 0.0000, 0.5000], + [0.0000, 0.3333, 0.5000]]) + """ + if validate_args: + _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) + preds, target = _multiclass_stat_scores_format(preds, target, top_k) + tp, fp, tn, fn = _multiclass_stat_scores_update(preds, target, num_classes, top_k, multidim_average, ignore_index) + return _precision_recall_reduce("recall", tp, fp, tn, fn, average=average, multidim_average=multidim_average) + + +def multilabel_recall( + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes `Recall`_ for multilabel tasks: + + .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} + + Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and + false negatives respecitively. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import multilabel_recall + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_recall(preds, target, num_labels=3) + tensor(0.6667) + >>> multilabel_recall(preds, target, num_labels=3, average=None) + tensor([1., 0., 1.]) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import multilabel_recall + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_recall(preds, target, num_labels=3) + tensor(0.6667) + >>> multilabel_recall(preds, target, num_labels=3, average=None) + tensor([1., 0., 1.]) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import multilabel_recall + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> multilabel_recall(preds, target, num_labels=3, multidim_average='samplewise') + tensor([0.6667, 0.0000]) + >>> multilabel_recall(preds, target, num_labels=3, multidim_average='samplewise', average=None) + tensor([[1., 1., 0.], + [0., 0., 0.]]) + + """ + if validate_args: + _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) + _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) + preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) + tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) + return _precision_recall_reduce("recall", tp, fp, tn, fn, average=average, multidim_average=multidim_average) def _precision_compute( @@ -75,15 +709,27 @@ def _precision_compute( def precision( preds: Tensor, target: Tensor, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, num_classes: Optional[int] = None, threshold: float = 0.5, top_k: Optional[int] = None, multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, ) -> Tensor: - r"""Computes `Precision`_ + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes `Precision`_ .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} @@ -183,6 +829,33 @@ def precision( tensor(0.2500) """ + if task is not None: + assert multidim_average is not None + if task == "binary": + return binary_precision(preds, target, threshold, multidim_average, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_precision( + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + ) + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_precision( + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + ) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") @@ -267,15 +940,27 @@ def _recall_compute( def recall( preds: Tensor, target: Tensor, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, num_classes: Optional[int] = None, threshold: float = 0.5, top_k: Optional[int] = None, multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, ) -> Tensor: - r"""Computes `Recall`_ + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes `Recall`_ .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} @@ -376,6 +1061,33 @@ def recall( tensor(0.2500) """ + if task is not None: + assert multidim_average is not None + if task == "binary": + return binary_recall(preds, target, threshold, multidim_average, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_recall( + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + ) + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_recall( + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + ) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) allowed_average = ("micro", "macro", "weighted", "samples", "none", None) if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") @@ -409,7 +1121,7 @@ def recall( def precision_recall( preds: Tensor, target: Tensor, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index 67dddac607b..8bf982f5c01 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -11,13 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import List, Optional, Sequence, Tuple, Union import torch from torch import Tensor, tensor from torch.nn import functional as F +from typing_extensions import Literal from torchmetrics.utilities import rank_zero_warn +from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.data import _bincount def _binary_clf_curve( @@ -26,39 +31,743 @@ def _binary_clf_curve( sample_weights: Optional[Sequence] = None, pos_label: int = 1, ) -> Tuple[Tensor, Tensor, Tensor]: - """adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py.""" - if sample_weights is not None and not isinstance(sample_weights, Tensor): - sample_weights = tensor(sample_weights, device=preds.device, dtype=torch.float) + """Calculates the tps and false positives for all unique thresholds in the preds tensor. Adapted from + https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_ranking.py. + + Args: + preds: 1d tensor with predictions + target: 1d tensor with true values + sample_weights: a 1d tensor with a weight per sample + pos_label: interger determining what the positive class in target tensor is + + Returns: + fps: 1d tensor with false positives for different thresholds + tps: 1d tensor with true positives for different thresholds + thresholds: the unique thresholds use for calculating fps and tps + """ + with torch.no_grad(): + if sample_weights is not None and not isinstance(sample_weights, Tensor): + sample_weights = tensor(sample_weights, device=preds.device, dtype=torch.float) + + # remove class dimension if necessary + if preds.ndim > target.ndim: + preds = preds[:, 0] + desc_score_indices = torch.argsort(preds, descending=True) + + preds = preds[desc_score_indices] + target = target[desc_score_indices] + + if sample_weights is not None: + weight = sample_weights[desc_score_indices] + else: + weight = 1.0 + + # pred typically has many tied values. Here we extract + # the indices associated with the distinct values. We also + # concatenate a value for the end of the curve. + distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0] + threshold_idxs = F.pad(distinct_value_indices, [0, 1], value=target.size(0) - 1) + target = (target == pos_label).to(torch.long) + tps = torch.cumsum(target * weight, dim=0)[threshold_idxs] + + if sample_weights is not None: + # express fps as a cumsum to ensure fps is increasing even in + # the presence of floating point errors + fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs] + else: + fps = 1 + threshold_idxs - tps + + return fps, tps, preds[threshold_idxs] + + +def _adjust_threshold_arg( + thresholds: Optional[Union[int, List[float], Tensor]] = None, device: Optional[torch.device] = None +) -> Optional[Tensor]: + """Utility function for converting the threshold arg for list and int to tensor format.""" + if isinstance(thresholds, int): + thresholds = torch.linspace(0, 1, thresholds, device=device) + if isinstance(thresholds, list): + thresholds = torch.tensor(thresholds, device=device) + return thresholds + + +def _binary_precision_recall_curve_arg_validation( + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + """Validate non tensor input. + + - ``threshold`` has to be None | a 1d tensor | a list of floats in the [0,1] range | an int + - ``ignore_index`` has to be None or int + """ + if thresholds is not None and not isinstance(thresholds, (list, int, Tensor)): + raise ValueError( + "Expected argument `thresholds` to either be an integer, list of floats or" + f" tensor of floats, but got {thresholds}" + ) + else: + if isinstance(thresholds, int) and thresholds < 2: + raise ValueError( + f"If argument `thresholds` is an integer, expected it to be larger than 1, but got {thresholds}" + ) + if isinstance(thresholds, list) and not all(isinstance(t, float) and 0 <= t <= 1 for t in thresholds): + raise ValueError( + "If argument `thresholds` is a list, expected all elements to be floats in the [0,1] range," + f" but got {thresholds}" + ) + if isinstance(thresholds, Tensor) and not thresholds.ndim == 1: + raise ValueError("If argument `thresholds` is an tensor, expected the tensor to be 1d") + + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + + +def _binary_precision_recall_curve_tensor_validation( + preds: Tensor, target: Tensor, ignore_index: Optional[int] = None +) -> None: + """Validate tensor input. + + - tensors have to be of same shape + - all values in target tensor that are not ignored have to be in {0, 1} + - that the pred tensor is floating point + """ + _check_same_shape(preds, target) + + if not preds.is_floating_point(): + raise ValueError( + "Expected argument `preds` to be an floating tensor with probability/logit scores," + f" but got tensor with dtype {preds.dtype}" + ) + + # Check that target only contains {0,1} values or value in ignore_index + unique_values = torch.unique(target) + if ignore_index is None: + check = torch.any((unique_values != 0) & (unique_values != 1)) + else: + check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index)) + if check: + raise RuntimeError( + f"Detected the following values in `target`: {unique_values} but expected only" + f" the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." + ) + + +def _binary_precision_recall_curve_format( + preds: Tensor, + target: Tensor, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + """Convert all input to the right format. + + - flattens additional dimensions + - Remove all datapoints that should be ignored + - Applies sigmoid if pred tensor not in [0,1] range + - Format thresholds arg to be a tensor + """ + preds = preds.flatten() + target = target.flatten() + if ignore_index is not None: + idx = target != ignore_index + preds = preds[idx] + target = target[idx] + + if not torch.all((0 <= preds) * (preds <= 1)): + preds = preds.sigmoid() + + thresholds = _adjust_threshold_arg(thresholds, preds.device) + return preds, target, thresholds + + +def _binary_precision_recall_curve_update( + preds: Tensor, + target: Tensor, + thresholds: Optional[Tensor], +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Returns the state to calculate the pr-curve with. + + If thresholds is `None` the direct preds and targets are used. If thresholds is not `None` we compute a multi + threshold confusion matrix. + """ + if thresholds is None: + return preds, target + len_t = len(thresholds) + preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0)).long() # num_samples x num_thresholds + unique_mapping = preds_t + 2 * target.unsqueeze(-1) + 4 * torch.arange(len_t, device=target.device) + bins = _bincount(unique_mapping.flatten(), minlength=4 * len_t) + return bins.reshape(len_t, 2, 2) - # remove class dimension if necessary - if preds.ndim > target.ndim: - preds = preds[:, 0] - desc_score_indices = torch.argsort(preds, descending=True) - preds = preds[desc_score_indices] - target = target[desc_score_indices] +def _binary_precision_recall_curve_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + thresholds: Optional[Tensor], + pos_label: int = 1, +) -> Tuple[Tensor, Tensor, Tensor]: + """Computes the final pr-curve. - if sample_weights is not None: - weight = sample_weights[desc_score_indices] + If state is a single tensor, then we calculate the pr-curve from a multi threshold confusion matrix. If state is + original input, then we dynamically compute the binary classification curve. + """ + if isinstance(state, Tensor): + tps = state[:, 1, 1] + fps = state[:, 0, 1] + fns = state[:, 1, 0] + precision = _safe_divide(tps, tps + fps) + recall = _safe_divide(tps, tps + fns) + precision = torch.cat([precision, torch.ones(1, dtype=precision.dtype, device=precision.device)]) + recall = torch.cat([recall, torch.zeros(1, dtype=recall.dtype, device=recall.device)]) + return precision, recall, thresholds else: - weight = 1.0 - - # pred typically has many tied values. Here we extract - # the indices associated with the distinct values. We also - # concatenate a value for the end of the curve. - distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0] - threshold_idxs = F.pad(distinct_value_indices, [0, 1], value=target.size(0) - 1) - target = (target == pos_label).to(torch.long) - tps = torch.cumsum(target * weight, dim=0)[threshold_idxs] - - if sample_weights is not None: - # express fps as a cumsum to ensure fps is increasing even in - # the presence of floating point errors - fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs] + fps, tps, thresholds = _binary_clf_curve(state[0], state[1], pos_label=pos_label) + precision = tps / (tps + fps) + recall = tps / tps[-1] + + # stop when full recall attained and reverse the outputs so recall is decreasing + last_ind = torch.where(tps == tps[-1])[0][0] + sl = slice(0, last_ind.item() + 1) + + # need to call reversed explicitly, since including that to slice would + # introduce negative strides that are not yet supported in pytorch + precision = torch.cat([reversed(precision[sl]), torch.ones(1, dtype=precision.dtype, device=precision.device)]) + recall = torch.cat([reversed(recall[sl]), torch.zeros(1, dtype=recall.dtype, device=recall.device)]) + thresholds = reversed(thresholds[sl]).detach().clone() # type: ignore + + return precision, recall, thresholds + + +def binary_precision_recall_curve( + preds: Tensor, + target: Tensor, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tuple[Tensor, Tensor, Tensor]: + r""" + Computes the precision-recall curve for binary tasks. The curve consist of multiple pairs of precision and + recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of 3 tensors containing: + + - precision: an 1d tensor of size (n_thresholds+1, ) with precision values + - recall: an 1d tensor of size (n_thresholds+1, ) with recall values + - thresholds: an 1d tensor of size (n_thresholds, ) with increasing threshold values + + Example: + >>> from torchmetrics.functional.classification import binary_precision_recall_curve + >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> binary_precision_recall_curve(preds, target, thresholds=None) # doctest: +NORMALIZE_WHITESPACE + (tensor([0.6667, 0.5000, 0.0000, 1.0000]), + tensor([1.0000, 0.5000, 0.0000, 0.0000]), + tensor([0.5000, 0.7000, 0.8000])) + >>> binary_precision_recall_curve(preds, target, thresholds=5) # doctest: +NORMALIZE_WHITESPACE + (tensor([0.5000, 0.6667, 0.6667, 0.0000, 0.0000, 1.0000]), + tensor([1., 1., 1., 0., 0., 0.]), + tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])) + """ + if validate_args: + _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) + _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index) + preds, target, thresholds = _binary_precision_recall_curve_format(preds, target, thresholds, ignore_index) + state = _binary_precision_recall_curve_update(preds, target, thresholds) + return _binary_precision_recall_curve_compute(state, thresholds) + + +def _multiclass_precision_recall_curve_arg_validation( + num_classes: int, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + """Validate non tensor input. + + - ``num_classes`` has to be an int larger than 1 + - ``threshold`` has to be None | a 1d tensor | a list of floats in the [0,1] range | an int + - ``ignore_index`` has to be None or int + """ + if not isinstance(num_classes, int) or num_classes < 2: + raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") + _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) + + +def _multiclass_precision_recall_curve_tensor_validation( + preds: Tensor, target: Tensor, num_classes: int, ignore_index: Optional[int] = None +) -> None: + """Validate tensor input. + + - target should have one more dimension than preds and all dimensions except for preds.shape[1] should match + exactly. preds.shape[1] should have size equal to number of classes + - all values in target tensor that are not ignored have to be in {0, 1} + """ + if not preds.ndim == target.ndim + 1: + raise ValueError( + f"Expected `preds` to have one more dimension than `target` but got {preds.ndim} and {target.ndim}" + ) + if not preds.is_floating_point(): + raise ValueError(f"Expected `preds` to be a float tensor, but got {preds.dtype}") + if preds.shape[1] != num_classes: + raise ValueError( + "Expected `preds.shape[1]` to be equal to the number of classes but" + f" got {preds.shape[1]} and {num_classes}." + ) + if preds.shape[0] != target.shape[0] or preds.shape[2:] != target.shape[1:]: + raise ValueError( + "Expected the shape of `preds` should be (N, C, ...) and the shape of `target` should be (N, ...)" + f" but got {preds.shape} and {target.shape}" + ) + + num_unique_values = len(torch.unique(target)) + if ignore_index is None: + check = num_unique_values > num_classes else: - fps = 1 + threshold_idxs - tps + check = num_unique_values > num_classes + 1 + if check: + raise RuntimeError( + "Detected more unique values in `target` than `num_classes`. Expected only " + f"{num_classes if ignore_index is None else num_classes + 1} but found " + f"{num_unique_values} in `target`." + ) + + +def _multiclass_precision_recall_curve_format( + preds: Tensor, + target: Tensor, + num_classes: int, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + """Convert all input to the right format. + + - flattens additional dimensions + - Remove all datapoints that should be ignored + - Applies softmax if pred tensor not in [0,1] range + - Format thresholds arg to be a tensor + """ + preds = preds.transpose(0, 1).reshape(num_classes, -1).T + target = target.flatten() + + if ignore_index is not None: + idx = target != ignore_index + preds = preds[idx] + target = target[idx] + + if not torch.all((0 <= preds) * (preds <= 1)): + preds = preds.softmax(1) - return fps, tps, preds[threshold_idxs] + thresholds = _adjust_threshold_arg(thresholds, preds.device) + return preds, target, thresholds + + +def _multiclass_precision_recall_curve_update( + preds: Tensor, + target: Tensor, + num_classes: int, + thresholds: Optional[Tensor], +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Returns the state to calculate the pr-curve with. + + If thresholds is `None` the direct preds and targets are used. If thresholds is not `None` we compute a multi + threshold confusion matrix. + """ + if thresholds is None: + return preds, target + len_t = len(thresholds) + # num_samples x num_classes x num_thresholds + preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0).unsqueeze(0)).long() + target_t = torch.nn.functional.one_hot(target, num_classes=num_classes) + unique_mapping = preds_t + 2 * target_t.unsqueeze(-1) + unique_mapping += 4 * torch.arange(num_classes, device=preds.device).unsqueeze(0).unsqueeze(-1) + unique_mapping += 4 * num_classes * torch.arange(len_t, device=preds.device) + bins = _bincount(unique_mapping.flatten(), minlength=4 * num_classes * len_t) + return bins.reshape(len_t, num_classes, 2, 2) + + +def _multiclass_precision_recall_curve_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_classes: int, + thresholds: Optional[Tensor], +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + """Computes the final pr-curve. + + If state is a single tensor, then we calculate the pr-curve from a multi threshold confusion matrix. If state is + original input, then we dynamically compute the binary classification curve in an iterative way. + """ + if isinstance(state, Tensor): + tps = state[:, :, 1, 1] + fps = state[:, :, 0, 1] + fns = state[:, :, 1, 0] + precision = _safe_divide(tps, tps + fps) + recall = _safe_divide(tps, tps + fns) + precision = torch.cat([precision, torch.ones(1, num_classes, dtype=precision.dtype, device=precision.device)]) + recall = torch.cat([recall, torch.zeros(1, num_classes, dtype=recall.dtype, device=recall.device)]) + return precision.T, recall.T, thresholds + else: + precision, recall, thresholds = [], [], [] + for i in range(num_classes): + res = _binary_precision_recall_curve_compute([state[0][:, i], state[1]], thresholds=None, pos_label=i) + precision.append(res[0]) + recall.append(res[1]) + thresholds.append(res[2]) + return precision, recall, thresholds + + +def multiclass_precision_recall_curve( + preds: Tensor, + target: Tensor, + num_classes: int, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + r""" + Computes the precision-recall curve for multiclass tasks. The curve consist of multiple pairs of precision and + recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of either 3 tensors or 3 lists containing + + - precision: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) + with precision values (length may differ between classes). If `thresholds` is set to something else, + then a single 2d tensor of size (n_classes, n_thresholds+1) with precision values is returned. + - recall: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) + with recall values (length may differ between classes). If `thresholds` is set to something else, + then a single 2d tensor of size (n_classes, n_thresholds+1) with recall values is returned. + - thresholds: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds, ) + with increasing threshold values (length may differ between classes). If `threshold` is set to something else, + then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all classes. + + Example: + >>> from torchmetrics.functional.classification import multiclass_precision_recall_curve + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> precision, recall, thresholds = multiclass_precision_recall_curve( + ... preds, target, num_classes=5, thresholds=None + ... ) + >>> precision # doctest: +NORMALIZE_WHITESPACE + [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), + tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] + >>> recall + [tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])] + >>> thresholds + [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] + >>> multiclass_precision_recall_curve( + ... preds, target, num_classes=5, thresholds=5 + ... ) # doctest: +NORMALIZE_WHITESPACE + (tensor([[0.2500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000], + [0.2500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000], + [0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000], + [0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]), + tensor([[1., 1., 1., 1., 0., 0.], + [1., 1., 1., 1., 0., 0.], + [1., 0., 0., 0., 0., 0.], + [1., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0.]]), + tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])) + """ + if validate_args: + _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) + _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index) + preds, target, thresholds = _multiclass_precision_recall_curve_format( + preds, target, num_classes, thresholds, ignore_index + ) + state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds) + return _multiclass_precision_recall_curve_compute(state, num_classes, thresholds) + + +def _multilabel_precision_recall_curve_arg_validation( + num_labels: int, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + """Validate non tensor input. + + - ``num_labels`` has to be an int larger than 1 + - ``threshold`` has to be None | a 1d tensor | a list of floats in the [0,1] range | an int + - ``ignore_index`` has to be None or int + """ + _multiclass_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) + + +def _multilabel_precision_recall_curve_tensor_validation( + preds: Tensor, target: Tensor, num_labels: int, ignore_index: Optional[int] = None +) -> None: + """Validate tensor input. + + - tensors have to be of same shape + - preds.shape[1] is equal to the number of labels + - all values in target tensor that are not ignored have to be in {0, 1} + - that the pred tensor is floating point + """ + _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index) + if preds.shape[1] != num_labels: + raise ValueError( + "Expected both `target.shape[1]` and `preds.shape[1]` to be equal to the number of labels" + f" but got {preds.shape[1]} and expected {num_labels}" + ) + + +def _multilabel_precision_recall_curve_format( + preds: Tensor, + target: Tensor, + num_labels: int, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + """Convert all input to the right format. + + - flattens additional dimensions + - Mask all datapoints that should be ignored with negative values + - Applies sigmoid if pred tensor not in [0,1] range + - Format thresholds arg to be a tensor + """ + preds = preds.transpose(0, 1).reshape(num_labels, -1).T + target = target.transpose(0, 1).reshape(num_labels, -1).T + if not torch.all((0 <= preds) * (preds <= 1)): + preds = preds.sigmoid() + + thresholds = _adjust_threshold_arg(thresholds, preds.device) + if ignore_index is not None and thresholds is not None: + preds = preds.clone() + target = target.clone() + # Make sure that when we map, it will always result in a negative number that we can filter away + idx = target == ignore_index + preds[idx] = -4 * num_labels * (len(thresholds) if thresholds is not None else 1) + target[idx] = -4 * num_labels * (len(thresholds) if thresholds is not None else 1) + + return preds, target, thresholds + + +def _multilabel_precision_recall_curve_update( + preds: Tensor, + target: Tensor, + num_labels: int, + thresholds: Optional[Tensor], +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Returns the state to calculate the pr-curve with. + + If thresholds is `None` the direct preds and targets are used. If thresholds is not `None` we compute a multi + threshold confusion matrix. + """ + if thresholds is None: + return preds, target + len_t = len(thresholds) + # num_samples x num_labels x num_thresholds + preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0).unsqueeze(0)).long() + unique_mapping = preds_t + 2 * target.unsqueeze(-1) + unique_mapping += 4 * torch.arange(num_labels, device=preds.device).unsqueeze(0).unsqueeze(-1) + unique_mapping += 4 * num_labels * torch.arange(len_t, device=preds.device) + unique_mapping = unique_mapping[unique_mapping >= 0] + bins = _bincount(unique_mapping, minlength=4 * num_labels * len_t) + return bins.reshape(len_t, num_labels, 2, 2) + + +def _multilabel_precision_recall_curve_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_labels: int, + thresholds: Optional[Tensor], + ignore_index: Optional[int] = None, +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + """Computes the final pr-curve. + + If state is a single tensor, then we calculate the pr-curve from a multi threshold confusion matrix. If state is + original input, then we dynamically compute the binary classification curve in an iterative way. + """ + if isinstance(state, Tensor): + tps = state[:, :, 1, 1] + fps = state[:, :, 0, 1] + fns = state[:, :, 1, 0] + precision = _safe_divide(tps, tps + fps) + recall = _safe_divide(tps, tps + fns) + precision = torch.cat([precision, torch.ones(1, num_labels, dtype=precision.dtype, device=precision.device)]) + recall = torch.cat([recall, torch.zeros(1, num_labels, dtype=recall.dtype, device=recall.device)]) + return precision.T, recall.T, thresholds + else: + precision, recall, thresholds = [], [], [] + for i in range(num_labels): + preds = state[0][:, i] + target = state[1][:, i] + if ignore_index is not None: + idx = target == ignore_index + preds = preds[~idx] + target = target[~idx] + res = _binary_precision_recall_curve_compute([preds, target], thresholds=None, pos_label=1) + precision.append(res[0]) + recall.append(res[1]) + thresholds.append(res[2]) + return precision, recall, thresholds + + +def multilabel_precision_recall_curve( + preds: Tensor, + target: Tensor, + num_labels: int, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + r""" + Computes the precision-recall curve for multilabel tasks. The curve consist of multiple pairs of precision and + recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of either 3 tensors or 3 lists containing + + - precision: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) + with precision values (length may differ between labels). If `thresholds` is set to something else, + then a single 2d tensor of size (n_labels, n_thresholds+1) with precision values is returned. + - recall: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) + with recall values (length may differ between labels). If `thresholds` is set to something else, + then a single 2d tensor of size (n_labels, n_thresholds+1) with recall values is returned. + - thresholds: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds, ) + with increasing threshold values (length may differ between labels). If `threshold` is set to something else, + then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all labels. + + Example: + >>> from torchmetrics.functional.classification import multilabel_precision_recall_curve + >>> preds = torch.tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = torch.tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> precision, recall, thresholds = multilabel_precision_recall_curve( + ... preds, target, num_labels=3, thresholds=None + ... ) + >>> precision # doctest: +NORMALIZE_WHITESPACE + [tensor([0.5000, 0.5000, 1.0000, 1.0000]), tensor([0.6667, 0.5000, 0.0000, 1.0000]), + tensor([0.7500, 1.0000, 1.0000, 1.0000])] + >>> recall # doctest: +NORMALIZE_WHITESPACE + [tensor([1.0000, 0.5000, 0.5000, 0.0000]), tensor([1.0000, 0.5000, 0.0000, 0.0000]), + tensor([1.0000, 0.6667, 0.3333, 0.0000])] + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + [tensor([0.0500, 0.4500, 0.7500]), tensor([0.5500, 0.6500, 0.7500]), + tensor([0.0500, 0.3500, 0.7500])] + >>> multilabel_precision_recall_curve( + ... preds, target, num_labels=3, thresholds=5 + ... ) # doctest: +NORMALIZE_WHITESPACE + (tensor([[0.5000, 0.5000, 1.0000, 1.0000, 0.0000, 1.0000], + [0.5000, 0.6667, 0.6667, 0.0000, 0.0000, 1.0000], + [0.7500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000]]), + tensor([[1.0000, 0.5000, 0.5000, 0.5000, 0.0000, 0.0000], + [1.0000, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000], + [1.0000, 0.6667, 0.3333, 0.3333, 0.0000, 0.0000]]), + tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])) + """ + if validate_args: + _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) + _multilabel_precision_recall_curve_tensor_validation(preds, target, num_labels, ignore_index) + preds, target, thresholds = _multilabel_precision_recall_curve_format( + preds, target, num_labels, thresholds, ignore_index + ) + state = _multilabel_precision_recall_curve_update(preds, target, num_labels, thresholds) + return _multilabel_precision_recall_curve_compute(state, num_labels, thresholds, ignore_index) def _precision_recall_curve_update( @@ -266,8 +975,21 @@ def precision_recall_curve( num_classes: Optional[int] = None, pos_label: Optional[int] = None, sample_weights: Optional[Sequence] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: - """Computes precision-recall pairs for different thresholds. + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes precision-recall pairs for different thresholds. Args: preds: predictions from model (probabilities). @@ -327,5 +1049,28 @@ def precision_recall_curve( >>> thresholds [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] """ + if task is not None: + if task == "binary": + return binary_precision_recall_curve(preds, target, thresholds, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + return multiclass_precision_recall_curve( + preds, target, num_classes, thresholds, ignore_index, validate_args + ) + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_precision_recall_curve(preds, target, num_labels, thresholds, ignore_index, validate_args) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label) return _precision_recall_curve_compute(preds, target, num_classes, pos_label, sample_weights) diff --git a/src/torchmetrics/functional/classification/ranking.py b/src/torchmetrics/functional/classification/ranking.py index dabd5163b2f..635d6ac1b8a 100644 --- a/src/torchmetrics/functional/classification/ranking.py +++ b/src/torchmetrics/functional/classification/ranking.py @@ -16,6 +16,12 @@ import torch from torch import Tensor +from torchmetrics.functional.classification.confusion_matrix import ( + _multilabel_confusion_matrix_arg_validation, + _multilabel_confusion_matrix_format, + _multilabel_confusion_matrix_tensor_validation, +) + def _rank_data(x: Tensor) -> Tensor: """Rank data based on values.""" @@ -26,6 +32,232 @@ def _rank_data(x: Tensor) -> Tensor: return ranks[inverse] +def _ranking_reduce(score: Tensor, n_elements: int) -> Tensor: + return score / n_elements + + +def _multilabel_ranking_tensor_validation( + preds: Tensor, target: Tensor, num_labels: int, ignore_index: Optional[int] = None +) -> None: + _multilabel_confusion_matrix_tensor_validation(preds, target, num_labels, ignore_index) + if not preds.is_floating_point(): + raise ValueError(f"Expected preds tensor to be floating point, but received input with dtype {preds.dtype}") + + +def _multilabel_coverage_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: + """Accumulate state for coverage error.""" + offset = torch.zeros_like(preds) + offset[target == 0] = preds.min().abs() + 10 # Any number >1 works + preds_mod = preds + offset + preds_min = preds_mod.min(dim=1)[0] + coverage = (preds >= preds_min[:, None]).sum(dim=1).to(torch.float32) + return coverage.sum(), coverage.numel() + + +def multilabel_coverage_error( + preds: Tensor, + target: Tensor, + num_labels: int, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + """Computes multilabel coverage error [1]. The score measure how far we need to go through the ranked scores to + cover all true labels. The best value is equal to the average number of labels in the target tensor per sample. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example: + >>> from torchmetrics.functional.classification import multilabel_coverage_error + >>> _ = torch.manual_seed(42) + >>> preds = torch.rand(10, 5) + >>> target = torch.randint(2, (10, 5)) + >>> multilabel_coverage_error(preds, target, num_labels=5) + tensor(3.9000) + + References: + [1] Tsoumakas, G., Katakis, I., & Vlahavas, I. (2010). Mining multi-label data. In Data mining and + knowledge discovery handbook (pp. 667-685). Springer US. + """ + if validate_args: + _multilabel_confusion_matrix_arg_validation(num_labels, threshold=0.0, ignore_index=ignore_index) + _multilabel_ranking_tensor_validation(preds, target, num_labels, ignore_index) + preds, target = _multilabel_confusion_matrix_format( + preds, target, num_labels, threshold=0.0, ignore_index=ignore_index, should_threshold=False + ) + coverage, total = _multilabel_coverage_error_update(preds, target) + return _ranking_reduce(coverage, total) + + +def _multilabel_ranking_average_precision_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: + """Accumulate state for label ranking average precision.""" + # Invert so that the highest score receives rank 1 + neg_preds = -preds + + score = torch.tensor(0.0, device=neg_preds.device) + n_preds, n_labels = neg_preds.shape + for i in range(n_preds): + relevant = target[i] == 1 + ranking = _rank_data(neg_preds[i][relevant]).float() + if len(ranking) > 0 and len(ranking) < n_labels: + rank = _rank_data(neg_preds[i])[relevant].float() + score_idx = (ranking / rank).mean() + else: + score_idx = 1.0 + score += score_idx + return score, n_preds + + +def multilabel_ranking_average_precision( + preds: Tensor, + target: Tensor, + num_labels: int, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + """Computes label ranking average precision score for multilabel data [1]. The score is the average over each + ground truth label assigned to each sample of the ratio of true vs. total labels with lower score. Best score + is 1. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example: + >>> from torchmetrics.functional.classification import multilabel_ranking_average_precision + >>> _ = torch.manual_seed(42) + >>> preds = torch.rand(10, 5) + >>> target = torch.randint(2, (10, 5)) + >>> multilabel_ranking_average_precision(preds, target, num_labels=5) + tensor(0.7744) + + References: + [1] Tsoumakas, G., Katakis, I., & Vlahavas, I. (2010). Mining multi-label data. In Data mining and + knowledge discovery handbook (pp. 667-685). Springer US. + """ + if validate_args: + _multilabel_confusion_matrix_arg_validation(num_labels, threshold=0.0, ignore_index=ignore_index) + _multilabel_ranking_tensor_validation(preds, target, num_labels, ignore_index) + preds, target = _multilabel_confusion_matrix_format( + preds, target, num_labels, threshold=0.0, ignore_index=ignore_index, should_threshold=False + ) + score, n_elements = _multilabel_ranking_average_precision_update(preds, target) + return _ranking_reduce(score, n_elements) + + +def _multilabel_ranking_loss_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: + """Accumulate state for label ranking loss. + + Args: + preds: tensor with predictions + target: tensor with ground truth labels + sample_weight: optional tensor with weight for each sample + """ + n_preds, n_labels = preds.shape + relevant = target == 1 + n_relevant = relevant.sum(dim=1) + + # Ignore instances where number of true labels is 0 or n_labels + mask = (n_relevant > 0) & (n_relevant < n_labels) + preds = preds[mask] + relevant = relevant[mask] + n_relevant = n_relevant[mask] + + # Nothing is relevant + if len(preds) == 0: + return torch.tensor(0.0, device=preds.device), 1 + + inverse = preds.argsort(dim=1).argsort(dim=1) + per_label_loss = ((n_labels - inverse) * relevant).to(torch.float32) + correction = 0.5 * n_relevant * (n_relevant + 1) + denom = n_relevant * (n_labels - n_relevant) + loss = (per_label_loss.sum(dim=1) - correction) / denom + return loss.sum(), n_preds + + +def multilabel_ranking_loss( + preds: Tensor, + target: Tensor, + num_labels: int, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + """Computes the label ranking loss for multilabel data [1]. The score is corresponds to the average number of + label pairs that are incorrectly ordered given some predictions weighted by the size of the label set and the + number of labels not in the label set. The best score is 0. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example: + >>> from torchmetrics.functional.classification import multilabel_ranking_loss + >>> _ = torch.manual_seed(42) + >>> preds = torch.rand(10, 5) + >>> target = torch.randint(2, (10, 5)) + >>> multilabel_ranking_loss(preds, target, num_labels=5) + tensor(0.4167) + + References: + [1] Tsoumakas, G., Katakis, I., & Vlahavas, I. (2010). Mining multi-label data. In Data mining and + knowledge discovery handbook (pp. 667-685). Springer US. + """ + if validate_args: + _multilabel_confusion_matrix_arg_validation(num_labels, threshold=0.0, ignore_index=ignore_index) + _multilabel_ranking_tensor_validation(preds, target, num_labels, ignore_index) + preds, target = _multilabel_confusion_matrix_format( + preds, target, num_labels, threshold=0.0, ignore_index=ignore_index, should_threshold=False + ) + loss, n_elements = _multilabel_ranking_loss_update(preds, target) + return _ranking_reduce(loss, n_elements) + + def _check_ranking_input(preds: Tensor, target: Tensor, sample_weight: Optional[Tensor] = None) -> Tensor: """Check that ranking input have the correct dimensions.""" if preds.ndim != 2 or target.ndim != 2: diff --git a/src/torchmetrics/functional/classification/recall_at_fixed_precision.py b/src/torchmetrics/functional/classification/recall_at_fixed_precision.py new file mode 100644 index 00000000000..2cb4d6124b8 --- /dev/null +++ b/src/torchmetrics/functional/classification/recall_at_fixed_precision.py @@ -0,0 +1,366 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple, Union + +import torch +from torch import Tensor + +from torchmetrics.functional.classification.precision_recall_curve import ( + _binary_precision_recall_curve_arg_validation, + _binary_precision_recall_curve_compute, + _binary_precision_recall_curve_format, + _binary_precision_recall_curve_tensor_validation, + _binary_precision_recall_curve_update, + _multiclass_precision_recall_curve_arg_validation, + _multiclass_precision_recall_curve_compute, + _multiclass_precision_recall_curve_format, + _multiclass_precision_recall_curve_tensor_validation, + _multiclass_precision_recall_curve_update, + _multilabel_precision_recall_curve_arg_validation, + _multilabel_precision_recall_curve_compute, + _multilabel_precision_recall_curve_format, + _multilabel_precision_recall_curve_tensor_validation, + _multilabel_precision_recall_curve_update, +) + + +def _recall_at_precision( + precision: Tensor, + recall: Tensor, + thresholds: Tensor, + min_precision: float, +) -> Tuple[Tensor, Tensor]: + try: + max_recall, _, best_threshold = max( + (r, p, t) for p, r, t in zip(precision, recall, thresholds) if p >= min_precision + ) + + except ValueError: + max_recall = torch.tensor(0.0, device=recall.device, dtype=recall.dtype) + best_threshold = torch.tensor(0) + + if max_recall == 0.0: + best_threshold = torch.tensor(1e6, device=thresholds.device, dtype=thresholds.dtype) + + return max_recall, best_threshold + + +def _binary_recall_at_fixed_precision_arg_validation( + min_precision: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) + if not isinstance(min_precision, float) and not (0 <= min_precision <= 1): + raise ValueError( + f"Expected argument `min_precision` to be an float in the [0,1] range, but got {min_precision}" + ) + + +def _binary_recall_at_fixed_precision_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + thresholds: Optional[Tensor], + min_precision: float, + pos_label: int = 1, +) -> Tuple[Tensor, Tensor]: + precision, recall, thresholds = _binary_precision_recall_curve_compute(state, thresholds, pos_label) + return _recall_at_precision(precision, recall, thresholds, min_precision) + + +def binary_recall_at_fixed_precision( + preds: Tensor, + target: Tensor, + min_precision: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tuple[Tensor, Tensor]: + r""" + Computes the higest possible recall value given the minimum precision thresholds provided. This is done by + first calculating the precision-recall curve for different thresholds and the find the recall for a given + precision level. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + min_precision: float value specifying minimum precision threshold. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of 2 tensors containing: + + - recall: an scalar tensor with the maximum recall for the given precision level + - threshold: an scalar tensor with the corresponding threshold level + + Example: + >>> from torchmetrics.functional.classification import binary_recall_at_fixed_precision + >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> binary_recall_at_fixed_precision(preds, target, min_precision=0.5, thresholds=None) + (tensor(1.), tensor(0.5000)) + >>> binary_recall_at_fixed_precision(preds, target, min_precision=0.5, thresholds=5) + (tensor(1.), tensor(0.5000)) + """ + if validate_args: + _binary_recall_at_fixed_precision_arg_validation(min_precision, thresholds, ignore_index) + _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index) + preds, target, thresholds = _binary_precision_recall_curve_format(preds, target, thresholds, ignore_index) + state = _binary_precision_recall_curve_update(preds, target, thresholds) + return _binary_recall_at_fixed_precision_compute(state, thresholds, min_precision) + + +def _multiclass_recall_at_fixed_precision_arg_validation( + num_classes: int, + min_precision: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) + if not isinstance(min_precision, float) and not (0 <= min_precision <= 1): + raise ValueError( + f"Expected argument `min_precision` to be an float in the [0,1] range, but got {min_precision}" + ) + + +def _multiclass_recall_at_fixed_precision_arg_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_classes: int, + thresholds: Optional[Tensor], + min_precision: float, +) -> Tuple[Tensor, Tensor]: + precision, recall, thresholds = _multiclass_precision_recall_curve_compute(state, num_classes, thresholds) + if isinstance(state, Tensor): + res = [_recall_at_precision(p, r, thresholds, min_precision) for p, r in zip(precision, recall)] + else: + res = [_recall_at_precision(p, r, t, min_precision) for p, r, t in zip(precision, recall, thresholds)] + recall = torch.stack([r[0] for r in res]) + thresholds = torch.stack([r[1] for r in res]) + return recall, thresholds + + +def multiclass_recall_at_fixed_precision( + preds: Tensor, + target: Tensor, + num_classes: int, + min_precision: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tuple[Tensor, Tensor]: + r""" + Computes the higest possible recall value given the minimum precision thresholds provided. This is done by + first calculating the precision-recall curve for different thresholds and the find the recall for a given + precision level. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + min_precision: float value specifying minimum precision threshold. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of either 2 tensors or 2 lists containing + + - recall: an 1d tensor of size (n_classes, ) with the maximum recall for the given precision level per class + - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class + + Example: + >>> from torchmetrics.functional.classification import multiclass_recall_at_fixed_precision + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> multiclass_recall_at_fixed_precision(preds, target, num_classes=5, min_precision=0.5, thresholds=None) + (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06])) + >>> multiclass_recall_at_fixed_precision(preds, target, num_classes=5, min_precision=0.5, thresholds=5) + (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06])) + """ + if validate_args: + _multiclass_recall_at_fixed_precision_arg_validation(num_classes, min_precision, thresholds, ignore_index) + _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index) + preds, target, thresholds = _multiclass_precision_recall_curve_format( + preds, target, num_classes, thresholds, ignore_index + ) + state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds) + return _multiclass_recall_at_fixed_precision_arg_compute(state, num_classes, thresholds, min_precision) + + +def _multilabel_recall_at_fixed_precision_arg_validation( + num_labels: int, + min_precision: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) + if not isinstance(min_precision, float) and not (0 <= min_precision <= 1): + raise ValueError( + f"Expected argument `min_precision` to be an float in the [0,1] range, but got {min_precision}" + ) + + +def _multilabel_recall_at_fixed_precision_arg_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_labels: int, + thresholds: Optional[Tensor], + ignore_index: Optional[int], + min_precision: float, +) -> Tuple[Tensor, Tensor]: + precision, recall, thresholds = _multilabel_precision_recall_curve_compute( + state, num_labels, thresholds, ignore_index + ) + if isinstance(state, Tensor): + res = [_recall_at_precision(p, r, thresholds, min_precision) for p, r in zip(precision, recall)] + else: + res = [_recall_at_precision(p, r, t, min_precision) for p, r, t in zip(precision, recall, thresholds)] + recall = torch.stack([r[0] for r in res]) + thresholds = torch.stack([r[1] for r in res]) + return recall, thresholds + + +def multilabel_recall_at_fixed_precision( + preds: Tensor, + target: Tensor, + num_labels: int, + min_precision: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tuple[Tensor, Tensor]: + r""" + Computes the higest possible recall value given the minimum precision thresholds provided. This is done by + first calculating the precision-recall curve for different thresholds and the find the recall for a given + precision level. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + min_precision: float value specifying minimum precision threshold. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of either 2 tensors or 2 lists containing + + - recall: an 1d tensor of size (n_classes, ) with the maximum recall for the given precision level per class + - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class + + Example: + >>> from torchmetrics.functional.classification import multilabel_recall_at_fixed_precision + >>> preds = torch.tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = torch.tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> multilabel_recall_at_fixed_precision(preds, target, num_labels=3, min_precision=0.5, thresholds=None) + (tensor([1., 1., 1.]), tensor([0.0500, 0.5500, 0.0500])) + >>> multilabel_recall_at_fixed_precision(preds, target, num_labels=3, min_precision=0.5, thresholds=5) + (tensor([1., 1., 1.]), tensor([0.0000, 0.5000, 0.0000])) + """ + if validate_args: + _multilabel_recall_at_fixed_precision_arg_validation(num_labels, min_precision, thresholds, ignore_index) + _multilabel_precision_recall_curve_tensor_validation(preds, target, num_labels, ignore_index) + preds, target, thresholds = _multilabel_precision_recall_curve_format( + preds, target, num_labels, thresholds, ignore_index + ) + state = _multilabel_precision_recall_curve_update(preds, target, num_labels, thresholds) + return _multilabel_recall_at_fixed_precision_arg_compute(state, num_labels, thresholds, ignore_index, min_precision) diff --git a/src/torchmetrics/functional/classification/roc.py b/src/torchmetrics/functional/classification/roc.py index 0e5ce0b58f0..fe01fd869a9 100644 --- a/src/torchmetrics/functional/classification/roc.py +++ b/src/torchmetrics/functional/classification/roc.py @@ -15,12 +15,413 @@ import torch from torch import Tensor +from typing_extensions import Literal from torchmetrics.functional.classification.precision_recall_curve import ( _binary_clf_curve, + _binary_precision_recall_curve_arg_validation, + _binary_precision_recall_curve_format, + _binary_precision_recall_curve_tensor_validation, + _binary_precision_recall_curve_update, + _multiclass_precision_recall_curve_arg_validation, + _multiclass_precision_recall_curve_format, + _multiclass_precision_recall_curve_tensor_validation, + _multiclass_precision_recall_curve_update, + _multilabel_precision_recall_curve_arg_validation, + _multilabel_precision_recall_curve_format, + _multilabel_precision_recall_curve_tensor_validation, + _multilabel_precision_recall_curve_update, _precision_recall_curve_update, ) from torchmetrics.utilities import rank_zero_warn +from torchmetrics.utilities.compute import _safe_divide + + +def _binary_roc_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + thresholds: Optional[Tensor], + pos_label: int = 1, +) -> Tuple[Tensor, Tensor, Tensor]: + if isinstance(state, Tensor) and thresholds is not None: + tps = state[:, 1, 1] + fps = state[:, 0, 1] + fns = state[:, 1, 0] + tns = state[:, 0, 0] + tpr = _safe_divide(tps, tps + fns).flip(0) + fpr = _safe_divide(fps, fps + tns).flip(0) + thresholds = thresholds.flip(0) + else: + fps, tps, thresholds = _binary_clf_curve(preds=state[0], target=state[1], pos_label=pos_label) + # Add an extra threshold position to make sure that the curve starts at (0, 0) + tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps]) + fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps]) + thresholds = torch.cat([torch.ones(1, dtype=thresholds.dtype, device=thresholds.device), thresholds]) + + if fps[-1] <= 0: + rank_zero_warn( + "No negative samples in targets, false positive value should be meaningless." + " Returning zero tensor in false positive score", + UserWarning, + ) + fpr = torch.zeros_like(thresholds) + else: + fpr = fps / fps[-1] + + if tps[-1] <= 0: + rank_zero_warn( + "No positive samples in targets, true positive value should be meaningless." + " Returning zero tensor in true positive score", + UserWarning, + ) + tpr = torch.zeros_like(thresholds) + else: + tpr = tps / tps[-1] + + return fpr, tpr, thresholds + + +def binary_roc( + preds: Tensor, + target: Tensor, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tuple[Tensor, Tensor, Tensor]: + r""" + Computes the Receiver Operating Characteristic (ROC) for binary tasks. The curve consist of multiple + pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, + such that the tradeoff between the two values can be seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Note that outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which + are sorted in reversed order during their calculation, such that they are monotome increasing. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of 3 tensors containing: + + - fpr: an 1d tensor of size (n_thresholds+1, ) with false positive rate values + - tpr: an 1d tensor of size (n_thresholds+1, ) with true positive rate values + - thresholds: an 1d tensor of size (n_thresholds, ) with decreasing threshold values + + Example: + >>> from torchmetrics.functional.classification import binary_roc + >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> binary_roc(preds, target, thresholds=None) # doctest: +NORMALIZE_WHITESPACE + (tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]), + tensor([0.0000, 0.0000, 0.5000, 1.0000, 1.0000]), + tensor([1.0000, 0.8000, 0.7000, 0.5000, 0.0000])) + >>> binary_roc(preds, target, thresholds=5) # doctest: +NORMALIZE_WHITESPACE + (tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]), + tensor([0., 0., 1., 1., 1.]), + tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000])) + """ + if validate_args: + _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) + _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index) + preds, target, thresholds = _binary_precision_recall_curve_format(preds, target, thresholds, ignore_index) + state = _binary_precision_recall_curve_update(preds, target, thresholds) + return _binary_roc_compute(state, thresholds) + + +def _multiclass_roc_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_classes: int, + thresholds: Optional[Tensor], +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + if isinstance(state, Tensor) and thresholds is not None: + tps = state[:, :, 1, 1] + fps = state[:, :, 0, 1] + fns = state[:, :, 1, 0] + tns = state[:, :, 0, 0] + tpr = _safe_divide(tps, tps + fns).flip(0).T + fpr = _safe_divide(fps, fps + tns).flip(0).T + thresholds = thresholds.flip(0) + else: + fpr, tpr, thresholds = [], [], [] + for i in range(num_classes): + res = _binary_roc_compute([state[0][:, i], state[1]], thresholds=None, pos_label=i) + fpr.append(res[0]) + tpr.append(res[1]) + thresholds.append(res[2]) + return fpr, tpr, thresholds + + +def multiclass_roc( + preds: Tensor, + target: Tensor, + num_classes: int, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + r""" + Computes the Receiver Operating Characteristic (ROC) for binary tasks. The curve consist of multiple + pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, + such that the tradeoff between the two values can be seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Note that outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which + are sorted in reversed order during their calculation, such that they are monotome increasing. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of either 3 tensors or 3 lists containing + + - fpr: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) + with false positive rate values (length may differ between classes). If `thresholds` is set to something else, + then a single 2d tensor of size (n_classes, n_thresholds+1) with false positive rate values is returned. + - tpr: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) + with true positive rate values (length may differ between classes). If `thresholds` is set to something else, + then a single 2d tensor of size (n_classes, n_thresholds+1) with true positive rate values is returned. + - thresholds: if `thresholds=None` a list for each class is returned with an 1d tensor of size (n_thresholds, ) + with decreasing threshold values (length may differ between classes). If `threshold` is set to something else, + then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all classes. + + Example: + >>> from torchmetrics.functional.classification import multiclass_roc + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> fpr, tpr, thresholds = multiclass_roc( + ... preds, target, num_classes=5, thresholds=None + ... ) + >>> fpr # doctest: +NORMALIZE_WHITESPACE + [tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), + tensor([0.0000, 0.3333, 1.0000]), tensor([0., 1.])] + >>> tpr + [tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0., 0.])] + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + [tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500]), + tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.0500])] + >>> multiclass_roc( + ... preds, target, num_classes=5, thresholds=5 + ... ) # doctest: +NORMALIZE_WHITESPACE + (tensor([[0.0000, 0.0000, 0.0000, 0.0000, 1.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 1.0000], + [0.0000, 0.3333, 0.3333, 0.3333, 1.0000], + [0.0000, 0.3333, 0.3333, 0.3333, 1.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]), + tensor([[0., 1., 1., 1., 1.], + [0., 1., 1., 1., 1.], + [0., 0., 0., 0., 1.], + [0., 0., 0., 0., 1.], + [0., 0., 0., 0., 0.]]), + tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000])) + """ + if validate_args: + _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) + _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index) + preds, target, thresholds = _multiclass_precision_recall_curve_format( + preds, target, num_classes, thresholds, ignore_index + ) + state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds) + return _multiclass_roc_compute(state, num_classes, thresholds) + + +def _multilabel_roc_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_labels: int, + thresholds: Optional[Tensor], + ignore_index: Optional[int] = None, +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + if isinstance(state, Tensor) and thresholds is not None: + tps = state[:, :, 1, 1] + fps = state[:, :, 0, 1] + fns = state[:, :, 1, 0] + tns = state[:, :, 0, 0] + tpr = _safe_divide(tps, tps + fns).flip(0).T + fpr = _safe_divide(fps, fps + tns).flip(0).T + thresholds = thresholds.flip(0) + else: + fpr, tpr, thresholds = [], [], [] + for i in range(num_labels): + preds = state[0][:, i] + target = state[1][:, i] + if ignore_index is not None: + idx = target == ignore_index + preds = preds[~idx] + target = target[~idx] + res = _binary_roc_compute([preds, target], thresholds=None, pos_label=1) + fpr.append(res[0]) + tpr.append(res[1]) + thresholds.append(res[2]) + return fpr, tpr, thresholds + + +def multilabel_roc( + preds: Tensor, + target: Tensor, + num_labels: int, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + r""" + Computes the Receiver Operating Characteristic (ROC) for binary tasks. The curve consist of multiple + pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, + such that the tradeoff between the two values can be seen. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Note that outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which + are sorted in reversed order during their calculation, such that they are monotome increasing. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of either 3 tensors or 3 lists containing + + - fpr: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) + with false positive rate values (length may differ between labels). If `thresholds` is set to something else, + then a single 2d tensor of size (n_labels, n_thresholds+1) with false positive rate values is returned. + - tpr: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) + with true positive rate values (length may differ between labels). If `thresholds` is set to something else, + then a single 2d tensor of size (n_labels, n_thresholds+1) with true positive rate values is returned. + - thresholds: if `thresholds=None` a list for each label is returned with an 1d tensor of size (n_thresholds, ) + with decreasing threshold values (length may differ between labels). If `threshold` is set to something else, + then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all labels. + + Example: + >>> from torchmetrics.functional.classification import multilabel_roc + >>> preds = torch.tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = torch.tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> fpr, tpr, thresholds = multilabel_roc( + ... preds, target, num_labels=3, thresholds=None + ... ) + >>> fpr # doctest: +NORMALIZE_WHITESPACE + [tensor([0.0000, 0.0000, 0.5000, 1.0000]), + tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]), + tensor([0., 0., 0., 1.])] + >>> tpr # doctest: +NORMALIZE_WHITESPACE + [tensor([0.0000, 0.5000, 0.5000, 1.0000]), + tensor([0.0000, 0.0000, 0.5000, 1.0000, 1.0000]), + tensor([0.0000, 0.3333, 0.6667, 1.0000])] + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + [tensor([1.0000, 0.7500, 0.4500, 0.0500]), + tensor([1.0000, 0.7500, 0.6500, 0.5500, 0.0500]), + tensor([1.0000, 0.7500, 0.3500, 0.0500])] + >>> multilabel_roc( + ... preds, target, num_labels=3, thresholds=5 + ... ) # doctest: +NORMALIZE_WHITESPACE + (tensor([[0.0000, 0.0000, 0.0000, 0.5000, 1.0000], + [0.0000, 0.5000, 0.5000, 0.5000, 1.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]), + tensor([[0.0000, 0.5000, 0.5000, 0.5000, 1.0000], + [0.0000, 0.0000, 1.0000, 1.0000, 1.0000], + [0.0000, 0.3333, 0.3333, 0.6667, 1.0000]]), + tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000])) + """ + if validate_args: + _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) + _multilabel_precision_recall_curve_tensor_validation(preds, target, num_labels, ignore_index) + preds, target, thresholds = _multilabel_precision_recall_curve_format( + preds, target, num_labels, thresholds, ignore_index + ) + state = _multilabel_precision_recall_curve_update(preds, target, num_labels, thresholds) + return _multilabel_roc_compute(state, num_labels, thresholds, ignore_index) def _roc_update( @@ -200,8 +601,21 @@ def roc( num_classes: Optional[int] = None, pos_label: Optional[int] = None, sample_weights: Optional[Sequence] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: - """Computes the Receiver Operating Characteristic (ROC). Works with both binary, multiclass and multilabel + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes the Receiver Operating Characteristic (ROC). Works with both binary, multiclass and multilabel input. .. note:: @@ -278,5 +692,26 @@ def roc( tensor([1.7576, 0.7576, 0.3680, 0.3468, 0.0745]), tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])] """ + if task is not None: + if task == "binary": + return binary_roc(preds, target, thresholds, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + return multiclass_roc(preds, target, num_classes, thresholds, ignore_index, validate_args) + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_roc(preds, target, num_labels, thresholds, ignore_index, validate_args) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) preds, target, num_classes, pos_label = _roc_update(preds, target, num_classes, pos_label) return _roc_compute(preds, target, num_classes, pos_label, sample_weights) diff --git a/src/torchmetrics/functional/classification/specificity.py b/src/torchmetrics/functional/classification/specificity.py index 193f01fe395..11c5d0a592f 100644 --- a/src/torchmetrics/functional/classification/specificity.py +++ b/src/torchmetrics/functional/classification/specificity.py @@ -15,9 +15,345 @@ import torch from torch import Tensor - -from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores, _stat_scores_update +from typing_extensions import Literal + +from torchmetrics.functional.classification.stat_scores import ( + _binary_stat_scores_arg_validation, + _binary_stat_scores_format, + _binary_stat_scores_tensor_validation, + _binary_stat_scores_update, + _multiclass_stat_scores_arg_validation, + _multiclass_stat_scores_format, + _multiclass_stat_scores_tensor_validation, + _multiclass_stat_scores_update, + _multilabel_stat_scores_arg_validation, + _multilabel_stat_scores_format, + _multilabel_stat_scores_tensor_validation, + _multilabel_stat_scores_update, + _reduce_stat_scores, + _stat_scores_update, +) +from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod +from torchmetrics.utilities.prints import rank_zero_warn + + +def _specificity_reduce( + tp: Tensor, + fp: Tensor, + tn: Tensor, + fn: Tensor, + average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]], + multidim_average: Literal["global", "samplewise"] = "global", +) -> Tensor: + if average == "binary": + return _safe_divide(tn, tn + fp) + elif average == "micro": + tn = tn.sum(dim=0 if multidim_average == "global" else 1) + fp = fp.sum(dim=0 if multidim_average == "global" else 1) + return _safe_divide(tn, tn + fp) + else: + specificity_score = _safe_divide(tn, tn + fp) + if average is None or average == "none": + return specificity_score + if average == "weighted": + weights = tp + fn + else: + weights = torch.ones_like(specificity_score) + return _safe_divide(weights * specificity_score, weights.sum(-1, keepdim=True)).sum(-1) + + +def binary_specificity( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes `Specificity`_ for binary tasks: + + .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and + false positives respecitively. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` + is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import binary_specificity + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> binary_specificity(preds, target) + tensor(0.6667) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import binary_specificity + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> binary_specificity(preds, target) + tensor(0.6667) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import binary_specificity + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> binary_specificity(preds, target, multidim_average='samplewise') + tensor([0.0000, 0.3333]) + """ + if validate_args: + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) + _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) + preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) + tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) + return _specificity_reduce(tp, fp, tn, fn, average="binary", multidim_average=multidim_average) + + +def multiclass_specificity( + preds: Tensor, + target: Tensor, + num_classes: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + top_k: int = 1, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes `Specificity`_ for multiclass tasks: + + .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and + false positives respecitively. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import multiclass_specificity + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> multiclass_specificity(preds, target, num_classes=3) + tensor(0.8889) + >>> multiclass_specificity(preds, target, num_classes=3, average=None) + tensor([1.0000, 0.6667, 1.0000]) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import multiclass_specificity + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> multiclass_specificity(preds, target, num_classes=3) + tensor(0.8889) + >>> multiclass_specificity(preds, target, num_classes=3, average=None) + tensor([1.0000, 0.6667, 1.0000]) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import multiclass_specificity + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> multiclass_specificity(preds, target, num_classes=3, multidim_average='samplewise') + tensor([0.7500, 0.6556]) + >>> multiclass_specificity(preds, target, num_classes=3, multidim_average='samplewise', average=None) + tensor([[0.7500, 0.7500, 0.7500], + [0.8000, 0.6667, 0.5000]]) + """ + if validate_args: + _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) + preds, target = _multiclass_stat_scores_format(preds, target, top_k) + tp, fp, tn, fn = _multiclass_stat_scores_update(preds, target, num_classes, top_k, multidim_average, ignore_index) + return _specificity_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average) + + +def multilabel_specificity( + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes `Specificity`_ for multilabel tasks + + .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and + false positives respecitively. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import multilabel_specificity + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_specificity(preds, target, num_labels=3) + tensor(0.6667) + >>> multilabel_specificity(preds, target, num_labels=3, average=None) + tensor([1., 1., 0.]) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import multilabel_specificity + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_specificity(preds, target, num_labels=3) + tensor(0.6667) + >>> multilabel_specificity(preds, target, num_labels=3, average=None) + tensor([1., 1., 0.]) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import multilabel_specificity + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> multilabel_specificity(preds, target, num_labels=3, multidim_average='samplewise') + tensor([0.0000, 0.3333]) + >>> multilabel_specificity(preds, target, num_labels=3, multidim_average='samplewise', average=None) + tensor([[0., 0., 0.], + [0., 0., 1.]]) + """ + if validate_args: + _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) + _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) + preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) + tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) + return _specificity_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average) def _specificity_compute( @@ -70,15 +406,27 @@ def _specificity_compute( def specificity( preds: Tensor, target: Tensor, - average: Optional[str] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, num_classes: Optional[int] = None, threshold: float = 0.5, top_k: Optional[int] = None, multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, ) -> Tensor: - r"""Computes `Specificity`_ + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes `Specificity`_ .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}} @@ -177,7 +525,33 @@ def specificity( >>> specificity(preds, target, average='micro') tensor(0.6250) """ - + if task is not None: + assert multidim_average is not None + if task == "binary": + return binary_specificity(preds, target, threshold, multidim_average, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_specificity( + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + ) + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_specificity( + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + ) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) allowed_average = ("micro", "macro", "weighted", "samples", "none", None) if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index b3cb7786e49..fbd12823073 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -15,9 +15,800 @@ import torch from torch import Tensor, tensor +from typing_extensions import Literal -from torchmetrics.utilities.checks import _input_format_classification +from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification +from torchmetrics.utilities.data import _bincount, _movedim, select_topk from torchmetrics.utilities.enums import AverageMethod, DataType, MDMCAverageMethod +from torchmetrics.utilities.prints import rank_zero_warn + + +def _binary_stat_scores_arg_validation( + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, +) -> None: + """Validate non tensor input. + + - ``threshold`` has to be a float in the [0,1] range + - ``multidim_average`` has to be either "global" or "samplewise" + - ``ignore_index`` has to be None or int + """ + if not (isinstance(threshold, float) and (0 <= threshold <= 1)): + raise ValueError(f"Expected argument `threshold` to be a float in the [0,1] range, but got {threshold}.") + allowed_multidim_average = ("global", "samplewise") + if multidim_average not in allowed_multidim_average: + raise ValueError( + f"Expected argument `multidim_average` to be one of {allowed_multidim_average}, but got {multidim_average}" + ) + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + + +def _binary_stat_scores_tensor_validation( + preds: Tensor, + target: Tensor, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, +) -> None: + """Validate tensor input. + + - tensors have to be of same shape + - all values in target tensor that are not ignored have to be in {0, 1} + - if pred tensor is not floating point, then all values also have to be in {0, 1} + - if ``multidim_average`` is set to ``samplewise`` preds tensor needs to be atleast 2 dimensional + """ + # Check that they have same shape + _check_same_shape(preds, target) + + # Check that target only contains [0,1] values or value in ignore_index + unique_values = torch.unique(target) + if ignore_index is None: + check = torch.any((unique_values != 0) & (unique_values != 1)) + else: + check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index)) + if check: + raise RuntimeError( + f"Detected the following values in `target`: {unique_values} but expected only" + f" the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." + ) + + # If preds is label tensor, also check that it only contains [0,1] values + if not preds.is_floating_point(): + unique_values = torch.unique(preds) + if torch.any((unique_values != 0) & (unique_values != 1)): + raise RuntimeError( + f"Detected the following values in `preds`: {unique_values} but expected only" + " the following values [0,1] since `preds` is a label tensor." + ) + + if multidim_average != "global" and preds.ndim < 2: + raise ValueError("Expected input to be atleast 2D when multidim_average is set to `samplewise`") + + +def _binary_stat_scores_format( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + ignore_index: Optional[int] = None, +) -> Tuple[Tensor, Tensor]: + """Convert all input to label format. + + - If preds tensor is floating point, applies sigmoid if pred tensor not in [0,1] range + - If preds tensor is floating point, thresholds afterwards + - Mask all datapoints that should be ignored with negative values + """ + if preds.is_floating_point(): + if not torch.all((0 <= preds) * (preds <= 1)): + # preds is logits, convert with sigmoid + preds = preds.sigmoid() + preds = preds > threshold + + preds = preds.reshape(preds.shape[0], -1) + target = target.reshape(target.shape[0], -1) + + if ignore_index is not None: + idx = target == ignore_index + target = target.clone() + target[idx] = -1 + + return preds, target + + +def _binary_stat_scores_update( + preds: Tensor, + target: Tensor, + multidim_average: Literal["global", "samplewise"] = "global", +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Computes the statistics.""" + sum_dim = [0, 1] if multidim_average == "global" else 1 + tp = ((target == preds) & (target == 1)).sum(sum_dim).squeeze() + fn = ((target != preds) & (target == 1)).sum(sum_dim).squeeze() + fp = ((target != preds) & (target == 0)).sum(sum_dim).squeeze() + tn = ((target == preds) & (target == 0)).sum(sum_dim).squeeze() + return tp, fp, tn, fn + + +def _binary_stat_scores_compute( + tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, multidim_average: Literal["global", "samplewise"] = "global" +) -> Tensor: + """Stack statistics and compute support also.""" + return torch.stack([tp, fp, tn, fn, tp + fn], dim=0 if multidim_average == "global" else 1).squeeze() + + +def binary_stat_scores( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes the number of true positives, false positives, true negatives, false negatives and the support + for binary tasks. Related to `Type I and Type II errors`_. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape + depends on the ``multidim_average`` parameter: + + - If ``multidim_average`` is set to ``global``, the shape will be ``(5,)`` + - If ``multidim_average`` is set to ``samplewise``, the shape will be ``(N, 5)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import binary_stat_scores + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> binary_stat_scores(preds, target) + tensor([2, 1, 2, 1, 3]) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import binary_stat_scores + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> binary_stat_scores(preds, target) + tensor([2, 1, 2, 1, 3]) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import binary_stat_scores + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> binary_stat_scores(preds, target, multidim_average='samplewise') + tensor([[2, 3, 0, 1, 3], + [0, 2, 1, 3, 3]]) + """ + if validate_args: + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) + _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) + preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) + tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) + return _binary_stat_scores_compute(tp, fp, tn, fn, multidim_average) + + +def _multiclass_stat_scores_arg_validation( + num_classes: int, + top_k: int = 1, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, +) -> None: + """Validate non tensor input. + + - ``num_classes`` has to be a int larger than 1 + - ``top_k`` has to be an int larger than 0 but no larger than number of classes + - ``average`` has to be "micro" | "macro" | "weighted" | "none" + - ``multidim_average`` has to be either "global" or "samplewise" + - ``ignore_index`` has to be None or int + """ + if not isinstance(num_classes, int) or num_classes < 2: + raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") + if not isinstance(top_k, int) and top_k < 1: + raise ValueError(f"Expected argument `top_k` to be an integer larger than or equal to 1, but got {top_k}") + if top_k > num_classes: + raise ValueError( + f"Expected argument `top_k` to be smaller or equal to `num_classes` but got {top_k} and {num_classes}" + ) + allowed_average = ("micro", "macro", "weighted", "none", None) + if average not in allowed_average: + raise ValueError(f"Expected argument `average` to be one of {allowed_average}, but got {average}") + allowed_multidim_average = ("global", "samplewise") + if multidim_average not in allowed_multidim_average: + raise ValueError( + f"Expected argument `multidim_average` to be one of {allowed_multidim_average}, but got {multidim_average}" + ) + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + + +def _multiclass_stat_scores_tensor_validation( + preds: Tensor, + target: Tensor, + num_classes: int, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, +) -> None: + """Validate tensor input. + + - if target has one more dimension than preds, then all dimensions except for preds.shape[1] should match + exactly. preds.shape[1] should have size equal to number of classes + - if preds and target have same number of dims, then all dimensions should match + - if ``multidim_average`` is set to ``samplewise`` preds tensor needs to be atleast 2 dimensional in the + int case and 3 dimensional in the float case + - all values in target tensor that are not ignored have to be {0, ..., num_classes - 1} + - if pred tensor is not floating point, then all values also have to be in {0, ..., num_classes - 1} + """ + if preds.ndim == target.ndim + 1: + if not preds.is_floating_point(): + raise ValueError("If `preds` have one dimension more than `target`, `preds` should be a float tensor.") + if preds.shape[1] != num_classes: + raise ValueError( + "If `preds` have one dimension more than `target`, `preds.shape[1]` should be" + " equal to number of classes." + ) + if preds.shape[2:] != target.shape[1:]: + raise ValueError( + "If `preds` have one dimension more than `target`, the shape of `preds` should be" + " (N, C, ...), and the shape of `target` should be (N, ...)." + ) + if multidim_average != "global" and preds.ndim < 3: + raise ValueError( + "If `preds` have one dimension more than `target`, the shape of `preds` should " + " atleast 3D when multidim_average is set to `samplewise`" + ) + + elif preds.ndim == target.ndim: + if preds.shape != target.shape: + raise ValueError( + "The `preds` and `target` should have the same shape,", + f" got `preds` with shape={preds.shape} and `target` with shape={target.shape}.", + ) + if multidim_average != "global" and preds.ndim < 2: + raise ValueError( + "When `preds` and `target` have the same shape, the shape of `preds` should " + " atleast 2D when multidim_average is set to `samplewise`" + ) + else: + raise ValueError( + "Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)" + " and `preds` should be (N, C, ...)." + ) + + num_unique_values = len(torch.unique(target)) + if ignore_index is None: + check = num_unique_values > num_classes + else: + check = num_unique_values > num_classes + 1 + if check: + raise RuntimeError( + "Detected more unique values in `target` than `num_classes`. Expected only " + f"{num_classes if ignore_index is None else num_classes + 1} but found" + f"{num_unique_values} in `target`." + ) + + if not preds.is_floating_point(): + unique_values = torch.unique(preds) + if len(unique_values) > num_classes: + raise RuntimeError( + "Detected more unique values in `preds` than `num_classes`. Expected only " + f"{num_classes} but found {len(unique_values)} in `preds`." + ) + + +def _multiclass_stat_scores_format( + preds: Tensor, + target: Tensor, + top_k: int = 1, +) -> Tuple[Tensor, Tensor]: + """Convert all input to label format except if ``top_k`` is not 1. + + - Applies argmax if preds have one more dimension than target + - Flattens additional dimensions + """ + # Apply argmax if we have one more dimension + if preds.ndim == target.ndim + 1 and top_k == 1: + preds = preds.argmax(dim=1) + if top_k != 1: + preds = preds.reshape(*preds.shape[:2], -1) + else: + preds = preds.reshape(preds.shape[0], -1) + target = target.reshape(target.shape[0], -1) + return preds, target + + +def _multiclass_stat_scores_update( + preds: Tensor, + target: Tensor, + num_classes: int, + top_k: int = 1, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Computes the statistics. + + - If ``multidim_average`` is equal to samplewise or ``top_k`` is not 1, we transform both preds and + target into one hot format. + - Else we calculate statistics by first calculating the confusion matrix and afterwards deriving the + statistics from that + - Remove all datapoints that should be ignored. Depending on if ``ignore_index`` is in the set of labels + or outside we have do use different augmentation stategies when one hot encoding. + """ + if multidim_average == "samplewise" or top_k != 1: + ignore_in = 0 <= ignore_index <= num_classes - 1 if ignore_index is not None else None + if ignore_index is not None and not ignore_in: + preds = preds.clone() + target = target.clone() + idx = target == ignore_index + preds[idx] = num_classes + target[idx] = num_classes + + if top_k > 1: + preds_oh = _movedim(select_topk(preds, topk=top_k, dim=1), 1, -1) + else: + preds_oh = torch.nn.functional.one_hot( + preds, num_classes + 1 if ignore_index is not None and not ignore_in else num_classes + ) + target_oh = torch.nn.functional.one_hot( + target, num_classes + 1 if ignore_index is not None and not ignore_in else num_classes + ) + if ignore_index is not None: + if 0 <= ignore_index <= num_classes - 1: + target_oh[target == ignore_index, :] = -1 + else: + preds_oh = preds_oh[..., :-1] + target_oh = target_oh[..., :-1] + target_oh[target == num_classes, :] = -1 + sum_dim = [0, 1] if multidim_average == "global" else [1] + tp = ((target_oh == preds_oh) & (target_oh == 1)).sum(sum_dim) + fn = ((target_oh != preds_oh) & (target_oh == 1)).sum(sum_dim) + fp = ((target_oh != preds_oh) & (target_oh == 0)).sum(sum_dim) + tn = ((target_oh == preds_oh) & (target_oh == 0)).sum(sum_dim) + return tp, fp, tn, fn + else: + preds = preds.flatten() + target = target.flatten() + if ignore_index is not None: + idx = target != ignore_index + preds = preds[idx] + target = target[idx] + unique_mapping = (target * num_classes + preds).to(torch.long) + bins = _bincount(unique_mapping, minlength=num_classes**2) + confmat = bins.reshape(num_classes, num_classes) + tp = confmat.diag() + fp = confmat.sum(0) - tp + fn = confmat.sum(1) - tp + tn = confmat.sum() - (fp + fn + tp) + return tp, fp, tn, fn + + +def _multiclass_stat_scores_compute( + tp: Tensor, + fp: Tensor, + tn: Tensor, + fn: Tensor, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", +) -> Tensor: + """Stack statistics and compute support also. + + Applies average strategy afterwards. + """ + res = torch.stack([tp, fp, tn, fn, tp + fn], dim=-1) + sum_dim = 0 if multidim_average == "global" else 1 + if average == "micro": + return res.sum(sum_dim) + elif average == "macro": + return res.float().mean(sum_dim) + elif average == "weighted": + weight = tp + fn + if multidim_average == "global": + return (res * (weight / weight.sum()).reshape(*weight.shape, 1)).sum(sum_dim) + else: + return (res * (weight / weight.sum(-1, keepdim=True)).reshape(*weight.shape, 1)).sum(sum_dim) + elif average is None or average == "none": + return res + + +def multiclass_stat_scores( + preds: Tensor, + target: Tensor, + num_classes: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + top_k: int = 1, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes the number of true positives, false positives, true negatives, false negatives and the support + for multiclass tasks. Related to `Type I and Type II errors`_. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape + depends on ``average`` and ``multidim_average`` parameters: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)`` + - If ``average=None/'none'``, the shape will be ``(C, 5)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)`` + - If ``average=None/'none'``, the shape will be ``(N, C, 5)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import multiclass_stat_scores + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> multiclass_stat_scores(preds, target, num_classes=3, average='micro') + tensor([3, 1, 7, 1, 4]) + >>> multiclass_stat_scores(preds, target, num_classes=3, average=None) + tensor([[1, 0, 2, 1, 2], + [1, 1, 2, 0, 1], + [1, 0, 3, 0, 1]]) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import multiclass_stat_scores + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> multiclass_stat_scores(preds, target, num_classes=3, average='micro') + tensor([3, 1, 7, 1, 4]) + >>> multiclass_stat_scores(preds, target, num_classes=3, average=None) + tensor([[1, 0, 2, 1, 2], + [1, 1, 2, 0, 1], + [1, 0, 3, 0, 1]]) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import multiclass_stat_scores + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise', average='micro') + tensor([[3, 3, 9, 3, 6], + [2, 4, 8, 4, 6]]) + >>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise', average=None) + tensor([[[2, 1, 3, 0, 2], + [0, 1, 3, 2, 2], + [1, 1, 3, 1, 2]], + [[0, 1, 4, 1, 1], + [1, 1, 2, 2, 3], + [1, 2, 2, 1, 2]]]) + """ + if validate_args: + _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) + preds, target = _multiclass_stat_scores_format(preds, target, top_k) + tp, fp, tn, fn = _multiclass_stat_scores_update(preds, target, num_classes, top_k, multidim_average, ignore_index) + return _multiclass_stat_scores_compute(tp, fp, tn, fn, average, multidim_average) + + +def _multilabel_stat_scores_arg_validation( + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, +) -> None: + """Validate non tensor input. + + - ``num_labels`` should be an int larger than 1 + - ``threshold`` has to be a float in the [0,1] range + - ``average`` has to be "micro" | "macro" | "weighted" | "none" + - ``multidim_average`` has to be either "global" or "samplewise" + - ``ignore_index`` has to be None or int + """ + if not isinstance(num_labels, int) or num_labels < 2: + raise ValueError(f"Expected argument `num_labels` to be an integer larger than 1, but got {num_labels}") + if not (isinstance(threshold, float) and (0 <= threshold <= 1)): + raise ValueError(f"Expected argument `threshold` to be a float, but got {threshold}.") + allowed_average = ("micro", "macro", "weighted", "none", None) + if average not in allowed_average: + raise ValueError(f"Expected argument `average` to be one of {allowed_average}, but got {average}") + allowed_multidim_average = ("global", "samplewise") + if multidim_average not in allowed_multidim_average: + raise ValueError( + f"Expected argument `multidim_average` to be one of {allowed_multidim_average}, but got {multidim_average}" + ) + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + + +def _multilabel_stat_scores_tensor_validation( + preds: Tensor, + target: Tensor, + num_labels: int, + multidim_average: str, + ignore_index: Optional[int] = None, +) -> None: + """Validate tensor input. + + - tensors have to be of same shape + - the second dimension of both tensors need to be equal to the number of labels + - all values in target tensor that are not ignored have to be in {0, 1} + - if pred tensor is not floating point, then all values also have to be in {0, 1} + - if ``multidim_average`` is set to ``samplewise`` preds tensor needs to be atleast 3 dimensional + """ + # Check that they have same shape + _check_same_shape(preds, target) + + if preds.shape[1] != num_labels: + raise ValueError( + "Expected both `target.shape[1]` and `preds.shape[1]` to be equal to the number of labels" + f" but got {preds.shape[1]} and expected {num_labels}" + ) + + # Check that target only contains [0,1] values or value in ignore_index + unique_values = torch.unique(target) + if ignore_index is None: + check = torch.any((unique_values != 0) & (unique_values != 1)) + else: + check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index)) + if check: + raise RuntimeError( + f"Detected the following values in `target`: {unique_values} but expected only" + f" the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." + ) + + # If preds is label tensor, also check that it only contains [0,1] values + if not preds.is_floating_point(): + unique_values = torch.unique(preds) + if torch.any((unique_values != 0) & (unique_values != 1)): + raise RuntimeError( + f"Detected the following values in `preds`: {unique_values} but expected only" + " the following values [0,1] since preds is a label tensor." + ) + + if multidim_average != "global" and preds.ndim < 3: + raise ValueError("Expected input to be atleast 3D when multidim_average is set to `samplewise`") + + +def _multilabel_stat_scores_format( + preds: Tensor, target: Tensor, num_labels: int, threshold: float = 0.5, ignore_index: Optional[int] = None +) -> Tuple[Tensor, Tensor]: + """Convert all input to label format. + + - If preds tensor is floating point, applies sigmoid if pred tensor not in [0,1] range + - If preds tensor is floating point, thresholds afterwards + - Mask all elements that should be ignored with negative numbers for later filtration + """ + if preds.is_floating_point(): + if not torch.all((0 <= preds) * (preds <= 1)): + preds = preds.sigmoid() + preds = preds > threshold + preds = preds.reshape(*preds.shape[:2], -1) + target = target.reshape(*target.shape[:2], -1) + + if ignore_index is not None: + idx = target == ignore_index + target = target.clone() + target[idx] = -1 + + return preds, target + + +def _multilabel_stat_scores_update( + preds: Tensor, target: Tensor, multidim_average: Literal["global", "samplewise"] = "global" +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Computes the statistics.""" + sum_dim = [0, -1] if multidim_average == "global" else [-1] + tp = ((target == preds) & (target == 1)).sum(sum_dim).squeeze() + fn = ((target != preds) & (target == 1)).sum(sum_dim).squeeze() + fp = ((target != preds) & (target == 0)).sum(sum_dim).squeeze() + tn = ((target == preds) & (target == 0)).sum(sum_dim).squeeze() + return tp, fp, tn, fn + + +def _multilabel_stat_scores_compute( + tp: Tensor, + fp: Tensor, + tn: Tensor, + fn: Tensor, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", +) -> Tensor: + """Stack statistics and compute support also. + + Applies average strategy afterwards. + """ + res = torch.stack([tp, fp, tn, fn, tp + fn], dim=-1) + sum_dim = 0 if multidim_average == "global" else 1 + if average == "micro": + return res.sum(sum_dim) + elif average == "macro": + return res.float().mean(sum_dim) + elif average == "weighted": + w = tp + fn + return (res * (w / w.sum()).reshape(*w.shape, 1)).sum(sum_dim) + elif average is None or average == "none": + return res + + +def multilabel_stat_scores( + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes the number of true positives, false positives, true negatives, false negatives and the support + for multilabel tasks. Related to `Type I and Type II errors`_. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape + depends on ``average`` and ``multidim_average`` parameters: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)`` + - If ``average=None/'none'``, the shape will be ``(C, 5)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)`` + - If ``average=None/'none'``, the shape will be ``(N, C, 5)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional.classification import multilabel_stat_scores + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_stat_scores(preds, target, num_labels=3, average='micro') + tensor([2, 1, 2, 1, 3]) + >>> multilabel_stat_scores(preds, target, num_labels=3, average=None) + tensor([[1, 0, 1, 0, 1], + [0, 0, 1, 1, 1], + [1, 1, 0, 0, 1]]) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import multilabel_stat_scores + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_stat_scores(preds, target, num_labels=3, average='micro') + tensor([2, 1, 2, 1, 3]) + >>> multilabel_stat_scores(preds, target, num_labels=3, average=None) + tensor([[1, 0, 1, 0, 1], + [0, 0, 1, 1, 1], + [1, 1, 0, 0, 1]]) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import multilabel_stat_scores + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise', average='micro') + tensor([[2, 3, 0, 1, 3], + [0, 2, 1, 3, 3]]) + >>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise', average=None) + tensor([[[1, 1, 0, 0, 1], + [1, 1, 0, 0, 1], + [0, 1, 0, 1, 1]], + [[0, 0, 0, 2, 2], + [0, 2, 0, 0, 0], + [0, 0, 1, 1, 1]]]) + + """ + if validate_args: + _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) + _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) + preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) + tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) + return _multilabel_stat_scores_compute(tp, fp, tn, fn, average, multidim_average) def _del_column(data: Tensor, idx: int) -> Tensor: @@ -299,8 +1090,22 @@ def stat_scores( threshold: float = 0.5, multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, ) -> Tensor: - r"""Computes the number of true positives, false positives, true negatives, false negatives. + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + + Computes the number of true positives, false positives, true negatives, false negatives. Related to `Type I and Type II errors`_ and the `confusion matrix`_. The reduction method (how the statistics are aggregated) is controlled by the @@ -416,6 +1221,33 @@ def stat_scores( tensor([2, 2, 6, 2, 4]) """ + if task is not None: + assert multidim_average is not None + if task == "binary": + return binary_stat_scores(preds, target, threshold, multidim_average, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_stat_scores( + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + ) + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_stat_scores( + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + ) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) if reduce not in ["micro", "macro", "samples"]: raise ValueError(f"The `reduce` {reduce} is not valid.") diff --git a/src/torchmetrics/functional/regression/__init__.py b/src/torchmetrics/functional/regression/__init__.py index 31ca5cb6f88..01c4af475b7 100644 --- a/src/torchmetrics/functional/regression/__init__.py +++ b/src/torchmetrics/functional/regression/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from torchmetrics.functional.regression.cosine_similarity import cosine_similarity # noqa: F401 from torchmetrics.functional.regression.explained_variance import explained_variance # noqa: F401 +from torchmetrics.functional.regression.kl_divergence import kl_divergence # noqa: F401 from torchmetrics.functional.regression.log_mse import mean_squared_log_error # noqa: F401 from torchmetrics.functional.regression.mae import mean_absolute_error # noqa: F401 from torchmetrics.functional.regression.mape import mean_absolute_percentage_error # noqa: F401 diff --git a/src/torchmetrics/functional/regression/kl_divergence.py b/src/torchmetrics/functional/regression/kl_divergence.py new file mode 100644 index 00000000000..4a4da4b8b07 --- /dev/null +++ b/src/torchmetrics/functional/regression/kl_divergence.py @@ -0,0 +1,112 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.compute import _safe_xlogy + + +def _kld_update(p: Tensor, q: Tensor, log_prob: bool) -> Tuple[Tensor, int]: + """Updates and returns KL divergence scores for each observation and the total number of observations. Checks + same shape and 2D nature of the input tensors else raises ValueError. + + Args: + p: data distribution with shape ``[N, d]`` + q: prior or approximate distribution with shape ``[N, d]`` + log_prob: bool indicating if input is log-probabilities or probabilities. If given as probabilities, + will normalize to make sure the distributes sum to 1 + """ + _check_same_shape(p, q) + if p.ndim != 2 or q.ndim != 2: + raise ValueError(f"Expected both p and q distribution to be 2D but got {p.ndim} and {q.ndim} respectively") + + total = p.shape[0] + if log_prob: + measures = torch.sum(p.exp() * (p - q), axis=-1) + else: + p = p / p.sum(axis=-1, keepdim=True) + q = q / q.sum(axis=-1, keepdim=True) + measures = _safe_xlogy(p, p / q).sum(axis=-1) + + return measures, total + + +def _kld_compute(measures: Tensor, total: Tensor, reduction: Literal["mean", "sum", "none", None] = "mean") -> Tensor: + """Computes the KL divergenece based on the type of reduction. + + Args: + measures: Tensor of KL divergence scores for each observation + total: Number of observations + reduction: + Determines how to reduce over the ``N``/batch dimension: + + - ``'mean'`` [default]: Averages score across samples + - ``'sum'``: Sum score across samples + - ``'none'`` or ``None``: Returns score per sample + + Example: + >>> p = torch.tensor([[0.36, 0.48, 0.16]]) + >>> q = torch.tensor([[1/3, 1/3, 1/3]]) + >>> measures, total = _kld_update(p, q, log_prob=False) + >>> _kld_compute(measures, total) + tensor(0.0853) + """ + + if reduction == "sum": + return measures.sum() + if reduction == "mean": + return measures.sum() / total + if reduction is None or reduction == "none": + return measures + return measures / total + + +def kl_divergence( + p: Tensor, q: Tensor, log_prob: bool = False, reduction: Literal["mean", "sum", "none", None] = "mean" +) -> Tensor: + r"""Computes `KL divergence`_ + + .. math:: + D_{KL}(P||Q) = \sum_{x\in\mathcal{X}} P(x) \log\frac{P(x)}{Q{x}} + + Where :math:`P` and :math:`Q` are probability distributions where :math:`P` usually represents a distribution + over data and :math:`Q` is often a prior or approximation of :math:`P`. It should be noted that the KL divergence + is a non-symetrical metric i.e. :math:`D_{KL}(P||Q) \neq D_{KL}(Q||P)`. + + Args: + p: data distribution with shape ``[N, d]`` + q: prior or approximate distribution with shape ``[N, d]`` + log_prob: bool indicating if input is log-probabilities or probabilities. If given as probabilities, + will normalize to make sure the distributes sum to 1 + reduction: + Determines how to reduce over the ``N``/batch dimension: + + - ``'mean'`` [default]: Averages score across samples + - ``'sum'``: Sum score across samples + - ``'none'`` or ``None``: Returns score per sample + + Example: + >>> import torch + >>> p = torch.tensor([[0.36, 0.48, 0.16]]) + >>> q = torch.tensor([[1/3, 1/3, 1/3]]) + >>> kl_divergence(p, q) + tensor(0.0853) + """ + measures, total = _kld_update(p, q, log_prob) + return _kld_compute(measures, total, reduction) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 2dd96e5e8f0..32e1f0c4312 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -337,7 +337,7 @@ def _reduce_states(self, incoming_state: Dict[str, Any]) -> None: if reduce_fn == dim_zero_sum: reduced = global_state + local_state elif reduce_fn == dim_zero_mean: - reduced = ((self._update_count - 1) * global_state + local_state) / self._update_count + reduced = ((self._update_count - 1) * global_state + local_state).float() / self._update_count elif reduce_fn == dim_zero_max: reduced = torch.max(global_state, local_state) elif reduce_fn == dim_zero_min: @@ -392,10 +392,9 @@ def wrapped_func(*args: Any, **kwargs: Any) -> None: except RuntimeError as err: if "Expected all tensors to be on" in str(err): raise RuntimeError( - "Encountered different devices in metric calculation" - " (see stacktrace for details)." - "This could be due to the metric class not being on the same device as input." - f"Instead of `metric={self.__class__.__name__}(...)` try to do" + "Encountered different devices in metric calculation (see stacktrace for details)." + " This could be due to the metric class not being on the same device as input." + f" Instead of `metric={self.__class__.__name__}(...)` try to do" f" `metric={self.__class__.__name__}(...).to(device)` where" " device corresponds to the device of the input." ) from err diff --git a/src/torchmetrics/regression/__init__.py b/src/torchmetrics/regression/__init__.py index 4a6f666a4dd..5515e38019c 100644 --- a/src/torchmetrics/regression/__init__.py +++ b/src/torchmetrics/regression/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from torchmetrics.regression.cosine_similarity import CosineSimilarity # noqa: F401 from torchmetrics.regression.explained_variance import ExplainedVariance # noqa: F401 +from torchmetrics.regression.kl_divergence import KLDivergence # noqa: F401 from torchmetrics.regression.log_mse import MeanSquaredLogError # noqa: F401 from torchmetrics.regression.mae import MeanAbsoluteError # noqa: F401 from torchmetrics.regression.mape import MeanAbsolutePercentageError # noqa: F401 diff --git a/src/torchmetrics/regression/kl_divergence.py b/src/torchmetrics/regression/kl_divergence.py new file mode 100644 index 00000000000..9764a407f52 --- /dev/null +++ b/src/torchmetrics/regression/kl_divergence.py @@ -0,0 +1,105 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.regression.kl_divergence import _kld_compute, _kld_update +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat + + +class KLDivergence(Metric): + r"""Computes the `KL divergence`_: + + .. math:: + D_{KL}(P||Q) = \sum_{x\in\mathcal{X}} P(x) \log\frac{P(x)}{Q{x}} + + Where :math:`P` and :math:`Q` are probability distributions where :math:`P` usually represents a distribution + over data and :math:`Q` is often a prior or approximation of :math:`P`. It should be noted that the KL divergence + is a non-symetrical metric i.e. :math:`D_{KL}(P||Q) \neq D_{KL}(Q||P)`. + + Args: + p: data distribution with shape ``[N, d]`` + q: prior or approximate distribution with shape ``[N, d]`` + log_prob: bool indicating if input is log-probabilities or probabilities. If given as probabilities, + will normalize to make sure the distributes sum to 1. + reduction: + Determines how to reduce over the ``N``/batch dimension: + + - ``'mean'`` [default]: Averages score across samples + - ``'sum'``: Sum score across samples + - ``'none'`` or ``None``: Returns score per sample + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + + Raises: + TypeError: + If ``log_prob`` is not an ``bool``. + ValueError: + If ``reduction`` is not one of ``'mean'``, ``'sum'``, ``'none'`` or ``None``. + + .. note:: + Half precision is only support on GPU for this metric + + Example: + >>> import torch + >>> from torchmetrics.functional import kl_divergence + >>> p = torch.tensor([[0.36, 0.48, 0.16]]) + >>> q = torch.tensor([[1/3, 1/3, 1/3]]) + >>> kl_divergence(p, q) + tensor(0.0853) + + """ + is_differentiable: bool = True + higher_is_better: bool = False + full_state_update: bool = False + total: Tensor + + def __init__( + self, + log_prob: bool = False, + reduction: Literal["mean", "sum", "none", None] = "mean", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if not isinstance(log_prob, bool): + raise TypeError(f"Expected argument `log_prob` to be bool but got {log_prob}") + self.log_prob = log_prob + + allowed_reduction = ["mean", "sum", "none", None] + if reduction not in allowed_reduction: + raise ValueError(f"Expected argument `reduction` to be one of {allowed_reduction} but got {reduction}") + self.reduction = reduction + + if self.reduction in ["mean", "sum"]: + self.add_state("measures", torch.tensor(0.0), dist_reduce_fx="sum") + else: + self.add_state("measures", [], dist_reduce_fx="cat") + self.add_state("total", torch.tensor(0), dist_reduce_fx="sum") + + def update(self, p: Tensor, q: Tensor) -> None: # type: ignore + measures, total = _kld_update(p, q, self.log_prob) + if self.reduction is None or self.reduction == "none": + self.measures.append(measures) + else: + self.measures += measures.sum() + self.total += total + + def compute(self) -> Tensor: + measures = dim_zero_cat(self.measures) if self.reduction is None or self.reduction == "none" else self.measures + return _kld_compute(measures, self.total, self.reduction) diff --git a/src/torchmetrics/utilities/checks.py b/src/torchmetrics/utilities/checks.py index 54d15b3d455..dfcb2922147 100644 --- a/src/torchmetrics/utilities/checks.py +++ b/src/torchmetrics/utilities/checks.py @@ -32,7 +32,9 @@ def _check_for_empty_tensors(preds: Tensor, target: Tensor) -> bool: def _check_same_shape(preds: Tensor, target: Tensor) -> None: """Check that predictions and target have the same shape, else raise error.""" if preds.shape != target.shape: - raise RuntimeError("Predictions and targets are expected to have the same shape") + raise RuntimeError( + f"Predictions and targets are expected to have the same shape, but got {preds.shape} and {target.shape}." + ) def _basic_input_validation( diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index f496baff818..cc9dd06af7b 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -11,9 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple + import torch from torch import Tensor +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_9 + def _safe_matmul(x: Tensor, y: Tensor) -> Tensor: """Safe calculation of matrix multiplication. @@ -38,3 +42,74 @@ def _safe_xlogy(x: Tensor, y: Tensor) -> Tensor: res = x * torch.log(y) res[x == 0] = 0.0 return res + + +def _safe_divide(num: Tensor, denom: Tensor) -> Tensor: + """Safe division, by preventing division by zero. + + Additionally casts to float if input is not already to secure backwards compatibility. + """ + denom[denom == 0.0] = 1 + num = num if num.is_floating_point() else num.float() + denom = denom if denom.is_floating_point() else denom.float() + return num / denom + + +def _auc_format_inputs(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: + """Checks that auc input is correct.""" + x = x.squeeze() if x.ndim > 1 else x + y = y.squeeze() if y.ndim > 1 else y + + if x.ndim > 1 or y.ndim > 1: + raise ValueError( + f"Expected both `x` and `y` tensor to be 1d, but got tensors with dimension {x.ndim} and {y.ndim}" + ) + if x.numel() != y.numel(): + raise ValueError( + f"Expected the same number of elements in `x` and `y` tensor but received {x.numel()} and {y.numel()}" + ) + return x, y + + +def _auc_compute_without_check(x: Tensor, y: Tensor, direction: float, axis: int = -1) -> Tensor: + """Computes area under the curve using the trapezoidal rule. + + Assumes increasing or decreasing order of `x`. + """ + with torch.no_grad(): + auc_: Tensor = torch.trapz(y, x, dim=axis) * direction + return auc_ + + +def _auc_compute(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor: + with torch.no_grad(): + if reorder: + x, x_idx = torch.sort(x, stable=True) if _TORCH_GREATER_EQUAL_1_9 else torch.sort(x) + y = y[x_idx] + + dx = x[1:] - x[:-1] + if (dx < 0).any(): + if (dx <= 0).all(): + direction = -1.0 + else: + raise ValueError( + "The `x` tensor is neither increasing or decreasing. Try setting the reorder argument to `True`." + ) + else: + direction = 1.0 + return _auc_compute_without_check(x, y, direction) + + +def auc(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor: + """Computes Area Under the Curve (AUC) using the trapezoidal rule. + + Args: + x: x-coordinates, must be either increasing or decreasing + y: y-coordinates + reorder: if True, will reorder the arrays to make it either increasing or decreasing + + Return: + Tensor containing AUC score + """ + x, y = _auc_format_inputs(x, y) + return _auc_compute(x, y, reorder=reorder) diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index f9cd329fb11..92df978e96c 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -261,15 +261,15 @@ def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: Returns: Number of occurrences for each unique element in x """ - if x.is_cuda and deterministic() or _TORCH_GREATER_EQUAL_1_12 and x.is_mps: - if minlength is None: - minlength = len(torch.unique(x)) + if minlength is None: + minlength = len(torch.unique(x)) + if deterministic() or _TORCH_GREATER_EQUAL_1_12 and x.is_mps: output = torch.zeros(minlength, device=x.device, dtype=torch.long) for i in range(minlength): output[i] = (x == i).sum() return output - else: - return torch.bincount(x, minlength=minlength) + z = torch.zeros(minlength, device=x.device, dtype=x.dtype) + return z.index_add_(0, x, torch.ones_like(x)) def allclose(tensor1: Tensor, tensor2: Tensor) -> bool: @@ -277,3 +277,11 @@ def allclose(tensor1: Tensor, tensor2: Tensor) -> bool: if tensor1.dtype != tensor2.dtype: tensor2 = tensor2.to(dtype=tensor1.dtype) return torch.allclose(tensor1, tensor2) + + +def _movedim(tensor: Tensor, dim1: int, dim2: int) -> tensor: + if _TORCH_GREATER_EQUAL_1_7: + return torch.movedim(tensor, dim1, dim2) + if dim2 >= 0: + dim2 += 1 + return tensor.unsqueeze(dim2).transpose(dim2, dim1).squeeze(dim1) diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index 8e0f890cd63..834b3ff63db 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -102,6 +102,7 @@ def _compare_version(package: str, op: Callable, version: str) -> Optional[bool] _TORCH_GREATER_EQUAL_1_6: Optional[bool] = _compare_version("torch", operator.ge, "1.6.0") _TORCH_GREATER_EQUAL_1_7: Optional[bool] = _compare_version("torch", operator.ge, "1.7.0") _TORCH_GREATER_EQUAL_1_8: Optional[bool] = _compare_version("torch", operator.ge, "1.8.0") +_TORCH_GREATER_EQUAL_1_9: Optional[bool] = _compare_version("torch", operator.ge, "1.9.0") _TORCH_GREATER_EQUAL_1_10: Optional[bool] = _compare_version("torch", operator.ge, "1.10.0") _TORCH_GREATER_EQUAL_1_11: Optional[bool] = _compare_version("torch", operator.ge, "1.11.0") _TORCH_GREATER_EQUAL_1_12: Optional[bool] = _compare_version("torch", operator.ge, "1.12.0") diff --git a/tests/unittests/classification/inputs.py b/tests/unittests/classification/inputs.py index ff88b452638..a17d525287a 100644 --- a/tests/unittests/classification/inputs.py +++ b/tests/unittests/classification/inputs.py @@ -13,13 +13,24 @@ # limitations under the License. from collections import namedtuple +import pytest import torch +from torch import Tensor from unittests.helpers import seed_all from unittests.helpers.testers import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES seed_all(1) + +def _inv_sigmoid(x: Tensor) -> Tensor: + return (x / (1 - x)).log() + + +def _logsoftmax(x: Tensor, dim: int = -1) -> Tensor: + return torch.nn.functional.log_softmax(x, dim) + + Input = namedtuple("Input", ["preds", "target"]) _input_binary_prob = Input( @@ -60,6 +71,140 @@ target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), ) +_binary_cases = ( + pytest.param( + Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + ), + id="input[single dim-labels]", + ), + pytest.param( + Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))), + id="input[single dim-probs]", + ), + pytest.param( + Input( + preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + ), + id="input[single dim-logits]", + ), + pytest.param( + Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), + id="input[multi dim-labels]", + ), + pytest.param( + Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), + id="input[multi dim-probs]", + ), + pytest.param( + Input( + preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), + id="input[multi dim-logits]", + ), +) + + +_multiclass_cases = ( + pytest.param( + Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + ), + id="input[single dim-labels]", + ), + pytest.param( + Input( + preds=torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES).softmax(-1), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + ), + id="input[single dim-probs]", + ), + pytest.param( + Input( + preds=_logsoftmax(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), -1), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + ), + id="input[single dim-logits]", + ), + pytest.param( + Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), + id="input[multi dim-labels]", + ), + pytest.param( + Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM).softmax(-2), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), + id="input[multi dim-probs]", + ), + pytest.param( + Input( + preds=_logsoftmax(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), -2), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), + id="input[multi dim-logits]", + ), +) + + +_multilabel_cases = ( + pytest.param( + Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + ), + id="input[single dim-labels]", + ), + pytest.param( + Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + ), + id="input[single dim-probs]", + ), + pytest.param( + Input( + preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + ), + id="input[single dim-logits]", + ), + pytest.param( + Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + ), + id="input[multi dim-labels]", + ), + pytest.param( + Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + ), + id="input[multi dim-probs]", + ), + pytest.param( + Input( + preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + ), + id="input[multi dim-logits]", + ), +) + # Generate edge multilabel edge case, where nothing matches (scores are undefined) __temp_preds = torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) __temp_target = abs(__temp_preds - 1) diff --git a/tests/unittests/classification/test_accuracy.py b/tests/unittests/classification/test_accuracy.py index 4377aed1b40..accaa395fad 100644 --- a/tests/unittests/classification/test_accuracy.py +++ b/tests/unittests/classification/test_accuracy.py @@ -11,440 +11,472 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial import numpy as np import pytest import torch +from scipy.special import expit as sigmoid from sklearn.metrics import accuracy_score as sk_accuracy -from torch import tensor - -from torchmetrics import Accuracy -from torchmetrics.functional import accuracy -from torchmetrics.utilities.checks import _input_format_classification -from torchmetrics.utilities.enums import AverageMethod, DataType -from unittests.classification import MetricWrapper -from unittests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob -from unittests.classification.inputs import _input_multiclass as _input_mcls -from unittests.classification.inputs import _input_multiclass_logits as _input_mcls_logits -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class -from unittests.classification.inputs import _input_multidim_multiclass as _input_mdmc -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel as _input_mlb -from unittests.classification.inputs import _input_multilabel_logits as _input_mlb_logits -from unittests.classification.inputs import _input_multilabel_multidim as _input_mlmd -from unittests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob -from unittests.classification.inputs import _negmetric_noneavg +from sklearn.metrics import confusion_matrix as sk_confusion_matrix + +from torchmetrics.classification.accuracy import BinaryAccuracy, MulticlassAccuracy, MultilabelAccuracy +from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy, multilabel_accuracy +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all -from unittests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester +from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sk_accuracy(preds, target, subset_accuracy): - sk_preds, sk_target, mode = _input_format_classification(preds, target, threshold=THRESHOLD) - sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() - - if mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy: - sk_preds, sk_target = np.transpose(sk_preds, (0, 2, 1)), np.transpose(sk_target, (0, 2, 1)) - sk_preds, sk_target = sk_preds.reshape(-1, sk_preds.shape[2]), sk_target.reshape(-1, sk_target.shape[2]) - elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy: - return np.all(sk_preds == sk_target, axis=(1, 2)).mean() - elif mode == DataType.MULTILABEL and not subset_accuracy: - sk_preds, sk_target = sk_preds.reshape(-1), sk_target.reshape(-1) - - return sk_accuracy(y_true=sk_target, y_pred=sk_preds) - - -@pytest.mark.parametrize( - "preds, target, subset_accuracy, mdmc_average", - [ - (_input_binary_logits.preds, _input_binary_logits.target, False, None), - (_input_binary_prob.preds, _input_binary_prob.target, False, None), - (_input_binary.preds, _input_binary.target, False, None), - (_input_mlb_prob.preds, _input_mlb_prob.target, True, None), - (_input_mlb_logits.preds, _input_mlb_logits.target, False, None), - (_input_mlb_prob.preds, _input_mlb_prob.target, False, None), - (_input_mlb.preds, _input_mlb.target, True, None), - (_input_mlb.preds, _input_mlb.target, False, "global"), - (_input_mcls_prob.preds, _input_mcls_prob.target, False, None), - (_input_mcls_logits.preds, _input_mcls_logits.target, False, None), - (_input_mcls.preds, _input_mcls.target, False, None), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, False, "global"), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, True, None), - (_input_mdmc.preds, _input_mdmc.target, False, "global"), - (_input_mdmc.preds, _input_mdmc.target, True, None), - (_input_mlmd_prob.preds, _input_mlmd_prob.target, True, None), - (_input_mlmd_prob.preds, _input_mlmd_prob.target, False, None), - (_input_mlmd.preds, _input_mlmd.target, True, None), - (_input_mlmd.preds, _input_mlmd.target, False, "global"), - ], -) -class TestAccuracies(MetricTester): +def _sk_accuracy(target, preds): + score = sk_accuracy(target, preds) + return score if not np.isnan(score) else 0.0 + + +def _sk_accuracy_binary(preds, target, ignore_index, multidim_average): + if multidim_average == "global": + preds = preds.view(-1).numpy() + target = target.view(-1).numpy() + else: + preds = preds.numpy() + target = target.numpy() + + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + + if multidim_average == "global": + target, preds = remove_ignore_index(target, preds, ignore_index) + return _sk_accuracy(target, preds) + else: + res = [] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + res.append(_sk_accuracy(true, pred)) + return np.stack(res) + + +@pytest.mark.parametrize("input", _binary_cases) +class TestBinaryAccuracy(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("ddp", [False, True]) - @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, subset_accuracy, mdmc_average): + def test_binary_accuracy(self, ddp, input, ignore_index, multidim_average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") + self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=Accuracy, - sk_metric=partial(_sk_accuracy, subset_accuracy=subset_accuracy), - dist_sync_on_step=dist_sync_on_step, - metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy, "mdmc_average": mdmc_average}, + metric_class=BinaryAccuracy, + sk_metric=partial(_sk_accuracy_binary, ignore_index=ignore_index, multidim_average=multidim_average), + metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, ) - def test_accuracy_fn(self, preds, target, subset_accuracy, mdmc_average): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + def test_binary_accuracy_functional(self, input, ignore_index, multidim_average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + self.run_functional_metric_test( - preds, - target, - metric_functional=accuracy, - sk_metric=partial(_sk_accuracy, subset_accuracy=subset_accuracy), - metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy}, + preds=preds, + target=target, + metric_functional=binary_accuracy, + sk_metric=partial(_sk_accuracy_binary, ignore_index=ignore_index, multidim_average=multidim_average), + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + }, ) - def test_accuracy_differentiability(self, preds, target, subset_accuracy, mdmc_average): + def test_binary_accuracy_differentiability(self, input): + preds, target = input self.run_differentiability_test( preds=preds, target=target, - metric_module=Accuracy, - metric_functional=accuracy, - metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy, "mdmc_average": mdmc_average}, + metric_module=BinaryAccuracy, + metric_functional=binary_accuracy, + metric_args={"threshold": THRESHOLD}, ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_accuracy_half_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryAccuracy, + metric_functional=binary_accuracy, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) -_l1to4 = [0.1, 0.2, 0.3, 0.4] -_l1to4t3 = np.array([_l1to4, _l1to4, _l1to4]) -_l1to4t3_mcls = [_l1to4t3.T, _l1to4t3.T, _l1to4t3.T] - -# The preds in these examples always put highest probability on class 3, second highest on class 2, -# third highest on class 1, and lowest on class 0 -_topk_preds_mcls = tensor([_l1to4t3, _l1to4t3]).float() -_topk_target_mcls = tensor([[1, 2, 3], [2, 1, 0]]) - -# This is like for MC case, but one sample in each batch is sabotaged with 0 class prediction :) -_topk_preds_mdmc = tensor([_l1to4t3_mcls, _l1to4t3_mcls]).float() -_topk_target_mdmc = tensor([[[1, 1, 0], [2, 2, 2], [3, 3, 3]], [[2, 2, 0], [1, 1, 1], [0, 0, 0]]]) - -# Multilabel -_ml_t1 = [0.8, 0.2, 0.8, 0.2] -_ml_t2 = [_ml_t1, _ml_t1] -_ml_ta2 = [[1, 0, 1, 1], [0, 1, 1, 0]] -_av_preds_ml = tensor([_ml_t2, _ml_t2]).float() -_av_target_ml = tensor([_ml_ta2, _ml_ta2]) - -# Inputs with negative target values to be ignored -Input = namedtuple("Input", ["preds", "target", "ignore_index", "result"]) -_binary_with_neg_tgt = Input( - preds=torch.tensor([0, 1, 0]), target=torch.tensor([0, 1, -1]), ignore_index=-1, result=torch.tensor(1.0) -) -_multiclass_logits_with_neg_tgt = Input( - preds=torch.tensor([[0.8, 0.1], [0.2, 0.7], [0.5, 0.5]]), - target=torch.tensor([0, 1, -1]), - ignore_index=-1, - result=torch.tensor(1.0), -) -_multidim_multiclass_with_neg_tgt = Input( - preds=torch.tensor([[0, 0], [1, 1], [0, 0]]), - target=torch.tensor([[0, 0], [-1, 1], [1, -1]]), - ignore_index=-1, - result=torch.tensor(0.75), -) -_multidim_multiclass_logits_with_neg_tgt = Input( - preds=torch.tensor([[[0.8, 0.7], [0.2, 0.4]], [[0.1, 0.2], [0.9, 0.8]], [[0.7, 0.9], [0.2, 0.4]]]), - target=torch.tensor([[0, 0], [-1, 1], [1, -1]]), - ignore_index=-1, - result=torch.tensor(0.75), -) - - -# Replace with a proper sk_metric test once sklearn 0.24 hits :) -@pytest.mark.parametrize( - "preds, target, exp_result, k, subset_accuracy", - [ - (_topk_preds_mcls, _topk_target_mcls, 1 / 6, 1, False), - (_topk_preds_mcls, _topk_target_mcls, 3 / 6, 2, False), - (_topk_preds_mcls, _topk_target_mcls, 5 / 6, 3, False), - (_topk_preds_mcls, _topk_target_mcls, 1 / 6, 1, True), - (_topk_preds_mcls, _topk_target_mcls, 3 / 6, 2, True), - (_topk_preds_mcls, _topk_target_mcls, 5 / 6, 3, True), - (_topk_preds_mdmc, _topk_target_mdmc, 1 / 6, 1, False), - (_topk_preds_mdmc, _topk_target_mdmc, 8 / 18, 2, False), - (_topk_preds_mdmc, _topk_target_mdmc, 13 / 18, 3, False), - (_topk_preds_mdmc, _topk_target_mdmc, 1 / 6, 1, True), - (_topk_preds_mdmc, _topk_target_mdmc, 2 / 6, 2, True), - (_topk_preds_mdmc, _topk_target_mdmc, 3 / 6, 3, True), - (_av_preds_ml, _av_target_ml, 5 / 8, None, False), - (_av_preds_ml, _av_target_ml, 0, None, True), - ], -) -def test_topk_accuracy(preds, target, exp_result, k, subset_accuracy): - topk = Accuracy(top_k=k, subset_accuracy=subset_accuracy, mdmc_average="global") - - for batch in range(preds.shape[0]): - topk(preds[batch], target[batch]) - - assert topk.compute() == exp_result - - # Test functional - total_samples = target.shape[0] * target.shape[1] - - preds = preds.view(total_samples, 4, -1) - target = target.view(total_samples, -1) - - assert accuracy(preds, target, top_k=k, subset_accuracy=subset_accuracy) == exp_result - - -# Only MC and MDMC with probs input type should be accepted for top_k -@pytest.mark.parametrize( - "preds, target", - [ - (_input_binary_prob.preds, _input_binary_prob.target), - (_input_binary.preds, _input_binary.target), - (_input_mlb_prob.preds, _input_mlb_prob.target), - (_input_mlb.preds, _input_mlb.target), - (_input_mcls.preds, _input_mcls.target), - (_input_mdmc.preds, _input_mdmc.target), - (_input_mlmd_prob.preds, _input_mlmd_prob.target), - (_input_mlmd.preds, _input_mlmd.target), - ], -) -def test_topk_accuracy_wrong_input_types(preds, target): - topk = Accuracy(top_k=1) - - with pytest.raises(ValueError): - topk(preds[0], target[0]) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_accuracy_half_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryAccuracy, + metric_functional=binary_accuracy, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) - with pytest.raises(ValueError): - accuracy(preds[0], target[0], top_k=1) +def _sk_accuracy_multiclass(preds, target, ignore_index, multidim_average, average): + if preds.ndim == target.ndim + 1: + preds = torch.argmax(preds, 1) + if multidim_average == "global": + preds = preds.numpy().flatten() + target = target.numpy().flatten() + target, preds = remove_ignore_index(target, preds, ignore_index) + if average == "micro": + return _sk_accuracy(target, preds) + confmat = sk_confusion_matrix(target, preds, labels=list(range(NUM_CLASSES))) + acc_per_class = confmat.diagonal() / confmat.sum(axis=1) + acc_per_class[np.isnan(acc_per_class)] = 0.0 + if average == "macro": + return acc_per_class.mean() + elif average == "weighted": + weights = confmat.sum(1) + return ((weights * acc_per_class) / weights.sum()).sum() + else: + return acc_per_class + else: + preds = preds.numpy() + target = target.numpy() + res = [] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + if average == "micro": + res.append(_sk_accuracy(true, pred)) + else: + confmat = sk_confusion_matrix(true, pred, labels=list(range(NUM_CLASSES))) + acc_per_class = confmat.diagonal() / confmat.sum(axis=1) + acc_per_class[np.isnan(acc_per_class)] = 0.0 + if average == "macro": + res.append(acc_per_class.mean()) + elif average == "weighted": + weights = confmat.sum(1) + score = ((weights * acc_per_class) / weights.sum()).sum() + res.append(0.0 if np.isnan(score) else score) + else: + res.append(acc_per_class) + return np.stack(res, 0) + + +@pytest.mark.parametrize("input", _multiclass_cases) +class TestMulticlassAccuracy(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_accuracy(self, ddp, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and target.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") -@pytest.mark.parametrize( - "average, mdmc_average, num_classes, inputs, ignore_index, top_k, threshold", - [ - ("unknown", None, None, _input_binary, None, None, 0.5), - ("micro", "unknown", None, _input_binary, None, None, 0.5), - ("macro", None, None, _input_binary, None, None, 0.5), - ("micro", None, None, _input_mdmc_prob, None, None, 0.5), - ("micro", None, None, _input_binary_prob, 0, None, 0.5), - ("micro", None, None, _input_mcls_prob, NUM_CLASSES, None, 0.5), - ("micro", None, NUM_CLASSES, _input_mcls_prob, NUM_CLASSES, None, 0.5), - (None, None, None, _input_mcls_prob, None, 0, 0.5), - (None, None, None, _input_mcls_prob, None, None, 1.5), - ], -) -def test_wrong_params(average, mdmc_average, num_classes, inputs, ignore_index, top_k, threshold): - preds, target = inputs.preds, inputs.target - - with pytest.raises(ValueError): - acc = Accuracy( - average=average, - mdmc_average=mdmc_average, - num_classes=num_classes, - ignore_index=ignore_index, - threshold=threshold, - top_k=top_k, - ) - acc(preds[0], target[0]) - acc.compute() - - with pytest.raises(ValueError): - accuracy( - preds[0], - target[0], - average=average, - mdmc_average=mdmc_average, - num_classes=num_classes, - ignore_index=ignore_index, - threshold=threshold, - top_k=top_k, + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassAccuracy, + sk_metric=partial( + _sk_accuracy_multiclass, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, + }, ) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + def test_multiclass_accuracy_functional(self, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and target.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") -@pytest.mark.parametrize( - "preds_mc, target_mc, preds_ml, target_ml", - [ - ( - tensor([0, 1, 1, 1]), - tensor([2, 2, 1, 1]), - tensor([[0.8, 0.2, 0.8, 0.7], [0.6, 0.4, 0.6, 0.5]]), - tensor([[1, 0, 1, 1], [0, 0, 1, 0]]), + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_accuracy, + sk_metric=partial( + _sk_accuracy_multiclass, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, + }, ) - ], -) -def test_different_modes(preds_mc, target_mc, preds_ml, target_ml): - acc = Accuracy() - acc(preds_mc, target_mc) - with pytest.raises(ValueError, match="^[You cannot use]"): - acc(preds_ml, target_ml) - - -_bin_t1 = [0.7, 0.6, 0.2, 0.1] -_av_preds_bin = tensor([_bin_t1, _bin_t1]).float() -_av_target_bin = tensor([[1, 0, 0, 0], [0, 1, 1, 0]]) - -@pytest.mark.parametrize( - "preds, target, num_classes, exp_result, average, mdmc_average", - [ - (_topk_preds_mcls, _topk_target_mcls, 4, 1 / 4, "macro", None), - (_topk_preds_mcls, _topk_target_mcls, 4, 1 / 6, "weighted", None), - (_topk_preds_mcls, _topk_target_mcls, 4, [0.0, 0.0, 0.0, 1.0], "none", None), - (_topk_preds_mcls, _topk_target_mcls, 4, 1 / 6, "samples", None), - (_topk_preds_mdmc, _topk_target_mdmc, 4, 1 / 24, "macro", "samplewise"), - (_topk_preds_mdmc, _topk_target_mdmc, 4, 1 / 6, "weighted", "samplewise"), - (_topk_preds_mdmc, _topk_target_mdmc, 4, [0.0, 0.0, 0.0, 1 / 6], "none", "samplewise"), - (_topk_preds_mdmc, _topk_target_mdmc, 4, 1 / 6, "samples", "samplewise"), - (_topk_preds_mdmc, _topk_target_mdmc, 4, 1 / 6, "samples", "global"), - (_av_preds_ml, _av_target_ml, 4, 5 / 8, "macro", None), - (_av_preds_ml, _av_target_ml, 4, 0.70000005, "weighted", None), - (_av_preds_ml, _av_target_ml, 4, [1 / 2, 1 / 2, 1.0, 1 / 2], "none", None), - (_av_preds_ml, _av_target_ml, 4, 5 / 8, "samples", None), - ], -) -def test_average_accuracy(preds, target, num_classes, exp_result, average, mdmc_average): - acc = Accuracy(num_classes=num_classes, average=average, mdmc_average=mdmc_average) - - for batch in range(preds.shape[0]): - acc(preds[batch], target[batch]) + def test_multiclass_accuracy_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassAccuracy, + metric_functional=multiclass_accuracy, + metric_args={"num_classes": NUM_CLASSES}, + ) - assert (acc.compute() == tensor(exp_result)).all() + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_accuracy_half_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassAccuracy, + metric_functional=multiclass_accuracy, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) - # Test functional - total_samples = target.shape[0] * target.shape[1] + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_accuracy_half_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassAccuracy, + metric_functional=multiclass_accuracy, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) - preds = preds.view(total_samples, num_classes, -1) - target = target.view(total_samples, -1) - acc_score = accuracy(preds, target, num_classes=num_classes, average=average, mdmc_average=mdmc_average) - assert (acc_score == tensor(exp_result)).all() +_mc_k_target = torch.tensor([0, 1, 2]) +_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) @pytest.mark.parametrize( - "preds, target, num_classes, exp_result, average, multiclass", + "k, preds, target, average, expected", [ - (_av_preds_bin, _av_target_bin, 2, 19 / 30, "macro", True), - (_av_preds_bin, _av_target_bin, 2, 5 / 8, "weighted", True), - (_av_preds_bin, _av_target_bin, 2, [3 / 5, 2 / 3], "none", True), - (_av_preds_bin, _av_target_bin, 2, 5 / 8, "samples", True), + (1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3)), + (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(3 / 3)), ], ) -def test_average_accuracy_bin(preds, target, num_classes, exp_result, average, multiclass): - acc = Accuracy(num_classes=num_classes, average=average, multiclass=multiclass) - - for batch in range(preds.shape[0]): - acc(preds[batch], target[batch]) +def test_top_k(k, preds, target, average, expected): + """A simple test to check that top_k works as expected.""" + class_metric = MulticlassAccuracy(top_k=k, average=average, num_classes=3) + class_metric.update(preds, target) + assert torch.isclose(class_metric.compute(), expected) + assert torch.isclose(multiclass_accuracy(preds, target, top_k=k, average=average, num_classes=3), expected) + + +def _sk_accuracy_multilabel(preds, target, ignore_index, multidim_average, average): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + preds = preds.reshape(*preds.shape[:2], -1) + target = target.reshape(*target.shape[:2], -1) + + if multidim_average == "global": + if average == "micro": + preds = preds.flatten() + target = target.flatten() + target, preds = remove_ignore_index(target, preds, ignore_index) + return _sk_accuracy(target, preds) + + accuracy, weights = [], [] + for i in range(preds.shape[1]): + pred, true = preds[:, i].flatten(), target[:, i].flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) + accuracy.append(_sk_accuracy(true, pred)) + weights.append(confmat[1, 1] + confmat[1, 0]) + res = np.stack(accuracy, axis=0) + + if average == "macro": + return res.mean(0) + elif average == "weighted": + weights = np.stack(weights, 0).astype(float) + weights_norm = weights.sum(-1, keepdims=True) + weights_norm[weights_norm == 0] = 1.0 + return ((weights * res) / weights_norm).sum(-1) + elif average is None or average == "none": + return res + else: + accuracy, weights = [], [] + for i in range(preds.shape[0]): + if average == "micro": + pred, true = preds[i].flatten(), target[i].flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + accuracy.append(_sk_accuracy(true, pred)) + confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) + weights.append(confmat[1, 1] + confmat[1, 0]) + else: + scores, w = [], [] + for j in range(preds.shape[1]): + pred, true = preds[i, j], target[i, j] + true, pred = remove_ignore_index(true, pred, ignore_index) + scores.append(_sk_accuracy(true, pred)) + confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) + w.append(confmat[1, 1] + confmat[1, 0]) + accuracy.append(np.stack(scores)) + weights.append(np.stack(w)) + if average == "micro": + return np.array(accuracy) + res = np.stack(accuracy, 0) + if average == "macro": + return res.mean(-1) + elif average == "weighted": + weights = np.stack(weights, 0).astype(float) + weights_norm = weights.sum(-1, keepdims=True) + weights_norm[weights_norm == 0] = 1.0 + return ((weights * res) / weights_norm).sum(-1) + elif average is None or average == "none": + return res + + +@pytest.mark.parametrize("input", _multilabel_cases) +class TestMultilabelAccuracy(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + def test_multilabel_accuracy(self, ddp, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") - assert (acc.compute() == tensor(exp_result)).all() - - # Test functional - total_samples = target.shape[0] * target.shape[1] + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MultilabelAccuracy, + sk_metric=partial( + _sk_accuracy_multilabel, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + }, + ) - preds = preds.view(total_samples, -1) - target = target.view(total_samples, -1) - acc_score = accuracy(preds, target, num_classes=num_classes, average=average, multiclass=multiclass) - assert (acc_score == tensor(exp_result)).all() + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + def test_multilabel_accuracy_functional(self, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multilabel_accuracy, + sk_metric=partial( + _sk_accuracy_multilabel, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + }, + ) -@pytest.mark.parametrize("metric_class, metric_fn", [(Accuracy, accuracy)]) -@pytest.mark.parametrize( - "ignore_index, expected", [(None, torch.tensor([1.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))] -) -def test_class_not_present(metric_class, metric_fn, ignore_index, expected): - """This tests that when metric is computed per class and a given class is not present in both the `preds` and - `target`, the resulting score is `nan`.""" - preds = torch.tensor([0, 0, 0]) - target = torch.tensor([0, 0, 0]) - num_classes = 2 - - # test functional - result_fn = metric_fn(preds, target, average=AverageMethod.NONE, num_classes=num_classes, ignore_index=ignore_index) - assert torch.allclose(expected, result_fn, equal_nan=True) - - # test class - cl_metric = metric_class(average=AverageMethod.NONE, num_classes=num_classes, ignore_index=ignore_index) - cl_metric(preds, target) - result_cl = cl_metric.compute() - assert torch.allclose(expected, result_cl, equal_nan=True) - - -@pytest.mark.parametrize("average", ["micro", "macro", "weighted"]) -def test_same_input(average): - preds = _input_miss_class.preds - target = _input_miss_class.target - preds_flat = torch.cat(list(preds), dim=0) - target_flat = torch.cat(list(target), dim=0) - - mc = Accuracy(num_classes=NUM_CLASSES, average=average) - for i in range(NUM_BATCHES): - mc.update(preds[i], target[i]) - class_res = mc.compute() - func_res = accuracy(preds_flat, target_flat, num_classes=NUM_CLASSES, average=average) - sk_res = sk_accuracy(target_flat, preds_flat) - - assert torch.allclose(class_res, torch.tensor(sk_res).float()) - assert torch.allclose(func_res, torch.tensor(sk_res).float()) + def test_multilabel_accuracy_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MultilabelAccuracy, + metric_functional=multilabel_accuracy, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_accuracy_half_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelAccuracy, + metric_functional=multilabel_accuracy, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) -@pytest.mark.parametrize( - "preds, target, ignore_index, result", - [ - ( - _binary_with_neg_tgt.preds, - _binary_with_neg_tgt.target, - _binary_with_neg_tgt.ignore_index, - _binary_with_neg_tgt.result, - ), - ( - _multiclass_logits_with_neg_tgt.preds, - _multiclass_logits_with_neg_tgt.target, - _multiclass_logits_with_neg_tgt.ignore_index, - _multiclass_logits_with_neg_tgt.result, - ), - ( - _multidim_multiclass_with_neg_tgt.preds, - _multidim_multiclass_with_neg_tgt.target, - _multidim_multiclass_with_neg_tgt.ignore_index, - _multidim_multiclass_with_neg_tgt.result, - ), - ( - _multidim_multiclass_logits_with_neg_tgt.preds, - _multidim_multiclass_logits_with_neg_tgt.target, - _multidim_multiclass_logits_with_neg_tgt.ignore_index, - _multidim_multiclass_logits_with_neg_tgt.result, - ), - ], -) -def test_negative_ignore_index(preds, target, ignore_index, result): - # We deduct -1 for an ignored index - num_classes = len(target.unique()) - 1 - - # Test class - acc = Accuracy(num_classes=num_classes, ignore_index=ignore_index) - acc_score = acc(preds, target) - assert torch.allclose(acc_score, result) - # Test functional metrics - acc_score = accuracy(preds, target, num_classes=num_classes, ignore_index=ignore_index) - assert torch.allclose(acc_score, result) - - # If the ignore index is not set properly, we expect to see an error - ignore_index = None - # Test class - acc = Accuracy(num_classes=num_classes, ignore_index=ignore_index) - with pytest.raises(ValueError, match="^[The `target` has to be a non-negative tensor.]"): - acc_score = acc(preds, target) - - # Test functional - with pytest.raises(ValueError, match="^[The `target` has to be a non-negative tensor.]"): - acc_score = accuracy(preds, target, num_classes=num_classes, ignore_index=ignore_index) - - -def test_negmetric_noneavg(noneavg=_negmetric_noneavg): - acc = MetricWrapper(Accuracy(average="none", num_classes=noneavg["pred1"].shape[1])) - result1 = acc(noneavg["pred1"], noneavg["target1"]) - assert torch.allclose(noneavg["res1"], result1, equal_nan=True) - result2 = acc(noneavg["pred2"], noneavg["target2"]) - assert torch.allclose(noneavg["res2"], result2, equal_nan=True) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_accuracy_half_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelAccuracy, + metric_functional=multilabel_accuracy, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) diff --git a/tests/unittests/classification/test_auroc.py b/tests/unittests/classification/test_auroc.py index 691150f4639..45b47c561f7 100644 --- a/tests/unittests/classification/test_auroc.py +++ b/tests/unittests/classification/test_auroc.py @@ -13,228 +13,349 @@ # limitations under the License. from functools import partial +import numpy as np import pytest import torch +from scipy.special import expit as sigmoid +from scipy.special import softmax from sklearn.metrics import roc_auc_score as sk_roc_auc_score -from torchmetrics.classification.auroc import AUROC -from torchmetrics.functional import auroc -from torchmetrics.utilities.imports import _TORCH_LOWER_1_6 -from unittests.classification.inputs import _input_binary_prob -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from torchmetrics.classification.auroc import BinaryAUROC, MulticlassAUROC, MultilabelAUROC +from torchmetrics.functional.classification.auroc import binary_auroc, multiclass_auroc, multilabel_auroc +from torchmetrics.functional.classification.roc import binary_roc +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_LOWER_1_6 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all -from unittests.helpers.testers import NUM_CLASSES, MetricTester +from unittests.helpers.testers import NUM_CLASSES, MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sk_auroc_binary_prob(preds, target, num_classes, average="macro", max_fpr=None, multi_class="ovr"): - # todo: `multi_class` is unused - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_roc_auc_score(y_true=sk_target, y_score=sk_preds, average=average, max_fpr=max_fpr) - - -def _sk_auroc_multiclass_prob(preds, target, num_classes, average="macro", max_fpr=None, multi_class="ovr"): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.view(-1).numpy() - return sk_roc_auc_score( - y_true=sk_target, - y_score=sk_preds, - average=average, - max_fpr=max_fpr, - multi_class=multi_class, - ) - - -def _sk_auroc_multidim_multiclass_prob(preds, target, num_classes, average="macro", max_fpr=None, multi_class="ovr"): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.view(-1).numpy() - return sk_roc_auc_score( - y_true=sk_target, - y_score=sk_preds, - average=average, - max_fpr=max_fpr, - multi_class=multi_class, - ) - - -def _sk_auroc_multilabel_prob(preds, target, num_classes, average="macro", max_fpr=None, multi_class="ovr"): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.reshape(-1, num_classes).numpy() - return sk_roc_auc_score( - y_true=sk_target, - y_score=sk_preds, - average=average, - max_fpr=max_fpr, - multi_class=multi_class, - ) - - -def _sk_auroc_multilabel_multidim_prob(preds, target, num_classes, average="macro", max_fpr=None, multi_class="ovr"): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - return sk_roc_auc_score( - y_true=sk_target, - y_score=sk_preds, - average=average, - max_fpr=max_fpr, - multi_class=multi_class, - ) - - -@pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5]) -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", - [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_auroc_binary_prob, 1), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_auroc_multiclass_prob, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_auroc_multidim_multiclass_prob, NUM_CLASSES), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_auroc_multilabel_prob, NUM_CLASSES), - (_input_mlmd_prob.preds, _input_mlmd_prob.target, _sk_auroc_multilabel_multidim_prob, NUM_CLASSES), - ], -) -class TestAUROC(MetricTester): - @pytest.mark.parametrize("average", ["macro", "weighted", "micro"]) +def _sk_auroc_binary(preds, target, max_fpr=None, ignore_index=None): + preds = preds.flatten().numpy() + target = target.flatten().numpy() + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_roc_auc_score(target, preds, max_fpr=max_fpr) + + +@pytest.mark.parametrize("input", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5])) +class TestBinaryAUROC(MetricTester): + @pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_auroc(self, preds, target, sk_metric, num_classes, average, max_fpr, ddp, dist_sync_on_step): - # max_fpr different from None is not support in multi class - if max_fpr is not None and num_classes != 1: - pytest.skip("max_fpr parameter not support for multi class or multi label") + def test_binary_auroc(self, input, ddp, max_fpr, ignore_index): + if max_fpr is not None and _TORCH_LOWER_1_6: + pytest.skip("requires torch v1.6 or higher to test max_fpr argument") + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryAUROC, + sk_metric=partial(_sk_auroc_binary, max_fpr=max_fpr, ignore_index=ignore_index), + metric_args={ + "max_fpr": max_fpr, + "thresholds": None, + "ignore_index": ignore_index, + }, + ) - # max_fpr only supported for torch v1.6 or higher + @pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + def test_binary_auroc_functional(self, input, max_fpr, ignore_index): if max_fpr is not None and _TORCH_LOWER_1_6: pytest.skip("requires torch v1.6 or higher to test max_fpr argument") + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_auroc, + sk_metric=partial(_sk_auroc_binary, max_fpr=max_fpr, ignore_index=ignore_index), + metric_args={ + "max_fpr": max_fpr, + "thresholds": None, + "ignore_index": ignore_index, + }, + ) - # average='micro' only supported for multilabel - if average == "micro" and preds.ndim > 2 and preds.ndim == target.ndim + 1: - pytest.skip("micro argument only support for multilabel input") + def test_binary_auroc_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryAUROC, + metric_functional=binary_auroc, + metric_args={"thresholds": None}, + ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_auroc_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryAUROC, + metric_functional=binary_auroc, + metric_args={"thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_auroc_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryAUROC, + metric_functional=binary_auroc, + metric_args={"thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.parametrize("threshold_fn", [lambda x: x, lambda x: x.numpy().tolist()], ids=["as tensor", "as list"]) + def test_binary_auroc_threshold_arg(self, input, threshold_fn): + preds, target = input + + for pred, true in zip(preds, target): + _, _, t = binary_roc(pred, true, thresholds=None) + ap1 = binary_auroc(pred, true, thresholds=None) + ap2 = binary_auroc(pred, true, thresholds=threshold_fn(t.flip(0))) + assert torch.allclose(ap1, ap2) + + +def _sk_auroc_multiclass(preds, target, average="macro", ignore_index=None): + preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) + target = target.numpy().flatten() + if not ((0 < preds) & (preds < 1)).all(): + preds = softmax(preds, 1) + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_roc_auc_score(target, preds, average=average, multi_class="ovr", labels=list(range(NUM_CLASSES))) + + +@pytest.mark.parametrize( + "input", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5]) +) +class TestMulticlassAUROC(MetricTester): + @pytest.mark.parametrize("average", ["macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_auroc(self, input, average, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=AUROC, - sk_metric=partial(sk_metric, num_classes=num_classes, average=average, max_fpr=max_fpr), - dist_sync_on_step=dist_sync_on_step, - metric_args={"num_classes": num_classes, "average": average, "max_fpr": max_fpr}, + metric_class=MulticlassAUROC, + sk_metric=partial(_sk_auroc_multiclass, average=average, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_classes": NUM_CLASSES, + "average": average, + "ignore_index": ignore_index, + }, ) - @pytest.mark.parametrize("average", ["macro", "weighted", "micro"]) - def test_auroc_functional(self, preds, target, sk_metric, num_classes, average, max_fpr): - # max_fpr different from None is not support in multi class - if max_fpr is not None and num_classes != 1: - pytest.skip("max_fpr parameter not support for multi class or multi label") + @pytest.mark.parametrize("average", ["macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + def test_multiclass_auroc_functional(self, input, average, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_auroc, + sk_metric=partial(_sk_auroc_multiclass, average=average, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_classes": NUM_CLASSES, + "average": average, + "ignore_index": ignore_index, + }, + ) - # max_fpr only supported for torch v1.6 or higher - if max_fpr is not None and _TORCH_LOWER_1_6: - pytest.skip("requires torch v1.6 or higher to test max_fpr argument") + def test_multiclass_auroc_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassAUROC, + metric_functional=multiclass_auroc, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + ) - # average='micro' only supported for multilabel - if average == "micro" and preds.ndim > 2 and preds.ndim == target.ndim + 1: - pytest.skip("micro argument only support for multilabel input") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_auroc_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if dtype == torch.half and not ((0 < preds) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassAUROC, + metric_functional=multiclass_auroc, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) - self.run_functional_metric_test( - preds, - target, - metric_functional=auroc, - sk_metric=partial(sk_metric, num_classes=num_classes, average=average, max_fpr=max_fpr), - metric_args={"num_classes": num_classes, "average": average, "max_fpr": max_fpr}, + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_auroc_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassAUROC, + metric_functional=multiclass_auroc, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, ) - def test_auroc_differentiability(self, preds, target, sk_metric, num_classes, max_fpr): - # max_fpr different from None is not support in multi class - if max_fpr is not None and num_classes != 1: - pytest.skip("max_fpr parameter not support for multi class or multi label") + @pytest.mark.parametrize("average", ["macro", "weighted", None]) + def test_multiclass_auroc_threshold_arg(self, input, average): + preds, target = input + if (preds < 0).any(): + preds = preds.softmax(dim=-1) + for pred, true in zip(preds, target): + pred = torch.tensor(np.round(pred.numpy(), 2)) + 1e-6 # rounding will simulate binning + ap1 = multiclass_auroc(pred, true, num_classes=NUM_CLASSES, average=average, thresholds=None) + ap2 = multiclass_auroc( + pred, true, num_classes=NUM_CLASSES, average=average, thresholds=torch.linspace(0, 1, 100) + ) + assert torch.allclose(ap1, ap2) + + +def _sk_auroc_multilabel(preds, target, average="macro", ignore_index=None): + if ignore_index is None: + if preds.ndim > 2: + target = target.transpose(2, 1).reshape(-1, NUM_CLASSES) + preds = preds.transpose(2, 1).reshape(-1, NUM_CLASSES) + target = target.numpy() + preds = preds.numpy() + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + return sk_roc_auc_score(target, preds, average=average, max_fpr=None) + if average == "micro": + return _sk_auroc_binary(preds.flatten(), target.flatten(), max_fpr=None, ignore_index=ignore_index) + res = [] + for i in range(NUM_CLASSES): + res.append(_sk_auroc_binary(preds[:, i], target[:, i], max_fpr=None, ignore_index=ignore_index)) + if average == "macro": + return np.array(res)[~np.isnan(res)].mean() + if average == "weighted": + weights = ((target == 1).sum([0, 2]) if target.ndim == 3 else (target == 1).sum(0)).numpy() + weights = weights / sum(weights) + return (np.array(res) * weights)[~np.isnan(res)].sum() + return res - # max_fpr only supported for torch v1.6 or higher - if max_fpr is not None and _TORCH_LOWER_1_6: - pytest.skip("requires torch v1.6 or higher to test max_fpr argument") +@pytest.mark.parametrize( + "input", (_multilabel_cases[1], _multilabel_cases[2], _multilabel_cases[4], _multilabel_cases[5]) +) +class TestMultilabelAUROC(MetricTester): + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multilabel_auroc(self, input, ddp, average, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MultilabelAUROC, + sk_metric=partial(_sk_auroc_multilabel, average=average, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_labels": NUM_CLASSES, + "average": average, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + def test_multilabel_auroc_functional(self, input, average, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multilabel_auroc, + sk_metric=partial(_sk_auroc_multilabel, average=average, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_labels": NUM_CLASSES, + "average": average, + "ignore_index": ignore_index, + }, + ) + + def test_multiclass_auroc_differentiability(self, input): + preds, target = input self.run_differentiability_test( preds=preds, target=target, - metric_module=AUROC, - metric_functional=auroc, - metric_args={"num_classes": num_classes, "max_fpr": max_fpr}, + metric_module=MultilabelAUROC, + metric_functional=multilabel_auroc, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_auroc_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if dtype == torch.half and not ((0 < preds) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelAUROC, + metric_functional=multilabel_auroc, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_auroc_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelAUROC, + metric_functional=multilabel_auroc, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, + ) -def test_error_on_different_mode(): - """test that an error is raised if the user pass in data of different modes (binary, multi-label, multi- - class)""" - metric = AUROC() - # pass in multi-class data - metric.update(torch.randn(10, 5).softmax(dim=-1), torch.randint(0, 5, (10,))) - with pytest.raises(ValueError, match=r"The mode of data.* should be constant.*"): - # pass in multi-label data - metric.update(torch.rand(10, 5), torch.randint(0, 2, (10, 5))) - - -def test_error_multiclass_no_num_classes(): - with pytest.raises( - ValueError, match="Detected input to `multiclass` but you did not provide `num_classes` argument" - ): - _ = auroc(torch.randn(20, 3).softmax(dim=-1), torch.randint(3, (20,))) - - -@pytest.mark.parametrize("device", ["cpu", "cuda"]) -def test_weighted_with_empty_classes(device): - """Tests that weighted multiclass AUROC calculation yields the same results if a new but empty class exists. - - Tests that the proper warnings and errors are raised - """ - if not torch.cuda.is_available() and device == "cuda": - pytest.skip("Test requires gpu to run") - - preds = torch.tensor( - [ - [0.90, 0.05, 0.05], - [0.05, 0.90, 0.05], - [0.05, 0.05, 0.90], - [0.85, 0.05, 0.10], - [0.10, 0.10, 0.80], - ] - ).to(device) - target = torch.tensor([0, 1, 1, 2, 2]).to(device) - num_classes = 3 - _auroc = auroc(preds, target, average="weighted", num_classes=num_classes) - - # Add in a class with zero observations at second to last index - preds = torch.cat( - (preds[:, : num_classes - 1], torch.rand_like(preds[:, 0:1]), preds[:, num_classes - 1 :]), axis=1 - ) - # Last class (2) gets moved to 3 - target[target == num_classes - 1] = num_classes - with pytest.warns(UserWarning, match="Class 2 had 0 observations, omitted from AUROC calculation"): - _auroc_empty_class = auroc(preds, target, average="weighted", num_classes=num_classes + 1) - assert _auroc == _auroc_empty_class - - target = torch.zeros_like(target) - with pytest.raises(ValueError, match="Found 1 non-empty class in `multiclass` AUROC calculation"): - _ = auroc(preds, target, average="weighted", num_classes=num_classes + 1) - - -def test_warnings_on_missing_class(): - """Test that a warning is given if either the positive or negative class is missing.""" - metric = AUROC() - # no positive samples - warning = ( - "No positive samples in targets, true positive value should be meaningless." - " Returning zero tensor in true positive score" - ) - with pytest.warns(UserWarning, match=warning): - score = metric(torch.randn(10).sigmoid(), torch.zeros(10).int()) - assert score == 0 - - warning = ( - "No negative samples in targets, false positive value should be meaningless." - " Returning zero tensor in false positive score" - ) - with pytest.warns(UserWarning, match=warning): - score = metric(torch.randn(10).sigmoid(), torch.ones(10).int()) - assert score == 0 + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + def test_multilabel_auroc_threshold_arg(self, input, average): + preds, target = input + if (preds < 0).any(): + preds = sigmoid(preds) + for pred, true in zip(preds, target): + pred = torch.tensor(np.round(pred.numpy(), 1)) + 1e-6 # rounding will simulate binning + ap1 = multilabel_auroc(pred, true, num_labels=NUM_CLASSES, average=average, thresholds=None) + ap2 = multilabel_auroc( + pred, true, num_labels=NUM_CLASSES, average=average, thresholds=torch.linspace(0, 1, 100) + ) + assert torch.allclose(ap1, ap2) diff --git a/tests/unittests/classification/test_average_precision.py b/tests/unittests/classification/test_average_precision.py index 7938486b2c0..8a331b214da 100644 --- a/tests/unittests/classification/test_average_precision.py +++ b/tests/unittests/classification/test_average_precision.py @@ -15,157 +15,351 @@ import numpy as np import pytest +import torch +from scipy.special import expit as sigmoid +from scipy.special import softmax from sklearn.metrics import average_precision_score as sk_average_precision_score -from torch import tensor - -from torchmetrics.classification.avg_precision import AveragePrecision -from torchmetrics.functional import average_precision -from unittests.classification.inputs import _input_binary_prob -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel + +from torchmetrics.classification.average_precision import ( + BinaryAveragePrecision, + MulticlassAveragePrecision, + MultilabelAveragePrecision, +) +from torchmetrics.functional.classification.average_precision import ( + binary_average_precision, + multiclass_average_precision, + multilabel_average_precision, +) +from torchmetrics.functional.classification.precision_recall_curve import binary_precision_recall_curve +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_8 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all -from unittests.helpers.testers import NUM_CLASSES, MetricTester +from unittests.helpers.testers import NUM_CLASSES, MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sk_average_precision_score(y_true, probas_pred, num_classes=1, average=None): - if num_classes == 1: - return sk_average_precision_score(y_true, probas_pred) +def _sk_average_precision_binary(preds, target, ignore_index=None): + preds = preds.flatten().numpy() + target = target.flatten().numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_average_precision_score(target, preds) - res = [] - for i in range(num_classes): - y_true_temp = np.zeros_like(y_true) - y_true_temp[y_true == i] = 1 - res.append(sk_average_precision_score(y_true_temp, probas_pred[:, i])) +@pytest.mark.parametrize("input", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5])) +class TestBinaryAveragePrecision(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_binary_average_precision(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryAveragePrecision, + sk_metric=partial(_sk_average_precision_binary, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_binary_average_precision_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_average_precision, + sk_metric=partial(_sk_average_precision_binary, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "ignore_index": ignore_index, + }, + ) + + def test_binary_average_precision_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryAveragePrecision, + metric_functional=binary_average_precision, + metric_args={"thresholds": None}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_average_precision_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8: + pytest.xfail(reason="torch.flip not support before pytorch v1.8 for cpu + half precision") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryAveragePrecision, + metric_functional=binary_average_precision, + metric_args={"thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_average_precision_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryAveragePrecision, + metric_functional=binary_average_precision, + metric_args={"thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.parametrize("threshold_fn", [lambda x: x, lambda x: x.numpy().tolist()], ids=["as tensor", "as list"]) + def test_binary_average_precision_threshold_arg(self, input, threshold_fn): + preds, target = input + + for pred, true in zip(preds, target): + _, _, t = binary_precision_recall_curve(pred, true, thresholds=None) + ap1 = binary_average_precision(pred, true, thresholds=None) + ap2 = binary_average_precision(pred, true, thresholds=threshold_fn(t)) + assert torch.allclose(ap1, ap2) + + +def _sk_average_precision_multiclass(preds, target, average="macro", ignore_index=None): + preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) + target = target.numpy().flatten() + if not ((0 < preds) & (preds < 1)).all(): + preds = softmax(preds, 1) + target, preds = remove_ignore_index(target, preds, ignore_index) + + res = [] + for i in range(NUM_CLASSES): + y_true_temp = np.zeros_like(target) + y_true_temp[target == i] = 1 + res.append(sk_average_precision_score(y_true_temp, preds[:, i])) if average == "macro": - return np.array(res).mean() + return np.array(res)[~np.isnan(res)].mean() if average == "weighted": - weights = np.bincount(y_true) if y_true.max() > 1 else y_true.sum(axis=0) + weights = np.bincount(target) weights = weights / sum(weights) - return (np.array(res) * weights).sum() - + return (np.array(res) * weights)[~np.isnan(res)].sum() return res -def _sk_avg_prec_binary_prob(preds, target, num_classes=1, average=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, average=average) +@pytest.mark.parametrize( + "input", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5]) +) +class TestMulticlassAveragePrecision(MetricTester): + @pytest.mark.parametrize("average", ["macro", "weighted", None]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_average_precision(self, input, average, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassAveragePrecision, + sk_metric=partial(_sk_average_precision_multiclass, average=average, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_classes": NUM_CLASSES, + "average": average, + "ignore_index": ignore_index, + }, + ) + @pytest.mark.parametrize("average", ["macro", "weighted", None]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + def test_multiclass_average_precision_functional(self, input, average, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_average_precision, + sk_metric=partial(_sk_average_precision_multiclass, average=average, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_classes": NUM_CLASSES, + "average": average, + "ignore_index": ignore_index, + }, + ) -def _sk_avg_prec_multiclass_prob(preds, target, num_classes=1, average=None): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.view(-1).numpy() + def test_multiclass_average_precision_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassAveragePrecision, + metric_functional=multiclass_average_precision, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + ) - return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, average=average) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_average_precision_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8: + pytest.xfail(reason="torch.flip not support before pytorch v1.8 for cpu + half precision") + if dtype == torch.half and not ((0 < preds) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassAveragePrecision, + metric_functional=multiclass_average_precision, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_average_precision_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassAveragePrecision, + metric_functional=multiclass_average_precision, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) -def _sk_avg_prec_multilabel_prob(preds, target, num_classes=1, average=None): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.view(-1, num_classes).numpy() - return sk_average_precision_score(sk_target, sk_preds, average=average) + @pytest.mark.parametrize("average", ["macro", "weighted", None]) + def test_multiclass_average_precision_threshold_arg(self, input, average): + preds, target = input + if (preds < 0).any(): + preds = preds.softmax(dim=-1) + for pred, true in zip(preds, target): + pred = torch.tensor(np.round(pred.numpy(), 2)) + 1e-6 # rounding will simulate binning + ap1 = multiclass_average_precision(pred, true, num_classes=NUM_CLASSES, average=average, thresholds=None) + ap2 = multiclass_average_precision( + pred, true, num_classes=NUM_CLASSES, average=average, thresholds=torch.linspace(0, 1, 100) + ) + assert torch.allclose(ap1, ap2) -def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1, average=None): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.view(-1).numpy() - return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, average=average) +def _sk_average_precision_multilabel(preds, target, average="macro", ignore_index=None): + if average == "micro": + return _sk_average_precision_binary(preds.flatten(), target.flatten(), ignore_index) + res = [] + for i in range(NUM_CLASSES): + res.append(_sk_average_precision_binary(preds[:, i], target[:, i], ignore_index)) + if average == "macro": + return np.array(res)[~np.isnan(res)].mean() + if average == "weighted": + weights = ((target == 1).sum([0, 2]) if target.ndim == 3 else (target == 1).sum(0)).numpy() + weights = weights / sum(weights) + return (np.array(res) * weights)[~np.isnan(res)].sum() + return res @pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", - [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_avg_prec_binary_prob, 1), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_avg_prec_multiclass_prob, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_avg_prec_multidim_multiclass_prob, NUM_CLASSES), - (_input_multilabel.preds, _input_multilabel.target, _sk_avg_prec_multilabel_prob, NUM_CLASSES), - ], + "input", (_multilabel_cases[1], _multilabel_cases[2], _multilabel_cases[4], _multilabel_cases[5]) ) -class TestAveragePrecision(MetricTester): +class TestMultilabelAveragePrecision(MetricTester): @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_average_precision(self, preds, target, sk_metric, num_classes, average, ddp, dist_sync_on_step): - if target.max() > 1 and average == "micro": - pytest.skip("average=micro and multiclass input cannot be used together") - + def test_multilabel_average_precision(self, input, ddp, average, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=AveragePrecision, - sk_metric=partial(sk_metric, num_classes=num_classes, average=average), - dist_sync_on_step=dist_sync_on_step, - metric_args={"num_classes": num_classes, "average": average}, + metric_class=MultilabelAveragePrecision, + sk_metric=partial(_sk_average_precision_multilabel, average=average, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_labels": NUM_CLASSES, + "average": average, + "ignore_index": ignore_index, + }, ) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) - def test_average_precision_functional(self, preds, target, sk_metric, num_classes, average): - if target.max() > 1 and average == "micro": - pytest.skip("average=micro and multiclass input cannot be used together") - + @pytest.mark.parametrize("ignore_index", [None, -1]) + def test_multilabel_average_precision_functional(self, input, average, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_functional_metric_test( preds=preds, target=target, - metric_functional=average_precision, - sk_metric=partial(sk_metric, num_classes=num_classes, average=average), - metric_args={"num_classes": num_classes, "average": average}, + metric_functional=multilabel_average_precision, + sk_metric=partial(_sk_average_precision_multilabel, average=average, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_labels": NUM_CLASSES, + "average": average, + "ignore_index": ignore_index, + }, ) - def test_average_precision_differentiability(self, preds, sk_metric, target, num_classes): + def test_multiclass_average_precision_differentiability(self, input): + preds, target = input self.run_differentiability_test( preds=preds, target=target, - metric_module=AveragePrecision, - metric_functional=average_precision, - metric_args={"num_classes": num_classes}, + metric_module=MultilabelAveragePrecision, + metric_functional=multilabel_average_precision, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_average_precision_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8: + pytest.xfail(reason="torch.flip not support before pytorch v1.8 for cpu + half precision") + if dtype == torch.half and not ((0 < preds) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelAveragePrecision, + metric_functional=multilabel_average_precision, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_average_precision_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelAveragePrecision, + metric_functional=multilabel_average_precision, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, + ) -@pytest.mark.parametrize( - ["scores", "target", "expected_score"], - [ - # Check the average_precision_score of a constant predictor is - # the TPR - # Generate a dataset with 25% of positives - # And a constant score - # The precision is then the fraction of positive whatever the recall - # is, as there is only one threshold: - (tensor([1, 1, 1, 1]), tensor([0, 0, 0, 1]), 0.25), - # With threshold 0.8 : 1 TP and 2 TN and one FN - (tensor([0.6, 0.7, 0.8, 9]), tensor([1, 0, 0, 1]), 0.75), - ], -) -def test_average_precision(scores, target, expected_score): - assert average_precision(scores, target) == expected_score - - -def test_average_precision_warnings_and_errors(): - """Test that the correct errors and warnings gets raised.""" - - # check average argument - with pytest.raises(ValueError, match="Expected argument `average` to be one .*"): - AveragePrecision(num_classes=5, average="samples") - - # check that micro average cannot be used with multilabel input - pred = tensor( - [ - [0.75, 0.05, 0.05, 0.05, 0.05], - [0.05, 0.75, 0.05, 0.05, 0.05], - [0.05, 0.05, 0.75, 0.05, 0.05], - [0.05, 0.05, 0.05, 0.75, 0.05], - ] - ) - target = tensor([0, 1, 3, 2]) - average_precision = AveragePrecision(num_classes=5, average="micro") - with pytest.raises(ValueError, match="Cannot use `micro` average with multi-class input"): - average_precision(pred, target) - - # check that warning is thrown when average=macro and nan is encoutered in individual scores - average_precision = AveragePrecision(num_classes=5, average="macro") - with pytest.warns(UserWarning, match="Average precision score for one or more classes was `nan`.*"): - average_precision(pred, target) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + def test_multilabel_average_precision_threshold_arg(self, input, average): + preds, target = input + if (preds < 0).any(): + preds = sigmoid(preds) + for pred, true in zip(preds, target): + pred = torch.tensor(np.round(pred.numpy(), 1)) + 1e-6 # rounding will simulate binning + ap1 = multilabel_average_precision(pred, true, num_labels=NUM_CLASSES, average=average, thresholds=None) + ap2 = multilabel_average_precision( + pred, true, num_labels=NUM_CLASSES, average=average, thresholds=torch.linspace(0, 1, 100) + ) + assert torch.allclose(ap1, ap2) diff --git a/tests/unittests/classification/test_calibration_error.py b/tests/unittests/classification/test_calibration_error.py index 52263f2ce11..4060ec590a9 100644 --- a/tests/unittests/classification/test_calibration_error.py +++ b/tests/unittests/classification/test_calibration_error.py @@ -1,122 +1,219 @@ -import functools -import re +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial import numpy as np import pytest -from scipy.special import softmax as _softmax - -from torchmetrics import CalibrationError -from torchmetrics.functional import calibration_error -from torchmetrics.utilities.checks import _input_format_classification -from torchmetrics.utilities.enums import DataType -from unittests.classification.inputs import _input_binary_logits, _input_binary_prob -from unittests.classification.inputs import _input_multiclass_logits as _input_mcls_logits -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob -from unittests.helpers import seed_all +import torch +from netcal.metrics import ECE, MCE +from scipy.special import expit as sigmoid +from scipy.special import softmax -# TODO: replace this with official sklearn implementation after next sklearn release -from unittests.helpers.reference_metrics import _calibration_error as sk_calib -from unittests.helpers.testers import THRESHOLD, MetricTester +from torchmetrics.classification.calibration_error import BinaryCalibrationError, MulticlassCalibrationError +from torchmetrics.functional.classification.calibration_error import ( + binary_calibration_error, + multiclass_calibration_error, +) +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_9 +from unittests.classification.inputs import _binary_cases, _multiclass_cases +from unittests.helpers import seed_all +from unittests.helpers.testers import NUM_CLASSES, MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sk_calibration(preds, target, n_bins, norm, debias=False): - _, _, mode = _input_format_classification(preds, target, threshold=THRESHOLD) - sk_preds, sk_target = preds.numpy(), target.numpy() - if mode == DataType.BINARY: - if not np.logical_and(0 <= sk_preds, sk_preds <= 1).all(): - sk_preds = 1.0 / (1 + np.exp(-sk_preds)) # sigmoid transform - if mode == DataType.MULTICLASS: - if not np.logical_and(0 <= sk_preds, sk_preds <= 1).all(): - sk_preds = _softmax(sk_preds, axis=1) - # binary label is whether or not the predicted class is correct - sk_target = np.equal(np.argmax(sk_preds, axis=1), sk_target) - sk_preds = np.max(sk_preds, axis=1) - elif mode == DataType.MULTIDIM_MULTICLASS: - # reshape from shape (N, C, ...) to (N*EXTRA_DIMS, C) - sk_preds = np.transpose(sk_preds, axes=(0, 2, 1)) - sk_preds = sk_preds.reshape(np.prod(sk_preds.shape[:-1]), sk_preds.shape[-1]) - # reshape from shape (N, ...) to (N*EXTRA_DIMS,) - # binary label is whether or not the predicted class is correct - sk_target = np.equal(np.argmax(sk_preds, axis=1), sk_target.flatten()) - sk_preds = np.max(sk_preds, axis=1) - return sk_calib(y_true=sk_target, y_prob=sk_preds, norm=norm, n_bins=n_bins, reduce_bias=debias) - - -@pytest.mark.parametrize("n_bins", [10, 15, 20]) -@pytest.mark.parametrize("norm", ["l1", "l2", "max"]) -@pytest.mark.parametrize( - "preds, target", - [ - (_input_binary_prob.preds, _input_binary_prob.target), - (_input_binary_logits.preds, _input_binary_logits.target), - (_input_mcls_prob.preds, _input_mcls_prob.target), - (_input_mcls_logits.preds, _input_mcls_logits.target), - (_input_mdmc_prob.preds, _input_mdmc_prob.target), - ], -) -class TestCE(MetricTester): +def _sk_binary_calibration_error(preds, target, n_bins, norm, ignore_index): + preds = preds.numpy().flatten() + target = target.numpy().flatten() + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + target, preds = remove_ignore_index(target, preds, ignore_index) + metric = ECE if norm == "l1" else MCE + return metric(n_bins).measure(preds, target) + + +@pytest.mark.parametrize("input", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5])) +class TestBinaryCalibrationError(MetricTester): + @pytest.mark.parametrize("n_bins", [10, 15, 20]) + @pytest.mark.parametrize("norm", ["l1", "max"]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_ce(self, preds, target, n_bins, ddp, dist_sync_on_step, norm): + def test_binary_calibration_error(self, input, ddp, n_bins, norm, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=CalibrationError, - sk_metric=functools.partial(_sk_calibration, n_bins=n_bins, norm=norm), - dist_sync_on_step=dist_sync_on_step, - metric_args={"n_bins": n_bins, "norm": norm}, + metric_class=BinaryCalibrationError, + sk_metric=partial(_sk_binary_calibration_error, n_bins=n_bins, norm=norm, ignore_index=ignore_index), + metric_args={ + "n_bins": n_bins, + "norm": norm, + "ignore_index": ignore_index, + }, ) - def test_ce_functional(self, preds, target, n_bins, norm): + @pytest.mark.parametrize("n_bins", [10, 15, 20]) + @pytest.mark.parametrize("norm", ["l1", "max"]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_binary_calibration_error_functional(self, input, n_bins, norm, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_functional_metric_test( - preds, - target, - metric_functional=calibration_error, - sk_metric=functools.partial(_sk_calibration, n_bins=n_bins, norm=norm), - metric_args={"n_bins": n_bins, "norm": norm}, + preds=preds, + target=target, + metric_functional=binary_calibration_error, + sk_metric=partial(_sk_binary_calibration_error, n_bins=n_bins, norm=norm, ignore_index=ignore_index), + metric_args={ + "n_bins": n_bins, + "norm": norm, + "ignore_index": ignore_index, + }, + ) + + def test_binary_calibration_error_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryCalibrationError, + metric_functional=binary_calibration_error, ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_calibration_error_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryCalibrationError, + metric_functional=binary_calibration_error, + dtype=dtype, + ) -@pytest.mark.parametrize("preds, targets", [(_input_mlb_prob.preds, _input_mlb_prob.target)]) -def test_invalid_input(preds, targets): - for p, t in zip(preds, targets): - with pytest.raises( - ValueError, - match=re.escape( - f"Calibration error is not well-defined for data with size {p.size()} and targets {t.size()}." - ), - ): - calibration_error(p, t) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_calibration_error_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryCalibrationError, + metric_functional=binary_calibration_error, + dtype=dtype, + ) -@pytest.mark.parametrize( - "preds, target", - [ - (_input_binary_prob.preds, _input_binary_prob.target), - (_input_mcls_prob.preds, _input_mcls_prob.target), - (_input_mdmc_prob.preds, _input_mdmc_prob.target), - ], -) -def test_invalid_norm(preds, target): - with pytest.raises(ValueError, match="Norm l3 is not supported. Please select from l1, l2, or max. "): - calibration_error(preds, target, norm="l3") +def _sk_multiclass_calibration_error(preds, target, n_bins, norm, ignore_index): + preds = preds.numpy() + target = target.numpy().flatten() + if not ((0 < preds) & (preds < 1)).all(): + preds = softmax(preds, 1) + preds = np.moveaxis(preds, 1, -1).reshape((-1, preds.shape[1])) + target, preds = remove_ignore_index(target, preds, ignore_index) + metric = ECE if norm == "l1" else MCE + return metric(n_bins).measure(preds, target) -@pytest.mark.parametrize("n_bins", [-10, -1, "fsd"]) @pytest.mark.parametrize( - "preds, targets", - [ - (_input_binary_prob.preds, _input_binary_prob.target), - (_input_mcls_prob.preds, _input_mcls_prob.target), - (_input_mdmc_prob.preds, _input_mdmc_prob.target), - ], + "input", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5]) ) -def test_invalid_bins(preds, targets, n_bins): - for p, t in zip(preds, targets): - with pytest.raises(ValueError, match=f"Expected argument `n_bins` to be a int larger than 0 but got {n_bins}"): - calibration_error(p, t, n_bins=n_bins) +class TestMulticlassCalibrationError(MetricTester): + @pytest.mark.parametrize("n_bins", [15, 20]) + @pytest.mark.parametrize("norm", ["l1", "max"]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_calibration_error(self, input, ddp, n_bins, norm, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassCalibrationError, + sk_metric=partial(_sk_multiclass_calibration_error, n_bins=n_bins, norm=norm, ignore_index=ignore_index), + metric_args={ + "num_classes": NUM_CLASSES, + "n_bins": n_bins, + "norm": norm, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("n_bins", [15, 20]) + @pytest.mark.parametrize("norm", ["l1", "max"]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multiclass_calibration_error_functional(self, input, n_bins, norm, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_calibration_error, + sk_metric=partial(_sk_multiclass_calibration_error, n_bins=n_bins, norm=norm, ignore_index=ignore_index), + metric_args={ + "num_classes": NUM_CLASSES, + "n_bins": n_bins, + "norm": norm, + "ignore_index": ignore_index, + }, + ) + + def test_multiclass_calibration_error_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassCalibrationError, + metric_functional=multiclass_calibration_error, + metric_args={"num_classes": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_calibration_error_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_9: + pytest.xfail(reason="torch.max in metric not supported before pytorch v1.9 for cpu + half") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.softmax in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassCalibrationError, + metric_functional=multiclass_calibration_error, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_calibration_error_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassCalibrationError, + metric_functional=multiclass_calibration_error, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) diff --git a/tests/unittests/classification/test_cohen_kappa.py b/tests/unittests/classification/test_cohen_kappa.py index d0a2b1c9c9c..39955efdeb0 100644 --- a/tests/unittests/classification/test_cohen_kappa.py +++ b/tests/unittests/classification/test_cohen_kappa.py @@ -1,133 +1,215 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from functools import partial import numpy as np import pytest import torch +from scipy.special import expit as sigmoid from sklearn.metrics import cohen_kappa_score as sk_cohen_kappa -from torchmetrics.classification.cohen_kappa import CohenKappa -from torchmetrics.functional.classification.cohen_kappa import cohen_kappa -from unittests.classification.inputs import _input_binary, _input_binary_prob -from unittests.classification.inputs import _input_multiclass as _input_mcls -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass as _input_mdmc -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel as _input_mlb -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, MulticlassCohenKappa +from torchmetrics.functional.classification.cohen_kappa import binary_cohen_kappa, multiclass_cohen_kappa +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from unittests.classification.inputs import _binary_cases, _multiclass_cases from unittests.helpers import seed_all -from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester +from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sk_cohen_kappa_binary_prob(preds, target, weights=None): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() +def _sk_cohen_kappa_binary(preds, target, weights=None, ignore_index=None): + preds = preds.view(-1).numpy() + target = target.view(-1).numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_cohen_kappa(y1=target, y2=preds, weights=weights) - return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) +@pytest.mark.parametrize("input", _binary_cases) +class TestBinaryCohenKappa(MetricTester): + atol = 1e-5 -def _sk_cohen_kappa_binary(preds, target, weights=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) - - -def _sk_cohen_kappa_multilabel_prob(preds, target, weights=None): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() - - return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) - - -def _sk_cohen_kappa_multilabel(preds, target, weights=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) - - -def _sk_cohen_kappa_multiclass_prob(preds, target, weights=None): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) - - -def _sk_cohen_kappa_multiclass(preds, target, weights=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) + @pytest.mark.parametrize("weights", ["linear", "quadratic", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_binary_cohen_kappa(self, input, ddp, weights, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryCohenKappa, + sk_metric=partial(_sk_cohen_kappa_binary, weights=weights, ignore_index=ignore_index), + metric_args={ + "threshold": THRESHOLD, + "weights": weights, + "ignore_index": ignore_index, + }, + ) + @pytest.mark.parametrize("weights", ["linear", "quadratic", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_binary_confusion_matrix_functional(self, input, weights, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_cohen_kappa, + sk_metric=partial(_sk_cohen_kappa_binary, weights=weights, ignore_index=ignore_index), + metric_args={ + "threshold": THRESHOLD, + "weights": weights, + "ignore_index": ignore_index, + }, + ) -def _sk_cohen_kappa_multidim_multiclass_prob(preds, target, weights=None): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() - sk_target = target.view(-1).numpy() + def test_binary_cohen_kappa_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryCohenKappa, + metric_functional=binary_cohen_kappa, + metric_args={"threshold": THRESHOLD}, + ) - return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_cohen_kappa_dtypes_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryCohenKappa, + metric_functional=binary_cohen_kappa, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_confusion_matrix_dtypes_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryCohenKappa, + metric_functional=binary_cohen_kappa, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) -def _sk_cohen_kappa_multidim_multiclass(preds, target, weights=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) +def _sk_cohen_kappa_multiclass(preds, target, weights, ignore_index=None): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + preds = np.argmax(preds, axis=1) + preds = preds.flatten() + target = target.flatten() + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_cohen_kappa(y1=target, y2=preds, weights=weights) -@pytest.mark.parametrize("weights", ["linear", "quadratic", None]) -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", - [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_cohen_kappa_binary_prob, 2), - (_input_binary.preds, _input_binary.target, _sk_cohen_kappa_binary, 2), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cohen_kappa_multilabel_prob, 2), - (_input_mlb.preds, _input_mlb.target, _sk_cohen_kappa_multilabel, 2), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cohen_kappa_multiclass_prob, NUM_CLASSES), - (_input_mcls.preds, _input_mcls.target, _sk_cohen_kappa_multiclass, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cohen_kappa_multidim_multiclass_prob, NUM_CLASSES), - (_input_mdmc.preds, _input_mdmc.target, _sk_cohen_kappa_multidim_multiclass, NUM_CLASSES), - ], -) -class TestCohenKappa(MetricTester): +@pytest.mark.parametrize("input", _multiclass_cases) +class TestMulticlassCohenKappa(MetricTester): atol = 1e-5 + @pytest.mark.parametrize("weights", ["linear", "quadratic", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_cohen_kappa(self, weights, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + def test_multiclass_cohen_kappa(self, input, ddp, weights, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=CohenKappa, - sk_metric=partial(sk_metric, weights=weights), - dist_sync_on_step=dist_sync_on_step, - metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "weights": weights}, + metric_class=MulticlassCohenKappa, + sk_metric=partial(_sk_cohen_kappa_multiclass, weights=weights, ignore_index=ignore_index), + metric_args={ + "num_classes": NUM_CLASSES, + "weights": weights, + "ignore_index": ignore_index, + }, ) - def test_cohen_kappa_functional(self, weights, preds, target, sk_metric, num_classes): + @pytest.mark.parametrize("weights", ["linear", "quadratic", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multiclass_confusion_matrix_functional(self, input, weights, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_functional_metric_test( - preds, - target, - metric_functional=cohen_kappa, - sk_metric=partial(sk_metric, weights=weights), - metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "weights": weights}, + preds=preds, + target=target, + metric_functional=multiclass_cohen_kappa, + sk_metric=partial(_sk_cohen_kappa_multiclass, weights=weights, ignore_index=ignore_index), + metric_args={ + "num_classes": NUM_CLASSES, + "weights": weights, + "ignore_index": ignore_index, + }, ) - def test_cohen_kappa_differentiability(self, preds, target, sk_metric, weights, num_classes): + def test_multiclass_cohen_kappa_differentiability(self, input): + preds, target = input self.run_differentiability_test( preds=preds, target=target, - metric_module=CohenKappa, - metric_functional=cohen_kappa, - metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "weights": weights}, + metric_module=MulticlassCohenKappa, + metric_functional=multiclass_cohen_kappa, + metric_args={"num_classes": NUM_CLASSES}, ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_cohen_kappa_dtypes_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassCohenKappa, + metric_functional=multiclass_cohen_kappa, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) -def test_warning_on_wrong_weights(tmpdir): - preds = torch.randint(3, size=(20,)) - target = torch.randint(3, size=(20,)) - - with pytest.raises(ValueError, match=".* ``weights`` but should be either None, 'linear' or 'quadratic'"): - cohen_kappa(preds, target, num_classes=3, weights="unknown_arg") + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_confusion_matrix_dtypes_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassCohenKappa, + metric_functional=multiclass_cohen_kappa, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) diff --git a/tests/unittests/classification/test_confusion_matrix.py b/tests/unittests/classification/test_confusion_matrix.py index f54ca95c7d1..2d58377e206 100644 --- a/tests/unittests/classification/test_confusion_matrix.py +++ b/tests/unittests/classification/test_confusion_matrix.py @@ -12,194 +12,314 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Any, Dict import numpy as np import pytest import torch +from scipy.special import expit as sigmoid from sklearn.metrics import confusion_matrix as sk_confusion_matrix -from sklearn.metrics import multilabel_confusion_matrix as sk_multilabel_confusion_matrix - -from torchmetrics import JaccardIndex -from torchmetrics.classification.confusion_matrix import ConfusionMatrix -from torchmetrics.functional.classification.confusion_matrix import confusion_matrix -from unittests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob -from unittests.classification.inputs import _input_multiclass as _input_mcls -from unittests.classification.inputs import _input_multiclass_logits as _input_mcls_logits -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass as _input_mdmc -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel as _input_mlb -from unittests.classification.inputs import _input_multilabel_logits as _input_mlb_logits -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob + +from torchmetrics.classification.confusion_matrix import ( + BinaryConfusionMatrix, + MulticlassConfusionMatrix, + MultilabelConfusionMatrix, +) +from torchmetrics.functional.classification.confusion_matrix import ( + binary_confusion_matrix, + multiclass_confusion_matrix, + multilabel_confusion_matrix, +) +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all -from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester +from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sk_cm_binary_prob(preds, target, normalize=None): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() - - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) - - -def _sk_cm_binary(preds, target, normalize=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) - - -def _sk_cm_multilabel_prob(preds, target, normalize=None): - sk_preds = (preds.numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.numpy() - - cm = sk_multilabel_confusion_matrix(y_true=sk_target, y_pred=sk_preds) - if normalize is not None: - if normalize == "true": - cm = cm / cm.sum(axis=1, keepdims=True) - elif normalize == "pred": - cm = cm / cm.sum(axis=0, keepdims=True) - elif normalize == "all": - cm = cm / cm.sum() - cm[np.isnan(cm)] = 0 - return cm - +def _sk_confusion_matrix_binary(preds, target, normalize=None, ignore_index=None): + preds = preds.view(-1).numpy() + target = target.view(-1).numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_confusion_matrix(y_true=target, y_pred=preds, labels=[0, 1], normalize=normalize) -def _sk_cm_multilabel(preds, target, normalize=None): - sk_preds = preds.numpy() - sk_target = target.numpy() - cm = sk_multilabel_confusion_matrix(y_true=sk_target, y_pred=sk_preds) - if normalize is not None: - if normalize == "true": - cm = cm / cm.sum(axis=1, keepdims=True) - elif normalize == "pred": - cm = cm / cm.sum(axis=0, keepdims=True) - elif normalize == "all": - cm = cm / cm.sum() - cm[np.isnan(cm)] = 0 - return cm +@pytest.mark.parametrize("input", _binary_cases) +class TestBinaryConfusionMatrix(MetricTester): + @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_binary_confusion_matrix(self, input, ddp, normalize, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryConfusionMatrix, + sk_metric=partial(_sk_confusion_matrix_binary, normalize=normalize, ignore_index=ignore_index), + metric_args={ + "threshold": THRESHOLD, + "normalize": normalize, + "ignore_index": ignore_index, + }, + ) + @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_binary_confusion_matrix_functional(self, input, normalize, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_confusion_matrix, + sk_metric=partial(_sk_confusion_matrix_binary, normalize=normalize, ignore_index=ignore_index), + metric_args={ + "threshold": THRESHOLD, + "normalize": normalize, + "ignore_index": ignore_index, + }, + ) -def _sk_cm_multiclass_prob(preds, target, normalize=None): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() - sk_target = target.view(-1).numpy() + def test_binary_confusion_matrix_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryConfusionMatrix, + metric_functional=binary_confusion_matrix, + metric_args={"threshold": THRESHOLD}, + ) - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_confusion_matrix_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryConfusionMatrix, + metric_functional=binary_confusion_matrix, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_confusion_matrix_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryConfusionMatrix, + metric_functional=binary_confusion_matrix, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) -def _sk_cm_multiclass(preds, target, normalize=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) +def _sk_confusion_matrix_multiclass(preds, target, normalize=None, ignore_index=None): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + preds = np.argmax(preds, axis=1) + preds = preds.flatten() + target = target.flatten() + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_confusion_matrix(y_true=target, y_pred=preds, normalize=normalize, labels=list(range(NUM_CLASSES))) -def _sk_cm_multidim_multiclass_prob(preds, target, normalize=None): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() - sk_target = target.view(-1).numpy() +@pytest.mark.parametrize("input", _multiclass_cases) +class TestMulticlassConfusionMatrix(MetricTester): + @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_confusion_matrix(self, input, ddp, normalize, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassConfusionMatrix, + sk_metric=partial(_sk_confusion_matrix_multiclass, normalize=normalize, ignore_index=ignore_index), + metric_args={ + "num_classes": NUM_CLASSES, + "normalize": normalize, + "ignore_index": ignore_index, + }, + ) - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multiclass_confusion_matrix_functional(self, input, normalize, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_confusion_matrix, + sk_metric=partial(_sk_confusion_matrix_multiclass, normalize=normalize, ignore_index=ignore_index), + metric_args={ + "num_classes": NUM_CLASSES, + "normalize": normalize, + "ignore_index": ignore_index, + }, + ) + def test_multiclass_confusion_matrix_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassConfusionMatrix, + metric_functional=multiclass_confusion_matrix, + metric_args={"num_classes": NUM_CLASSES}, + ) -def _sk_cm_multidim_multiclass(preds, target, normalize=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_confusion_matrix_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassConfusionMatrix, + metric_functional=multiclass_confusion_matrix, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_confusion_matrix_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassConfusionMatrix, + metric_functional=multiclass_confusion_matrix, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) -@pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes, multilabel", - [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_cm_binary_prob, 2, False), - (_input_binary_logits.preds, _input_binary_logits.target, _sk_cm_binary_prob, 2, False), - (_input_binary.preds, _input_binary.target, _sk_cm_binary, 2, False), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cm_multilabel_prob, NUM_CLASSES, True), - (_input_mlb_logits.preds, _input_mlb_logits.target, _sk_cm_multilabel_prob, NUM_CLASSES, True), - (_input_mlb.preds, _input_mlb.target, _sk_cm_multilabel, NUM_CLASSES, True), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cm_multiclass_prob, NUM_CLASSES, False), - (_input_mcls_logits.preds, _input_mcls_logits.target, _sk_cm_multiclass_prob, NUM_CLASSES, False), - (_input_mcls.preds, _input_mcls.target, _sk_cm_multiclass, NUM_CLASSES, False), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cm_multidim_multiclass_prob, NUM_CLASSES, False), - (_input_mdmc.preds, _input_mdmc.target, _sk_cm_multidim_multiclass, NUM_CLASSES, False), - ], -) -class TestConfusionMatrix(MetricTester): +def _sk_confusion_matrix_multilabel(preds, target, normalize=None, ignore_index=None): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + preds = np.moveaxis(preds, 1, -1).reshape((-1, preds.shape[1])) + target = np.moveaxis(target, 1, -1).reshape((-1, target.shape[1])) + confmat = [] + for i in range(preds.shape[1]): + pred, true = preds[:, i], target[:, i] + true, pred = remove_ignore_index(true, pred, ignore_index) + confmat.append(sk_confusion_matrix(true, pred, normalize=normalize, labels=[0, 1])) + return np.stack(confmat, axis=0) + + +@pytest.mark.parametrize("input", _multilabel_cases) +class TestMultilabelConfusionMatrix(MetricTester): + @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_confusion_matrix( - self, normalize, preds, target, sk_metric, num_classes, multilabel, ddp, dist_sync_on_step - ): + def test_multilabel_confusion_matrix(self, input, ddp, normalize, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=ConfusionMatrix, - sk_metric=partial(sk_metric, normalize=normalize), - dist_sync_on_step=dist_sync_on_step, + metric_class=MultilabelConfusionMatrix, + sk_metric=partial(_sk_confusion_matrix_multilabel, normalize=normalize, ignore_index=ignore_index), metric_args={ - "num_classes": num_classes, - "threshold": THRESHOLD, + "num_labels": NUM_CLASSES, "normalize": normalize, - "multilabel": multilabel, + "ignore_index": ignore_index, }, ) - def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric, num_classes, multilabel): + @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multilabel_confusion_matrix_functional(self, input, normalize, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_functional_metric_test( preds=preds, target=target, - metric_functional=confusion_matrix, - sk_metric=partial(sk_metric, normalize=normalize), + metric_functional=multilabel_confusion_matrix, + sk_metric=partial(_sk_confusion_matrix_multilabel, normalize=normalize, ignore_index=ignore_index), metric_args={ - "num_classes": num_classes, - "threshold": THRESHOLD, + "num_labels": NUM_CLASSES, "normalize": normalize, - "multilabel": multilabel, + "ignore_index": ignore_index, }, ) - def test_confusion_matrix_differentiability(self, normalize, preds, target, sk_metric, num_classes, multilabel): + def test_multilabel_confusion_matrix_differentiability(self, input): + preds, target = input self.run_differentiability_test( preds=preds, target=target, - metric_module=ConfusionMatrix, - metric_functional=confusion_matrix, - metric_args={ - "num_classes": num_classes, - "threshold": THRESHOLD, - "normalize": normalize, - "multilabel": multilabel, - }, + metric_module=MultilabelConfusionMatrix, + metric_functional=multilabel_confusion_matrix, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_confusion_matrix_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelConfusionMatrix, + metric_functional=multilabel_confusion_matrix, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_confusion_matrix_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelConfusionMatrix, + metric_functional=multilabel_confusion_matrix, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) -def test_warning_on_nan(tmpdir): + +def test_warning_on_nan(): preds = torch.randint(3, size=(20,)) target = torch.randint(3, size=(20,)) with pytest.warns( UserWarning, - match=".* nan values found in confusion matrix have been replaced with zeros.", + match=".* NaN values found in confusion matrix have been replaced with zeros.", ): - confusion_matrix(preds, target, num_classes=5, normalize="true") - - -@pytest.mark.parametrize( - "metric_args", - [ - {"num_classes": 1, "normalize": "true"}, - {"num_classes": 1, "normalize": "pred"}, - {"num_classes": 1, "normalize": "all"}, - {"num_classes": 1, "normalize": "none"}, - {"num_classes": 1, "normalize": None}, - ], -) -def test_provide_superclass_kwargs(metric_args: Dict[str, Any]): - """Test instantiating subclasses with superclass arguments as kwargs.""" - JaccardIndex(**metric_args) + multiclass_confusion_matrix(preds, target, num_classes=5, normalize="true") diff --git a/tests/unittests/classification/test_exact_match.py b/tests/unittests/classification/test_exact_match.py new file mode 100644 index 00000000000..863100e4ad0 --- /dev/null +++ b/tests/unittests/classification/test_exact_match.py @@ -0,0 +1,151 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import numpy as np +import pytest +import torch +from scipy.special import expit as sigmoid + +from torchmetrics.classification.exact_match import MultilabelExactMatch +from torchmetrics.functional.classification.exact_match import multilabel_exact_match +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from unittests.classification.inputs import _multilabel_cases +from unittests.helpers import seed_all +from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index + +seed_all(42) + + +def _sk_exact_match_multilabel(preds, target, ignore_index, multidim_average): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + preds = preds.reshape(*preds.shape[:2], -1) + target = target.reshape(*target.shape[:2], -1) + + if ignore_index is not None: + target = np.copy(target) + target[target == ignore_index] = -1 + + if multidim_average == "global": + preds = np.moveaxis(preds, 1, -1).reshape(-1, NUM_CLASSES) + target = np.moveaxis(target, 1, -1).reshape(-1, NUM_CLASSES) + correct = ((preds == target).sum(1) == NUM_CLASSES).sum() + total = preds.shape[0] + else: + correct = ((preds == target).sum(1) == NUM_CLASSES).sum(1) + total = preds.shape[2] + return correct / total + + +@pytest.mark.parametrize("input", _multilabel_cases) +class TestMultilabelExactMatch(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + def test_multilabel_exact_match(self, ddp, input, ignore_index, multidim_average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MultilabelExactMatch, + sk_metric=partial( + _sk_exact_match_multilabel, + ignore_index=ignore_index, + multidim_average=multidim_average, + ), + metric_args={ + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + }, + ) + + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + def test_multilabel_exact_match_functional(self, input, ignore_index, multidim_average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") + + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multilabel_exact_match, + sk_metric=partial( + _sk_exact_match_multilabel, + ignore_index=ignore_index, + multidim_average=multidim_average, + ), + metric_args={ + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + }, + ) + + def test_multilabel_exact_match_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MultilabelExactMatch, + metric_functional=multilabel_exact_match, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_exact_match_half_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelExactMatch, + metric_functional=multilabel_exact_match, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_exact_match_half_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelExactMatch, + metric_functional=multilabel_exact_match, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) diff --git a/tests/unittests/classification/test_f_beta.py b/tests/unittests/classification/test_f_beta.py index d4ae4801cd4..868b3f9613c 100644 --- a/tests/unittests/classification/test_f_beta.py +++ b/tests/unittests/classification/test_f_beta.py @@ -12,391 +12,311 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Callable, Optional import numpy as np import pytest import torch -from sklearn.metrics import f1_score, fbeta_score +from scipy.special import expit as sigmoid +from sklearn.metrics import confusion_matrix as sk_confusion_matrix +from sklearn.metrics import f1_score as sk_f1_score +from sklearn.metrics import fbeta_score as sk_fbeta_score from torch import Tensor -from torchmetrics import F1Score, FBetaScore, Metric -from torchmetrics.functional import f1_score as f1_score_pl -from torchmetrics.functional import fbeta_score as fbeta_score_pl -from torchmetrics.utilities.checks import _input_format_classification -from torchmetrics.utilities.enums import AverageMethod -from unittests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob -from unittests.classification.inputs import _input_multiclass as _input_mcls -from unittests.classification.inputs import _input_multiclass_logits as _input_mcls_logits -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class -from unittests.classification.inputs import _input_multidim_multiclass as _input_mdmc -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel as _input_mlb -from unittests.classification.inputs import _input_multilabel_logits as _input_mlb_logits -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from torchmetrics.classification.f_beta import ( + BinaryF1Score, + BinaryFBetaScore, + MulticlassF1Score, + MulticlassFBetaScore, + MultilabelF1Score, + MultilabelFBetaScore, +) +from torchmetrics.functional.classification.f_beta import ( + binary_f1_score, + binary_fbeta_score, + multiclass_f1_score, + multiclass_fbeta_score, + multilabel_f1_score, + multilabel_fbeta_score, +) +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all -from unittests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester +from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sk_fbeta_f1(preds, target, sk_fn, num_classes, average, multiclass, ignore_index, mdmc_average=None): - if average == "none": - average = None - if num_classes == 1: - average = "binary" - - labels = list(range(num_classes)) - try: - labels.remove(ignore_index) - except ValueError: - pass - - sk_preds, sk_target, _ = _input_format_classification( - preds, target, THRESHOLD, num_classes=num_classes, multiclass=multiclass - ) - sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() - sk_scores = sk_fn(sk_target, sk_preds, average=average, zero_division=0, labels=labels) - - if len(labels) != num_classes and not average: - sk_scores = np.insert(sk_scores, ignore_index, np.nan) - - return sk_scores - - -def _sk_fbeta_f1_multidim_multiclass( - preds, target, sk_fn, num_classes, average, multiclass, ignore_index, mdmc_average -): - preds, target, _ = _input_format_classification( - preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass - ) - - if mdmc_average == "global": - preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) - target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) - - return _sk_fbeta_f1(preds, target, sk_fn, num_classes, average, False, ignore_index) - if mdmc_average == "samplewise": - scores = [] - - for i in range(preds.shape[0]): - pred_i = preds[i, ...].T - target_i = target[i, ...].T - scores_i = _sk_fbeta_f1(pred_i, target_i, sk_fn, num_classes, average, False, ignore_index) +def _sk_fbeta_score_binary(preds, target, sk_fn, ignore_index, multidim_average): + if multidim_average == "global": + preds = preds.view(-1).numpy() + target = target.view(-1).numpy() + else: + preds = preds.numpy() + target = target.numpy() - scores.append(np.expand_dims(scores_i, 0)) + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) - return np.concatenate(scores).mean(axis=0) + if multidim_average == "global": + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_fn(target, preds) + else: + res = [] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + res.append(sk_fn(true, pred)) + return np.stack(res) +@pytest.mark.parametrize("input", _binary_cases) @pytest.mark.parametrize( - "metric_class, metric_fn", + "module, functional, compare", [ - (partial(FBetaScore, beta=2.0), partial(fbeta_score_pl, beta=2.0)), - (F1Score, f1_score_pl), + (BinaryF1Score, binary_f1_score, sk_f1_score), + (partial(BinaryFBetaScore, beta=2.0), partial(binary_fbeta_score, beta=2.0), partial(sk_fbeta_score, beta=2.0)), ], + ids=["f1", "fbeta"], ) -@pytest.mark.parametrize( - "average, mdmc_average, num_classes, ignore_index, match_str", - [ - ("wrong", None, None, None, "`average`"), - ("micro", "wrong", None, None, "`mdmc"), - ("macro", None, None, None, "number of classes"), - ("macro", None, 1, 0, "ignore_index"), - ], -) -def test_wrong_params(metric_class, metric_fn, average, mdmc_average, num_classes, ignore_index, match_str): - with pytest.raises(ValueError, match=match_str): - metric_class( - average=average, - mdmc_average=mdmc_average, - num_classes=num_classes, - ignore_index=ignore_index, - ) +class TestBinaryFBetaScore(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("ddp", [False, True]) + def test_binary_fbeta_score(self, ddp, input, module, functional, compare, ignore_index, multidim_average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") - with pytest.raises(ValueError, match=match_str): - metric_fn( - _input_binary.preds[0], - _input_binary.target[0], - average=average, - mdmc_average=mdmc_average, - num_classes=num_classes, - ignore_index=ignore_index, + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=module, + sk_metric=partial( + _sk_fbeta_score_binary, sk_fn=compare, ignore_index=ignore_index, multidim_average=multidim_average + ), + metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, ) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + def test_binary_fbeta_score_functional(self, input, module, functional, compare, ignore_index, multidim_average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") -@pytest.mark.parametrize( - "metric_class, metric_fn", - [ - (partial(FBetaScore, beta=2.0), partial(fbeta_score_pl, beta=2.0)), - (F1Score, f1_score_pl), - ], -) -def test_zero_division(metric_class, metric_fn): - """Test that zero_division works correctly (currently should just set to 0).""" - - preds = torch.tensor([1, 2, 1, 1]) - target = torch.tensor([2, 0, 2, 1]) - - cl_metric = metric_class(average="none", num_classes=3) - cl_metric(preds, target) - - result_cl = cl_metric.compute() - result_fn = metric_fn(preds, target, average="none", num_classes=3) - - assert result_cl[0] == result_fn[0] == 0 - - -@pytest.mark.parametrize( - "metric_class, metric_fn", - [ - (partial(FBetaScore, beta=2.0), partial(fbeta_score_pl, beta=2.0)), - (F1Score, f1_score_pl), - ], -) -def test_no_support(metric_class, metric_fn): - """This tests a rare edge case, where there is only one class present. - - in target, and ignore_index is set to exactly that class - and the - average method is equal to 'weighted'. - - This would mean that the sum of weights equals zero, and would, without - taking care of this case, return NaN. However, the reduction function - should catch that and set the metric to equal the value of zero_division - in this case (zero_division is for now not configurable and equals 0). - """ - - preds = torch.tensor([1, 1, 0, 0]) - target = torch.tensor([0, 0, 0, 0]) - - cl_metric = metric_class(average="weighted", num_classes=2, ignore_index=0) - cl_metric(preds, target) - - result_cl = cl_metric.compute() - result_fn = metric_fn(preds, target, average="weighted", num_classes=2, ignore_index=0) - - assert result_cl == result_fn == 0 - + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=functional, + sk_metric=partial( + _sk_fbeta_score_binary, sk_fn=compare, ignore_index=ignore_index, multidim_average=multidim_average + ), + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + }, + ) -@pytest.mark.parametrize( - "metric_class, metric_fn", - [ - (partial(FBetaScore, beta=2.0), partial(fbeta_score_pl, beta=2.0)), - (F1Score, f1_score_pl), - ], -) -@pytest.mark.parametrize( - "ignore_index, expected", [(None, torch.tensor([1.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))] -) -def test_class_not_present(metric_class, metric_fn, ignore_index, expected): - """This tests that when metric is computed per class and a given class is not present in both the `preds` and - `target`, the resulting score is `nan`.""" - preds = torch.tensor([0, 0, 0]) - target = torch.tensor([0, 0, 0]) - num_classes = 2 + def test_binary_fbeta_score_differentiability(self, input, module, functional, compare): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"threshold": THRESHOLD}, + ) - # test functional - result_fn = metric_fn(preds, target, average=AverageMethod.NONE, num_classes=num_classes, ignore_index=ignore_index) - assert torch.allclose(expected, result_fn, equal_nan=True) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_fbeta_score_half_cpu(self, input, module, functional, compare, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) - # test class - cl_metric = metric_class(average=AverageMethod.NONE, num_classes=num_classes, ignore_index=ignore_index) - cl_metric(preds, target) - result_cl = cl_metric.compute() - assert torch.allclose(expected, result_cl, equal_nan=True) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_fbeta_score_half_gpu(self, input, module, functional, compare, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) +def _sk_fbeta_score_multiclass(preds, target, sk_fn, ignore_index, multidim_average, average): + if preds.ndim == target.ndim + 1: + preds = torch.argmax(preds, 1) + if multidim_average == "global": + preds = preds.numpy().flatten() + target = target.numpy().flatten() + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_fn(target, preds, average=average) + else: + preds = preds.numpy() + target = target.numpy() + res = [] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + res.append(sk_fn(true, pred, average=average, labels=list(range(NUM_CLASSES)))) + return np.stack(res, 0) + + +@pytest.mark.parametrize("input", _multiclass_cases) @pytest.mark.parametrize( - "metric_class, metric_fn, sk_fn", - [ - (partial(FBetaScore, beta=2.0), partial(fbeta_score_pl, beta=2.0), partial(fbeta_score, beta=2.0)), - (F1Score, f1_score_pl, f1_score), - ], -) -@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) -@pytest.mark.parametrize("ignore_index", [None, 0]) -@pytest.mark.parametrize( - "preds, target, num_classes, multiclass, mdmc_average, sk_wrapper", + "module, functional, compare", [ - (_input_binary_logits.preds, _input_binary_logits.target, 1, None, None, _sk_fbeta_f1), - (_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_fbeta_f1), - (_input_binary.preds, _input_binary.target, 1, False, None, _sk_fbeta_f1), - (_input_mlb_logits.preds, _input_mlb_logits.target, NUM_CLASSES, None, None, _sk_fbeta_f1), - (_input_mlb_prob.preds, _input_mlb_prob.target, NUM_CLASSES, None, None, _sk_fbeta_f1), - (_input_mlb.preds, _input_mlb.target, NUM_CLASSES, False, None, _sk_fbeta_f1), - (_input_mcls_logits.preds, _input_mcls_logits.target, NUM_CLASSES, None, None, _sk_fbeta_f1), - (_input_mcls_prob.preds, _input_mcls_prob.target, NUM_CLASSES, None, None, _sk_fbeta_f1), - (_input_mcls.preds, _input_mcls.target, NUM_CLASSES, None, None, _sk_fbeta_f1), - (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "global", _sk_fbeta_f1_multidim_multiclass), + (MulticlassF1Score, multiclass_f1_score, sk_f1_score), ( - _input_mdmc_prob.preds, - _input_mdmc_prob.target, - NUM_CLASSES, - None, - "global", - _sk_fbeta_f1_multidim_multiclass, - ), - (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "samplewise", _sk_fbeta_f1_multidim_multiclass), - ( - _input_mdmc_prob.preds, - _input_mdmc_prob.target, - NUM_CLASSES, - None, - "samplewise", - _sk_fbeta_f1_multidim_multiclass, + partial(MulticlassFBetaScore, beta=2.0), + partial(multiclass_fbeta_score, beta=2.0), + partial(sk_fbeta_score, beta=2.0), ), ], + ids=["f1", "fbeta"], ) -class TestFBeta(MetricTester): +class TestMulticlassFBetaScore(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_fbeta_f1( - self, - ddp: bool, - dist_sync_on_step: bool, - preds: Tensor, - target: Tensor, - sk_wrapper: Callable, - metric_class: Metric, - metric_fn: Callable, - sk_fn: Callable, - multiclass: Optional[bool], - num_classes: Optional[int], - average: str, - mdmc_average: Optional[str], - ignore_index: Optional[int], + def test_multiclass_fbeta_score( + self, ddp, input, module, functional, compare, ignore_index, multidim_average, average ): - if num_classes == 1 and average != "micro": - pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - - if average == "weighted" and ignore_index is not None and mdmc_average is not None: - pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and target.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=metric_class, + metric_class=module, sk_metric=partial( - sk_wrapper, - sk_fn=sk_fn, - average=average, - num_classes=num_classes, - multiclass=multiclass, + _sk_fbeta_score_multiclass, + sk_fn=compare, ignore_index=ignore_index, - mdmc_average=mdmc_average, + multidim_average=multidim_average, + average=average, ), - dist_sync_on_step=dist_sync_on_step, metric_args={ - "num_classes": num_classes, - "average": average, - "threshold": THRESHOLD, - "multiclass": multiclass, "ignore_index": ignore_index, - "mdmc_average": mdmc_average, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, }, ) - def test_fbeta_f1_functional( - self, - preds: Tensor, - target: Tensor, - sk_wrapper: Callable, - metric_class: Metric, - metric_fn: Callable, - sk_fn: Callable, - multiclass: Optional[bool], - num_classes: Optional[int], - average: str, - mdmc_average: Optional[str], - ignore_index: Optional[int], + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + def test_multiclass_fbeta_score_functional( + self, input, module, functional, compare, ignore_index, multidim_average, average ): - if num_classes == 1 and average != "micro": - pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - - if average == "weighted" and ignore_index is not None and mdmc_average is not None: - pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and target.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") self.run_functional_metric_test( - preds, - target, - metric_functional=metric_fn, + preds=preds, + target=target, + metric_functional=functional, sk_metric=partial( - sk_wrapper, - sk_fn=sk_fn, - average=average, - num_classes=num_classes, - multiclass=multiclass, + _sk_fbeta_score_multiclass, + sk_fn=compare, ignore_index=ignore_index, - mdmc_average=mdmc_average, + multidim_average=multidim_average, + average=average, ), metric_args={ - "num_classes": num_classes, - "average": average, - "threshold": THRESHOLD, - "multiclass": multiclass, "ignore_index": ignore_index, - "mdmc_average": mdmc_average, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, }, ) - def test_fbeta_f1_differentiability( - self, - preds: Tensor, - target: Tensor, - sk_wrapper: Callable, - metric_class: Metric, - metric_fn: Callable, - sk_fn: Callable, - multiclass: Optional[bool], - num_classes: Optional[int], - average: str, - mdmc_average: Optional[str], - ignore_index: Optional[int], - ): - if num_classes == 1 and average != "micro": - pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") + def test_multiclass_fbeta_score_differentiability(self, input, module, functional, compare): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"num_classes": NUM_CLASSES}, + ) - if average == "weighted" and ignore_index is not None and mdmc_average is not None: - pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_fbeta_score_half_cpu(self, input, module, functional, compare, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) - self.run_differentiability_test( - preds, - target, - metric_functional=metric_fn, - metric_module=metric_class, - metric_args={ - "num_classes": num_classes, - "average": average, - "threshold": THRESHOLD, - "multiclass": multiclass, - "ignore_index": ignore_index, - "mdmc_average": mdmc_average, - }, + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_fbeta_score_half_gpu(self, input, module, functional, compare, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, ) _mc_k_target = torch.tensor([0, 1, 2]) _mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) -_ml_k_target = torch.tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) -_ml_k_preds = torch.tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) @pytest.mark.parametrize( "metric_class, metric_fn", [ - (partial(FBetaScore, beta=2.0), partial(fbeta_score_pl, beta=2.0)), - (F1Score, fbeta_score_pl), + (partial(MulticlassFBetaScore, beta=2.0), partial(multiclass_fbeta_score, beta=2.0)), + (MulticlassF1Score, multiclass_f1_score), ], ) @pytest.mark.parametrize( @@ -404,8 +324,6 @@ def test_fbeta_f1_differentiability( [ (1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)), (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(5 / 6), torch.tensor(2 / 3)), - (1, _ml_k_preds, _ml_k_target, "micro", torch.tensor(0.0), torch.tensor(0.0)), - (2, _ml_k_preds, _ml_k_target, "micro", torch.tensor(5 / 18), torch.tensor(2 / 9)), ], ) def test_top_k( @@ -418,10 +336,7 @@ def test_top_k( expected_fbeta: Tensor, expected_f1: Tensor, ): - """A simple test to check that top_k works as expected. - - Just a sanity check, the tests in StatScores should already guarantee the corectness of results. - """ + """A simple test to check that top_k works as expected.""" class_metric = metric_class(top_k=k, average=average, num_classes=3) class_metric.update(preds, target) @@ -434,29 +349,203 @@ def test_top_k( assert torch.isclose(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result) -@pytest.mark.parametrize("ignore_index", [None, 2]) -@pytest.mark.parametrize("average", ["micro", "macro", "weighted"]) +def _sk_fbeta_score_multilabel_global(preds, target, sk_fn, ignore_index, average): + if average == "micro": + preds = preds.flatten() + target = target.flatten() + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_fn(target, preds) + + fbeta_score, weights = [], [] + for i in range(preds.shape[1]): + pred, true = preds[:, i].flatten(), target[:, i].flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + fbeta_score.append(sk_fn(true, pred)) + confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) + weights.append(confmat[1, 1] + confmat[1, 0]) + res = np.stack(fbeta_score, axis=0) + + if average == "macro": + return res.mean(0) + elif average == "weighted": + weights = np.stack(weights, 0).astype(float) + weights_norm = weights.sum(-1, keepdims=True) + weights_norm[weights_norm == 0] = 1.0 + return ((weights * res) / weights_norm).sum(-1) + elif average is None or average == "none": + return res + + +def _sk_fbeta_score_multilabel_local(preds, target, sk_fn, ignore_index, average): + fbeta_score, weights = [], [] + for i in range(preds.shape[0]): + if average == "micro": + pred, true = preds[i].flatten(), target[i].flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + fbeta_score.append(sk_fn(true, pred)) + confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) + weights.append(confmat[1, 1] + confmat[1, 0]) + else: + scores, w = [], [] + for j in range(preds.shape[1]): + pred, true = preds[i, j], target[i, j] + true, pred = remove_ignore_index(true, pred, ignore_index) + scores.append(sk_fn(true, pred)) + confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) + w.append(confmat[1, 1] + confmat[1, 0]) + fbeta_score.append(np.stack(scores)) + weights.append(np.stack(w)) + if average == "micro": + return np.array(fbeta_score) + res = np.stack(fbeta_score, 0) + if average == "macro": + return res.mean(-1) + elif average == "weighted": + weights = np.stack(weights, 0).astype(float) + weights_norm = weights.sum(-1, keepdims=True) + weights_norm[weights_norm == 0] = 1.0 + return ((weights * res) / weights_norm).sum(-1) + elif average is None or average == "none": + return res + + +def _sk_fbeta_score_multilabel(preds, target, sk_fn, ignore_index, multidim_average, average): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + preds = preds.reshape(*preds.shape[:2], -1) + target = target.reshape(*target.shape[:2], -1) + if ignore_index is None and multidim_average == "global": + return sk_fn( + target.transpose(0, 2, 1).reshape(-1, NUM_CLASSES), + preds.transpose(0, 2, 1).reshape(-1, NUM_CLASSES), + average=average, + ) + elif multidim_average == "global": + return _sk_fbeta_score_multilabel_global(preds, target, sk_fn, ignore_index, average) + return _sk_fbeta_score_multilabel_local(preds, target, sk_fn, ignore_index, average) + + +@pytest.mark.parametrize("input", _multilabel_cases) @pytest.mark.parametrize( - "metric_class, metric_functional, sk_fn", + "module, functional, compare", [ - (partial(FBetaScore, beta=2.0), partial(fbeta_score_pl, beta=2.0), partial(fbeta_score, beta=2.0)), - (F1Score, f1_score_pl, f1_score), + (MultilabelF1Score, multilabel_f1_score, sk_f1_score), + ( + partial(MultilabelFBetaScore, beta=2.0), + partial(multilabel_fbeta_score, beta=2.0), + partial(sk_fbeta_score, beta=2.0), + ), ], + ids=["f1", "fbeta"], ) -def test_same_input(metric_class, metric_functional, sk_fn, average, ignore_index): - preds = _input_miss_class.preds - target = _input_miss_class.target - preds_flat = torch.cat(list(preds), dim=0) - target_flat = torch.cat(list(target), dim=0) - - mc = metric_class(num_classes=NUM_CLASSES, average=average, ignore_index=ignore_index) - for i in range(NUM_BATCHES): - mc.update(preds[i], target[i]) - class_res = mc.compute() - func_res = metric_functional( - preds_flat, target_flat, num_classes=NUM_CLASSES, average=average, ignore_index=ignore_index - ) - sk_res = sk_fn(target_flat, preds_flat, average=average, zero_division=0) - - assert torch.allclose(class_res, torch.tensor(sk_res).float()) - assert torch.allclose(func_res, torch.tensor(sk_res).float()) +class TestMultilabelFBetaScore(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + def test_multilabel_fbeta_score( + self, ddp, input, module, functional, compare, ignore_index, multidim_average, average + ): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=module, + sk_metric=partial( + _sk_fbeta_score_multilabel, + sk_fn=compare, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + }, + ) + + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + def test_multilabel_fbeta_score_functional( + self, input, module, functional, compare, ignore_index, multidim_average, average + ): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") + + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=functional, + sk_metric=partial( + _sk_fbeta_score_multilabel, + sk_fn=compare, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + }, + ) + + def test_multilabel_fbeta_score_differentiability(self, input, module, functional, compare): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_fbeta_score_half_cpu(self, input, module, functional, compare, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_fbeta_score_half_gpu(self, input, module, functional, compare, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) diff --git a/tests/unittests/classification/test_hamming_distance.py b/tests/unittests/classification/test_hamming_distance.py index 11fab91d66b..e4c64704818 100644 --- a/tests/unittests/classification/test_hamming_distance.py +++ b/tests/unittests/classification/test_hamming_distance.py @@ -11,96 +11,476 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial + +import numpy as np import pytest +import torch +from scipy.special import expit as sigmoid +from sklearn.metrics import confusion_matrix as sk_confusion_matrix from sklearn.metrics import hamming_loss as sk_hamming_loss -from torchmetrics import HammingDistance -from torchmetrics.functional import hamming_distance -from torchmetrics.utilities.checks import _input_format_classification -from unittests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob -from unittests.classification.inputs import _input_multiclass as _input_mcls -from unittests.classification.inputs import _input_multiclass_logits as _input_mcls_logits -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass as _input_mdmc -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel as _input_mlb -from unittests.classification.inputs import _input_multilabel_logits as _input_mlb_logits -from unittests.classification.inputs import _input_multilabel_multidim as _input_mlmd -from unittests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from torchmetrics.classification.hamming import ( + BinaryHammingDistance, + MulticlassHammingDistance, + MultilabelHammingDistance, +) +from torchmetrics.functional.classification.hamming import ( + binary_hamming_distance, + multiclass_hamming_distance, + multilabel_hamming_distance, +) +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all -from unittests.helpers.testers import THRESHOLD, MetricTester +from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sk_hamming_loss(preds, target): - sk_preds, sk_target, _ = _input_format_classification(preds, target, threshold=THRESHOLD) - sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() - sk_preds, sk_target = sk_preds.reshape(sk_preds.shape[0], -1), sk_target.reshape(sk_target.shape[0], -1) - - return sk_hamming_loss(y_true=sk_target, y_pred=sk_preds) - - -@pytest.mark.parametrize( - "preds, target", - [ - (_input_binary_logits.preds, _input_binary_logits.target), - (_input_binary_prob.preds, _input_binary_prob.target), - (_input_binary.preds, _input_binary.target), - (_input_mlb_logits.preds, _input_mlb_logits.target), - (_input_mlb_prob.preds, _input_mlb_prob.target), - (_input_mlb.preds, _input_mlb.target), - (_input_mcls_logits.preds, _input_mcls_logits.target), - (_input_mcls_prob.preds, _input_mcls_prob.target), - (_input_mcls.preds, _input_mcls.target), - (_input_mdmc_prob.preds, _input_mdmc_prob.target), - (_input_mdmc.preds, _input_mdmc.target), - (_input_mlmd_prob.preds, _input_mlmd_prob.target), - (_input_mlmd.preds, _input_mlmd.target), - ], -) -class TestHammingDistance(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_hamming_distance_class(self, ddp, dist_sync_on_step, preds, target): +def _sk_hamming_loss(target, preds): + score = sk_hamming_loss(target, preds) + return score if not np.isnan(score) else 1.0 + + +def _sk_hamming_distance_binary(preds, target, ignore_index, multidim_average): + if multidim_average == "global": + preds = preds.view(-1).numpy() + target = target.view(-1).numpy() + else: + preds = preds.numpy() + target = target.numpy() + + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + + if multidim_average == "global": + target, preds = remove_ignore_index(target, preds, ignore_index) + return _sk_hamming_loss(target, preds) + else: + res = [] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + res.append(_sk_hamming_loss(true, pred)) + return np.stack(res) + + +@pytest.mark.parametrize("input", _binary_cases) +class TestBinaryHammingDistance(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("ddp", [False, True]) + def test_binary_hamming_distance(self, ddp, input, ignore_index, multidim_average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") + self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=HammingDistance, - sk_metric=_sk_hamming_loss, - dist_sync_on_step=dist_sync_on_step, - metric_args={"threshold": THRESHOLD}, + metric_class=BinaryHammingDistance, + sk_metric=partial( + _sk_hamming_distance_binary, ignore_index=ignore_index, multidim_average=multidim_average + ), + metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, ) - def test_hamming_distance_fn(self, preds, target): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + def test_binary_hamming_distance_functional(self, input, ignore_index, multidim_average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + self.run_functional_metric_test( preds=preds, target=target, - metric_functional=hamming_distance, - sk_metric=_sk_hamming_loss, - metric_args={"threshold": THRESHOLD}, + metric_functional=binary_hamming_distance, + sk_metric=partial( + _sk_hamming_distance_binary, ignore_index=ignore_index, multidim_average=multidim_average + ), + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + }, ) - def test_hamming_distance_differentiability(self, preds, target): + def test_binary_hamming_distance_differentiability(self, input): + preds, target = input self.run_differentiability_test( preds=preds, target=target, - metric_module=HammingDistance, - metric_functional=hamming_distance, + metric_module=BinaryHammingDistance, + metric_functional=binary_hamming_distance, + metric_args={"threshold": THRESHOLD}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_hamming_distance_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryHammingDistance, + metric_functional=binary_hamming_distance, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_hamming_distance_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryHammingDistance, + metric_functional=binary_hamming_distance, metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) + + +def _sk_hamming_distance_multiclass_global(preds, target, ignore_index, average): + preds = preds.numpy().flatten() + target = target.numpy().flatten() + target, preds = remove_ignore_index(target, preds, ignore_index) + if average == "micro": + return _sk_hamming_loss(target, preds) + confmat = sk_confusion_matrix(y_true=target, y_pred=preds, labels=list(range(NUM_CLASSES))) + hamming_per_class = 1 - confmat.diagonal() / confmat.sum(axis=1) + hamming_per_class[np.isnan(hamming_per_class)] = 1.0 + if average == "macro": + return hamming_per_class.mean() + elif average == "weighted": + weights = confmat.sum(1) + return ((weights * hamming_per_class) / weights.sum()).sum() + return hamming_per_class + + +def _sk_hamming_distance_multiclass_local(preds, target, ignore_index, average): + preds = preds.numpy() + target = target.numpy() + res = [] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + if average == "micro": + res.append(_sk_hamming_loss(true, pred)) + else: + confmat = sk_confusion_matrix(true, pred, labels=list(range(NUM_CLASSES))) + hamming_per_class = 1 - confmat.diagonal() / confmat.sum(axis=1) + hamming_per_class[np.isnan(hamming_per_class)] = 1.0 + if average == "macro": + res.append(hamming_per_class.mean()) + elif average == "weighted": + weights = confmat.sum(1) + score = ((weights * hamming_per_class) / weights.sum()).sum() + res.append(0.0 if np.isnan(score) else score) + else: + res.append(hamming_per_class) + return np.stack(res, 0) + + +def _sk_hamming_distance_multiclass(preds, target, ignore_index, multidim_average, average): + if preds.ndim == target.ndim + 1: + preds = torch.argmax(preds, 1) + if multidim_average == "global": + return _sk_hamming_distance_multiclass_global(preds, target, ignore_index, average) + return _sk_hamming_distance_multiclass_local(preds, target, ignore_index, average) + + +@pytest.mark.parametrize("input", _multiclass_cases) +class TestMulticlassHammingDistance(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_hamming_distance(self, ddp, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and target.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassHammingDistance, + sk_metric=partial( + _sk_hamming_distance_multiclass, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, + }, + ) + + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + def test_multiclass_hamming_distance_functional(self, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and target.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_hamming_distance, + sk_metric=partial( + _sk_hamming_distance_multiclass, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, + }, ) + def test_multiclass_hamming_distance_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassHammingDistance, + metric_functional=multiclass_hamming_distance, + metric_args={"num_classes": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_hamming_distance_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassHammingDistance, + metric_functional=multiclass_hamming_distance, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_hamming_distance_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassHammingDistance, + metric_functional=multiclass_hamming_distance, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + +def _sk_hamming_distance_multilabel_global(preds, target, ignore_index, average): + if average == "micro": + preds = preds.flatten() + target = target.flatten() + target, preds = remove_ignore_index(target, preds, ignore_index) + return _sk_hamming_loss(target, preds) + + hamming, weights = [], [] + for i in range(preds.shape[1]): + pred, true = preds[:, i].flatten(), target[:, i].flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) + hamming.append(_sk_hamming_loss(true, pred)) + weights.append(confmat[1, 1] + confmat[1, 0]) + res = np.stack(hamming, axis=0) + + if average == "macro": + return res.mean(0) + elif average == "weighted": + weights = np.stack(weights, 0).astype(float) + weights_norm = weights.sum(-1, keepdims=True) + weights_norm[weights_norm == 0] = 1.0 + return ((weights * res) / weights_norm).sum(-1) + elif average is None or average == "none": + return res + + +def _sk_hamming_distance_multilabel_local(preds, target, ignore_index, average): + hamming, weights = [], [] + for i in range(preds.shape[0]): + if average == "micro": + pred, true = preds[i].flatten(), target[i].flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + hamming.append(_sk_hamming_loss(true, pred)) + else: + scores, w = [], [] + for j in range(preds.shape[1]): + pred, true = preds[i, j], target[i, j] + true, pred = remove_ignore_index(true, pred, ignore_index) + scores.append(_sk_hamming_loss(true, pred)) + confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) + w.append(confmat[1, 1] + confmat[1, 0]) + hamming.append(np.stack(scores)) + weights.append(np.stack(w)) + if average == "micro": + return np.array(hamming) + res = np.stack(hamming, 0) + if average == "macro": + return res.mean(-1) + elif average == "weighted": + weights = np.stack(weights, 0).astype(float) + weights_norm = weights.sum(-1, keepdims=True) + weights_norm[weights_norm == 0] = 1.0 + return ((weights * res) / weights_norm).sum(-1) + elif average is None or average == "none": + return res + + +def _sk_hamming_distance_multilabel(preds, target, ignore_index, multidim_average, average): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + preds = preds.reshape(*preds.shape[:2], -1) + target = target.reshape(*target.shape[:2], -1) + + if multidim_average == "global": + return _sk_hamming_distance_multilabel_global(preds, target, ignore_index, average) + return _sk_hamming_distance_multilabel_local(preds, target, ignore_index, average) + + +@pytest.mark.parametrize("input", _multilabel_cases) +class TestMultilabelHammingDistance(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + def test_multilabel_hamming_distance(self, ddp, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MultilabelHammingDistance, + sk_metric=partial( + _sk_hamming_distance_multilabel, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + }, + ) -@pytest.mark.parametrize("threshold", [1.5]) -def test_wrong_params(threshold): - preds, target = _input_mcls_prob.preds, _input_mcls_prob.target + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + def test_multilabel_hamming_distance_functional(self, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") - with pytest.raises(ValueError): - ham_dist = HammingDistance(threshold=threshold) - ham_dist(preds, target) - ham_dist.compute() + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multilabel_hamming_distance, + sk_metric=partial( + _sk_hamming_distance_multilabel, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + }, + ) - with pytest.raises(ValueError): - hamming_distance(preds, target, threshold=threshold) + def test_multilabel_hamming_distance_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MultilabelHammingDistance, + metric_functional=multilabel_hamming_distance, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_hamming_distance_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelHammingDistance, + metric_functional=multilabel_hamming_distance, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_hamming_distance_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelHammingDistance, + metric_functional=multilabel_hamming_distance, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) diff --git a/tests/unittests/classification/test_hinge.py b/tests/unittests/classification/test_hinge.py index f495c14d36d..a0bf4904b55 100644 --- a/tests/unittests/classification/test_hinge.py +++ b/tests/unittests/classification/test_hinge.py @@ -16,141 +16,195 @@ import numpy as np import pytest import torch +from scipy.special import expit as sigmoid +from scipy.special import softmax from sklearn.metrics import hinge_loss as sk_hinge from sklearn.preprocessing import OneHotEncoder -from torchmetrics import HingeLoss -from torchmetrics.functional import hinge_loss -from torchmetrics.functional.classification.hinge import MulticlassMode -from unittests.classification.inputs import Input -from unittests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester +from torchmetrics.classification.hinge import BinaryHingeLoss, MulticlassHingeLoss +from torchmetrics.functional.classification.hinge import binary_hinge_loss, multiclass_hinge_loss +from unittests.classification.inputs import _binary_cases, _multiclass_cases +from unittests.helpers.testers import NUM_CLASSES, MetricTester, inject_ignore_index, remove_ignore_index torch.manual_seed(42) -_input_binary = Input( - preds=torch.randn(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) -) - -_input_binary_single = Input(preds=torch.randn((NUM_BATCHES, 1)), target=torch.randint(high=2, size=(NUM_BATCHES, 1))) - -_input_multiclass = Input( - preds=torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), -) +def _sk_binary_hinge_loss(preds, target, ignore_index): + preds = preds.numpy().flatten() + target = target.numpy().flatten() + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) -def _sk_hinge(preds, target, squared, multiclass_mode): - sk_preds, sk_target = preds.numpy(), target.numpy() - - if multiclass_mode == MulticlassMode.ONE_VS_ALL: - enc = OneHotEncoder() - enc.fit(sk_target.reshape(-1, 1)) - sk_target = enc.transform(sk_target.reshape(-1, 1)).toarray() - - if sk_preds.ndim == 1 or multiclass_mode == MulticlassMode.ONE_VS_ALL: - sk_target = 2 * sk_target - 1 - - if squared or sk_target.max() != 1 or sk_target.min() != -1: - # Squared not an option in sklearn and infers classes incorrectly with single element, so adapted from source - if sk_preds.ndim == 1 or multiclass_mode == MulticlassMode.ONE_VS_ALL: - margin = sk_target * sk_preds - else: - mask = np.ones_like(sk_preds, dtype=bool) - mask[np.arange(sk_target.shape[0]), sk_target] = False - margin = sk_preds[~mask] - margin -= np.max(sk_preds[mask].reshape(sk_target.shape[0], -1), axis=1) - measures = 1 - margin - measures = np.clip(measures, 0, None) - - if squared: - measures = measures**2 - return measures.mean(axis=0) - if multiclass_mode == MulticlassMode.ONE_VS_ALL: - result = np.zeros(sk_preds.shape[1]) - for i in range(result.shape[0]): - result[i] = sk_hinge(y_true=sk_target[:, i], pred_decision=sk_preds[:, i]) - return result - - return sk_hinge(y_true=sk_target, pred_decision=sk_preds) + target, preds = remove_ignore_index(target, preds, ignore_index) + target = 2 * target - 1 + return sk_hinge(target, preds) -@pytest.mark.parametrize( - "preds, target, squared, multiclass_mode", - [ - (_input_binary.preds, _input_binary.target, False, None), - (_input_binary.preds, _input_binary.target, True, None), - (_input_binary_single.preds, _input_binary_single.target, False, None), - (_input_binary_single.preds, _input_binary_single.target, True, None), - (_input_multiclass.preds, _input_multiclass.target, False, MulticlassMode.CRAMMER_SINGER), - (_input_multiclass.preds, _input_multiclass.target, True, MulticlassMode.CRAMMER_SINGER), - (_input_multiclass.preds, _input_multiclass.target, False, MulticlassMode.ONE_VS_ALL), - (_input_multiclass.preds, _input_multiclass.target, True, MulticlassMode.ONE_VS_ALL), - ], -) -class TestHinge(MetricTester): +@pytest.mark.parametrize("input", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5])) +class TestBinaryHingeLoss(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_hinge_class(self, ddp, dist_sync_on_step, preds, target, squared, multiclass_mode): + def test_binary_hinge_loss(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=HingeLoss, - sk_metric=partial(_sk_hinge, squared=squared, multiclass_mode=multiclass_mode), - dist_sync_on_step=dist_sync_on_step, + metric_class=BinaryHingeLoss, + sk_metric=partial(_sk_binary_hinge_loss, ignore_index=ignore_index), metric_args={ - "squared": squared, - "multiclass_mode": multiclass_mode, + "ignore_index": ignore_index, }, ) - def test_hinge_fn(self, preds, target, squared, multiclass_mode): + @pytest.mark.parametrize("ignore_index", [None, -1]) + def test_binary_hinge_loss_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_functional_metric_test( preds=preds, target=target, - metric_functional=partial(hinge_loss, squared=squared, multiclass_mode=multiclass_mode), - sk_metric=partial(_sk_hinge, squared=squared, multiclass_mode=multiclass_mode), + metric_functional=binary_hinge_loss, + sk_metric=partial(_sk_binary_hinge_loss, ignore_index=ignore_index), + metric_args={ + "ignore_index": ignore_index, + }, ) - def test_hinge_differentiability(self, preds, target, squared, multiclass_mode): + def test_binary_hinge_loss_differentiability(self, input): + preds, target = input self.run_differentiability_test( preds=preds, target=target, - metric_module=HingeLoss, - metric_functional=partial(hinge_loss, squared=squared, multiclass_mode=multiclass_mode), + metric_module=BinaryHingeLoss, + metric_functional=binary_hinge_loss, ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_hinge_loss_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half: + pytest.xfail(reason="torch.clamp does not support cpu + half") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryHingeLoss, + metric_functional=binary_hinge_loss, + dtype=dtype, + ) -_input_multi_target = Input(preds=torch.randn(BATCH_SIZE), target=torch.randint(high=2, size=(BATCH_SIZE, 2))) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_hinge_loss_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryHingeLoss, + metric_functional=binary_hinge_loss, + dtype=dtype, + ) -_input_binary_different_sizes = Input( - preds=torch.randn(BATCH_SIZE * 2), target=torch.randint(high=2, size=(BATCH_SIZE,)) -) -_input_multi_different_sizes = Input( - preds=torch.randn(BATCH_SIZE * 2, NUM_CLASSES), target=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE,)) -) +def _sk_multiclass_hinge_loss(preds, target, multiclass_mode, ignore_index): + preds = preds.numpy() + target = target.numpy().flatten() + if not ((0 < preds) & (preds < 1)).all(): + preds = softmax(preds, 1) + preds = np.moveaxis(preds, 1, -1).reshape((-1, preds.shape[1])) + target, preds = remove_ignore_index(target, preds, ignore_index) -_input_extra_dim = Input( - preds=torch.randn(BATCH_SIZE, NUM_CLASSES, 2), target=torch.randint(high=2, size=(BATCH_SIZE,)) -) + if multiclass_mode == "one-vs-all": + enc = OneHotEncoder() + enc.fit(target.reshape(-1, 1)) + target = enc.transform(target.reshape(-1, 1)).toarray() + target = 2 * target - 1 + result = np.zeros(preds.shape[1]) + for i in range(result.shape[0]): + result[i] = sk_hinge(y_true=target[:, i], pred_decision=preds[:, i]) + return result + else: + return sk_hinge(target, preds) @pytest.mark.parametrize( - "preds, target, multiclass_mode", - [ - (_input_multi_target.preds, _input_multi_target.target, None), - (_input_binary_different_sizes.preds, _input_binary_different_sizes.target, None), - (_input_multi_different_sizes.preds, _input_multi_different_sizes.target, None), - (_input_extra_dim.preds, _input_extra_dim.target, None), - (_input_multiclass.preds[0], _input_multiclass.target[0], "invalid_mode"), - ], + "input", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5]) ) -def test_bad_inputs_fn(preds, target, multiclass_mode): - with pytest.raises(ValueError): - _ = hinge_loss(preds, target, multiclass_mode=multiclass_mode) +class TestMulticlassHingeLoss(MetricTester): + @pytest.mark.parametrize("multiclass_mode", ["crammer-singer", "one-vs-all"]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_hinge_loss(self, input, ddp, multiclass_mode, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassHingeLoss, + sk_metric=partial(_sk_multiclass_hinge_loss, multiclass_mode=multiclass_mode, ignore_index=ignore_index), + metric_args={ + "num_classes": NUM_CLASSES, + "multiclass_mode": multiclass_mode, + "ignore_index": ignore_index, + }, + ) + @pytest.mark.parametrize("multiclass_mode", ["crammer-singer", "one-vs-all"]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + def test_multiclass_hinge_loss_functional(self, input, multiclass_mode, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_hinge_loss, + sk_metric=partial(_sk_multiclass_hinge_loss, multiclass_mode=multiclass_mode, ignore_index=ignore_index), + metric_args={ + "num_classes": NUM_CLASSES, + "multiclass_mode": multiclass_mode, + "ignore_index": ignore_index, + }, + ) -def test_bad_inputs_class(): - with pytest.raises(ValueError): - HingeLoss(multiclass_mode="invalid_mode") + def test_multiclass_hinge_loss_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassHingeLoss, + metric_functional=multiclass_hinge_loss, + metric_args={"num_classes": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_hinge_loss_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half: + pytest.xfail(reason="torch.clamp does not support cpu + half") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassHingeLoss, + metric_functional=multiclass_hinge_loss, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_hinge_loss_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassHingeLoss, + metric_functional=multiclass_hinge_loss, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) diff --git a/tests/unittests/classification/test_jaccard.py b/tests/unittests/classification/test_jaccard.py index 945001a9220..4b6ad133211 100644 --- a/tests/unittests/classification/test_jaccard.py +++ b/tests/unittests/classification/test_jaccard.py @@ -16,224 +16,302 @@ import numpy as np import pytest import torch -from sklearn.metrics import jaccard_score as sk_jaccard_score -from torch import Tensor, tensor - -from torchmetrics.classification.jaccard import JaccardIndex -from torchmetrics.functional import jaccard_index -from unittests.classification.inputs import _input_binary, _input_binary_prob -from unittests.classification.inputs import _input_multiclass as _input_mcls -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass as _input_mdmc -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel as _input_mlb -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob -from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester - - -def _sk_jaccard_binary_prob(preds, target, average=None): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() - - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) - - -def _sk_jaccard_binary(preds, target, average=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) - - -def _sk_jaccard_multilabel_prob(preds, target, average=None): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() +from scipy.special import expit as sigmoid +from sklearn.metrics import confusion_matrix as sk_confusion_matrix +from sklearn.metrics import jaccard_score as sk_jaccard_index + +from torchmetrics.classification.jaccard import BinaryJaccardIndex, MulticlassJaccardIndex, MultilabelJaccardIndex +from torchmetrics.functional.classification.jaccard import ( + binary_jaccard_index, + multiclass_jaccard_index, + multilabel_jaccard_index, +) +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index, remove_ignore_index - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) +def _sk_jaccard_index_binary(preds, target, ignore_index=None): + preds = preds.view(-1).numpy() + target = target.view(-1).numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_jaccard_index(y_true=target, y_pred=preds) -def _sk_jaccard_multilabel(preds, target, average=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) +@pytest.mark.parametrize("input", _binary_cases) +class TestBinaryJaccardIndex(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_binary_jaccard_index(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryJaccardIndex, + sk_metric=partial(_sk_jaccard_index_binary, ignore_index=ignore_index), + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + }, + ) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_binary_jaccard_index_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_jaccard_index, + sk_metric=partial(_sk_jaccard_index_binary, ignore_index=ignore_index), + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + }, + ) -def _sk_jaccard_multiclass_prob(preds, target, average=None): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() - sk_target = target.view(-1).numpy() + def test_binary_jaccard_index_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryJaccardIndex, + metric_functional=binary_jaccard_index, + metric_args={"threshold": THRESHOLD}, + ) - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_jaccard_index_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryJaccardIndex, + metric_functional=binary_jaccard_index, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_jaccard_index_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryJaccardIndex, + metric_functional=binary_jaccard_index, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) -def _sk_jaccard_multiclass(preds, target, average=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) +def _sk_jaccard_index_multiclass(preds, target, ignore_index=None, average="macro"): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + preds = np.argmax(preds, axis=1) + preds = preds.flatten() + target = target.flatten() + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_jaccard_index(y_true=target, y_pred=preds, average=average) -def _sk_jaccard_multidim_multiclass_prob(preds, target, average=None): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() - sk_target = target.view(-1).numpy() +@pytest.mark.parametrize("input", _multiclass_cases) +class TestMulticlassJaccardIndex(MetricTester): + @pytest.mark.parametrize("average", ["macro", "micro", "weighted", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_jaccard_index(self, input, ddp, ignore_index, average): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassJaccardIndex, + sk_metric=partial(_sk_jaccard_index_multiclass, ignore_index=ignore_index, average=average), + metric_args={ + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + "average": average, + }, + ) - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + @pytest.mark.parametrize("average", ["macro", "micro", "weighted", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multiclass_jaccard_index_functional(self, input, ignore_index, average): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_jaccard_index, + sk_metric=partial(_sk_jaccard_index_multiclass, ignore_index=ignore_index, average=average), + metric_args={ + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + "average": average, + }, + ) + def test_multiclass_jaccard_index_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassJaccardIndex, + metric_functional=multiclass_jaccard_index, + metric_args={"num_classes": NUM_CLASSES}, + ) -def _sk_jaccard_multidim_multiclass(preds, target, average=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_jaccard_index_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassJaccardIndex, + metric_functional=multiclass_jaccard_index, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_jaccard_index_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassJaccardIndex, + metric_functional=multiclass_jaccard_index, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) -@pytest.mark.parametrize("average", [None, "macro", "micro", "weighted"]) -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", - [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_jaccard_binary_prob, 2), - (_input_binary.preds, _input_binary.target, _sk_jaccard_binary, 2), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_jaccard_multilabel_prob, 2), - (_input_mlb.preds, _input_mlb.target, _sk_jaccard_multilabel, 2), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_jaccard_multiclass_prob, NUM_CLASSES), - (_input_mcls.preds, _input_mcls.target, _sk_jaccard_multiclass, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_jaccard_multidim_multiclass_prob, NUM_CLASSES), - (_input_mdmc.preds, _input_mdmc.target, _sk_jaccard_multidim_multiclass, NUM_CLASSES), - ], -) -class TestJaccardIndex(MetricTester): +def _sk_jaccard_index_multilabel(preds, target, ignore_index=None, average="macro"): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + preds = np.moveaxis(preds, 1, -1).reshape((-1, preds.shape[1])) + target = np.moveaxis(target, 1, -1).reshape((-1, target.shape[1])) + if ignore_index is None: + return sk_jaccard_index(y_true=target, y_pred=preds, average=average) + else: + if average == "micro": + return _sk_jaccard_index_binary(torch.tensor(preds), torch.tensor(target), ignore_index) + scores, weights = [], [] + for i in range(preds.shape[1]): + pred, true = preds[:, i], target[:, i] + true, pred = remove_ignore_index(true, pred, ignore_index) + confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) + scores.append(sk_jaccard_index(true, pred)) + weights.append(confmat[1, 0] + confmat[1, 1]) + scores = np.stack(scores, axis=0) + weights = np.stack(weights, axis=0) + if average is None or average == "none": + return scores + elif average == "macro": + return scores.mean() + return ((scores * weights) / weights.sum()).sum() + + +@pytest.mark.parametrize("input", _multilabel_cases) +class TestMultilabelJaccardIndex(MetricTester): + @pytest.mark.parametrize("average", ["macro", "micro", "weighted", None]) + @pytest.mark.parametrize("ignore_index", [None]) # , -1, 0]) @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_jaccard(self, average, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): - # average = "macro" if reduction == "elementwise_mean" else None # convert tags + def test_multilabel_jaccard_index(self, input, ddp, ignore_index, average): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=JaccardIndex, - sk_metric=partial(sk_metric, average=average), - dist_sync_on_step=dist_sync_on_step, - metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "average": average}, + metric_class=MultilabelJaccardIndex, + sk_metric=partial(_sk_jaccard_index_multilabel, ignore_index=ignore_index, average=average), + metric_args={ + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, + "average": average, + }, ) - def test_jaccard_functional(self, average, preds, target, sk_metric, num_classes): - # average = "macro" if reduction == "elementwise_mean" else None # convert tags + @pytest.mark.parametrize("average", ["macro", "micro", "weighted", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multilabel_jaccard_index_functional(self, input, ignore_index, average): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_functional_metric_test( - preds, - target, - metric_functional=jaccard_index, - sk_metric=partial(sk_metric, average=average), - metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "average": average}, + preds=preds, + target=target, + metric_functional=multilabel_jaccard_index, + sk_metric=partial(_sk_jaccard_index_multilabel, ignore_index=ignore_index, average=average), + metric_args={ + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, + "average": average, + }, ) - def test_jaccard_differentiability(self, average, preds, target, sk_metric, num_classes): + def test_multilabel_jaccard_index_differentiability(self, input): + preds, target = input self.run_differentiability_test( preds=preds, target=target, - metric_module=JaccardIndex, - metric_functional=jaccard_index, - metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "average": average}, + metric_module=MultilabelJaccardIndex, + metric_functional=multilabel_jaccard_index, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_jaccard_index_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelJaccardIndex, + metric_functional=multilabel_jaccard_index, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) -@pytest.mark.parametrize( - ["half_ones", "average", "ignore_index", "expected"], - [ - (False, "none", None, Tensor([1, 1, 1])), - (False, "macro", None, Tensor([1])), - (False, "none", 0, Tensor([1, 1])), - (True, "none", None, Tensor([0.5, 0.5, 0.5])), - (True, "macro", None, Tensor([0.5])), - (True, "none", 0, Tensor([2 / 3, 1 / 2])), - ], -) -def test_jaccard(half_ones, average, ignore_index, expected): - preds = (torch.arange(120) % 3).view(-1, 1) - target = (torch.arange(120) % 3).view(-1, 1) - if half_ones: - preds[:60] = 1 - jaccard_val = jaccard_index( - preds=preds, - target=target, - average=average, - num_classes=3, - ignore_index=ignore_index, - # reduction=reduction, - ) - assert torch.allclose(jaccard_val, expected, atol=1e-9) - - -# test `absent_score` -@pytest.mark.parametrize( - ["pred", "target", "ignore_index", "absent_score", "num_classes", "expected"], - [ - # Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid - # scores the function can return ([0., 1.] range, inclusive). - # 2 classes, class 0 is correct everywhere, class 1 is absent. - ([0], [0], None, -1.0, 2, [1.0, -1.0]), - ([0, 0], [0, 0], None, -1.0, 2, [1.0, -1.0]), - # absent_score not applied if only class 0 is present and it's the only class. - ([0], [0], None, -1.0, 1, [1.0]), - # 2 classes, class 1 is correct everywhere, class 0 is absent. - ([1], [1], None, -1.0, 2, [-1.0, 1.0]), - ([1, 1], [1, 1], None, -1.0, 2, [-1.0, 1.0]), - # When 0 index ignored, class 0 does not get a score (not even the absent_score). - ([1], [1], 0, -1.0, 2, [1.0]), - # 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score. - ([0, 2], [0, 2], None, -1.0, 3, [1.0, -1.0, 1.0]), - ([2, 0], [2, 0], None, -1.0, 3, [1.0, -1.0, 1.0]), - # 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score. - ([0, 1], [0, 1], None, -1.0, 3, [1.0, 1.0, -1.0]), - ([1, 0], [1, 0], None, -1.0, 3, [1.0, 1.0, -1.0]), - # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class - # 2 is absent. - ([0, 1], [0, 0], None, -1.0, 3, [0.5, 0.0, -1.0]), - # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class - # 2 is absent. - ([0, 0], [0, 1], None, -1.0, 3, [0.5, 0.0, -1.0]), - # Sanity checks with absent_score of 1.0. - ([0, 2], [0, 2], None, 1.0, 3, [1.0, 1.0, 1.0]), - ([0, 2], [0, 2], 0, 1.0, 3, [1.0, 1.0]), - ], -) -def test_jaccard_absent_score(pred, target, ignore_index, absent_score, num_classes, expected): - jaccard_val = jaccard_index( - preds=tensor(pred), - target=tensor(target), - average=None, - ignore_index=ignore_index, - absent_score=absent_score, - num_classes=num_classes, - # reduction="none", - ) - assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val)) - - -# example data taken from -# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py -@pytest.mark.parametrize( - ["pred", "target", "ignore_index", "num_classes", "average", "expected"], - [ - # Ignoring an index outside of [0, num_classes-1] should have no effect. - ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, "none", [1, 1 / 2, 2 / 3]), - ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, "none", [1, 1 / 2, 2 / 3]), - ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, "none", [1, 1 / 2, 2 / 3]), - # Ignoring a valid index drops only that index from the result. - ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "none", [1 / 2, 2 / 3]), - ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, "none", [1, 2 / 3]), - ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, "none", [1, 1]), - # When reducing to mean or sum, the ignored index does not contribute to the output. - ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "macro", [7 / 12]), - # ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "sum", [7 / 6]), - ], -) -def test_jaccard_ignore_index(pred, target, ignore_index, num_classes, average, expected): - jaccard_val = jaccard_index( - preds=tensor(pred), - target=tensor(target), - average=average, - ignore_index=ignore_index, - num_classes=num_classes, - # reduction=reduction, - ) - assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val)) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_jaccard_index_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelJaccardIndex, + metric_functional=multilabel_jaccard_index, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) diff --git a/tests/unittests/classification/test_matthews_corrcoef.py b/tests/unittests/classification/test_matthews_corrcoef.py index 2502fb25b8a..93b3371c0f4 100644 --- a/tests/unittests/classification/test_matthews_corrcoef.py +++ b/tests/unittests/classification/test_matthews_corrcoef.py @@ -11,139 +11,293 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial + import numpy as np import pytest import torch +from scipy.special import expit as sigmoid from sklearn.metrics import matthews_corrcoef as sk_matthews_corrcoef -from torchmetrics.classification.matthews_corrcoef import MatthewsCorrCoef -from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef -from unittests.classification.inputs import _input_binary, _input_binary_prob -from unittests.classification.inputs import _input_multiclass as _input_mcls -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass as _input_mdmc -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel as _input_mlb -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from torchmetrics.classification.matthews_corrcoef import ( + BinaryMatthewsCorrCoef, + MulticlassMatthewsCorrCoef, + MultilabelMatthewsCorrCoef, +) +from torchmetrics.functional.classification.matthews_corrcoef import ( + binary_matthews_corrcoef, + multiclass_matthews_corrcoef, + multilabel_matthews_corrcoef, +) +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all -from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester +from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sk_matthews_corrcoef_binary_prob(preds, target): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() - - return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) - +def _sk_matthews_corrcoef_binary(preds, target, ignore_index=None): + preds = preds.view(-1).numpy() + target = target.view(-1).numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_matthews_corrcoef(y_true=target, y_pred=preds) -def _sk_matthews_corrcoef_binary(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) - - -def _sk_matthews_corrcoef_multilabel_prob(preds, target): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() - - return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) +@pytest.mark.parametrize("input", _binary_cases) +class TestBinaryMatthewsCorrCoef(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_binary_matthews_corrcoef(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryMatthewsCorrCoef, + sk_metric=partial(_sk_matthews_corrcoef_binary, ignore_index=ignore_index), + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + }, + ) -def _sk_matthews_corrcoef_multilabel(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_binary_matthews_corrcoef_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_matthews_corrcoef, + sk_metric=partial(_sk_matthews_corrcoef_binary, ignore_index=ignore_index), + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + }, + ) - return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) + def test_binary_matthews_corrcoef_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryMatthewsCorrCoef, + metric_functional=binary_matthews_corrcoef, + metric_args={"threshold": THRESHOLD}, + ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_matthews_corrcoef_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryMatthewsCorrCoef, + metric_functional=binary_matthews_corrcoef, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) -def _sk_matthews_corrcoef_multiclass_prob(preds, target): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() - sk_target = target.view(-1).numpy() + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_matthews_corrcoef_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryMatthewsCorrCoef, + metric_functional=binary_matthews_corrcoef, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) - return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) +def _sk_matthews_corrcoef_multiclass(preds, target, ignore_index=None): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + preds = np.argmax(preds, axis=1) + preds = preds.flatten() + target = target.flatten() + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_matthews_corrcoef(y_true=target, y_pred=preds) -def _sk_matthews_corrcoef_multiclass(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) +@pytest.mark.parametrize("input", _multiclass_cases) +class TestMulticlassMatthewsCorrCoef(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_matthews_corrcoef(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassMatthewsCorrCoef, + sk_metric=partial(_sk_matthews_corrcoef_multiclass, ignore_index=ignore_index), + metric_args={ + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multiclass_matthews_corrcoef_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_matthews_corrcoef, + sk_metric=partial(_sk_matthews_corrcoef_multiclass, ignore_index=ignore_index), + metric_args={ + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) -def _sk_matthews_corrcoef_multidim_multiclass_prob(preds, target): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() - sk_target = target.view(-1).numpy() + def test_multiclass_matthews_corrcoef_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassMatthewsCorrCoef, + metric_functional=multiclass_matthews_corrcoef, + metric_args={"num_classes": NUM_CLASSES}, + ) - return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_matthews_corrcoef_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassMatthewsCorrCoef, + metric_functional=multiclass_matthews_corrcoef, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_matthews_corrcoef_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassMatthewsCorrCoef, + metric_functional=multiclass_matthews_corrcoef, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) -def _sk_matthews_corrcoef_multidim_multiclass(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) +def _sk_matthews_corrcoef_multilabel(preds, target, ignore_index=None): + preds = preds.view(-1).numpy() + target = target.view(-1).numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_matthews_corrcoef(y_true=target, y_pred=preds) -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", - [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_matthews_corrcoef_binary_prob, 2), - (_input_binary.preds, _input_binary.target, _sk_matthews_corrcoef_binary, 2), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_matthews_corrcoef_multilabel_prob, 2), - (_input_mlb.preds, _input_mlb.target, _sk_matthews_corrcoef_multilabel, 2), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_matthews_corrcoef_multiclass_prob, NUM_CLASSES), - (_input_mcls.preds, _input_mcls.target, _sk_matthews_corrcoef_multiclass, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_matthews_corrcoef_multidim_multiclass_prob, NUM_CLASSES), - (_input_mdmc.preds, _input_mdmc.target, _sk_matthews_corrcoef_multidim_multiclass, NUM_CLASSES), - ], -) -class TestMatthewsCorrCoef(MetricTester): +@pytest.mark.parametrize("input", _multilabel_cases) +class TestMultilabelMatthewsCorrCoef(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_matthews_corrcoef(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + def test_multilabel_matthews_corrcoef(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=MatthewsCorrCoef, - sk_metric=sk_metric, - dist_sync_on_step=dist_sync_on_step, + metric_class=MultilabelMatthewsCorrCoef, + sk_metric=partial(_sk_matthews_corrcoef_multilabel, ignore_index=ignore_index), metric_args={ - "num_classes": num_classes, - "threshold": THRESHOLD, + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, }, ) - def test_matthews_corrcoef_functional(self, preds, target, sk_metric, num_classes): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multilabel_matthews_corrcoef_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_functional_metric_test( - preds, - target, - metric_functional=matthews_corrcoef, - sk_metric=sk_metric, + preds=preds, + target=target, + metric_functional=multilabel_matthews_corrcoef, + sk_metric=partial(_sk_matthews_corrcoef_multilabel, ignore_index=ignore_index), metric_args={ - "num_classes": num_classes, - "threshold": THRESHOLD, + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, }, ) - def test_matthews_corrcoef_differentiability(self, preds, target, sk_metric, num_classes): + def test_multilabel_matthews_corrcoef_differentiability(self, input): + preds, target = input self.run_differentiability_test( preds=preds, target=target, - metric_module=MatthewsCorrCoef, - metric_functional=matthews_corrcoef, - metric_args={ - "num_classes": num_classes, - "threshold": THRESHOLD, - }, + metric_module=MultilabelMatthewsCorrCoef, + metric_functional=multilabel_matthews_corrcoef, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_matthews_corrcoef_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelMatthewsCorrCoef, + metric_functional=multilabel_matthews_corrcoef, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_matthews_corrcoef_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelMatthewsCorrCoef, + metric_functional=multilabel_matthews_corrcoef, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, ) -def test_zero_case(): +def test_zero_case_in_multiclass(): """Cases where the denominator in the matthews corrcoef is 0, the score should return 0.""" # Example where neither 1 or 2 is present in the target tensor - out = matthews_corrcoef(torch.tensor([0, 1, 2]), torch.tensor([0, 0, 0]), 3) + out = multiclass_matthews_corrcoef(torch.tensor([0, 1, 2]), torch.tensor([0, 0, 0]), 3) assert out == 0.0 diff --git a/tests/unittests/classification/test_precision_recall.py b/tests/unittests/classification/test_precision_recall.py index 32af07430d1..75a2873c7fd 100644 --- a/tests/unittests/classification/test_precision_recall.py +++ b/tests/unittests/classification/test_precision_recall.py @@ -12,383 +12,312 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Callable, Optional import numpy as np import pytest import torch -from sklearn.metrics import precision_score, recall_score +from scipy.special import expit as sigmoid +from sklearn.metrics import confusion_matrix as sk_confusion_matrix +from sklearn.metrics import precision_score as sk_precision_score +from sklearn.metrics import recall_score as sk_recall_score from torch import Tensor, tensor -from torchmetrics import Metric, Precision, Recall -from torchmetrics.functional import precision, precision_recall, recall -from torchmetrics.utilities.checks import _input_format_classification -from torchmetrics.utilities.enums import AverageMethod -from unittests.classification import MetricWrapper -from unittests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob -from unittests.classification.inputs import _input_multiclass as _input_mcls -from unittests.classification.inputs import _input_multiclass_logits as _input_mcls_logits -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class -from unittests.classification.inputs import _input_multidim_multiclass as _input_mdmc -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel as _input_mlb -from unittests.classification.inputs import _input_multilabel_logits as _input_mlb_logits -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob -from unittests.classification.inputs import _negmetric_noneavg +from torchmetrics.classification.precision_recall import ( + BinaryPrecision, + BinaryRecall, + MulticlassPrecision, + MulticlassRecall, + MultilabelPrecision, + MultilabelRecall, +) +from torchmetrics.functional.classification.precision_recall import ( + binary_precision, + binary_recall, + multiclass_precision, + multiclass_recall, + multilabel_precision, + multilabel_recall, +) +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all -from unittests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester +from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sk_prec_recall(preds, target, sk_fn, num_classes, average, multiclass, ignore_index, mdmc_average=None): - # todo: `mdmc_average` is unused - if average == "none": - average = None - if num_classes == 1: - average = "binary" - - labels = list(range(num_classes)) - try: - labels.remove(ignore_index) - except ValueError: - pass - - sk_preds, sk_target, _ = _input_format_classification( - preds, target, THRESHOLD, num_classes=num_classes, multiclass=multiclass - ) - sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() - - sk_scores = sk_fn(sk_target, sk_preds, average=average, zero_division=0, labels=labels) - - if len(labels) != num_classes and not average: - sk_scores = np.insert(sk_scores, ignore_index, np.nan) - - return sk_scores - - -def _sk_prec_recall_multidim_multiclass( - preds, target, sk_fn, num_classes, average, multiclass, ignore_index, mdmc_average -): - preds, target, _ = _input_format_classification( - preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass - ) - - if mdmc_average == "global": - preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) - target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) - - return _sk_prec_recall(preds, target, sk_fn, num_classes, average, False, ignore_index) - if mdmc_average == "samplewise": - scores = [] - - for i in range(preds.shape[0]): - pred_i = preds[i, ...].T - target_i = target[i, ...].T - scores_i = _sk_prec_recall(pred_i, target_i, sk_fn, num_classes, average, False, ignore_index) +def _sk_precision_recall_binary(preds, target, sk_fn, ignore_index, multidim_average): + if multidim_average == "global": + preds = preds.view(-1).numpy() + target = target.view(-1).numpy() + else: + preds = preds.numpy() + target = target.numpy() - scores.append(np.expand_dims(scores_i, 0)) + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) - return np.concatenate(scores).mean(axis=0) + if multidim_average == "global": + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_fn(target, preds) + else: + res = [] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + res.append(sk_fn(true, pred)) + return np.stack(res) -@pytest.mark.parametrize("metric, fn_metric", [(Precision, precision), (Recall, recall)]) +@pytest.mark.parametrize("input", _binary_cases) @pytest.mark.parametrize( - "average, mdmc_average, num_classes, ignore_index, match_str", + "module, functional, compare", [ - ("wrong", None, None, None, "`average`"), - ("micro", "wrong", None, None, "`mdmc"), - ("macro", None, None, None, "number of classes"), - ("macro", None, 1, 0, "ignore_index"), + (BinaryPrecision, binary_precision, sk_precision_score), + (BinaryRecall, binary_recall, sk_recall_score), ], + ids=["precision", "recall"], ) -def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ignore_index, match_str): - with pytest.raises(ValueError, match=match_str): - metric( - average=average, - mdmc_average=mdmc_average, - num_classes=num_classes, - ignore_index=ignore_index, - ) - - with pytest.raises(ValueError, match=match_str): - fn_metric( - _input_binary.preds[0], - _input_binary.target[0], - average=average, - mdmc_average=mdmc_average, - num_classes=num_classes, - ignore_index=ignore_index, - ) +class TestBinaryPrecisionRecall(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("ddp", [False, True]) + def test_binary_precision_recall(self, ddp, input, module, functional, compare, ignore_index, multidim_average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") - with pytest.raises(ValueError, match=match_str): - precision_recall( - _input_binary.preds[0], - _input_binary.target[0], - average=average, - mdmc_average=mdmc_average, - num_classes=num_classes, - ignore_index=ignore_index, + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=module, + sk_metric=partial( + _sk_precision_recall_binary, sk_fn=compare, ignore_index=ignore_index, multidim_average=multidim_average + ), + metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, ) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + def test_binary_precision_recall_functional( + self, input, module, functional, compare, ignore_index, multidim_average + ): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") -@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)]) -def test_zero_division(metric_class, metric_fn): - """Test that zero_division works correctly (currently should just set to 0).""" - - preds = tensor([0, 2, 1, 1]) - target = tensor([2, 1, 2, 1]) - - cl_metric = metric_class(average="none", num_classes=3) - cl_metric(preds, target) - - result_cl = cl_metric.compute() - result_fn = metric_fn(preds, target, average="none", num_classes=3) - - assert result_cl[0] == result_fn[0] == 0 - - -@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)]) -def test_no_support(metric_class, metric_fn): - """This tests a rare edge case, where there is only one class present. - - in target, and ignore_index is set to exactly that class - and the - average method is equal to 'weighted'. - - This would mean that the sum of weights equals zero, and would, without - taking care of this case, return NaN. However, the reduction function - should catch that and set the metric to equal the value of zero_division - in this case (zero_division is for now not configurable and equals 0). - """ - - preds = tensor([1, 1, 0, 0]) - target = tensor([0, 0, 0, 0]) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=functional, + sk_metric=partial( + _sk_precision_recall_binary, sk_fn=compare, ignore_index=ignore_index, multidim_average=multidim_average + ), + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + }, + ) - cl_metric = metric_class(average="weighted", num_classes=2, ignore_index=0) - cl_metric(preds, target) + def test_binary_precision_recall_differentiability(self, input, module, functional, compare): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"threshold": THRESHOLD}, + ) - result_cl = cl_metric.compute() - result_fn = metric_fn(preds, target, average="weighted", num_classes=2, ignore_index=0) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_precision_recall_half_cpu(self, input, module, functional, compare, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) - assert result_cl == result_fn == 0 + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_precision_recall_half_gpu(self, input, module, functional, compare, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) +def _sk_precision_recall_multiclass(preds, target, sk_fn, ignore_index, multidim_average, average): + if preds.ndim == target.ndim + 1: + preds = torch.argmax(preds, 1) + if multidim_average == "global": + preds = preds.numpy().flatten() + target = target.numpy().flatten() + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_fn(target, preds, average=average) + else: + preds = preds.numpy() + target = target.numpy() + res = [] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + res.append(sk_fn(true, pred, average=average, labels=list(range(NUM_CLASSES)))) + return np.stack(res, 0) + + +@pytest.mark.parametrize("input", _multiclass_cases) @pytest.mark.parametrize( - "metric_class, metric_fn, sk_fn", [(Recall, recall, recall_score), (Precision, precision, precision_score)] -) -@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) -@pytest.mark.parametrize("ignore_index", [None, 0]) -@pytest.mark.parametrize( - "preds, target, num_classes, multiclass, mdmc_average, sk_wrapper", + "module, functional, compare", [ - (_input_binary_logits.preds, _input_binary_logits.target, 1, None, None, _sk_prec_recall), - (_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_prec_recall), - (_input_binary.preds, _input_binary.target, 1, False, None, _sk_prec_recall), - (_input_mlb_logits.preds, _input_mlb_logits.target, NUM_CLASSES, None, None, _sk_prec_recall), - (_input_mlb_prob.preds, _input_mlb_prob.target, NUM_CLASSES, None, None, _sk_prec_recall), - (_input_mlb.preds, _input_mlb.target, NUM_CLASSES, False, None, _sk_prec_recall), - (_input_mcls_logits.preds, _input_mcls_logits.target, NUM_CLASSES, None, None, _sk_prec_recall), - (_input_mcls_prob.preds, _input_mcls_prob.target, NUM_CLASSES, None, None, _sk_prec_recall), - (_input_mcls.preds, _input_mcls.target, NUM_CLASSES, None, None, _sk_prec_recall), - (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "global", _sk_prec_recall_multidim_multiclass), - ( - _input_mdmc_prob.preds, - _input_mdmc_prob.target, - NUM_CLASSES, - None, - "global", - _sk_prec_recall_multidim_multiclass, - ), - (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "samplewise", _sk_prec_recall_multidim_multiclass), - ( - _input_mdmc_prob.preds, - _input_mdmc_prob.target, - NUM_CLASSES, - None, - "samplewise", - _sk_prec_recall_multidim_multiclass, - ), + (MulticlassPrecision, multiclass_precision, sk_precision_score), + (MulticlassRecall, multiclass_recall, sk_recall_score), ], + ids=["precision", "recall"], ) -class TestPrecisionRecall(MetricTester): - @pytest.mark.parametrize("ddp", [False, True]) - @pytest.mark.parametrize("dist_sync_on_step", [False]) - def test_precision_recall_class( - self, - ddp: bool, - dist_sync_on_step: bool, - preds: Tensor, - target: Tensor, - sk_wrapper: Callable, - metric_class: Metric, - metric_fn: Callable, - sk_fn: Callable, - multiclass: Optional[bool], - num_classes: Optional[int], - average: str, - mdmc_average: Optional[str], - ignore_index: Optional[int], +class TestMulticlassPrecisionRecall(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_precision_recall( + self, ddp, input, module, functional, compare, ignore_index, multidim_average, average ): - # todo: `metric_fn` is unused - if num_classes == 1 and average != "micro": - pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - - if average == "weighted" and ignore_index is not None and mdmc_average is not None: - pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and target.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=metric_class, + metric_class=module, sk_metric=partial( - sk_wrapper, - sk_fn=sk_fn, - average=average, - num_classes=num_classes, - multiclass=multiclass, + _sk_precision_recall_multiclass, + sk_fn=compare, ignore_index=ignore_index, - mdmc_average=mdmc_average, + multidim_average=multidim_average, + average=average, ), - dist_sync_on_step=dist_sync_on_step, metric_args={ - "num_classes": num_classes, - "average": average, - "threshold": THRESHOLD, - "multiclass": multiclass, "ignore_index": ignore_index, - "mdmc_average": mdmc_average, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, }, ) - def test_precision_recall_fn( - self, - preds: Tensor, - target: Tensor, - sk_wrapper: Callable, - metric_class: Metric, - metric_fn: Callable, - sk_fn: Callable, - multiclass: Optional[bool], - num_classes: Optional[int], - average: str, - mdmc_average: Optional[str], - ignore_index: Optional[int], + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + def test_multiclass_precision_recall_functional( + self, input, module, functional, compare, ignore_index, multidim_average, average ): - # todo: `metric_class` is unused - if num_classes == 1 and average != "micro": - pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - - if average == "weighted" and ignore_index is not None and mdmc_average is not None: - pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and target.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") self.run_functional_metric_test( - preds, - target, - metric_functional=metric_fn, + preds=preds, + target=target, + metric_functional=functional, sk_metric=partial( - sk_wrapper, - sk_fn=sk_fn, - average=average, - num_classes=num_classes, - multiclass=multiclass, + _sk_precision_recall_multiclass, + sk_fn=compare, ignore_index=ignore_index, - mdmc_average=mdmc_average, + multidim_average=multidim_average, + average=average, ), metric_args={ - "num_classes": num_classes, - "average": average, - "threshold": THRESHOLD, - "multiclass": multiclass, "ignore_index": ignore_index, - "mdmc_average": mdmc_average, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, }, ) - def test_precision_recall_differentiability( - self, - preds: Tensor, - target: Tensor, - sk_wrapper: Callable, - metric_class: Metric, - metric_fn: Callable, - sk_fn: Callable, - multiclass: Optional[bool], - num_classes: Optional[int], - average: str, - mdmc_average: Optional[str], - ignore_index: Optional[int], - ): - # todo: `metric_class` is unused - if num_classes == 1 and average != "micro": - pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - - if average == "weighted" and ignore_index is not None and mdmc_average is not None: - pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") - + def test_multiclass_precision_recall_differentiability(self, input, module, functional, compare): + preds, target = input self.run_differentiability_test( preds=preds, target=target, - metric_module=metric_class, - metric_functional=metric_fn, - metric_args={ - "num_classes": num_classes, - "average": average, - "threshold": THRESHOLD, - "multiclass": multiclass, - "ignore_index": ignore_index, - "mdmc_average": mdmc_average, - }, + metric_module=module, + metric_functional=functional, + metric_args={"num_classes": NUM_CLASSES}, ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_precision_recall_half_cpu(self, input, module, functional, compare, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) -@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) -def test_precision_recall_joint(average): - """A simple test of the joint precision_recall metric. - - No need to test this thorougly, as it is just a combination of precision and recall, which are already tested - thoroughly. - """ - - precision_result = precision( - _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES - ) - recall_result = recall( - _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES - ) - - prec_recall_result = precision_recall( - _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES - ) - - assert torch.equal(precision_result, prec_recall_result[0]) - assert torch.equal(recall_result, prec_recall_result[1]) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_precision_recall_half_gpu(self, input, module, functional, compare, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) _mc_k_target = tensor([0, 1, 2]) _mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) -_ml_k_target = tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) -_ml_k_preds = tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) -@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)]) +@pytest.mark.parametrize( + "metric_class, metric_fn", [(MulticlassPrecision, multiclass_precision), (MulticlassRecall, multiclass_recall)] +) @pytest.mark.parametrize( "k, preds, target, average, expected_prec, expected_recall", [ (1, _mc_k_preds, _mc_k_target, "micro", tensor(2 / 3), tensor(2 / 3)), (2, _mc_k_preds, _mc_k_target, "micro", tensor(1 / 2), tensor(1.0)), - (1, _ml_k_preds, _ml_k_target, "micro", tensor(0.0), tensor(0.0)), - (2, _ml_k_preds, _ml_k_target, "micro", tensor(1 / 6), tensor(1 / 3)), ], ) def test_top_k( @@ -401,15 +330,12 @@ def test_top_k( expected_prec: Tensor, expected_recall: Tensor, ): - """A simple test to check that top_k works as expected. - - Just a sanity check, the tests in StatScores should already guarantee the correctness of results. - """ + """A simple test to check that top_k works as expected.""" class_metric = metric_class(top_k=k, average=average, num_classes=3) class_metric.update(preds, target) - if metric_class.__name__ == "Precision": + if metric_class.__name__ == "MulticlassPrecision": result = expected_prec else: result = expected_recall @@ -418,53 +344,199 @@ def test_top_k( assert torch.equal(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result) -@pytest.mark.parametrize("metric_class, metric_fn", [(Precision, precision), (Recall, recall)]) +def _sk_precision_recall_multilabel_global(preds, target, sk_fn, ignore_index, average): + if average == "micro": + preds = preds.flatten() + target = target.flatten() + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_fn(target, preds) + + precision_recall, weights = [], [] + for i in range(preds.shape[1]): + pred, true = preds[:, i].flatten(), target[:, i].flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + precision_recall.append(sk_fn(true, pred)) + confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) + weights.append(confmat[1, 1] + confmat[1, 0]) + res = np.stack(precision_recall, axis=0) + + if average == "macro": + return res.mean(0) + elif average == "weighted": + weights = np.stack(weights, 0).astype(float) + weights_norm = weights.sum(-1, keepdims=True) + weights_norm[weights_norm == 0] = 1.0 + return ((weights * res) / weights_norm).sum(-1) + elif average is None or average == "none": + return res + + +def _sk_precision_recall_multilabel_local(preds, target, sk_fn, ignore_index, average): + precision_recall, weights = [], [] + for i in range(preds.shape[0]): + if average == "micro": + pred, true = preds[i].flatten(), target[i].flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + precision_recall.append(sk_fn(true, pred)) + confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) + weights.append(confmat[1, 1] + confmat[1, 0]) + else: + scores, w = [], [] + for j in range(preds.shape[1]): + pred, true = preds[i, j], target[i, j] + true, pred = remove_ignore_index(true, pred, ignore_index) + scores.append(sk_fn(true, pred)) + confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) + w.append(confmat[1, 1] + confmat[1, 0]) + precision_recall.append(np.stack(scores)) + weights.append(np.stack(w)) + if average == "micro": + return np.array(precision_recall) + res = np.stack(precision_recall, 0) + if average == "macro": + return res.mean(-1) + elif average == "weighted": + weights = np.stack(weights, 0).astype(float) + weights_norm = weights.sum(-1, keepdims=True) + weights_norm[weights_norm == 0] = 1.0 + return ((weights * res) / weights_norm).sum(-1) + elif average is None or average == "none": + return res + + +def _sk_precision_recall_multilabel(preds, target, sk_fn, ignore_index, multidim_average, average): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + preds = preds.reshape(*preds.shape[:2], -1) + target = target.reshape(*target.shape[:2], -1) + if ignore_index is None and multidim_average == "global": + return sk_fn( + target.transpose(0, 2, 1).reshape(-1, NUM_CLASSES), + preds.transpose(0, 2, 1).reshape(-1, NUM_CLASSES), + average=average, + ) + elif multidim_average == "global": + return _sk_precision_recall_multilabel_global(preds, target, sk_fn, ignore_index, average) + return _sk_precision_recall_multilabel_local(preds, target, sk_fn, ignore_index, average) + + +@pytest.mark.parametrize("input", _multilabel_cases) @pytest.mark.parametrize( - "ignore_index, expected", [(None, torch.tensor([1.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))] + "module, functional, compare", + [ + (MultilabelPrecision, multilabel_precision, sk_precision_score), + (MultilabelRecall, multilabel_recall, sk_recall_score), + ], + ids=["precision", "recall"], ) -def test_class_not_present(metric_class, metric_fn, ignore_index, expected): - """This tests that when metric is computed per class and a given class is not present in both the `preds` and - `target`, the resulting score is `nan`.""" - preds = torch.tensor([0, 0, 0]) - target = torch.tensor([0, 0, 0]) - num_classes = 2 +class TestMultilabelPrecisionRecall(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + def test_multilabel_precision_recall( + self, ddp, input, module, functional, compare, ignore_index, multidim_average, average + ): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=module, + sk_metric=partial( + _sk_precision_recall_multilabel, + sk_fn=compare, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + }, + ) - # test functional - result_fn = metric_fn(preds, target, average=AverageMethod.NONE, num_classes=num_classes, ignore_index=ignore_index) - assert torch.allclose(expected, result_fn, equal_nan=True) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + def test_multilabel_precision_recall_functional( + self, input, module, functional, compare, ignore_index, multidim_average, average + ): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") - # test class - cl_metric = metric_class(average=AverageMethod.NONE, num_classes=num_classes, ignore_index=ignore_index) - cl_metric(preds, target) - result_cl = cl_metric.compute() - assert torch.allclose(expected, result_cl, equal_nan=True) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=functional, + sk_metric=partial( + _sk_precision_recall_multilabel, + sk_fn=compare, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + }, + ) + + def test_multilabel_precision_recall_differentiability(self, input, module, functional, compare): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_precision_recall_half_cpu(self, input, module, functional, compare, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) -@pytest.mark.parametrize("average", ["micro", "macro", "weighted"]) -@pytest.mark.parametrize( - "metric_class, metric_functional, sk_fn", [(Precision, precision, precision_score), (Recall, recall, recall_score)] -) -def test_same_input(metric_class, metric_functional, sk_fn, average): - preds = _input_miss_class.preds - target = _input_miss_class.target - preds_flat = torch.cat(list(preds), dim=0) - target_flat = torch.cat(list(target), dim=0) - - mc = metric_class(num_classes=NUM_CLASSES, average=average) - for i in range(NUM_BATCHES): - mc.update(preds[i], target[i]) - class_res = mc.compute() - func_res = metric_functional(preds_flat, target_flat, num_classes=NUM_CLASSES, average=average) - sk_res = sk_fn(target_flat, preds_flat, average=average, zero_division=1) - - assert torch.allclose(class_res, torch.tensor(sk_res).float()) - assert torch.allclose(func_res, torch.tensor(sk_res).float()) - - -@pytest.mark.parametrize("metric_cls", [Precision, Recall]) -def test_noneavg(metric_cls, noneavg=_negmetric_noneavg): - prec = MetricWrapper(metric_cls(average="none", num_classes=noneavg["pred1"].shape[1])) - result1 = prec(noneavg["pred1"], noneavg["target1"]) - assert torch.allclose(noneavg["res1"], result1, equal_nan=True) - result2 = prec(noneavg["pred2"], noneavg["target2"]) - assert torch.allclose(noneavg["res2"], result2, equal_nan=True) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_precision_recall_half_gpu(self, input, module, functional, compare, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=module, + metric_functional=functional, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) diff --git a/tests/unittests/classification/test_precision_recall_curve.py b/tests/unittests/classification/test_precision_recall_curve.py index c56441a7f4c..9bebfc9c40b 100644 --- a/tests/unittests/classification/test_precision_recall_curve.py +++ b/tests/unittests/classification/test_precision_recall_curve.py @@ -16,131 +16,338 @@ import numpy as np import pytest import torch +from scipy.special import expit as sigmoid +from scipy.special import softmax from sklearn.metrics import precision_recall_curve as sk_precision_recall_curve -from torch import Tensor, tensor - -from torchmetrics.classification.precision_recall_curve import PrecisionRecallCurve -from torchmetrics.functional import precision_recall_curve -from torchmetrics.functional.classification.precision_recall_curve import _binary_clf_curve -from unittests.classification.inputs import _input_binary_prob -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob + +from torchmetrics.classification.precision_recall_curve import ( + BinaryPrecisionRecallCurve, + MulticlassPrecisionRecallCurve, + MultilabelPrecisionRecallCurve, +) +from torchmetrics.functional.classification.precision_recall_curve import ( + binary_precision_recall_curve, + multiclass_precision_recall_curve, + multilabel_precision_recall_curve, +) +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_8 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all -from unittests.helpers.testers import NUM_CLASSES, MetricTester +from unittests.helpers.testers import NUM_CLASSES, MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sk_precision_recall_curve(y_true, probas_pred, num_classes=1): - """Adjusted comparison function that can also handles multiclass.""" - if num_classes == 1: - return sk_precision_recall_curve(y_true, probas_pred) +def _sk_precision_recall_curve_binary(preds, target, ignore_index=None): + preds = preds.flatten().numpy() + target = target.flatten().numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_precision_recall_curve(target, preds) + + +@pytest.mark.parametrize("input", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5])) +class TestBinaryPrecisionRecallCurve(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_binary_precision_recall_curve(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryPrecisionRecallCurve, + sk_metric=partial(_sk_precision_recall_curve_binary, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_binary_precision_recall_curve_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_precision_recall_curve, + sk_metric=partial(_sk_precision_recall_curve_binary, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "ignore_index": ignore_index, + }, + ) + + def test_binary_precision_recall_curve_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryPrecisionRecallCurve, + metric_functional=binary_precision_recall_curve, + metric_args={"thresholds": None}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_precision_recall_curve_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryPrecisionRecallCurve, + metric_functional=binary_precision_recall_curve, + metric_args={"thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_precision_recall_curve_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryPrecisionRecallCurve, + metric_functional=binary_precision_recall_curve, + metric_args={"thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.parametrize("threshold_fn", [lambda x: x, lambda x: x.numpy().tolist()], ids=["as tensor", "as list"]) + def test_binary_precision_recall_curve_threshold_arg(self, input, threshold_fn): + preds, target = input + + for pred, true in zip(preds, target): + p1, r1, t1 = binary_precision_recall_curve(pred, true, thresholds=None) + p2, r2, t2 = binary_precision_recall_curve(pred, true, thresholds=threshold_fn(t1)) + + assert torch.allclose(p1, p2) + assert torch.allclose(r1, r2) + assert torch.allclose(t1, t2) + + +def _sk_precision_recall_curve_multiclass(preds, target, ignore_index=None): + preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) + target = target.numpy().flatten() + if not ((0 < preds) & (preds < 1)).all(): + preds = softmax(preds, 1) + target, preds = remove_ignore_index(target, preds, ignore_index) precision, recall, thresholds = [], [], [] - for i in range(num_classes): - y_true_temp = np.zeros_like(y_true) - y_true_temp[y_true == i] = 1 - res = sk_precision_recall_curve(y_true_temp, probas_pred[:, i]) + for i in range(NUM_CLASSES): + target_temp = np.zeros_like(target) + target_temp[target == i] = 1 + res = sk_precision_recall_curve(target_temp, preds[:, i]) precision.append(res[0]) recall.append(res[1]) thresholds.append(res[2]) - return precision, recall, thresholds + # return precision, recall, thresholds + return [np.nan_to_num(x, nan=0.0) for x in [precision, recall, thresholds]] -def _sk_prec_rc_binary_prob(preds, target, num_classes=1): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() +@pytest.mark.parametrize( + "input", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5]) +) +class TestMulticlassPrecisionRecallCurve(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_precision_recall_curve(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassPrecisionRecallCurve, + sk_metric=partial(_sk_precision_recall_curve_multiclass, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) - return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + @pytest.mark.parametrize("ignore_index", [None, -1]) + def test_multiclass_precision_recall_curve_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_precision_recall_curve, + sk_metric=partial(_sk_precision_recall_curve_multiclass, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) + def test_multiclass_precision_recall_curve_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassPrecisionRecallCurve, + metric_functional=multiclass_precision_recall_curve, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + ) -def _sk_prec_rc_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.view(-1).numpy() + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_precision_recall_curve_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if dtype == torch.half and not ((0 < preds) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassPrecisionRecallCurve, + metric_functional=multiclass_precision_recall_curve, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) - return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_precision_recall_curve_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassPrecisionRecallCurve, + metric_functional=multiclass_precision_recall_curve, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) + @pytest.mark.parametrize("threshold_fn", [lambda x: x, lambda x: x.numpy().tolist()], ids=["as tensor", "as list"]) + def test_multiclass_precision_recall_curve_threshold_arg(self, input, threshold_fn): + preds, target = input + for pred, true in zip(preds, target): + p1, r1, t1 = multiclass_precision_recall_curve(pred, true, num_classes=NUM_CLASSES, thresholds=None) + for i, t in enumerate(t1): + p2, r2, t2 = multiclass_precision_recall_curve( + pred, true, num_classes=NUM_CLASSES, thresholds=threshold_fn(t) + ) -def _sk_prec_rc_multidim_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.view(-1).numpy() - return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + assert torch.allclose(p1[i], p2[i]) + assert torch.allclose(r1[i], r2[i]) + assert torch.allclose(t1[i], t2) + + +def _sk_precision_recall_curve_multilabel(preds, target, ignore_index=None): + precision, recall, thresholds = [], [], [] + for i in range(NUM_CLASSES): + res = _sk_precision_recall_curve_binary(preds[:, i], target[:, i], ignore_index) + precision.append(res[0]) + recall.append(res[1]) + thresholds.append(res[2]) + return precision, recall, thresholds @pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", - [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_prec_rc_binary_prob, 1), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_prec_rc_multiclass_prob, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_prec_rc_multidim_multiclass_prob, NUM_CLASSES), - ], + "input", (_multilabel_cases[1], _multilabel_cases[2], _multilabel_cases[4], _multilabel_cases[5]) ) -class TestPrecisionRecallCurve(MetricTester): +class TestMultilabelPrecisionRecallCurve(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_precision_recall_curve(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + def test_multilabel_precision_recall_curve(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=PrecisionRecallCurve, - sk_metric=partial(sk_metric, num_classes=num_classes), - dist_sync_on_step=dist_sync_on_step, - metric_args={"num_classes": num_classes}, + metric_class=MultilabelPrecisionRecallCurve, + sk_metric=partial(_sk_precision_recall_curve_multilabel, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, + }, ) - def test_precision_recall_curve_functional(self, preds, target, sk_metric, num_classes): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multilabel_precision_recall_curve_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_functional_metric_test( - preds, - target, - metric_functional=precision_recall_curve, - sk_metric=partial(sk_metric, num_classes=num_classes), - metric_args={"num_classes": num_classes}, + preds=preds, + target=target, + metric_functional=multilabel_precision_recall_curve, + sk_metric=partial(_sk_precision_recall_curve_multilabel, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, + }, ) - def test_precision_recall_curve_differentiability(self, preds, target, sk_metric, num_classes): + def test_multiclass_precision_recall_curve_differentiability(self, input): + preds, target = input self.run_differentiability_test( - preds, - target, - metric_module=PrecisionRecallCurve, - metric_functional=precision_recall_curve, - metric_args={"num_classes": num_classes}, + preds=preds, + target=target, + metric_module=MultilabelPrecisionRecallCurve, + metric_functional=multilabel_precision_recall_curve, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_precision_recall_curve_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if dtype == torch.half and not ((0 < preds) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelPrecisionRecallCurve, + metric_functional=multilabel_precision_recall_curve, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, + ) -@pytest.mark.parametrize( - ["pred", "target", "expected_p", "expected_r", "expected_t"], - [([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1.0, 1.0], [1, 0.5, 0.5, 0.5, 0.0], [1, 2, 3, 4])], -) -def test_pr_curve(pred, target, expected_p, expected_r, expected_t): - p, r, t = precision_recall_curve(tensor(pred), tensor(target)) - assert p.size() == r.size() - assert p.size(0) == t.size(0) + 1 - - assert torch.allclose(p, tensor(expected_p).to(p)) - assert torch.allclose(r, tensor(expected_r).to(r)) - assert torch.allclose(t, tensor(expected_t).to(t)) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_precision_recall_curve_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelPrecisionRecallCurve, + metric_functional=multilabel_precision_recall_curve, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, + ) + @pytest.mark.parametrize("threshold_fn", [lambda x: x, lambda x: x.numpy().tolist()], ids=["as tensor", "as list"]) + def test_multilabel_precision_recall_curve_threshold_arg(self, input, threshold_fn): + preds, target = input + for pred, true in zip(preds, target): + p1, r1, t1 = multilabel_precision_recall_curve(pred, true, num_labels=NUM_CLASSES, thresholds=None) + for i, t in enumerate(t1): + p2, r2, t2 = multilabel_precision_recall_curve( + pred, true, num_labels=NUM_CLASSES, thresholds=threshold_fn(t) + ) -@pytest.mark.parametrize( - "sample_weight, pos_label, exp_shape", - [(1, 1.0, 42), (None, 1.0, 42)], -) -def test_binary_clf_curve(sample_weight, pos_label, exp_shape): - # TODO: move back the pred and target to test func arguments - # if you fix the array inside the function, you'd also have fix the shape, - # because when the array changes, you also have to fix the shape - seed_all(0) - pred = torch.randint(low=51, high=99, size=(100,), dtype=torch.float) / 100 - target = tensor([0, 1] * 50, dtype=torch.int) - if sample_weight is not None: - sample_weight = torch.ones_like(pred) * sample_weight - - fps, tps, thresh = _binary_clf_curve(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) - - assert isinstance(tps, Tensor) - assert isinstance(fps, Tensor) - assert isinstance(thresh, Tensor) - assert tps.shape == (exp_shape,) - assert fps.shape == (exp_shape,) - assert thresh.shape == (exp_shape,) + assert torch.allclose(p1[i], p2[i]) + assert torch.allclose(r1[i], r2[i]) + assert torch.allclose(t1[i], t2) diff --git a/tests/unittests/classification/test_ranking.py b/tests/unittests/classification/test_ranking.py index 070d43e47e5..280c0a26865 100644 --- a/tests/unittests/classification/test_ranking.py +++ b/tests/unittests/classification/test_ranking.py @@ -11,91 +11,138 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial + +import numpy as np import pytest import torch +from scipy.special import expit as sigmoid from sklearn.metrics import coverage_error as sk_coverage_error from sklearn.metrics import label_ranking_average_precision_score as sk_label_ranking from sklearn.metrics import label_ranking_loss as sk_label_ranking_loss -from torchmetrics.classification.ranking import CoverageError, LabelRankingAveragePrecision, LabelRankingLoss +from torchmetrics.classification.ranking import ( + MultilabelCoverageError, + MultilabelRankingAveragePrecision, + MultilabelRankingLoss, +) from torchmetrics.functional.classification.ranking import ( - coverage_error, - label_ranking_average_precision, - label_ranking_loss, + multilabel_coverage_error, + multilabel_ranking_average_precision, + multilabel_ranking_loss, ) -from unittests.classification.inputs import _input_multilabel_logits as _input_mlb_logits -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_9 +from unittests.classification.inputs import _multilabel_cases from unittests.helpers import seed_all -from unittests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from unittests.helpers.testers import NUM_CLASSES, MetricTester, inject_ignore_index seed_all(42) -def _sk_coverage_error(preds, target, sample_weight=None): - if sample_weight is not None: - sample_weight = sample_weight.numpy() - return sk_coverage_error(target.numpy(), preds.numpy(), sample_weight=sample_weight) - - -def _sk_label_ranking(preds, target, sample_weight=None): - if sample_weight is not None: - sample_weight = sample_weight.numpy() - return sk_label_ranking(target.numpy(), preds.numpy(), sample_weight=sample_weight) - - -def _sk_label_ranking_loss(preds, target, sample_weight=None): - if sample_weight is not None: - sample_weight = sample_weight.numpy() - return sk_label_ranking_loss(target.numpy(), preds.numpy(), sample_weight=sample_weight) +def _sk_ranking(preds, target, fn, ignore_index): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = np.moveaxis(preds, 1, -1).reshape((-1, preds.shape[1])) + target = np.moveaxis(target, 1, -1).reshape((-1, target.shape[1])) + if ignore_index is not None: + idx = target == ignore_index + target[idx] = -1 + return fn(target, preds) @pytest.mark.parametrize( "metric, functional_metric, sk_metric", [ - (CoverageError, coverage_error, _sk_coverage_error), - (LabelRankingAveragePrecision, label_ranking_average_precision, _sk_label_ranking), - (LabelRankingLoss, label_ranking_loss, _sk_label_ranking_loss), + (MultilabelCoverageError, multilabel_coverage_error, sk_coverage_error), + (MultilabelRankingAveragePrecision, multilabel_ranking_average_precision, sk_label_ranking), + (MultilabelRankingLoss, multilabel_ranking_loss, sk_label_ranking_loss), ], ) @pytest.mark.parametrize( - "preds, target", - [ - (_input_mlb_logits.preds, _input_mlb_logits.target), - (_input_mlb_prob.preds, _input_mlb_prob.target), - ], + "input", (_multilabel_cases[1], _multilabel_cases[2], _multilabel_cases[4], _multilabel_cases[5]) ) -@pytest.mark.parametrize("sample_weight", [None, torch.rand(NUM_BATCHES, BATCH_SIZE)]) -class TestRanking(MetricTester): - @pytest.mark.parametrize("ddp", [False, True]) - @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_ranking_class( - self, ddp, dist_sync_on_step, preds, target, metric, functional_metric, sk_metric, sample_weight - ): +class TestMultilabelRanking(MetricTester): + @pytest.mark.parametrize("ignore_index", [None]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multilabel_ranking(self, input, metric, functional_metric, sk_metric, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_class_metric_test( ddp=ddp, preds=preds, target=target, metric_class=metric, - sk_metric=sk_metric, - dist_sync_on_step=dist_sync_on_step, - fragment_kwargs=True, - sample_weight=sample_weight, + sk_metric=partial(_sk_ranking, fn=sk_metric, ignore_index=ignore_index), + metric_args={ + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, + }, ) - def test_ranking_functional(self, preds, target, metric, functional_metric, sk_metric, sample_weight): + @pytest.mark.parametrize("ignore_index", [None]) + def test_multilabel_ranking_functional(self, input, metric, functional_metric, sk_metric, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_functional_metric_test( - preds, - target, + preds=preds, + target=target, metric_functional=functional_metric, - sk_metric=sk_metric, - fragment_kwargs=True, - sample_weight=sample_weight, + sk_metric=partial(_sk_ranking, fn=sk_metric, ignore_index=ignore_index), + metric_args={ + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, + }, ) - def test_ranking_differentiability(self, preds, target, metric, functional_metric, sk_metric, sample_weight): + def test_multilabel_ranking_differentiability(self, input, metric, functional_metric, sk_metric): + preds, target = input self.run_differentiability_test( preds=preds, target=target, metric_module=metric, metric_functional=functional_metric, + metric_args={"num_labels": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_ranking_dtype_cpu(self, input, metric, functional_metric, sk_metric, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if dtype == torch.half and functional_metric == multilabel_ranking_average_precision: + pytest.xfail( + reason="multilabel_ranking_average_precision requires torch.unique which is not implemented for half" + ) + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_9 and functional_metric == multilabel_coverage_error: + pytest.xfail( + reason="multilabel_coverage_error requires torch.min which is only implemented for half" + " in v1.9 or higher of torch." + ) + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=metric, + metric_functional=functional_metric, + metric_args={"num_labels": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_ranking_dtype_gpu(self, input, metric, functional_metric, sk_metric, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=metric, + metric_functional=functional_metric, + metric_args={"num_labels": NUM_CLASSES}, + dtype=dtype, ) diff --git a/tests/unittests/classification/test_recall_at_fixed_precision.py b/tests/unittests/classification/test_recall_at_fixed_precision.py new file mode 100644 index 00000000000..3032f757cde --- /dev/null +++ b/tests/unittests/classification/test_recall_at_fixed_precision.py @@ -0,0 +1,389 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import numpy as np +import pytest +import torch +from scipy.special import expit as sigmoid +from scipy.special import softmax +from sklearn.metrics import precision_recall_curve as _sk_precision_recall_curve + +from torchmetrics.classification.recall_at_fixed_precision import ( + BinaryRecallAtFixedPrecision, + MulticlassRecallAtFixedPrecision, + MultilabelRecallAtFixedPrecision, +) +from torchmetrics.functional.classification.recall_at_fixed_precision import ( + binary_recall_at_fixed_precision, + multiclass_recall_at_fixed_precision, + multilabel_recall_at_fixed_precision, +) +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_8 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.helpers import seed_all +from unittests.helpers.testers import NUM_CLASSES, MetricTester, inject_ignore_index, remove_ignore_index + +seed_all(42) + + +def recall_at_precision_x_multilabel(predictions, targets, min_precision): + precision, recall, thresholds = _sk_precision_recall_curve(targets, predictions) + + try: + tuple_all = [(r, p, t) for p, r, t in zip(precision, recall, thresholds) if p >= min_precision] + max_recall, _, best_threshold = max(tuple_all) + except ValueError: + max_recall, best_threshold = 0, 1e6 + + return float(max_recall), float(best_threshold) + + +def _sk_recall_at_fixed_precision_binary(preds, target, min_precision, ignore_index=None): + preds = preds.flatten().numpy() + target = target.flatten().numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + target, preds = remove_ignore_index(target, preds, ignore_index) + return recall_at_precision_x_multilabel(preds, target, min_precision) + + +@pytest.mark.parametrize("input", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5])) +class TestBinaryRecallAtFixedPrecision(MetricTester): + @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.85]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_binary_recall_at_fixed_precision(self, input, ddp, min_precision, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryRecallAtFixedPrecision, + sk_metric=partial( + _sk_recall_at_fixed_precision_binary, min_precision=min_precision, ignore_index=ignore_index + ), + metric_args={ + "min_precision": min_precision, + "thresholds": None, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_binary_recall_at_fixed_precision_functional(self, input, min_precision, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_recall_at_fixed_precision, + sk_metric=partial( + _sk_recall_at_fixed_precision_binary, min_precision=min_precision, ignore_index=ignore_index + ), + metric_args={ + "min_precision": min_precision, + "thresholds": None, + "ignore_index": ignore_index, + }, + ) + + def test_binary_recall_at_fixed_precision_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryRecallAtFixedPrecision, + metric_functional=binary_recall_at_fixed_precision, + metric_args={"min_precision": 0.5, "thresholds": None}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_recall_at_fixed_precision_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8: + pytest.xfail(reason="torch.flip not support before pytorch v1.8 for cpu + half precision") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryRecallAtFixedPrecision, + metric_functional=binary_recall_at_fixed_precision, + metric_args={"min_precision": 0.5, "thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_recall_at_fixed_precision_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryRecallAtFixedPrecision, + metric_functional=binary_recall_at_fixed_precision, + metric_args={"min_precision": 0.5, "thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8]) + def test_binary_recall_at_fixed_precision_threshold_arg(self, input, min_precision): + preds, target = input + + for pred, true in zip(preds, target): + pred = torch.tensor(np.round(pred.numpy(), 1)) + 1e-6 # rounding will simulate binning + r1, _ = binary_recall_at_fixed_precision(pred, true, min_precision=min_precision, thresholds=None) + r2, _ = binary_recall_at_fixed_precision( + pred, true, min_precision=min_precision, thresholds=torch.linspace(0, 1, 100) + ) + assert torch.allclose(r1, r2) + + +def _sk_recall_at_fixed_precision_multiclass(preds, target, min_precision, ignore_index=None): + preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) + target = target.numpy().flatten() + if not ((0 < preds) & (preds < 1)).all(): + preds = softmax(preds, 1) + target, preds = remove_ignore_index(target, preds, ignore_index) + + recall, thresholds = [], [] + for i in range(NUM_CLASSES): + target_temp = np.zeros_like(target) + target_temp[target == i] = 1 + res = recall_at_precision_x_multilabel(preds[:, i], target_temp, min_precision) + recall.append(res[0]) + thresholds.append(res[1]) + return recall, thresholds + + +@pytest.mark.parametrize( + "input", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5]) +) +class TestMulticlassRecallAtFixedPrecision(MetricTester): + @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_recall_at_fixed_precision(self, input, ddp, min_precision, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassRecallAtFixedPrecision, + sk_metric=partial( + _sk_recall_at_fixed_precision_multiclass, min_precision=min_precision, ignore_index=ignore_index + ), + metric_args={ + "min_precision": min_precision, + "thresholds": None, + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multiclass_recall_at_fixed_precision_functional(self, input, min_precision, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_recall_at_fixed_precision, + sk_metric=partial( + _sk_recall_at_fixed_precision_multiclass, min_precision=min_precision, ignore_index=ignore_index + ), + metric_args={ + "min_precision": min_precision, + "thresholds": None, + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) + + def test_multiclass_recall_at_fixed_precision_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassRecallAtFixedPrecision, + metric_functional=multiclass_recall_at_fixed_precision, + metric_args={"min_precision": 0.5, "thresholds": None, "num_classes": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_recall_at_fixed_precision_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8: + pytest.xfail(reason="torch.flip not support before pytorch v1.8 for cpu + half precision") + if dtype == torch.half and not ((0 < preds) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassRecallAtFixedPrecision, + metric_functional=multiclass_recall_at_fixed_precision, + metric_args={"min_precision": 0.5, "thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_recall_at_fixed_precision_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassRecallAtFixedPrecision, + metric_functional=multiclass_recall_at_fixed_precision, + metric_args={"min_precision": 0.5, "thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8]) + def test_multiclass_recall_at_fixed_precision_threshold_arg(self, input, min_precision): + preds, target = input + if (preds < 0).any(): + preds = preds.softmax(dim=-1) + for pred, true in zip(preds, target): + pred = torch.tensor(np.round(pred.numpy(), 1)) + 1e-6 # rounding will simulate binning + r1, _ = multiclass_recall_at_fixed_precision( + pred, true, num_classes=NUM_CLASSES, min_precision=min_precision, thresholds=None + ) + r2, _ = multiclass_recall_at_fixed_precision( + pred, true, num_classes=NUM_CLASSES, min_precision=min_precision, thresholds=torch.linspace(0, 1, 100) + ) + assert all(torch.allclose(r1[i], r2[i]) for i in range(len(r1))) + + +def _sk_recall_at_fixed_precision_multilabel(preds, target, min_precision, ignore_index=None): + recall, thresholds = [], [] + for i in range(NUM_CLASSES): + res = _sk_recall_at_fixed_precision_binary(preds[:, i], target[:, i], min_precision, ignore_index) + recall.append(res[0]) + thresholds.append(res[1]) + return recall, thresholds + + +@pytest.mark.parametrize( + "input", (_multilabel_cases[1], _multilabel_cases[2], _multilabel_cases[4], _multilabel_cases[5]) +) +class TestMultilabelRecallAtFixedPrecision(MetricTester): + @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multilabel_recall_at_fixed_precision(self, input, ddp, min_precision, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MultilabelRecallAtFixedPrecision, + sk_metric=partial( + _sk_recall_at_fixed_precision_multilabel, min_precision=min_precision, ignore_index=ignore_index + ), + metric_args={ + "min_precision": min_precision, + "thresholds": None, + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multilabel_recall_at_fixed_precision_functional(self, input, min_precision, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multilabel_recall_at_fixed_precision, + sk_metric=partial( + _sk_recall_at_fixed_precision_multilabel, min_precision=min_precision, ignore_index=ignore_index + ), + metric_args={ + "min_precision": min_precision, + "thresholds": None, + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) + + def test_multiclass_recall_at_fixed_precision_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MultilabelRecallAtFixedPrecision, + metric_functional=multilabel_recall_at_fixed_precision, + metric_args={"min_precision": 0.5, "thresholds": None, "num_labels": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_recall_at_fixed_precision_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8: + pytest.xfail(reason="torch.flip not support before pytorch v1.8 for cpu + half precision") + if dtype == torch.half and not ((0 < preds) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelRecallAtFixedPrecision, + metric_functional=multilabel_recall_at_fixed_precision, + metric_args={"min_precision": 0.5, "thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_recall_at_fixed_precision_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelRecallAtFixedPrecision, + metric_functional=multilabel_recall_at_fixed_precision, + metric_args={"min_precision": 0.5, "thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8]) + def test_multilabel_recall_at_fixed_precision_threshold_arg(self, input, min_precision): + preds, target = input + if (preds < 0).any(): + preds = sigmoid(preds) + for pred, true in zip(preds, target): + pred = torch.tensor(np.round(pred.numpy(), 1)) + 1e-6 # rounding will simulate binning + r1, _ = multilabel_recall_at_fixed_precision( + pred, true, num_labels=NUM_CLASSES, min_precision=min_precision, thresholds=None + ) + r2, _ = multilabel_recall_at_fixed_precision( + pred, true, num_labels=NUM_CLASSES, min_precision=min_precision, thresholds=torch.linspace(0, 1, 100) + ) + assert all(torch.allclose(r1[i], r2[i]) for i in range(len(r1))) diff --git a/tests/unittests/classification/test_roc.py b/tests/unittests/classification/test_roc.py index b5b787e9221..c3f21c26089 100644 --- a/tests/unittests/classification/test_roc.py +++ b/tests/unittests/classification/test_roc.py @@ -16,152 +16,327 @@ import numpy as np import pytest import torch +from scipy.special import expit as sigmoid +from scipy.special import softmax from sklearn.metrics import roc_curve as sk_roc_curve -from torch import tensor - -from torchmetrics.classification.roc import ROC -from torchmetrics.functional import roc -from unittests.classification.inputs import _input_binary_prob -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob + +from torchmetrics.classification.roc import BinaryROC, MulticlassROC, MultilabelROC +from torchmetrics.functional.classification.roc import binary_roc, multiclass_roc, multilabel_roc +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all -from unittests.helpers.testers import NUM_CLASSES, MetricTester +from unittests.helpers.testers import NUM_CLASSES, MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sk_roc_curve(y_true, probas_pred, num_classes: int = 1, multilabel: bool = False): - """Adjusted comparison function that can also handles multiclass.""" - if num_classes == 1: - return sk_roc_curve(y_true, probas_pred, drop_intermediate=False) +def _sk_roc_binary(preds, target, ignore_index=None): + preds = preds.flatten().numpy() + target = target.flatten().numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + target, preds = remove_ignore_index(target, preds, ignore_index) + fpr, tpr, thresholds = sk_roc_curve(target, preds, drop_intermediate=False) + thresholds[0] = 1.0 + return [np.nan_to_num(x, nan=0.0) for x in [fpr, tpr, thresholds]] + + +@pytest.mark.parametrize("input", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5])) +class TestBinaryROC(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_binary_roc(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryROC, + sk_metric=partial(_sk_roc_binary, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_binary_roc_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_roc, + sk_metric=partial(_sk_roc_binary, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "ignore_index": ignore_index, + }, + ) + + def test_binary_roc_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryROC, + metric_functional=binary_roc, + metric_args={"thresholds": None}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_roc_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryROC, + metric_functional=binary_roc, + metric_args={"thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_roc_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryROC, + metric_functional=binary_roc, + metric_args={"thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.parametrize("threshold_fn", [lambda x: x, lambda x: x.numpy().tolist()], ids=["as tensor", "as list"]) + def test_binary_roc_threshold_arg(self, input, threshold_fn): + preds, target = input + for pred, true in zip(preds, target): + p1, r1, t1 = binary_roc(pred, true, thresholds=None) + p2, r2, t2 = binary_roc(pred, true, thresholds=threshold_fn(t1.flip(0))) + assert torch.allclose(p1, p2) + assert torch.allclose(r1, r2) + assert torch.allclose(t1, t2) + + +def _sk_roc_multiclass(preds, target, ignore_index=None): + preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) + target = target.numpy().flatten() + if not ((0 < preds) & (preds < 1)).all(): + preds = softmax(preds, 1) + target, preds = remove_ignore_index(target, preds, ignore_index) fpr, tpr, thresholds = [], [], [] - for i in range(num_classes): - if multilabel: - y_true_temp = y_true[:, i] - else: - y_true_temp = np.zeros_like(y_true) - y_true_temp[y_true == i] = 1 - - res = sk_roc_curve(y_true_temp, probas_pred[:, i], drop_intermediate=False) + for i in range(NUM_CLASSES): + target_temp = np.zeros_like(target) + target_temp[target == i] = 1 + res = sk_roc_curve(target_temp, preds[:, i], drop_intermediate=False) + res[2][0] = 1.0 + fpr.append(res[0]) tpr.append(res[1]) thresholds.append(res[2]) - return fpr, tpr, thresholds - - -def _sk_roc_binary_prob(preds, target, num_classes=1): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() + return [np.nan_to_num(x, nan=0.0) for x in [fpr, tpr, thresholds]] - return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) +@pytest.mark.parametrize( + "input", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5]) +) +class TestMulticlassROC(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_roc(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassROC, + sk_metric=partial(_sk_roc_multiclass, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) -def _sk_roc_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.view(-1).numpy() + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multiclass_roc_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_roc, + sk_metric=partial(_sk_roc_multiclass, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) - return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + def test_multiclass_roc_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassROC, + metric_functional=multiclass_roc, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_roc_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if dtype == torch.half and not ((0 < preds) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassROC, + metric_functional=multiclass_roc, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) -def _sk_roc_multidim_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.view(-1).numpy() - return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_roc_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassROC, + metric_functional=multiclass_roc, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) + @pytest.mark.parametrize("threshold_fn", [lambda x: x, lambda x: x.numpy().tolist()], ids=["as tensor", "as list"]) + def test_multiclass_roc_threshold_arg(self, input, threshold_fn): + preds, target = input + for pred, true in zip(preds, target): + p1, r1, t1 = multiclass_roc(pred, true, num_classes=NUM_CLASSES, thresholds=None) + for i, t in enumerate(t1): + p2, r2, t2 = multiclass_roc(pred, true, num_classes=NUM_CLASSES, thresholds=threshold_fn(t.flip(0))) -def _sk_roc_multilabel_prob(preds, target, num_classes=1): - sk_preds = preds.numpy() - sk_target = target.numpy() - return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, multilabel=True) + assert torch.allclose(p1[i], p2[i]) + assert torch.allclose(r1[i], r2[i]) + assert torch.allclose(t1[i], t2) -def _sk_roc_multilabel_multidim_prob(preds, target, num_classes=1): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, multilabel=True) +def _sk_roc_multilabel(preds, target, ignore_index=None): + fpr, tpr, thresholds = [], [], [] + for i in range(NUM_CLASSES): + res = _sk_roc_binary(preds[:, i], target[:, i], ignore_index) + fpr.append(res[0]) + tpr.append(res[1]) + thresholds.append(res[2]) + return fpr, tpr, thresholds @pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", - [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_roc_multilabel_prob, NUM_CLASSES), - (_input_mlmd_prob.preds, _input_mlmd_prob.target, _sk_roc_multilabel_multidim_prob, NUM_CLASSES), - ], + "input", (_multilabel_cases[1], _multilabel_cases[2], _multilabel_cases[4], _multilabel_cases[5]) ) -class TestROC(MetricTester): +class TestMultilabelROC(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_roc(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + def test_multilabel_roc(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=ROC, - sk_metric=partial(sk_metric, num_classes=num_classes), - dist_sync_on_step=dist_sync_on_step, - metric_args={"num_classes": num_classes}, + metric_class=MultilabelROC, + sk_metric=partial(_sk_roc_multilabel, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, + }, ) - def test_roc_functional(self, preds, target, sk_metric, num_classes): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multilabel_roc_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_functional_metric_test( - preds, - target, - metric_functional=roc, - sk_metric=partial(sk_metric, num_classes=num_classes), - metric_args={"num_classes": num_classes}, + preds=preds, + target=target, + metric_functional=multilabel_roc, + sk_metric=partial(_sk_roc_multilabel, ignore_index=ignore_index), + metric_args={ + "thresholds": None, + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, + }, ) - def test_roc_differentiability(self, preds, target, sk_metric, num_classes): + def test_multiclass_roc_differentiability(self, input): + preds, target = input self.run_differentiability_test( - preds, - target, - metric_module=ROC, - metric_functional=roc, - metric_args={"num_classes": num_classes}, + preds=preds, + target=target, + metric_module=MultilabelROC, + metric_functional=multilabel_roc, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_roc_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if dtype == torch.half and not ((0 < preds) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelROC, + metric_functional=multilabel_roc, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_roc_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelROC, + metric_functional=multilabel_roc, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, + ) -@pytest.mark.parametrize( - ["pred", "target", "expected_tpr", "expected_fpr"], - [ - ([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]), - ([1, 0], [0, 1], [0, 0, 1], [0, 1, 1]), - ([1, 1], [1, 0], [0, 1], [0, 1]), - ([1, 0], [1, 0], [0, 1, 1], [0, 0, 1]), - ([0.5, 0.5], [0, 1], [0, 1], [0, 1]), - ], -) -def test_roc_curve(pred, target, expected_tpr, expected_fpr): - fpr, tpr, thresh = roc(tensor(pred), tensor(target)) - - assert fpr.shape == tpr.shape - assert fpr.size(0) == thresh.size(0) - assert torch.allclose(fpr, tensor(expected_fpr).to(fpr)) - assert torch.allclose(tpr, tensor(expected_tpr).to(tpr)) - - -def test_warnings_on_missing_class(): - """Test that a warning is given if either the positive or negative class is missing.""" - metric = ROC() - # no positive samples - warning = ( - "No positive samples in targets, true positive value should be meaningless." - " Returning zero tensor in true positive score" - ) - with pytest.warns(UserWarning, match=warning): - _, tpr, _ = metric(torch.randn(10).sigmoid(), torch.zeros(10)) - assert all(tpr == 0) - - warning = ( - "No negative samples in targets, false positive value should be meaningless." - " Returning zero tensor in false positive score" - ) - with pytest.warns(UserWarning, match=warning): - fpr, _, _ = metric(torch.randn(10).sigmoid(), torch.ones(10)) - assert all(fpr == 0) + @pytest.mark.parametrize("threshold_fn", [lambda x: x, lambda x: x.numpy().tolist()], ids=["as tensor", "as list"]) + def test_multilabel_roc_threshold_arg(self, input, threshold_fn): + preds, target = input + for pred, true in zip(preds, target): + p1, r1, t1 = multilabel_roc(pred, true, num_labels=NUM_CLASSES, thresholds=None) + for i, t in enumerate(t1): + p2, r2, t2 = multilabel_roc(pred, true, num_labels=NUM_CLASSES, thresholds=threshold_fn(t.flip(0))) + + assert torch.allclose(p1[i], p2[i]) + assert torch.allclose(r1[i], r2[i]) + assert torch.allclose(t1[i], t2) diff --git a/tests/unittests/classification/test_specificity.py b/tests/unittests/classification/test_specificity.py index 3771b397072..0681fc98a78 100644 --- a/tests/unittests/classification/test_specificity.py +++ b/tests/unittests/classification/test_specificity.py @@ -11,402 +11,518 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math from functools import partial -from typing import Callable, Optional import numpy as np import pytest import torch -from sklearn.metrics import multilabel_confusion_matrix +from scipy.special import expit as sigmoid +from sklearn.metrics import confusion_matrix as sk_confusion_matrix from torch import Tensor, tensor -from torchmetrics import Metric, Specificity -from torchmetrics.functional import specificity -from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores -from torchmetrics.utilities.checks import _input_format_classification -from torchmetrics.utilities.enums import AverageMethod -from unittests.classification.inputs import _input_binary, _input_binary_prob -from unittests.classification.inputs import _input_multiclass as _input_mcls -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass as _input_mdmc -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel as _input_mlb -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from torchmetrics.classification.specificity import BinarySpecificity, MulticlassSpecificity, MultilabelSpecificity +from torchmetrics.functional.classification.specificity import ( + binary_specificity, + multiclass_specificity, + multilabel_specificity, +) +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all -from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester +from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index seed_all(42) -def _sk_stats_score(preds, target, reduce, num_classes, multiclass, ignore_index, top_k): - preds, target, _ = _input_format_classification( - preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass, top_k=top_k - ) - sk_preds, sk_target = preds.numpy(), target.numpy() - - if reduce != "macro" and ignore_index is not None and preds.shape[1] > 1: - sk_preds = np.delete(sk_preds, ignore_index, 1) - sk_target = np.delete(sk_target, ignore_index, 1) - - if preds.shape[1] == 1 and reduce == "samples": - sk_target = sk_target.T - sk_preds = sk_preds.T - - sk_stats = multilabel_confusion_matrix( - sk_target, sk_preds, samplewise=(reduce == "samples") and preds.shape[1] != 1 - ) - - if preds.shape[1] == 1 and reduce != "samples": - sk_stats = sk_stats[[1]].reshape(-1, 4)[:, [3, 1, 0, 2]] +def _calc_specificity(tn, fp): + """safely calculate specificity.""" + denom = tn + fp + if np.isscalar(tn): + denom = 1.0 if denom == 0 else denom else: - sk_stats = sk_stats.reshape(-1, 4)[:, [3, 1, 0, 2]] - - if reduce == "micro": - sk_stats = sk_stats.sum(axis=0, keepdims=True) - - sk_stats = np.concatenate([sk_stats, sk_stats[:, [3]] + sk_stats[:, [0]]], 1) - - if reduce == "micro": - sk_stats = sk_stats[0] + denom[denom == 0] = 1.0 + return tn / denom - if reduce == "macro" and ignore_index is not None and preds.shape[1]: - sk_stats[ignore_index, :] = -1 - if reduce == "micro": - _, fp, tn, _, _ = sk_stats +def _sk_specificity_binary(preds, target, ignore_index, multidim_average): + if multidim_average == "global": + preds = preds.view(-1).numpy() + target = target.view(-1).numpy() else: - _, fp, tn, _ = sk_stats[:, 0], sk_stats[:, 1], sk_stats[:, 2], sk_stats[:, 3] - return fp, tn - - -def _sk_spec(preds, target, reduce, num_classes, multiclass, ignore_index, top_k=None, mdmc_reduce=None, stats=None): - - if stats: - fp, tn = stats + preds = preds.numpy() + target = target.numpy() + + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + + if multidim_average == "global": + if ignore_index is not None: + idx = target == ignore_index + target = target[~idx] + preds = preds[~idx] + tn, fp, _, _ = sk_confusion_matrix(y_true=target, y_pred=preds, labels=[0, 1]).ravel() + return _calc_specificity(tn, fp) else: - stats = _sk_stats_score(preds, target, reduce, num_classes, multiclass, ignore_index, top_k) - fp, tn = stats - - fp, tn = tensor(fp), tensor(tn) - spec = _reduce_stat_scores( - numerator=tn, - denominator=tn + fp, - weights=None if reduce != "weighted" else tn + fp, - average=reduce, - mdmc_average=mdmc_reduce, - ) - if reduce in [None, "none"] and ignore_index is not None and preds.shape[1] > 1: - spec = spec.numpy() - spec = np.insert(spec, ignore_index, math.nan) - spec = tensor(spec) - - return spec - - -def _sk_spec_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes, multiclass, ignore_index, top_k=None): - preds, target, _ = _input_format_classification( - preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass, top_k=top_k - ) - - if mdmc_reduce == "global": - preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) - target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) - return _sk_spec(preds, target, reduce, num_classes, False, ignore_index, top_k, mdmc_reduce) - fp, tn = [], [] - stats = [] - - for i in range(preds.shape[0]): - pred_i = preds[i, ...].T - target_i = target[i, ...].T - fp_i, tn_i = _sk_stats_score(pred_i, target_i, reduce, num_classes, False, ignore_index, top_k) - fp.append(fp_i) - tn.append(tn_i) + res = [] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + if ignore_index is not None: + idx = true == ignore_index + true = true[~idx] + pred = pred[~idx] + tn, fp, _, _ = sk_confusion_matrix(y_true=true, y_pred=pred, labels=[0, 1]).ravel() + res.append(_calc_specificity(tn, fp)) + return np.stack(res) + + +@pytest.mark.parametrize("input", _binary_cases) +class TestBinarySpecificity(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("ddp", [False, True]) + def test_binary_specificity(self, ddp, input, ignore_index, multidim_average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") - stats.append(fp) - stats.append(tn) - return _sk_spec(preds[0], target[0], reduce, num_classes, multiclass, ignore_index, top_k, mdmc_reduce, stats) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinarySpecificity, + sk_metric=partial(_sk_specificity_binary, ignore_index=ignore_index, multidim_average=multidim_average), + metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, + ) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + def test_binary_specificity_functional(self, input, ignore_index, multidim_average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") -@pytest.mark.parametrize("metric, fn_metric", [(Specificity, specificity)]) -@pytest.mark.parametrize( - "average, mdmc_average, num_classes, ignore_index, match_str", - [ - ("wrong", None, None, None, "`average`"), - ("micro", "wrong", None, None, "`mdmc"), - ("macro", None, None, None, "number of classes"), - ("macro", None, 1, 0, "ignore_index"), - ], -) -def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ignore_index, match_str): - with pytest.raises(ValueError, match=match_str): - metric( - average=average, - mdmc_average=mdmc_average, - num_classes=num_classes, - ignore_index=ignore_index, + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_specificity, + sk_metric=partial(_sk_specificity_binary, ignore_index=ignore_index, multidim_average=multidim_average), + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + }, ) - with pytest.raises(ValueError, match=match_str): - fn_metric( - _input_binary.preds[0], - _input_binary.target[0], - average=average, - mdmc_average=mdmc_average, - num_classes=num_classes, - ignore_index=ignore_index, + def test_binary_specificity_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinarySpecificity, + metric_functional=binary_specificity, + metric_args={"threshold": THRESHOLD}, ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_specificity_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinarySpecificity, + metric_functional=binary_specificity, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) -@pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) -def test_zero_division(metric_class, metric_fn): - """Test that zero_division works correctly (currently should just set to 0).""" - - preds = tensor([1, 2, 1, 1]) - target = tensor([0, 0, 0, 0]) - - cl_metric = metric_class(average="none", num_classes=3) - cl_metric(preds, target) - - result_cl = cl_metric.compute() - result_fn = metric_fn(preds, target, average="none", num_classes=3) - - assert result_cl[0] == result_fn[0] == 0 - - -@pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) -def test_no_support(metric_class, metric_fn): - """This tests a rare edge case, where there is only one class present. - - in target, and ignore_index is set to exactly that class - and the - average method is equal to 'weighted'. - - This would mean that the sum of weights equals zero, and would, without - taking care of this case, return NaN. However, the reduction function - should catch that and set the metric to equal the value of zero_division - in this case (zero_division is for now not configurable and equals 0). - """ - - preds = tensor([1, 1, 0, 0]) - target = tensor([0, 0, 0, 0]) - - cl_metric = metric_class(average="weighted", num_classes=2, ignore_index=1) - cl_metric(preds, target) - - result_cl = cl_metric.compute() - result_fn = metric_fn(preds, target, average="weighted", num_classes=2, ignore_index=1) - - assert result_cl == result_fn == 0 + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_specificity_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinarySpecificity, + metric_functional=binary_specificity, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) -@pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) -@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) -@pytest.mark.parametrize("ignore_index", [None, 0]) -@pytest.mark.parametrize( - "preds, target, num_classes, multiclass, mdmc_average, sk_wrapper", - [ - (_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_spec), - (_input_binary.preds, _input_binary.target, 1, False, None, _sk_spec), - (_input_mlb_prob.preds, _input_mlb_prob.target, NUM_CLASSES, None, None, _sk_spec), - (_input_mlb.preds, _input_mlb.target, NUM_CLASSES, False, None, _sk_spec), - (_input_mcls_prob.preds, _input_mcls_prob.target, NUM_CLASSES, None, None, _sk_spec), - (_input_mcls.preds, _input_mcls.target, NUM_CLASSES, None, None, _sk_spec), - (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "global", _sk_spec_mdim_mcls), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "global", _sk_spec_mdim_mcls), - (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "samplewise", _sk_spec_mdim_mcls), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "samplewise", _sk_spec_mdim_mcls), - ], -) -class TestSpecificity(MetricTester): - @pytest.mark.parametrize("ddp", [False, True]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_specificity_class( - self, - ddp: bool, - dist_sync_on_step: bool, - preds: Tensor, - target: Tensor, - sk_wrapper: Callable, - metric_class: Metric, - metric_fn: Callable, - multiclass: Optional[bool], - num_classes: Optional[int], - average: str, - mdmc_average: Optional[str], - ignore_index: Optional[int], - ): - # todo: `metric_fn` is unused - if num_classes == 1 and average != "micro": - pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - - if average == "weighted" and ignore_index is not None and mdmc_average is not None: - pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") +def _sk_specificity_multiclass_global(preds, target, ignore_index, average): + preds = preds.numpy().flatten() + target = target.numpy().flatten() + + if ignore_index is not None: + idx = target == ignore_index + target = target[~idx] + preds = preds[~idx] + confmat = sk_confusion_matrix(y_true=target, y_pred=preds, labels=list(range(NUM_CLASSES))) + tp = np.diag(confmat) + fp = confmat.sum(0) - tp + fn = confmat.sum(1) - tp + tn = confmat.sum() - (fp + fn + tp) + + if average == "micro": + return _calc_specificity(tn.sum(), fp.sum()) + + res = _calc_specificity(tn, fp) + if average == "macro": + return res.mean(0) + elif average == "weighted": + w = tp + fn + return (res * (w / w.sum()).reshape(-1, 1)).sum(0) + elif average is None or average == "none": + return res + + +def _sk_specificity_multiclass_local(preds, target, ignore_index, average): + preds = preds.numpy() + target = target.numpy() + + res = [] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + + if ignore_index is not None: + idx = true == ignore_index + true = true[~idx] + pred = pred[~idx] + confmat = sk_confusion_matrix(y_true=true, y_pred=pred, labels=list(range(NUM_CLASSES))) + tp = np.diag(confmat) + fp = confmat.sum(0) - tp + fn = confmat.sum(1) - tp + tn = confmat.sum() - (fp + fn + tp) + if average == "micro": + res.append(_calc_specificity(tn.sum(), fp.sum())) + + r = _calc_specificity(tn, fp) + if average == "macro": + res.append(r.mean(0)) + elif average == "weighted": + w = tp + fn + res.append((r * (w / w.sum()).reshape(-1, 1)).sum(0)) + elif average is None or average == "none": + res.append(r) + return np.stack(res, 0) + + +def _sk_specificity_multiclass(preds, target, ignore_index, multidim_average, average): + if preds.ndim == target.ndim + 1: + preds = torch.argmax(preds, 1) + if multidim_average == "global": + return _sk_specificity_multiclass_global(preds, target, ignore_index, average) + return _sk_specificity_multiclass_local(preds, target, ignore_index, average) + + +@pytest.mark.parametrize("input", _multiclass_cases) +class TestMulticlassSpecificity(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_specificity(self, ddp, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and target.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=metric_class, + metric_class=MulticlassSpecificity, sk_metric=partial( - sk_wrapper, - reduce=average, - num_classes=num_classes, - multiclass=multiclass, + _sk_specificity_multiclass, ignore_index=ignore_index, - mdmc_reduce=mdmc_average, + multidim_average=multidim_average, + average=average, ), - dist_sync_on_step=dist_sync_on_step, metric_args={ - "num_classes": num_classes, - "average": average, - "threshold": THRESHOLD, - "multiclass": multiclass, "ignore_index": ignore_index, - "mdmc_average": mdmc_average, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, }, ) - def test_specificity_fn( - self, - preds: Tensor, - target: Tensor, - sk_wrapper: Callable, - metric_class: Metric, - metric_fn: Callable, - multiclass: Optional[bool], - num_classes: Optional[int], - average: str, - mdmc_average: Optional[str], - ignore_index: Optional[int], - ): - # todo: `metric_class` is unused - if num_classes == 1 and average != "micro": - pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - - if average == "weighted" and ignore_index is not None and mdmc_average is not None: - pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + def test_multiclass_specificity_functional(self, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and target.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") self.run_functional_metric_test( - preds, - target, - metric_functional=metric_fn, + preds=preds, + target=target, + metric_functional=multiclass_specificity, sk_metric=partial( - sk_wrapper, - reduce=average, - num_classes=num_classes, - multiclass=multiclass, + _sk_specificity_multiclass, ignore_index=ignore_index, - mdmc_reduce=mdmc_average, + multidim_average=multidim_average, + average=average, ), metric_args={ - "num_classes": num_classes, - "average": average, - "threshold": THRESHOLD, - "multiclass": multiclass, "ignore_index": ignore_index, - "mdmc_average": mdmc_average, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, }, ) - def test_accuracy_differentiability( - self, - preds: Tensor, - target: Tensor, - sk_wrapper: Callable, - metric_class: Metric, - metric_fn: Callable, - multiclass: Optional[bool], - num_classes: Optional[int], - average: str, - mdmc_average: Optional[str], - ignore_index: Optional[int], - ): - - if num_classes == 1 and average != "micro": - pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - - if average == "weighted" and ignore_index is not None and mdmc_average is not None: - pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") - + def test_multiclass_specificity_differentiability(self, input): + preds, target = input self.run_differentiability_test( preds=preds, target=target, - metric_module=metric_class, - metric_functional=metric_fn, - metric_args={ - "num_classes": num_classes, - "average": average, - "threshold": THRESHOLD, - "multiclass": multiclass, - "ignore_index": ignore_index, - "mdmc_average": mdmc_average, - }, + metric_module=MulticlassSpecificity, + metric_functional=multiclass_specificity, + metric_args={"num_classes": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_specificity_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassSpecificity, + metric_functional=multiclass_specificity, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_specificity_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassSpecificity, + metric_functional=multiclass_specificity, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, ) _mc_k_target = tensor([0, 1, 2]) _mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) -_ml_k_target = tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) -_ml_k_preds = tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) -@pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) @pytest.mark.parametrize( "k, preds, target, average, expected_spec", [ (1, _mc_k_preds, _mc_k_target, "micro", tensor(5 / 6)), (2, _mc_k_preds, _mc_k_target, "micro", tensor(1 / 2)), - (1, _ml_k_preds, _ml_k_target, "micro", tensor(1 / 2)), - (2, _ml_k_preds, _ml_k_target, "micro", tensor(1 / 6)), ], ) -def test_top_k( - metric_class, - metric_fn, - k: int, - preds: Tensor, - target: Tensor, - average: str, - expected_spec: Tensor, -): - """A simple test to check that top_k works as expected. - - Just a sanity check, the tests in StatScores should already guarantee the correctness of results. - """ - - class_metric = metric_class(top_k=k, average=average, num_classes=3) +def test_top_k(k: int, preds: Tensor, target: Tensor, average: str, expected_spec: Tensor): + """A simple test to check that top_k works as expected.""" + class_metric = MulticlassSpecificity(top_k=k, average=average, num_classes=3) class_metric.update(preds, target) assert torch.equal(class_metric.compute(), expected_spec) - assert torch.equal(metric_fn(preds, target, top_k=k, average=average, num_classes=3), expected_spec) + assert torch.equal(multiclass_specificity(preds, target, top_k=k, average=average, num_classes=3), expected_spec) + + +def _sk_specificity_multilabel_global(preds, target, ignore_index, average): + tns, fps = [], [] + for i in range(preds.shape[1]): + p, t = preds[:, i].flatten(), target[:, i].flatten() + if ignore_index is not None: + idx = t == ignore_index + t = t[~idx] + p = p[~idx] + tn, fp, fn, tp = sk_confusion_matrix(t, p, labels=[0, 1]).ravel() + tns.append(tn) + fps.append(fp) + + tn = np.array(tns) + fp = np.array(fps) + if average == "micro": + return _calc_specificity(tn.sum(), fp.sum()) + + res = _calc_specificity(tn, fp) + if average == "macro": + return res.mean(0) + elif average == "weighted": + w = res[:, 0] + res[:, 3] + return (res * (w / w.sum()).reshape(-1, 1)).sum(0) + elif average is None or average == "none": + return res + + +def _sk_specificity_multilabel_local(preds, target, ignore_index, average): + specificity = [] + for i in range(preds.shape[0]): + tns, fps = [], [] + for j in range(preds.shape[1]): + pred, true = preds[i, j], target[i, j] + if ignore_index is not None: + idx = true == ignore_index + true = true[~idx] + pred = pred[~idx] + tn, fp, _, _ = sk_confusion_matrix(true, pred, labels=[0, 1]).ravel() + tns.append(tn) + fps.append(fp) + tn = np.array(tns) + fp = np.array(fps) + if average == "micro": + specificity.append(_calc_specificity(tn.sum(), fp.sum())) + else: + specificity.append(_calc_specificity(tn, fp)) + + res = np.stack(specificity, 0) + if average == "micro" or average is None or average == "none": + return res + elif average == "macro": + return res.mean(-1) + elif average == "weighted": + w = res[:, 0, :] + res[:, 3, :] + return (res * (w / w.sum())[:, np.newaxis]).sum(-1) + elif average is None or average == "none": + return np.moveaxis(res, 1, -1) + + +def _sk_specificity_multilabel(preds, target, ignore_index, multidim_average, average): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + preds = preds.reshape(*preds.shape[:2], -1) + target = target.reshape(*target.shape[:2], -1) + if multidim_average == "global": + return _sk_specificity_multilabel_global(preds, target, ignore_index, average) + return _sk_specificity_multilabel_local(preds, target, ignore_index, average) + + +@pytest.mark.parametrize("input", _multilabel_cases) +class TestMultilabelSpecificity(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + def test_multilabel_specificity(self, ddp, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MultilabelSpecificity, + sk_metric=partial( + _sk_specificity_multilabel, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + }, + ) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + def test_multilabel_specificity_functional(self, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") -@pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) -@pytest.mark.parametrize( - "ignore_index, expected", [(None, torch.tensor([0.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))] -) -def test_class_not_present(metric_class, metric_fn, ignore_index, expected): - """This tests that when metric is computed per class and a given class is not present in both the `preds` and - `target`, the resulting score is `nan`.""" - preds = torch.tensor([0, 0, 0]) - target = torch.tensor([0, 0, 0]) - num_classes = 2 - - # test functional - result_fn = metric_fn(preds, target, average=AverageMethod.NONE, num_classes=num_classes, ignore_index=ignore_index) - assert torch.allclose(expected, result_fn, equal_nan=True) - - # test class - cl_metric = metric_class(average=AverageMethod.NONE, num_classes=num_classes, ignore_index=ignore_index) - cl_metric(preds, target) - result_cl = cl_metric.compute() - assert torch.allclose(expected, result_cl, equal_nan=True) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multilabel_specificity, + sk_metric=partial( + _sk_specificity_multilabel, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + }, + ) + + def test_multilabel_specificity_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MultilabelSpecificity, + metric_functional=multilabel_specificity, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_specificity_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelSpecificity, + metric_functional=multilabel_specificity, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_specificity_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelSpecificity, + metric_functional=multilabel_specificity, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index 75e52f66b4a..4331020adf5 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -12,334 +12,470 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Any, Callable, Dict, Optional import numpy as np import pytest import torch -from sklearn.metrics import multilabel_confusion_matrix -from torch import Tensor, tensor - -from torchmetrics import Accuracy, Dice, FBetaScore, Precision, Recall, Specificity, StatScores -from torchmetrics.functional import stat_scores -from torchmetrics.utilities.checks import _input_format_classification -from unittests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob, _input_multiclass -from unittests.classification.inputs import _input_multiclass_logits as _input_mcls_logits -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass as _input_mdmc -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel as _input_mcls -from unittests.classification.inputs import _input_multilabel_logits as _input_mlb_logits -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from scipy.special import expit as sigmoid +from sklearn.metrics import confusion_matrix as sk_confusion_matrix + +from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores +from torchmetrics.functional.classification.stat_scores import ( + binary_stat_scores, + multiclass_stat_scores, + multilabel_stat_scores, +) +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all -from unittests.helpers.testers import NUM_CLASSES, MetricTester +from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sk_stat_scores(preds, target, reduce, num_classes, multiclass, ignore_index, top_k, threshold, mdmc_reduce=None): - # todo: `mdmc_reduce` is unused - preds, target, _ = _input_format_classification( - preds, target, threshold=threshold, num_classes=num_classes, multiclass=multiclass, top_k=top_k - ) - sk_preds, sk_target = preds.numpy(), target.numpy() +def _sk_stat_scores_binary(preds, target, ignore_index, multidim_average): + if multidim_average == "global": + preds = preds.view(-1).numpy() + target = target.view(-1).numpy() + else: + preds = preds.numpy() + target = target.numpy() + + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + + if multidim_average == "global": + target, preds = remove_ignore_index(target, preds, ignore_index) + tn, fp, fn, tp = sk_confusion_matrix(y_true=target, y_pred=preds, labels=[0, 1]).ravel() + return np.array([tp, fp, tn, fn, tp + fn]) + else: + res = [] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + tn, fp, fn, tp = sk_confusion_matrix(y_true=true, y_pred=pred, labels=[0, 1]).ravel() + res.append(np.array([tp, fp, tn, fn, tp + fn])) + return np.stack(res) + + +@pytest.mark.parametrize("input", _binary_cases) +class TestBinaryStatScores(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("ddp", [False, True]) + def test_binary_stat_scores(self, ddp, input, ignore_index, multidim_average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") - if reduce != "macro" and ignore_index is not None and preds.shape[1] > 1: - sk_preds = np.delete(sk_preds, ignore_index, 1) - sk_target = np.delete(sk_target, ignore_index, 1) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryStatScores, + sk_metric=partial(_sk_stat_scores_binary, ignore_index=ignore_index, multidim_average=multidim_average), + metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, + ) - if preds.shape[1] == 1 and reduce == "samples": - sk_target = sk_target.T - sk_preds = sk_preds.T + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + def test_binary_stat_scores_functional(self, input, ignore_index, multidim_average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") - sk_stats = multilabel_confusion_matrix( - sk_target, sk_preds, samplewise=(reduce == "samples") and preds.shape[1] != 1 - ) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_stat_scores, + sk_metric=partial(_sk_stat_scores_binary, ignore_index=ignore_index, multidim_average=multidim_average), + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + }, + ) - if preds.shape[1] == 1 and reduce != "samples": - sk_stats = sk_stats[[1]].reshape(-1, 4)[:, [3, 1, 0, 2]] - else: - sk_stats = sk_stats.reshape(-1, 4)[:, [3, 1, 0, 2]] + def test_binary_stat_scores_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryStatScores, + metric_functional=binary_stat_scores, + metric_args={"threshold": THRESHOLD}, + ) - if reduce == "micro": - sk_stats = sk_stats.sum(axis=0, keepdims=True) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_stat_scores_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryStatScores, + metric_functional=binary_stat_scores, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) - sk_stats = np.concatenate([sk_stats, sk_stats[:, [3]] + sk_stats[:, [0]]], 1) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_stat_scores_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryStatScores, + metric_functional=binary_stat_scores, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) - if reduce == "micro": - sk_stats = sk_stats[0] - if reduce == "macro" and ignore_index is not None and preds.shape[1]: - sk_stats[ignore_index, :] = -1 +def _sk_stat_scores_multiclass_global(preds, target, ignore_index, average): + preds = preds.numpy().flatten() + target = target.numpy().flatten() + target, preds = remove_ignore_index(target, preds, ignore_index) + confmat = sk_confusion_matrix(y_true=target, y_pred=preds, labels=list(range(NUM_CLASSES))) + tp = np.diag(confmat) + fp = confmat.sum(0) - tp + fn = confmat.sum(1) - tp + tn = confmat.sum() - (fp + fn + tp) + + res = np.stack([tp, fp, tn, fn, tp + fn], 1) + if average == "micro": + return res.sum(0) + elif average == "macro": + return res.mean(0) + elif average == "weighted": + w = tp + fn + return (res * (w / w.sum()).reshape(-1, 1)).sum(0) + elif average is None or average == "none": + return res + + +def _sk_stat_scores_multiclass_local(preds, target, ignore_index, average): + preds = preds.numpy() + target = target.numpy() + + res = [] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + confmat = sk_confusion_matrix(y_true=true, y_pred=pred, labels=list(range(NUM_CLASSES))) + tp = np.diag(confmat) + fp = confmat.sum(0) - tp + fn = confmat.sum(1) - tp + tn = confmat.sum() - (fp + fn + tp) + r = np.stack([tp, fp, tn, fn, tp + fn], 1) + if average == "micro": + res.append(r.sum(0)) + elif average == "macro": + res.append(r.mean(0)) + elif average == "weighted": + w = tp + fn + res.append((r * (w / w.sum()).reshape(-1, 1)).sum(0)) + elif average is None or average == "none": + res.append(r) + return np.stack(res, 0) + + +def _sk_stat_scores_multiclass(preds, target, ignore_index, multidim_average, average): + if preds.ndim == target.ndim + 1: + preds = torch.argmax(preds, 1) + if multidim_average == "global": + return _sk_stat_scores_multiclass_global(preds, target, ignore_index, average) + return _sk_stat_scores_multiclass_local(preds, target, ignore_index, average) + + +@pytest.mark.parametrize("input", _multiclass_cases) +class TestMulticlassStatScores(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_stat_scores(self, ddp, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and target.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") - return sk_stats + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassStatScores, + sk_metric=partial( + _sk_stat_scores_multiclass, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, + }, + ) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + def test_multiclass_stat_scores_functional(self, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and target.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") -def _sk_stat_scores_mdim_mcls( - preds, target, reduce, mdmc_reduce, num_classes, multiclass, ignore_index, top_k, threshold -): - preds, target, _ = _input_format_classification( - preds, target, threshold=threshold, num_classes=num_classes, multiclass=multiclass, top_k=top_k - ) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_stat_scores, + sk_metric=partial( + _sk_stat_scores_multiclass, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, + }, + ) - if mdmc_reduce == "global": - preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) - target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) + def test_multiclass_stat_scores_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassStatScores, + metric_functional=multiclass_stat_scores, + metric_args={"num_classes": NUM_CLASSES}, + ) - return _sk_stat_scores(preds, target, reduce, None, False, ignore_index, top_k, threshold) - if mdmc_reduce == "samplewise": - scores = [] + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_stat_scores_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassStatScores, + metric_functional=multiclass_stat_scores, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) - for i in range(preds.shape[0]): - pred_i = preds[i, ...].T - target_i = target[i, ...].T - scores_i = _sk_stat_scores(pred_i, target_i, reduce, None, False, ignore_index, top_k, threshold) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_stat_scores_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassStatScores, + metric_functional=multiclass_stat_scores, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) - scores.append(np.expand_dims(scores_i, 0)) - return np.concatenate(scores) +_mc_k_target = torch.tensor([0, 1, 2]) +_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) @pytest.mark.parametrize( - "reduce, mdmc_reduce, num_classes, inputs, ignore_index", + "k, preds, target, average, expected", [ - ["unknown", None, None, _input_binary, None], - ["micro", "unknown", None, _input_binary, None], - ["macro", None, None, _input_binary, None], - ["micro", None, None, _input_mdmc_prob, None], - ["micro", None, None, _input_binary_prob, 0], - ["micro", None, None, _input_mcls_prob, NUM_CLASSES], - ["micro", None, NUM_CLASSES, _input_mcls_prob, NUM_CLASSES], + (1, _mc_k_preds, _mc_k_target, "micro", torch.tensor([2, 1, 5, 1, 3])), + (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor([3, 3, 3, 0, 3])), + (1, _mc_k_preds, _mc_k_target, None, torch.tensor([[0, 1, 1], [0, 1, 0], [2, 1, 2], [1, 0, 0], [1, 1, 1]])), + (2, _mc_k_preds, _mc_k_target, None, torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])), ], ) -def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index): - """Test a combination of parameters that are invalid and should raise an error. - - This includes invalid ``reduce`` and ``mdmc_reduce`` parameter values, not setting ``num_classes`` when - ``reduce='macro'`, not setting ``mdmc_reduce`` when inputs are multi-dim multi-class``, setting ``ignore_index`` - when inputs are binary, as well as setting ``ignore_index`` to a value higher than the number of classes. - """ - with pytest.raises(ValueError): - stat_scores( - inputs.preds[0], inputs.target[0], reduce, mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index - ) +def test_top_k_multiclass(k, preds, target, average, expected): + """A simple test to check that top_k works as expected.""" + class_metric = MulticlassStatScores(top_k=k, average=average, num_classes=3) + class_metric.update(preds, target) - with pytest.raises(ValueError): - sts = StatScores(reduce=reduce, mdmc_reduce=mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index) - sts(inputs.preds[0], inputs.target[0]) + assert torch.allclose(class_metric.compute().long(), expected.T) + assert torch.allclose( + multiclass_stat_scores(preds, target, top_k=k, average=average, num_classes=3).long(), expected.T + ) -@pytest.mark.parametrize("ignore_index", [None, 0]) -@pytest.mark.parametrize("reduce", ["micro", "macro", "samples"]) -@pytest.mark.parametrize( - "preds, target, sk_fn, mdmc_reduce, num_classes, multiclass, top_k, threshold", - [ - (_input_binary_logits.preds, _input_binary_logits.target, _sk_stat_scores, None, 1, None, None, 0.0), - (_input_binary_prob.preds, _input_binary_prob.target, _sk_stat_scores, None, 1, None, None, 0.5), - (_input_binary.preds, _input_binary.target, _sk_stat_scores, None, 1, False, None, 0.5), - (_input_mlb_logits.preds, _input_mlb_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.5), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2, 0.5), - (_input_mcls.preds, _input_mcls.target, _sk_stat_scores, None, NUM_CLASSES, False, None, 0.5), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.5), - (_input_mcls_logits.preds, _input_mcls_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2, 0.0), - (_input_multiclass.preds, _input_multiclass.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0), - (_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "samplewise", NUM_CLASSES, None, None, 0.0), - ( - _input_mdmc_prob.preds, - _input_mdmc_prob.target, - _sk_stat_scores_mdim_mcls, - "samplewise", - NUM_CLASSES, - None, - None, - 0.0, - ), - (_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "global", NUM_CLASSES, None, None, 0.0), - ( - _input_mdmc_prob.preds, - _input_mdmc_prob.target, - _sk_stat_scores_mdim_mcls, - "global", - NUM_CLASSES, - None, - None, - 0.0, - ), - ], -) -class TestStatScores(MetricTester): - # DDP tests temporarily disabled due to hanging issues - @pytest.mark.parametrize("ddp", [False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - @pytest.mark.parametrize("dtype", [torch.float, torch.double]) - def test_stat_scores_class( - self, - ddp: bool, - dist_sync_on_step: bool, - dtype: torch.dtype, - sk_fn: Callable, - preds: Tensor, - target: Tensor, - reduce: str, - mdmc_reduce: Optional[str], - num_classes: Optional[int], - multiclass: Optional[bool], - ignore_index: Optional[int], - top_k: Optional[int], - threshold: Optional[float], - ): - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - - if preds.is_floating_point(): - preds = preds.to(dtype) - if target.is_floating_point(): - target = target.to(dtype) +def _sk_stat_scores_multilabel(preds, target, ignore_index, multidim_average, average): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + preds = preds.reshape(*preds.shape[:2], -1) + target = target.reshape(*target.shape[:2], -1) + if multidim_average == "global": + stat_scores = [] + for i in range(preds.shape[1]): + pred, true = preds[:, i].flatten(), target[:, i].flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + tn, fp, fn, tp = sk_confusion_matrix(true, pred, labels=[0, 1]).ravel() + stat_scores.append(np.array([tp, fp, tn, fn, tp + fn])) + res = np.stack(stat_scores, axis=0) + + if average == "micro": + return res.sum(0) + elif average == "macro": + return res.mean(0) + elif average == "weighted": + w = res[:, 0] + res[:, 3] + return (res * (w / w.sum()).reshape(-1, 1)).sum(0) + elif average is None or average == "none": + return res + else: + stat_scores = [] + for i in range(preds.shape[0]): + scores = [] + for j in range(preds.shape[1]): + pred, true = preds[i, j], target[i, j] + true, pred = remove_ignore_index(true, pred, ignore_index) + tn, fp, fn, tp = sk_confusion_matrix(true, pred, labels=[0, 1]).ravel() + scores.append(np.array([tp, fp, tn, fn, tp + fn])) + stat_scores.append(np.stack(scores, 1)) + res = np.stack(stat_scores, 0) + if average == "micro": + return res.sum(-1) + elif average == "macro": + return res.mean(-1) + elif average == "weighted": + w = res[:, 0, :] + res[:, 3, :] + return (res * (w / w.sum())[:, np.newaxis]).sum(-1) + elif average is None or average == "none": + return np.moveaxis(res, 1, -1) + + +@pytest.mark.parametrize("input", _multilabel_cases) +class TestMultilabelStatScores(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + def test_multilabel_stat_scores(self, ddp, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=StatScores, + metric_class=MultilabelStatScores, sk_metric=partial( - sk_fn, - reduce=reduce, - mdmc_reduce=mdmc_reduce, - num_classes=num_classes, - multiclass=multiclass, + _sk_stat_scores_multilabel, ignore_index=ignore_index, - top_k=top_k, - threshold=threshold, + multidim_average=multidim_average, + average=average, ), - dist_sync_on_step=dist_sync_on_step, metric_args={ - "num_classes": num_classes, - "reduce": reduce, - "mdmc_reduce": mdmc_reduce, - "threshold": threshold, - "multiclass": multiclass, + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, "ignore_index": ignore_index, - "top_k": top_k, + "multidim_average": multidim_average, + "average": average, }, ) - def test_stat_scores_fn( - self, - sk_fn: Callable, - preds: Tensor, - target: Tensor, - reduce: str, - mdmc_reduce: Optional[str], - num_classes: Optional[int], - multiclass: Optional[bool], - ignore_index: Optional[int], - top_k: Optional[int], - threshold: Optional[float], - ): - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + def test_multilabel_stat_scores_functional(self, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") self.run_functional_metric_test( - preds, - target, - metric_functional=stat_scores, + preds=preds, + target=target, + metric_functional=multilabel_stat_scores, sk_metric=partial( - sk_fn, - reduce=reduce, - mdmc_reduce=mdmc_reduce, - num_classes=num_classes, - multiclass=multiclass, + _sk_stat_scores_multilabel, ignore_index=ignore_index, - top_k=top_k, - threshold=threshold, + multidim_average=multidim_average, + average=average, ), metric_args={ - "num_classes": num_classes, - "reduce": reduce, - "mdmc_reduce": mdmc_reduce, - "threshold": threshold, - "multiclass": multiclass, + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, "ignore_index": ignore_index, - "top_k": top_k, + "multidim_average": multidim_average, + "average": average, }, ) - def test_stat_scores_differentiability( - self, - sk_fn: Callable, - preds: Tensor, - target: Tensor, - reduce: str, - mdmc_reduce: Optional[str], - num_classes: Optional[int], - multiclass: Optional[bool], - ignore_index: Optional[int], - top_k: Optional[int], - threshold: Optional[float], - ): - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - + def test_multilabel_stat_scores_differentiability(self, input): + preds, target = input self.run_differentiability_test( - preds, - target, - metric_module=StatScores, - metric_functional=stat_scores, - metric_args={ - "num_classes": num_classes, - "reduce": reduce, - "mdmc_reduce": mdmc_reduce, - "threshold": threshold, - "multiclass": multiclass, - "ignore_index": ignore_index, - "top_k": top_k, - }, + preds=preds, + target=target, + metric_module=MultilabelStatScores, + metric_functional=multilabel_stat_scores, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_stat_scores_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelStatScores, + metric_functional=multilabel_stat_scores, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) -_mc_k_target = tensor([0, 1, 2]) -_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) -_ml_k_target = tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) -_ml_k_preds = tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) - - -@pytest.mark.parametrize( - "k, preds, target, reduce, expected", - [ - (1, _mc_k_preds, _mc_k_target, "micro", tensor([2, 1, 5, 1, 3])), - (2, _mc_k_preds, _mc_k_target, "micro", tensor([3, 3, 3, 0, 3])), - (1, _ml_k_preds, _ml_k_target, "micro", tensor([0, 3, 3, 3, 3])), - (2, _ml_k_preds, _ml_k_target, "micro", tensor([1, 5, 1, 2, 3])), - (1, _mc_k_preds, _mc_k_target, "macro", tensor([[0, 1, 1], [0, 1, 0], [2, 1, 2], [1, 0, 0], [1, 1, 1]])), - (2, _mc_k_preds, _mc_k_target, "macro", tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])), - (1, _ml_k_preds, _ml_k_target, "macro", tensor([[0, 0, 0], [1, 0, 2], [1, 1, 1], [1, 2, 0], [1, 2, 0]])), - (2, _ml_k_preds, _ml_k_target, "macro", tensor([[0, 1, 0], [2, 0, 3], [0, 1, 0], [1, 1, 0], [1, 2, 0]])), - ], -) -def test_top_k(k: int, preds: Tensor, target: Tensor, reduce: str, expected: Tensor): - """A simple test to check that top_k works as expected.""" - - class_metric = StatScores(top_k=k, reduce=reduce, num_classes=3) - class_metric.update(preds, target) - - assert torch.equal(class_metric.compute(), expected.T) - assert torch.equal(stat_scores(preds, target, top_k=k, reduce=reduce, num_classes=3), expected.T) - - -@pytest.mark.parametrize( - "metric_args", - [ - {"reduce": "micro"}, - {"num_classes": 1, "reduce": "macro"}, - {"reduce": "samples"}, - {"mdmc_reduce": None}, - {"mdmc_reduce": "samplewise"}, - {"mdmc_reduce": "global"}, - ], -) -@pytest.mark.parametrize("metric_cls", [Accuracy, Dice, FBetaScore, Precision, Recall, Specificity]) -def test_provide_superclass_kwargs(metric_cls: StatScores, metric_args: Dict[str, Any]): - """Test instantiating subclasses with superclass arguments as kwargs.""" - metric_cls(**metric_args) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_stat_scores_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelStatScores, + metric_functional=multilabel_stat_scores, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) diff --git a/tests/unittests/helpers/testers.py b/tests/unittests/helpers/testers.py index 92008ef5a4f..d7fa26b571d 100644 --- a/tests/unittests/helpers/testers.py +++ b/tests/unittests/helpers/testers.py @@ -14,8 +14,9 @@ import os import pickle import sys +from copy import deepcopy from functools import partial -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import pytest @@ -300,12 +301,13 @@ def _functional_test( _assert_allclose(tm_result, sk_result, atol=atol) -def _assert_half_support( +def _assert_dtype_support( metric_module: Optional[Metric], metric_functional: Optional[Callable], preds: Tensor, target: Tensor, device: str = "cpu", + dtype: torch.dtype = torch.half, **kwargs_update, ): """Test if an metric can be used with half precision tensors. @@ -319,10 +321,10 @@ def _assert_half_support( kwargs_update: Additional keyword arguments that will be passed with preds and target when running update on the metric. """ - y_hat = preds[0].half().to(device) if preds[0].is_floating_point() else preds[0].to(device) - y = target[0].half().to(device) if target[0].is_floating_point() else target[0].to(device) + y_hat = preds[0].to(dtype=dtype, device=device) if preds[0].is_floating_point() else preds[0].to(device) + y = target[0].to(dtype=dtype, device=device) if target[0].is_floating_point() else target[0].to(device) kwargs_update = { - k: (v[0].half() if v.is_floating_point() else v[0]).to(device) if isinstance(v, Tensor) else v + k: (v[0].to(dtype=dtype) if v.is_floating_point() else v[0]).to(device) if isinstance(v, Tensor) else v for k, v in kwargs_update.items() } if metric_module is not None: @@ -402,7 +404,7 @@ def run_class_metric_test( target: Union[Tensor, List[Dict]], metric_class: Metric, sk_metric: Callable, - dist_sync_on_step: bool, + dist_sync_on_step: bool = False, metric_args: dict = None, check_dist_sync_on_step: bool = True, check_batch: bool = True, @@ -482,6 +484,7 @@ def run_precision_test_cpu( metric_module: Optional[Metric] = None, metric_functional: Optional[Callable] = None, metric_args: Optional[dict] = None, + dtype: torch.dtype = torch.half, **kwargs_update, ): """Test if a metric can be used with half precision tensors on cpu @@ -495,12 +498,13 @@ def run_precision_test_cpu( target when running update on the metric. """ metric_args = metric_args or {} - _assert_half_support( + _assert_dtype_support( metric_module(**metric_args) if metric_module is not None else None, - metric_functional, + partial(metric_functional, **metric_args) if metric_functional is not None else None, preds, target, device="cpu", + dtype=dtype, **kwargs_update, ) @@ -511,6 +515,7 @@ def run_precision_test_gpu( metric_module: Optional[Metric] = None, metric_functional: Optional[Callable] = None, metric_args: Optional[dict] = None, + dtype: torch.dtype = torch.half, **kwargs_update, ): """Test if a metric can be used with half precision tensors on gpu @@ -524,12 +529,13 @@ def run_precision_test_gpu( target when running update on the metric. """ metric_args = metric_args or {} - _assert_half_support( + _assert_dtype_support( metric_module(**metric_args) if metric_module is not None else None, - metric_functional, + partial(metric_functional, **metric_args) if metric_functional is not None else None, preds, target, device="cuda", + dtype=dtype, **kwargs_update, ) @@ -619,3 +625,31 @@ def compute(self): class DummyMetricMultiOutput(DummyMetricSum): def compute(self): return [self.x, self.x] + + +def inject_ignore_index(x: Tensor, ignore_index: int) -> Tensor: + """Utility function for injecting the ignore index value into a tensor randomly.""" + if any(x.flatten() == ignore_index): # ignore index is a class label + return x + classes = torch.unique(x) + idx = torch.randperm(x.numel()) + x = deepcopy(x) + # randomly set either element {9, 10} to the ignore index value + skip = torch.randint(9, 11, (1,)).item() + x.view(-1)[idx[::skip]] = ignore_index + # if we accedently removed a class completly in a batch, reintroduce it again + for batch in x: + new_classes = torch.unique(batch) + class_not_in = [c not in new_classes for c in classes] + if any(class_not_in): + missing_class = int(np.where(class_not_in)[0][0]) + batch[torch.where(batch == ignore_index)[0][0]] = missing_class + return x + + +def remove_ignore_index(target: Tensor, preds: Tensor, ignore_index: Optional[int]) -> Tuple[Tensor, Tensor]: + """Utility function for removing samples that are equal to the ignore_index in comparison functions.""" + if ignore_index is not None: + idx = target == ignore_index + target, preds = deepcopy(target[~idx]), deepcopy(preds[~idx]) + return target, preds diff --git a/tests/unittests/classification/test_kl_divergence.py b/tests/unittests/regression/test_kl_divergence.py similarity index 97% rename from tests/unittests/classification/test_kl_divergence.py rename to tests/unittests/regression/test_kl_divergence.py index e8905473440..8a0bba13c2b 100644 --- a/tests/unittests/classification/test_kl_divergence.py +++ b/tests/unittests/regression/test_kl_divergence.py @@ -21,8 +21,8 @@ from scipy.stats import entropy from torch import Tensor -from torchmetrics.classification import KLDivergence -from torchmetrics.functional import kl_divergence +from torchmetrics.functional.regression.kl_divergence import kl_divergence +from torchmetrics.regression.kl_divergence import KLDivergence from unittests.helpers import seed_all from unittests.helpers.testers import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, MetricTester diff --git a/tests/unittests/retrieval/helpers.py b/tests/unittests/retrieval/helpers.py index 2561d0fb5c7..4c41479df80 100644 --- a/tests/unittests/retrieval/helpers.py +++ b/tests/unittests/retrieval/helpers.py @@ -485,7 +485,7 @@ def run_precision_test_cpu( metric_module: Metric, metric_functional: Callable, ): - def metric_functional_ignore_indexes(preds, target, indexes): + def metric_functional_ignore_indexes(preds, target, indexes, empty_target_action): return metric_functional(preds, target) super().run_precision_test_cpu( @@ -508,7 +508,7 @@ def run_precision_test_gpu( if not torch.cuda.is_available(): pytest.skip("Test requires GPU") - def metric_functional_ignore_indexes(preds, target, indexes): + def metric_functional_ignore_indexes(preds, target, indexes, empty_target_action): return metric_functional(preds, target) super().run_precision_test_gpu( diff --git a/tests/unittests/classification/test_auc.py b/tests/unittests/utilities/test_auc.py similarity index 77% rename from tests/unittests/classification/test_auc.py rename to tests/unittests/utilities/test_auc.py index 9b3b33f6194..c6c12c0cf7f 100644 --- a/tests/unittests/classification/test_auc.py +++ b/tests/unittests/utilities/test_auc.py @@ -19,8 +19,7 @@ from sklearn.metrics import auc as _sk_auc from torch import tensor -from torchmetrics.classification.auc import AUC -from torchmetrics.functional import auc +from torchmetrics.utilities.compute import auc from unittests.helpers import seed_all from unittests.helpers.testers import NUM_BATCHES, MetricTester @@ -55,30 +54,12 @@ def sk_auc(x, y, reorder=False): @pytest.mark.parametrize("x, y", _examples) class TestAUC(MetricTester): - @pytest.mark.parametrize("ddp", [False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_auc(self, x, y, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp=ddp, - preds=x, - target=y, - metric_class=AUC, - sk_metric=sk_auc, - dist_sync_on_step=dist_sync_on_step, - ) - @pytest.mark.parametrize("reorder", [True, False]) def test_auc_functional(self, x, y, reorder): self.run_functional_metric_test( x, y, metric_functional=auc, sk_metric=partial(sk_auc, reorder=reorder), metric_args={"reorder": reorder} ) - @pytest.mark.parametrize("reorder", [True, False]) - def test_auc_differentiability(self, x, y, reorder): - self.run_differentiability_test( - preds=x, target=y, metric_module=AUC, metric_functional=auc, metric_args={"reorder": reorder} - ) - @pytest.mark.parametrize("unsqueeze_x", (True, False)) @pytest.mark.parametrize("unsqueeze_y", (True, False)) diff --git a/tests/unittests/test_utilities.py b/tests/unittests/utilities/test_utilities.py similarity index 91% rename from tests/unittests/test_utilities.py rename to tests/unittests/utilities/test_utilities.py index 9f5a5ccc222..d88ac3d9eb3 100644 --- a/tests/unittests/test_utilities.py +++ b/tests/unittests/utilities/test_utilities.py @@ -20,6 +20,7 @@ from torchmetrics.utilities.checks import _allclose_recursive from torchmetrics.utilities.data import _bincount, _flatten, _flatten_dict, to_categorical, to_onehot from torchmetrics.utilities.distributed import class_reduce, reduce +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_7 def test_prints(): @@ -113,7 +114,7 @@ def test_bincount(): """test that bincount works in deterministic setting on GPU.""" torch.use_deterministic_algorithms(True) - x = torch.randint(100, size=(100,)) + x = torch.randint(10, size=(100,)) # uses custom implementation res1 = _bincount(x, minlength=10) @@ -157,3 +158,14 @@ def test_check_full_state_update_fn(capsys, metric_class, expected): def test_recursive_allclose(input, expected): res = _allclose_recursive(*input) assert res == expected + + +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_7, reason="test requires access to `torch.movedim`") +@pytest.mark.parametrize("dim1, dim2", [(1, 3), (1, -1)]) +def test_movedim(dim1, dim2): + x = torch.randn(5, 4, 3, 2, 1) + res1 = torch.movedim(x, dim1, dim2) + if dim2 >= 0: + dim2 += 1 + res2 = x.unsqueeze(dim2).transpose(dim2, dim1).squeeze(dim1) + assert torch.allclose(res1, res2)